Skip to main content

recall_echo/
config.rs

1use std::fmt;
2use std::fs;
3use std::path::Path;
4
5use serde::{Deserialize, Serialize};
6
7const DEFAULT_MAX_ENTRIES: usize = 5;
8const CONFIG_FILE: &str = ".recall-echo.toml";
9
10// ── Provider enum ────────────────────────────────────────────────────────
11
12/// LLM provider for entity extraction.
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "kebab-case")]
15pub enum Provider {
16    Anthropic,
17    Openai,
18    ClaudeCode,
19}
20
21impl Provider {
22    pub fn default_model(&self) -> &'static str {
23        match self {
24            Provider::Anthropic => "claude-haiku-4-5-20251001",
25            Provider::Openai => "llama3.2",
26            Provider::ClaudeCode => "",
27        }
28    }
29
30    pub fn default_api_base(&self) -> &'static str {
31        match self {
32            Provider::Anthropic => "https://api.anthropic.com/v1/messages",
33            Provider::Openai => "http://localhost:11434/v1",
34            Provider::ClaudeCode => "",
35        }
36    }
37
38    pub fn from_str_loose(s: &str) -> Result<Self, String> {
39        match s.to_lowercase().as_str() {
40            "anthropic" | "claude" => Ok(Provider::Anthropic),
41            "openai" | "ollama" => Ok(Provider::Openai),
42            "claude-code" | "claudecode" => Ok(Provider::ClaudeCode),
43            other => Err(format!(
44                "unknown provider: {other} (use 'anthropic', 'ollama', or 'claude-code')"
45            )),
46        }
47    }
48}
49
50impl fmt::Display for Provider {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        match self {
53            Provider::Anthropic => write!(f, "anthropic"),
54            Provider::Openai => write!(f, "openai"),
55            Provider::ClaudeCode => write!(f, "claude-code"),
56        }
57    }
58}
59
60// ── Config structs ───────────────────────────────────────────────────────
61
62#[derive(Debug, Default, Serialize, Deserialize)]
63pub struct Config {
64    #[serde(default)]
65    pub ephemeral: EphemeralConfig,
66    #[serde(default)]
67    pub llm: LlmSection,
68    #[serde(default)]
69    pub pipeline: Option<PipelineSection>,
70}
71
72#[derive(Debug, Serialize, Deserialize)]
73pub struct EphemeralConfig {
74    #[serde(default = "default_max_entries")]
75    pub max_entries: usize,
76}
77
78impl Default for EphemeralConfig {
79    fn default() -> Self {
80        Self {
81            max_entries: DEFAULT_MAX_ENTRIES,
82        }
83    }
84}
85
86fn default_max_entries() -> usize {
87    DEFAULT_MAX_ENTRIES
88}
89
90#[derive(Debug, Serialize, Deserialize)]
91pub struct LlmSection {
92    #[serde(default = "default_provider")]
93    pub provider: Provider,
94    #[serde(default)]
95    pub model: String,
96    #[serde(default)]
97    pub api_base: String,
98}
99
100impl Default for LlmSection {
101    fn default() -> Self {
102        Self {
103            provider: Provider::Anthropic,
104            model: String::new(),
105            api_base: String::new(),
106        }
107    }
108}
109
110impl LlmSection {
111    /// Resolved model — uses configured value or provider default.
112    pub fn resolved_model(&self) -> &str {
113        if self.model.is_empty() {
114            self.provider.default_model()
115        } else {
116            &self.model
117        }
118    }
119
120    /// Resolved API base — uses configured value or provider default.
121    pub fn resolved_api_base(&self) -> &str {
122        if self.api_base.is_empty() {
123            self.provider.default_api_base()
124        } else {
125            &self.api_base
126        }
127    }
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct PipelineSection {
132    /// Directory containing pipeline documents (LEARNING.md, THOUGHTS.md, etc.)
133    #[serde(default)]
134    pub docs_dir: Option<String>,
135    /// Auto-sync pipeline on archive (default: false)
136    #[serde(default)]
137    pub auto_sync: Option<bool>,
138}
139
140fn default_provider() -> Provider {
141    Provider::Anthropic
142}
143
144// ── Load / Save ──────────────────────────────────────────────────────────
145
146/// Config file path for a given base directory.
147pub fn config_path(base: &Path) -> std::path::PathBuf {
148    base.join(CONFIG_FILE)
149}
150
151/// Load config from .recall-echo.toml in the given directory.
152/// Returns defaults if file doesn't exist or is malformed.
153pub fn load_from_dir(dir: &Path) -> Config {
154    load(dir)
155}
156
157/// Load config from .recall-echo.toml in the base dir.
158/// Returns defaults if file doesn't exist or is malformed.
159pub fn load(base: &Path) -> Config {
160    let path = config_path(base);
161    if !path.exists() {
162        return Config::default();
163    }
164
165    let content = match fs::read_to_string(&path) {
166        Ok(c) => c,
167        Err(_) => return Config::default(),
168    };
169
170    match toml::from_str(&content) {
171        Ok(cfg) => validate(cfg),
172        Err(_) => Config::default(),
173    }
174}
175
176/// Save config to .recall-echo.toml in the base dir.
177pub fn save(base: &Path, config: &Config) -> Result<(), String> {
178    let path = config_path(base);
179    let content = toml::to_string_pretty(config).map_err(|e| format!("serialize config: {e}"))?;
180    fs::write(&path, content).map_err(|e| format!("write {}: {e}", path.display()))
181}
182
183/// Returns true if .recall-echo.toml exists in the directory.
184pub fn exists(base: &Path) -> bool {
185    config_path(base).exists()
186}
187
188fn validate(mut cfg: Config) -> Config {
189    if !(1..=50).contains(&cfg.ephemeral.max_entries) {
190        cfg.ephemeral.max_entries = DEFAULT_MAX_ENTRIES;
191    }
192    cfg
193}
194
195// ── Config mutation helpers ──────────────────────────────────────────────
196
197impl Config {
198    /// Set a dotted config key (e.g. "llm.provider", "ephemeral.max_entries").
199    pub fn set_key(&mut self, key: &str, value: &str) -> Result<(), String> {
200        match key {
201            "llm.provider" | "provider" => {
202                let provider = Provider::from_str_loose(value)?;
203                // When switching provider, reset model and api_base to defaults
204                self.llm.model = String::new();
205                self.llm.api_base = String::new();
206                self.llm.provider = provider;
207                Ok(())
208            }
209            "llm.model" | "model" => {
210                self.llm.model = value.to_string();
211                Ok(())
212            }
213            "llm.api_base" | "api_base" => {
214                self.llm.api_base = value.to_string();
215                Ok(())
216            }
217            "ephemeral.max_entries" => {
218                let n: usize = value
219                    .parse()
220                    .map_err(|_| format!("invalid number: {value}"))?;
221                if !(1..=50).contains(&n) {
222                    return Err("max_entries must be between 1 and 50".into());
223                }
224                self.ephemeral.max_entries = n;
225                Ok(())
226            }
227            "pipeline.docs_dir" => {
228                let section = self.pipeline.get_or_insert(PipelineSection {
229                    docs_dir: None,
230                    auto_sync: None,
231                });
232                section.docs_dir = Some(value.to_string());
233                Ok(())
234            }
235            "pipeline.auto_sync" => {
236                let b: bool = value
237                    .parse()
238                    .map_err(|_| format!("invalid boolean: {value}"))?;
239                let section = self.pipeline.get_or_insert(PipelineSection {
240                    docs_dir: None,
241                    auto_sync: None,
242                });
243                section.auto_sync = Some(b);
244                Ok(())
245            }
246            other => Err(format!("unknown config key: {other}")),
247        }
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn default_config() {
257        let cfg = Config::default();
258        assert_eq!(cfg.ephemeral.max_entries, 5);
259        assert_eq!(cfg.llm.provider, Provider::Anthropic);
260        assert!(cfg.llm.model.is_empty());
261    }
262
263    #[test]
264    fn parse_ephemeral_only() {
265        let cfg: Config = toml::from_str("[ephemeral]\nmax_entries = 10\n").unwrap();
266        assert_eq!(cfg.ephemeral.max_entries, 10);
267        assert_eq!(cfg.llm.provider, Provider::Anthropic);
268    }
269
270    #[test]
271    fn parse_llm_section() {
272        let cfg: Config = toml::from_str(
273            "[llm]\nprovider = \"openai\"\nmodel = \"llama3.1\"\napi_base = \"http://myhost:11434/v1\"\n",
274        )
275        .unwrap();
276        assert_eq!(cfg.llm.provider, Provider::Openai);
277        assert_eq!(cfg.llm.model, "llama3.1");
278        assert_eq!(cfg.llm.api_base, "http://myhost:11434/v1");
279    }
280
281    #[test]
282    fn parse_claude_code_provider() {
283        let cfg: Config = toml::from_str("[llm]\nprovider = \"claude-code\"\n").unwrap();
284        assert_eq!(cfg.llm.provider, Provider::ClaudeCode);
285    }
286
287    #[test]
288    fn resolved_defaults() {
289        let llm = LlmSection::default();
290        assert_eq!(llm.resolved_model(), "claude-haiku-4-5-20251001");
291        assert_eq!(
292            llm.resolved_api_base(),
293            "https://api.anthropic.com/v1/messages"
294        );
295    }
296
297    #[test]
298    fn resolved_custom_overrides_default() {
299        let llm = LlmSection {
300            provider: Provider::Openai,
301            model: "mistral-7b".into(),
302            api_base: String::new(),
303        };
304        assert_eq!(llm.resolved_model(), "mistral-7b");
305        assert_eq!(llm.resolved_api_base(), "http://localhost:11434/v1");
306    }
307
308    #[test]
309    fn round_trip_toml() {
310        let cfg = Config {
311            ephemeral: EphemeralConfig { max_entries: 3 },
312            llm: LlmSection {
313                provider: Provider::Openai,
314                model: "llama3.2".into(),
315                api_base: "http://localhost:11434/v1".into(),
316            },
317            pipeline: None,
318        };
319        let s = toml::to_string_pretty(&cfg).unwrap();
320        let parsed: Config = toml::from_str(&s).unwrap();
321        assert_eq!(parsed.ephemeral.max_entries, 3);
322        assert_eq!(parsed.llm.provider, Provider::Openai);
323        assert_eq!(parsed.llm.model, "llama3.2");
324    }
325
326    #[test]
327    fn set_key_provider() {
328        let mut cfg = Config::default();
329        cfg.set_key("llm.provider", "ollama").unwrap();
330        assert_eq!(cfg.llm.provider, Provider::Openai);
331        assert!(cfg.llm.model.is_empty());
332    }
333
334    #[test]
335    fn set_key_model() {
336        let mut cfg = Config::default();
337        cfg.set_key("llm.model", "claude-sonnet-4-6").unwrap();
338        assert_eq!(cfg.llm.model, "claude-sonnet-4-6");
339    }
340
341    #[test]
342    fn set_key_unknown_fails() {
343        let mut cfg = Config::default();
344        assert!(cfg.set_key("nonexistent.key", "value").is_err());
345    }
346
347    #[test]
348    fn provider_from_str_loose() {
349        assert_eq!(
350            Provider::from_str_loose("ollama").unwrap(),
351            Provider::Openai
352        );
353        assert_eq!(
354            Provider::from_str_loose("claude").unwrap(),
355            Provider::Anthropic
356        );
357        assert_eq!(
358            Provider::from_str_loose("claude-code").unwrap(),
359            Provider::ClaudeCode
360        );
361        assert!(Provider::from_str_loose("unknown").is_err());
362    }
363
364    #[test]
365    fn save_and_load() {
366        let tmp = tempfile::tempdir().unwrap();
367        let cfg = Config {
368            ephemeral: EphemeralConfig { max_entries: 7 },
369            llm: LlmSection {
370                provider: Provider::ClaudeCode,
371                model: String::new(),
372                api_base: String::new(),
373            },
374            pipeline: None,
375        };
376        save(tmp.path(), &cfg).unwrap();
377        let loaded = load(tmp.path());
378        assert_eq!(loaded.ephemeral.max_entries, 7);
379        assert_eq!(loaded.llm.provider, Provider::ClaudeCode);
380    }
381
382    #[test]
383    fn load_nonexistent_file() {
384        let tmp = tempfile::tempdir().unwrap();
385        let cfg = load(tmp.path());
386        assert_eq!(cfg.ephemeral.max_entries, 5);
387    }
388
389    #[test]
390    fn validate_out_of_range() {
391        let cfg = validate(Config {
392            ephemeral: EphemeralConfig { max_entries: 100 },
393            llm: LlmSection::default(),
394            pipeline: None,
395        });
396        assert_eq!(cfg.ephemeral.max_entries, 5);
397    }
398}