agtrace_runtime/
config.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::PathBuf;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ProviderConfig {
8    pub enabled: bool,
9    pub log_root: PathBuf,
10    #[serde(default)]
11    pub context_window_override: Option<u64>,
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize, Default)]
15pub struct Config {
16    #[serde(default)]
17    pub providers: HashMap<String, ProviderConfig>,
18}
19
20impl Config {
21    pub fn load() -> Result<Self> {
22        let config_path = Self::default_path()?;
23        Self::load_from(&config_path)
24    }
25
26    pub fn load_from(path: &PathBuf) -> Result<Self> {
27        if !path.exists() {
28            return Ok(Self::default());
29        }
30
31        let content = std::fs::read_to_string(path)?;
32        let config: Config = toml::from_str(&content)?;
33        Ok(config)
34    }
35
36    pub fn save(&self) -> Result<()> {
37        let config_path = Self::default_path()?;
38        self.save_to(&config_path)
39    }
40
41    pub fn save_to(&self, path: &PathBuf) -> Result<()> {
42        if let Some(parent) = path.parent() {
43            std::fs::create_dir_all(parent)?;
44        }
45
46        let content = toml::to_string_pretty(self)?;
47        std::fs::write(path, content)?;
48        Ok(())
49    }
50
51    pub fn default_path() -> Result<PathBuf> {
52        let home = std::env::var("HOME")
53            .or_else(|_| std::env::var("USERPROFILE"))
54            .map_err(|_| anyhow::anyhow!("Could not determine home directory"))?;
55
56        Ok(PathBuf::from(home).join(".agtrace").join("config.toml"))
57    }
58
59    pub fn detect_providers() -> Result<Self> {
60        let mut providers = HashMap::new();
61
62        for (name, path) in agtrace_providers::get_default_log_paths() {
63            if path.exists() {
64                providers.insert(
65                    name,
66                    ProviderConfig {
67                        enabled: true,
68                        log_root: path,
69                        context_window_override: None,
70                    },
71                );
72            }
73        }
74
75        Ok(Config { providers })
76    }
77
78    pub fn enabled_providers(&self) -> Vec<(&String, &ProviderConfig)> {
79        self.providers
80            .iter()
81            .filter(|(_, config)| config.enabled)
82            .collect()
83    }
84
85    pub fn set_provider(&mut self, name: String, config: ProviderConfig) {
86        self.providers.insert(name, config);
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use tempfile::TempDir;
94
95    #[test]
96    fn test_config_default() {
97        let config = Config::default();
98        assert_eq!(config.providers.len(), 0);
99    }
100
101    #[test]
102    fn test_config_save_and_load() -> Result<()> {
103        let temp_dir = TempDir::new()?;
104        let config_path = temp_dir.path().join("config.toml");
105
106        let mut config = Config::default();
107        config.set_provider(
108            "claude".to_string(),
109            ProviderConfig {
110                enabled: true,
111                log_root: PathBuf::from("/home/user/.claude/projects"),
112                context_window_override: None,
113            },
114        );
115
116        config.save_to(&config_path)?;
117        assert!(config_path.exists());
118
119        let loaded = Config::load_from(&config_path)?;
120        assert_eq!(loaded.providers.len(), 1);
121        assert!(loaded.providers.contains_key("claude"));
122        assert!(loaded.providers.get("claude").unwrap().enabled);
123
124        Ok(())
125    }
126
127    #[test]
128    fn test_enabled_providers() {
129        let mut config = Config::default();
130        config.set_provider(
131            "claude".to_string(),
132            ProviderConfig {
133                enabled: true,
134                log_root: PathBuf::from("/test/claude"),
135                context_window_override: None,
136            },
137        );
138        config.set_provider(
139            "codex".to_string(),
140            ProviderConfig {
141                enabled: false,
142                log_root: PathBuf::from("/test/codex"),
143                context_window_override: None,
144            },
145        );
146
147        let enabled = config.enabled_providers();
148        assert_eq!(enabled.len(), 1);
149        assert_eq!(enabled[0].0, "claude");
150    }
151
152    #[test]
153    fn test_load_nonexistent_returns_default() -> Result<()> {
154        let temp_dir = TempDir::new()?;
155        let config_path = temp_dir.path().join("nonexistent.toml");
156
157        let config = Config::load_from(&config_path)?;
158        assert_eq!(config.providers.len(), 0);
159
160        Ok(())
161    }
162}