Skip to main content

agent_trace/state/
config.rs

1use crate::types::{DocType, StoreId};
2use anyhow::{Context, Result};
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::path::{Path, PathBuf};
6
7// ── Synthesis Config ─────────────────────────────────────────────────────────
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
10#[serde(rename_all = "lowercase")]
11pub enum SynthesisMode {
12    #[default]
13    Auto,
14    Remote,
15    Ollama,
16    /// Legacy: configs with mode=embedded are migrated to Auto at load time.
17    #[serde(alias = "embedded")]
18    Embedded,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
22#[serde(rename_all = "lowercase")]
23pub enum SynthesisProvider {
24    #[default]
25    Ollama,
26    Openai,
27    Anthropic,
28    Openrouter,
29    Custom,
30    /// Legacy: configs with provider=embedded are treated as Ollama.
31    #[serde(alias = "embedded")]
32    Embedded,
33}
34
35impl SynthesisProvider {
36    pub fn slug(self) -> &'static str {
37        match self {
38            Self::Openai => "openai",
39            Self::Anthropic => "anthropic",
40            Self::Openrouter => "openrouter",
41            Self::Ollama => "ollama",
42            Self::Custom => "custom",
43            Self::Embedded => "ollama", // legacy: treat as ollama
44        }
45    }
46
47    pub fn default_model(self) -> &'static str {
48        match self {
49            Self::Openai => "gpt-4o-mini",
50            Self::Anthropic => "claude-3-5-haiku-latest",
51            Self::Openrouter => "openai/gpt-4o-mini",
52            Self::Ollama => "qwen2.5:1.5b",
53            Self::Custom => "gpt-4o-mini",
54            Self::Embedded => "qwen2.5:1.5b", // legacy: migrate to ollama default
55        }
56    }
57
58    pub fn default_base_url(self) -> &'static str {
59        match self {
60            Self::Openai => "https://api.openai.com/v1",
61            Self::Anthropic => "https://api.anthropic.com/v1",
62            Self::Openrouter => "https://openrouter.ai/api/v1",
63            Self::Ollama => "http://127.0.0.1:11434/v1",
64            Self::Custom => "http://127.0.0.1:11434/v1",
65            Self::Embedded => "http://127.0.0.1:11434/v1", // legacy: migrate to ollama
66        }
67    }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
71pub struct SynthesisConfig {
72    #[serde(default)]
73    pub mode: SynthesisMode,
74    #[serde(default)]
75    pub provider: SynthesisProvider,
76    #[serde(default = "default_synthesis_model")]
77    pub model: String,
78    pub base_url: Option<String>,
79    #[serde(default = "default_max_tokens")]
80    pub max_tokens: usize,
81    #[serde(default = "default_synthesis_temperature")]
82    pub temperature: f32,
83    #[serde(default = "default_refresh_every_ops")]
84    pub refresh_every_ops: usize,
85    /// Legacy field — ignored; kept for deserializing old config files.
86    #[serde(skip_serializing, default)]
87    pub fallback: serde_json::Value,
88}
89
90fn default_synthesis_model() -> String {
91    SynthesisProvider::Ollama.default_model().into()
92}
93
94fn default_max_tokens() -> usize {
95    4096
96}
97
98fn default_synthesis_temperature() -> f32 {
99    0.3
100}
101
102fn default_refresh_every_ops() -> usize {
103    10
104}
105
106impl Default for SynthesisConfig {
107    fn default() -> Self {
108        Self {
109            mode: SynthesisMode::Auto,
110            provider: SynthesisProvider::Ollama,
111            model: default_synthesis_model(),
112            base_url: None,
113            max_tokens: default_max_tokens(),
114            temperature: default_synthesis_temperature(),
115            refresh_every_ops: default_refresh_every_ops(),
116            fallback: serde_json::Value::Null,
117        }
118    }
119}
120
121impl SynthesisConfig {
122    /// Return the configured model, or the provider's default if blank.
123    pub fn effective_model(&self) -> String {
124        if self.model.trim().is_empty() {
125            self.provider.default_model().into()
126        } else {
127            self.model.clone()
128        }
129    }
130
131    pub fn effective_base_url(&self) -> String {
132        self.base_url
133            .clone()
134            .filter(|u| !u.is_empty())
135            .unwrap_or_else(|| self.provider.default_base_url().to_string())
136    }
137
138    pub fn provider_needs_credentials(provider: SynthesisProvider) -> bool {
139        matches!(
140            provider,
141            SynthesisProvider::Openai
142                | SynthesisProvider::Anthropic
143                | SynthesisProvider::Openrouter
144        )
145    }
146
147    /// Synthesis config that cannot reach any backend (unit tests only).
148    #[cfg(test)]
149    pub fn for_unit_tests_degraded() -> Self {
150        Self {
151            base_url: Some("http://127.0.0.1:1/v1".into()),
152            ..Default::default()
153        }
154    }
155
156    pub fn merge(base: Self, override_cfg: Option<&Self>) -> Self {
157        let Some(ov) = override_cfg else {
158            return base;
159        };
160        Self {
161            mode: ov.mode,
162            provider: ov.provider,
163            model: if ov.model.is_empty() {
164                base.model
165            } else {
166                ov.model.clone()
167            },
168            base_url: ov.base_url.clone().or(base.base_url),
169            max_tokens: if ov.max_tokens == 0 {
170                base.max_tokens
171            } else {
172                ov.max_tokens
173            },
174            temperature: ov.temperature,
175            refresh_every_ops: if ov.refresh_every_ops == 0 {
176                base.refresh_every_ops
177            } else {
178                ov.refresh_every_ops
179            },
180            fallback: serde_json::Value::Null,
181        }
182    }
183}
184
185// ── Credentials ──────────────────────────────────────────────────────────────
186
187#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
188pub struct ProviderCredentials {
189    pub api_key: Option<String>,
190}
191
192#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
193pub struct CredentialsStore {
194    #[serde(default)]
195    pub openai: Option<ProviderCredentials>,
196    #[serde(default)]
197    pub anthropic: Option<ProviderCredentials>,
198    #[serde(default)]
199    pub openrouter: Option<ProviderCredentials>,
200    #[serde(default)]
201    pub custom: Option<ProviderCredentials>,
202}
203
204impl CredentialsStore {
205    pub fn load() -> Result<Self> {
206        let path = credentials_path();
207        if !path.exists() {
208            return Ok(Self::default());
209        }
210        let contents = std::fs::read_to_string(&path)
211            .with_context(|| format!("Reading credentials: {}", path.display()))?;
212        let store: Self = toml::from_str(&contents)
213            .with_context(|| format!("Parsing credentials: {}", path.display()))?;
214        Ok(store)
215    }
216
217    pub fn save(&self) -> Result<()> {
218        let path = credentials_path();
219        if let Some(parent) = path.parent() {
220            std::fs::create_dir_all(parent)?;
221        }
222        let contents = toml::to_string_pretty(self)?;
223        crate::util::atomic_write(&path, &contents)?;
224        #[cfg(unix)]
225        {
226            use std::os::unix::fs::PermissionsExt;
227            std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?;
228        }
229        Ok(())
230    }
231
232    pub fn api_key_for(&self, provider: SynthesisProvider) -> Option<String> {
233        self.stored_key(provider)
234            .or_else(|| self.env_key(provider))
235            .filter(|k| !k.is_empty())
236    }
237
238    fn stored_key(&self, provider: SynthesisProvider) -> Option<String> {
239        let section = match provider {
240            SynthesisProvider::Openai => &self.openai,
241            SynthesisProvider::Anthropic => &self.anthropic,
242            SynthesisProvider::Openrouter => &self.openrouter,
243            SynthesisProvider::Custom => &self.custom,
244            _ => return None,
245        };
246        section.as_ref().and_then(|s| s.api_key.clone())
247    }
248
249    fn env_key(&self, provider: SynthesisProvider) -> Option<String> {
250        let var = match provider {
251            SynthesisProvider::Openai => "OPENAI_API_KEY",
252            SynthesisProvider::Anthropic => "ANTHROPIC_API_KEY",
253            SynthesisProvider::Openrouter => "OPENROUTER_API_KEY",
254            SynthesisProvider::Custom => "AGENT_TRACE_API_KEY",
255            _ => return None,
256        };
257        std::env::var(var).ok()
258    }
259
260    pub fn set_key(&mut self, provider: SynthesisProvider, key: String) {
261        let entry = ProviderCredentials { api_key: Some(key) };
262        match provider {
263            SynthesisProvider::Openai => self.openai = Some(entry),
264            SynthesisProvider::Anthropic => self.anthropic = Some(entry),
265            SynthesisProvider::Openrouter => self.openrouter = Some(entry),
266            SynthesisProvider::Custom => self.custom = Some(entry),
267            _ => {}
268        }
269    }
270
271    pub fn clear_key(&mut self, provider: SynthesisProvider) {
272        match provider {
273            SynthesisProvider::Openai => self.openai = None,
274            SynthesisProvider::Anthropic => self.anthropic = None,
275            SynthesisProvider::Openrouter => self.openrouter = None,
276            SynthesisProvider::Custom => self.custom = None,
277            _ => {}
278        }
279    }
280
281    pub fn redacted_key(&self, provider: SynthesisProvider) -> Option<String> {
282        self.api_key_for(provider).map(|k| {
283            if k.len() <= 8 {
284                "***".into()
285            } else {
286                format!("{}...{}", &k[..4], &k[k.len() - 4..])
287            }
288        })
289    }
290}
291
292pub fn credentials_path() -> PathBuf {
293    dirs_next::config_dir()
294        .unwrap_or_else(|| PathBuf::from("."))
295        .join("agent-trace")
296        .join("credentials.toml")
297}
298
299/// Legacy LLM config — kept for serde deserialization of old config files.
300/// No longer written or used; the embedded/Candle path was removed.
301#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
302pub struct LlmConfig {
303    pub model_path: Option<PathBuf>,
304    #[serde(default = "default_llm_max_tokens")]
305    pub max_tokens: usize,
306    #[serde(default = "default_llm_temperature")]
307    pub temperature: f32,
308}
309
310fn default_llm_max_tokens() -> usize {
311    4096
312}
313fn default_llm_temperature() -> f32 {
314    0.7
315}
316
317// ── UI Config ────────────────────────────────────────────────────────────────
318
319#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
320pub struct UiConfig {
321    /// Show ASCII art banner on startup.
322    pub show_banner: bool,
323    /// Number of changelog entries to show on startup.
324    pub changelog_limit: usize,
325    /// Use ASCII-only box drawing characters.
326    pub ascii_only: bool,
327}
328
329impl Default for UiConfig {
330    fn default() -> Self {
331        Self {
332            show_banner: true,
333            changelog_limit: 50,
334            ascii_only: false,
335        }
336    }
337}
338
339// ── Defaults Config ──────────────────────────────────────────────────────────
340
341#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
342pub struct DefaultsConfig {
343    /// Default doc type for newly added files.
344    pub default_doc_type: DocType,
345    /// Default agent name when --agent flag is not provided.
346    pub default_agent_name: Option<String>,
347}
348
349impl Default for DefaultsConfig {
350    fn default() -> Self {
351        Self {
352            default_doc_type: DocType::Scratch,
353            default_agent_name: None,
354        }
355    }
356}
357
358// ── Global Config ────────────────────────────────────────────────────────────
359
360#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
361pub struct GlobalConfig {
362    /// Legacy field — ignored; kept for deserializing old global configs.
363    #[serde(skip_serializing, default)]
364    pub llm: LlmConfig,
365    #[serde(default)]
366    pub synthesis: SynthesisConfig,
367    #[serde(default)]
368    pub ui: UiConfig,
369    #[serde(default)]
370    pub defaults: DefaultsConfig,
371}
372
373impl GlobalConfig {
374    /// Load from `~/.config/agent-trace/config.toml`, using defaults if absent.
375    pub fn load() -> Result<Self> {
376        let path = global_config_path();
377        if !path.exists() {
378            return Ok(Self::default());
379        }
380        let contents = std::fs::read_to_string(&path)
381            .with_context(|| format!("Reading global config: {}", path.display()))?;
382        toml::from_str(&contents)
383            .with_context(|| format!("Parsing global config: {}", path.display()))
384    }
385
386    pub fn save(&self) -> Result<()> {
387        let path = global_config_path();
388        if let Some(parent) = path.parent() {
389            std::fs::create_dir_all(parent)?;
390        }
391        let contents = toml::to_string_pretty(self)?;
392        crate::util::atomic_write(&path, &contents)?;
393        Ok(())
394    }
395}
396
397pub fn global_config_path() -> PathBuf {
398    dirs_next::config_dir()
399        .unwrap_or_else(|| PathBuf::from("."))
400        .join("agent-trace")
401        .join("config.toml")
402}
403
404// ── Store Config ─────────────────────────────────────────────────────────────
405
406#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
407pub struct StoreInfo {
408    pub id: StoreId,
409    pub name: String,
410    pub created: DateTime<Utc>,
411    pub agent_trace_version: String,
412}
413
414impl StoreInfo {
415    pub fn new(name: String) -> Self {
416        Self {
417            id: StoreId::new(),
418            name,
419            created: Utc::now(),
420            agent_trace_version: env!("CARGO_PKG_VERSION").to_string(),
421        }
422    }
423}
424
425#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
426pub struct PollingConfig {
427    /// Poll interval in milliseconds.
428    pub interval_ms: u64,
429    /// When false, the background poll loop is not started (manual refresh only).
430    pub enabled: bool,
431}
432
433impl Default for PollingConfig {
434    fn default() -> Self {
435        Self {
436            interval_ms: 1000,
437            enabled: true,
438        }
439    }
440}
441
442#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
443pub struct StoreConfig {
444    pub store: StoreInfo,
445    pub llm: Option<LlmConfig>,
446    #[serde(default)]
447    pub synthesis: Option<SynthesisConfig>,
448    #[serde(default)]
449    pub polling: PollingConfig,
450}
451
452impl StoreConfig {
453    /// Load from `.agent-trace/config.toml` inside the store root.
454    pub fn load(store_root: &Path) -> Result<Self> {
455        let path = store_config_path(store_root);
456        let contents = std::fs::read_to_string(&path)
457            .with_context(|| format!("Reading store config: {}", path.display()))?;
458        toml::from_str(&contents)
459            .with_context(|| format!("Parsing store config: {}", path.display()))
460    }
461
462    pub fn save(&self, store_root: &Path) -> Result<()> {
463        let path = store_config_path(store_root);
464        let contents = toml::to_string_pretty(self)?;
465        crate::util::atomic_write(&path, &contents)?;
466        Ok(())
467    }
468}
469
470pub fn store_config_path(store_root: &Path) -> PathBuf {
471    store_root.join(".agent-trace").join("config.toml")
472}
473
474// ── Merged Config ─────────────────────────────────────────────────────────────
475
476/// Resolved configuration: per-store values override global defaults.
477#[derive(Debug, Clone)]
478pub struct MergedConfig {
479    #[allow(dead_code)]
480    pub store: StoreInfo,
481    pub synthesis: SynthesisConfig,
482    pub ui: UiConfig,
483    pub defaults: DefaultsConfig,
484    pub polling: PollingConfig,
485}
486
487impl Default for MergedConfig {
488    fn default() -> Self {
489        let global = GlobalConfig::default();
490        Self {
491            synthesis: global.synthesis,
492            ui: global.ui,
493            defaults: global.defaults,
494            polling: PollingConfig::default(),
495            store: StoreInfo::new("default".into()),
496        }
497    }
498}
499
500impl MergedConfig {
501    pub fn merge(global: GlobalConfig, store: StoreConfig) -> Self {
502        Self {
503            synthesis: SynthesisConfig::merge(global.synthesis, store.synthesis.as_ref()),
504            ui: global.ui,
505            defaults: global.defaults,
506            polling: store.polling,
507            store: store.store,
508        }
509    }
510
511    /// Load and merge configs for a given store root.
512    pub fn load(store_root: &Path) -> Result<Self> {
513        let global = GlobalConfig::load()?;
514        let store = StoreConfig::load(store_root)?;
515        Ok(Self::merge(global, store))
516    }
517}
518
519#[cfg(test)]
520mod tests {
521    use super::*;
522    use tempfile::TempDir;
523
524    fn write_toml(path: &Path, content: &str) {
525        std::fs::create_dir_all(path.parent().unwrap()).unwrap();
526        std::fs::write(path, content).unwrap();
527    }
528
529    #[test]
530    fn test_global_config_defaults() {
531        let cfg = GlobalConfig::default();
532        assert_eq!(cfg.synthesis.model, "qwen2.5:1.5b");
533        assert_eq!(cfg.ui, UiConfig::default());
534        assert_eq!(cfg.defaults, DefaultsConfig::default());
535    }
536
537    #[test]
538    fn test_effective_model_returns_default_when_blank() {
539        let mut syn = SynthesisConfig::default();
540        syn.model = String::new();
541        assert_eq!(syn.effective_model(), syn.provider.default_model());
542    }
543
544    #[test]
545    fn test_effective_model_returns_configured() {
546        let mut syn = SynthesisConfig::default();
547        syn.model = "llama3:8b".into();
548        assert_eq!(syn.effective_model(), "llama3:8b");
549    }
550
551    #[test]
552    fn test_credentials_roundtrip() {
553        let tmp = TempDir::new().unwrap();
554        let cred_path = tmp.path().join("credentials.toml");
555        // Override path via writing directly for unit test
556        let mut store = CredentialsStore::default();
557        store.set_key(SynthesisProvider::Openai, "sk-test-key".into());
558        let contents = toml::to_string_pretty(&store).unwrap();
559        std::fs::write(&cred_path, &contents).unwrap();
560        let loaded: CredentialsStore =
561            toml::from_str(&std::fs::read_to_string(&cred_path).unwrap()).unwrap();
562        assert_eq!(
563            loaded.api_key_for(SynthesisProvider::Openai),
564            Some("sk-test-key".into())
565        );
566    }
567
568    #[test]
569    fn test_synthesis_merge_store_override() {
570        let global = GlobalConfig::default();
571        let store = StoreConfig {
572            store: StoreInfo::new("s".into()),
573            llm: None,
574            synthesis: Some(SynthesisConfig {
575                model: "gpt-4o".into(),
576                provider: SynthesisProvider::Openai,
577                ..Default::default()
578            }),
579            polling: PollingConfig::default(),
580        };
581        let merged = MergedConfig::merge(global, store);
582        assert_eq!(merged.synthesis.model, "gpt-4o");
583        assert_eq!(merged.synthesis.provider, SynthesisProvider::Openai);
584    }
585
586    #[test]
587    fn test_store_config_roundtrip() {
588        let tmp = TempDir::new().unwrap();
589        let store_root = tmp.path();
590        std::fs::create_dir_all(store_root.join(".agent-trace")).unwrap();
591
592        let info = StoreInfo::new("test-store".into());
593        let cfg = StoreConfig {
594            store: info,
595            llm: None,
596            synthesis: Some(SynthesisConfig {
597                model: "qwen2.5:1.5b".into(),
598                ..Default::default()
599            }),
600            polling: PollingConfig::default(),
601        };
602        cfg.save(store_root).unwrap();
603
604        let loaded = StoreConfig::load(store_root).unwrap();
605        assert_eq!(loaded.store.name, "test-store");
606        assert_eq!(loaded.synthesis.as_ref().unwrap().model, "qwen2.5:1.5b");
607    }
608
609    #[test]
610    fn test_legacy_llm_config_in_store_still_deserializes() {
611        // Old configs with [llm] sections should still parse without error
612        let tmp = TempDir::new().unwrap();
613        let store_root = tmp.path();
614        let path = store_config_path(store_root);
615        let toml_content = r#"
616[store]
617id = "00000000-0000-0000-0000-000000000001"
618name = "legacy"
619created = "2024-01-01T00:00:00Z"
620agent_trace_version = "0.0.1"
621
622[llm]
623model_path = "/tmp/model.gguf"
624max_tokens = 2048
625temperature = 0.5
626"#;
627        write_toml(&path, toml_content);
628        let loaded = StoreConfig::load(store_root).unwrap();
629        assert_eq!(loaded.store.name, "legacy");
630    }
631
632    #[test]
633    fn test_malformed_toml_error() {
634        let tmp = TempDir::new().unwrap();
635        let store_root = tmp.path();
636        let path = store_config_path(store_root);
637        write_toml(&path, "this is not [ valid toml }{");
638        let err = StoreConfig::load(store_root);
639        assert!(err.is_err());
640        let msg = err.unwrap_err().to_string();
641        assert!(msg.contains("config.toml") || msg.contains("Parsing"));
642    }
643
644    #[test]
645    fn test_store_info_has_uuid() {
646        let info = StoreInfo::new("my-store".into());
647        assert!(!info.id.0.is_empty());
648        assert_eq!(info.name, "my-store");
649        assert_eq!(info.agent_trace_version, env!("CARGO_PKG_VERSION"));
650        // Validate UUID format
651        assert!(info.id.0.parse::<uuid::Uuid>().is_ok());
652    }
653}