use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
use super::PathResolver;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
#[derive(Default)]
pub struct GlobalConfig {
pub general: GeneralConfig,
pub paths: PathsConfig,
pub eval: EvalConfig,
pub gym: GymConfig,
pub llm: LlmConfig,
pub logging: LoggingConfig,
pub desktop: DesktopConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
#[derive(Default)]
pub struct GeneralConfig {
pub default_project_type: ProjectType,
pub telemetry_enabled: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum ProjectType {
#[default]
Eval,
Gym,
Both,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
#[derive(Default)]
pub struct PathsConfig {
pub user_data_dir: Option<PathBuf>,
pub scenario_search_paths: Vec<PathBuf>,
pub report_output_dir: Option<PathBuf>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct EvalConfig {
pub default_runs: u32,
pub default_seed: Option<u64>,
pub default_parallel: u32,
pub target_tick_duration_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct GymConfig {
pub data_dir: Option<PathBuf>,
pub default_episodes: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct LlmConfig {
pub default_provider: LlmProvider,
pub cache_enabled: bool,
pub cache_ttl_hours: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum LlmProvider {
#[default]
OpenAI,
Anthropic,
Local,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct LoggingConfig {
pub level: LogLevel,
pub file_enabled: bool,
pub max_size_mb: u32,
pub max_files: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum LogLevel {
Trace,
Debug,
#[default]
Info,
Warn,
Error,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct DesktopConfig {
pub remember_window_size: bool,
pub recent_projects_limit: u32,
pub auto_reload_scenarios: bool,
pub theme: Theme,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum Theme {
Light,
Dark,
#[default]
System,
}
impl Default for EvalConfig {
fn default() -> Self {
Self {
default_runs: 30,
default_seed: None,
default_parallel: 1,
target_tick_duration_ms: 10,
}
}
}
impl Default for GymConfig {
fn default() -> Self {
Self {
data_dir: None,
default_episodes: 1000,
}
}
}
impl Default for LlmConfig {
fn default() -> Self {
Self {
default_provider: LlmProvider::default(),
cache_enabled: true,
cache_ttl_hours: 168, }
}
}
impl Default for LoggingConfig {
fn default() -> Self {
Self {
level: LogLevel::default(),
file_enabled: true,
max_size_mb: 100,
max_files: 5,
}
}
}
impl Default for DesktopConfig {
fn default() -> Self {
Self {
remember_window_size: true,
recent_projects_limit: 10,
auto_reload_scenarios: true,
theme: Theme::default(),
}
}
}
impl GlobalConfig {
pub fn load_from_file(path: &Path) -> Result<Self, ConfigError> {
let content = std::fs::read_to_string(path).map_err(|e| ConfigError::Io {
path: path.to_path_buf(),
source: e,
})?;
toml::from_str(&content).map_err(|e| ConfigError::Parse {
path: path.to_path_buf(),
source: e,
})
}
pub fn load_global() -> Self {
let path = PathResolver::global_config_file();
if path.exists() {
match Self::load_from_file(&path) {
Ok(config) => config,
Err(e) => {
tracing::warn!("Failed to load global config: {}", e);
Self::default()
}
}
} else {
Self::default()
}
}
pub fn load_merged() -> Self {
let mut config = Self::load_global();
if let Some(project_path) = PathResolver::project_config_file() {
if project_path.exists() {
match Self::load_from_file(&project_path) {
Ok(project_config) => {
config.merge(project_config);
}
Err(e) => {
tracing::warn!("Failed to load project config: {}", e);
}
}
}
}
config
}
pub fn merge(&mut self, other: Self) {
self.general.default_project_type = other.general.default_project_type;
self.general.telemetry_enabled = other.general.telemetry_enabled;
if other.paths.user_data_dir.is_some() {
self.paths.user_data_dir = other.paths.user_data_dir;
}
self.paths
.scenario_search_paths
.extend(other.paths.scenario_search_paths);
if other.paths.report_output_dir.is_some() {
self.paths.report_output_dir = other.paths.report_output_dir;
}
self.eval.default_runs = other.eval.default_runs;
if other.eval.default_seed.is_some() {
self.eval.default_seed = other.eval.default_seed;
}
self.eval.default_parallel = other.eval.default_parallel;
self.eval.target_tick_duration_ms = other.eval.target_tick_duration_ms;
if other.gym.data_dir.is_some() {
self.gym.data_dir = other.gym.data_dir;
}
self.gym.default_episodes = other.gym.default_episodes;
self.llm.default_provider = other.llm.default_provider;
self.llm.cache_enabled = other.llm.cache_enabled;
self.llm.cache_ttl_hours = other.llm.cache_ttl_hours;
self.logging.level = other.logging.level;
self.logging.file_enabled = other.logging.file_enabled;
self.logging.max_size_mb = other.logging.max_size_mb;
self.logging.max_files = other.logging.max_files;
self.desktop.remember_window_size = other.desktop.remember_window_size;
self.desktop.recent_projects_limit = other.desktop.recent_projects_limit;
self.desktop.auto_reload_scenarios = other.desktop.auto_reload_scenarios;
self.desktop.theme = other.desktop.theme;
}
pub fn save_to_file(&self, path: &Path) -> Result<(), ConfigError> {
let content = toml::to_string_pretty(self).map_err(ConfigError::Serialize)?;
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(|e| ConfigError::Io {
path: parent.to_path_buf(),
source: e,
})?;
}
std::fs::write(path, content).map_err(|e| ConfigError::Io {
path: path.to_path_buf(),
source: e,
})
}
pub fn save_global(&self) -> Result<(), ConfigError> {
self.save_to_file(&PathResolver::global_config_file())
}
pub fn resolved_user_data_dir(&self) -> PathBuf {
self.paths
.user_data_dir
.clone()
.unwrap_or_else(PathResolver::user_data_dir)
}
pub fn resolved_reports_dir(&self) -> PathBuf {
self.paths
.report_output_dir
.clone()
.unwrap_or_else(PathResolver::reports_dir)
}
}
#[derive(Debug, thiserror::Error)]
pub enum ConfigError {
#[error("Failed to read config file {path}: {source}")]
Io {
path: PathBuf,
source: std::io::Error,
},
#[error("Failed to parse config file {path}: {source}")]
Parse {
path: PathBuf,
source: toml::de::Error,
},
#[error("Failed to serialize config: {0}")]
Serialize(toml::ser::Error),
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_default_config() {
let config = GlobalConfig::default();
assert_eq!(config.eval.default_runs, 30);
assert_eq!(config.logging.level, LogLevel::Info);
assert_eq!(config.desktop.theme, Theme::System);
}
#[test]
fn test_save_and_load() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("config.toml");
let mut config = GlobalConfig::default();
config.eval.default_runs = 50;
config.logging.level = LogLevel::Debug;
config.save_to_file(&path).unwrap();
let loaded = GlobalConfig::load_from_file(&path).unwrap();
assert_eq!(loaded.eval.default_runs, 50);
assert_eq!(loaded.logging.level, LogLevel::Debug);
}
#[test]
fn test_merge_configs() {
let mut base = GlobalConfig::default();
base.eval.default_runs = 10;
base.paths.scenario_search_paths = vec![PathBuf::from("/base/path")];
let mut override_config = GlobalConfig::default();
override_config.eval.default_runs = 20;
override_config.paths.scenario_search_paths = vec![PathBuf::from("/override/path")];
base.merge(override_config);
assert_eq!(base.eval.default_runs, 20);
assert_eq!(base.paths.scenario_search_paths.len(), 2);
assert_eq!(
base.paths.scenario_search_paths[0],
PathBuf::from("/base/path")
);
assert_eq!(
base.paths.scenario_search_paths[1],
PathBuf::from("/override/path")
);
}
#[test]
fn test_parse_toml() {
let toml_str = r#"
[general]
default_project_type = "eval"
telemetry_enabled = false
[eval]
default_runs = 100
target_tick_duration_ms = 5
[logging]
level = "debug"
"#;
let config: GlobalConfig = toml::from_str(toml_str).unwrap();
assert_eq!(config.eval.default_runs, 100);
assert_eq!(config.eval.target_tick_duration_ms, 5);
assert_eq!(config.logging.level, LogLevel::Debug);
}
#[test]
fn test_load_global_missing_file() {
let config = GlobalConfig::load_global();
assert_eq!(config.eval.default_runs, 30);
}
}