use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum ConfigError {
#[error("Failed to read config file: {0}")]
ReadError(#[from] std::io::Error),
#[error("Failed to parse config: {0}")]
ParseError(#[from] toml::de::Error),
#[error("Missing required configuration: {0}")]
MissingConfig(String),
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct Config {
pub server: ServerConfig,
pub aegis: AegisConfig,
pub auth: AuthConfig,
pub inference: InferenceConfig,
pub dashboard: DashboardConfig,
#[serde(default)]
pub hub: HubConfig,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ServerConfig {
#[serde(default = "default_host")]
pub host: String,
#[serde(default = "default_port")]
pub port: u16,
#[serde(default = "default_data_dir")]
pub data_dir: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AegisConfig {
#[serde(default = "default_aegis_host")]
pub host: String,
#[serde(default = "default_aegis_port")]
pub port: u16,
#[serde(default = "default_aegis_user")]
pub username: String,
#[serde(default = "default_aegis_pass")]
pub password: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AuthConfig {
#[serde(default = "default_jwt_secret")]
pub jwt_secret: String,
#[serde(default = "default_jwt_expiry")]
pub jwt_expiry_hours: u64,
#[serde(default = "default_session_timeout")]
pub session_timeout_minutes: u64,
#[serde(default)]
pub require_mfa: bool,
#[serde(default = "default_allow_registration")]
pub allow_public_registration: bool,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct InferenceConfig {
#[serde(default = "default_port_start")]
pub default_port_range_start: u16,
#[serde(default = "default_port_end")]
pub default_port_range_end: u16,
#[serde(default = "default_max_endpoints")]
pub max_endpoints: u32,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct DashboardConfig {
#[serde(default = "default_dashboard_port")]
pub port: u16,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct HubConfig {
#[serde(default = "default_hub_url")]
pub hub_url: String,
#[serde(default = "default_hub_cache_dir")]
pub cache_dir: String,
}
fn default_host() -> String {
"0.0.0.0".to_string()
}
fn default_port() -> u16 {
3000
}
fn default_data_dir() -> String {
"~/.axonml".to_string()
}
fn default_aegis_host() -> String {
"localhost".to_string()
}
fn default_aegis_port() -> u16 {
9090
}
fn default_aegis_user() -> String {
String::new()
}
fn default_aegis_pass() -> String {
String::new()
}
fn default_jwt_secret() -> String {
String::new()
}
fn default_jwt_expiry() -> u64 {
24
}
fn default_session_timeout() -> u64 {
30
}
fn default_allow_registration() -> bool {
true
}
fn default_port_start() -> u16 {
8100
}
fn default_port_end() -> u16 {
8199
}
fn default_max_endpoints() -> u32 {
10
}
fn default_dashboard_port() -> u16 {
8080
}
fn default_hub_url() -> String {
"https://hub.axonml.dev/v1".to_string()
}
fn default_hub_cache_dir() -> String {
"~/.axonml/hub_cache".to_string()
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
host: default_host(),
port: default_port(),
data_dir: default_data_dir(),
}
}
}
impl Default for AegisConfig {
fn default() -> Self {
Self {
host: default_aegis_host(),
port: default_aegis_port(),
username: default_aegis_user(),
password: default_aegis_pass(),
}
}
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
jwt_secret: default_jwt_secret(),
jwt_expiry_hours: default_jwt_expiry(),
session_timeout_minutes: default_session_timeout(),
require_mfa: false,
allow_public_registration: default_allow_registration(),
}
}
}
impl Default for InferenceConfig {
fn default() -> Self {
Self {
default_port_range_start: default_port_start(),
default_port_range_end: default_port_end(),
max_endpoints: default_max_endpoints(),
}
}
}
impl Default for DashboardConfig {
fn default() -> Self {
Self {
port: default_dashboard_port(),
}
}
}
impl Default for HubConfig {
fn default() -> Self {
Self {
hub_url: default_hub_url(),
cache_dir: default_hub_cache_dir(),
}
}
}
impl Config {
pub fn load() -> Result<Self, ConfigError> {
let config_path = Self::config_path();
if config_path.exists() {
Self::load_from_path(&config_path)
} else {
Ok(Self::default())
}
}
pub fn load_from_path(path: &PathBuf) -> Result<Self, ConfigError> {
let content = std::fs::read_to_string(path)?;
let config: Config = toml::from_str(&content)?;
Ok(config)
}
pub fn config_path() -> PathBuf {
dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(".axonml")
.join("config.toml")
}
pub fn data_dir(&self) -> PathBuf {
let path = self.server.data_dir.replace(
"~",
dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.to_str()
.unwrap_or("."),
);
PathBuf::from(path)
}
pub fn models_dir(&self) -> PathBuf {
self.data_dir().join("models")
}
pub fn runs_dir(&self) -> PathBuf {
self.data_dir().join("runs")
}
pub fn logs_dir(&self) -> PathBuf {
self.data_dir().join("logs")
}
pub fn checkpoints_dir(&self) -> PathBuf {
self.data_dir().join("checkpoints")
}
pub fn hub_cache_dir(&self) -> PathBuf {
let path = self.hub.cache_dir.replace(
"~",
dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.to_str()
.unwrap_or("."),
);
PathBuf::from(path)
}
pub fn ensure_directories(&self) -> std::io::Result<()> {
std::fs::create_dir_all(self.data_dir())?;
std::fs::create_dir_all(self.models_dir())?;
std::fs::create_dir_all(self.runs_dir())?;
std::fs::create_dir_all(self.logs_dir())?;
std::fs::create_dir_all(self.checkpoints_dir())?;
std::fs::create_dir_all(self.hub_cache_dir())?;
Ok(())
}
pub fn aegis_url(&self) -> String {
format!("http://{}:{}", self.aegis.host, self.aegis.port)
}
pub fn validate(&self) -> Result<(), ConfigError> {
if self.auth.jwt_secret.is_empty() {
return Err(ConfigError::MissingConfig(
"jwt_secret is required. Set auth.jwt_secret in config.toml or AXONML_JWT_SECRET environment variable.".to_string()
));
}
if self.auth.jwt_secret.len() < 32 {
return Err(ConfigError::MissingConfig(
"jwt_secret must be at least 32 characters long for security.".to_string(),
));
}
if self.aegis.username.is_empty() || self.aegis.password.is_empty() {
return Err(ConfigError::MissingConfig(
"Database credentials are required. Set aegis.username and aegis.password in config.toml.".to_string()
));
}
Ok(())
}
pub fn validate_warnings(&self) -> Vec<String> {
let mut warnings = Vec::new();
if self.auth.allow_public_registration {
warnings.push("INFO: Public registration is enabled.".to_string());
}
if !self.auth.require_mfa {
warnings.push("INFO: MFA is not required for users.".to_string());
}
warnings
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = Config::default();
assert_eq!(config.server.port, 3000);
assert_eq!(config.aegis.port, 9090);
}
#[test]
fn test_parse_config() {
let toml = r#"
[server]
host = "127.0.0.1"
port = 8000
[aegis]
host = "db.example.com"
port = 5432
[auth]
jwt_secret = "test_secret_that_is_at_least_32_characters_long_for_security"
require_mfa = true
[inference]
max_endpoints = 5
[dashboard]
port = 3000
"#;
let config: Config = toml::from_str(toml).unwrap();
assert_eq!(config.server.port, 8000);
assert_eq!(config.aegis.host, "db.example.com");
assert!(config.auth.require_mfa);
}
}