1use std::fmt;
2use std::fs;
3use std::path::Path;
4
5use serde::{Deserialize, Serialize};
6
7const DEFAULT_MAX_ENTRIES: usize = 5;
8const CONFIG_FILE: &str = ".recall-echo.toml";
9
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "kebab-case")]
15pub enum Provider {
16 Anthropic,
17 Openai,
18 ClaudeCode,
19}
20
21impl Provider {
22 pub fn default_model(&self) -> &'static str {
23 match self {
24 Provider::Anthropic => "claude-haiku-4-5-20251001",
25 Provider::Openai => "llama3.2",
26 Provider::ClaudeCode => "",
27 }
28 }
29
30 pub fn default_api_base(&self) -> &'static str {
31 match self {
32 Provider::Anthropic => "https://api.anthropic.com/v1/messages",
33 Provider::Openai => "http://localhost:11434/v1",
34 Provider::ClaudeCode => "",
35 }
36 }
37
38 pub fn from_str_loose(s: &str) -> Result<Self, String> {
39 match s.to_lowercase().as_str() {
40 "anthropic" | "claude" => Ok(Provider::Anthropic),
41 "openai" | "ollama" => Ok(Provider::Openai),
42 "claude-code" | "claudecode" => Ok(Provider::ClaudeCode),
43 other => Err(format!(
44 "unknown provider: {other} (use 'anthropic', 'ollama', or 'claude-code')"
45 )),
46 }
47 }
48}
49
50impl fmt::Display for Provider {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 match self {
53 Provider::Anthropic => write!(f, "anthropic"),
54 Provider::Openai => write!(f, "openai"),
55 Provider::ClaudeCode => write!(f, "claude-code"),
56 }
57 }
58}
59
60#[derive(Debug, Default, Serialize, Deserialize)]
63pub struct Config {
64 #[serde(default)]
65 pub ephemeral: EphemeralConfig,
66 #[serde(default)]
67 pub llm: LlmSection,
68 #[serde(default)]
69 pub pipeline: Option<PipelineSection>,
70}
71
72#[derive(Debug, Serialize, Deserialize)]
73pub struct EphemeralConfig {
74 #[serde(default = "default_max_entries")]
75 pub max_entries: usize,
76}
77
78impl Default for EphemeralConfig {
79 fn default() -> Self {
80 Self {
81 max_entries: DEFAULT_MAX_ENTRIES,
82 }
83 }
84}
85
86fn default_max_entries() -> usize {
87 DEFAULT_MAX_ENTRIES
88}
89
90#[derive(Debug, Serialize, Deserialize)]
91pub struct LlmSection {
92 #[serde(default = "default_provider")]
93 pub provider: Provider,
94 #[serde(default)]
95 pub model: String,
96 #[serde(default)]
97 pub api_base: String,
98}
99
100impl Default for LlmSection {
101 fn default() -> Self {
102 Self {
103 provider: Provider::Anthropic,
104 model: String::new(),
105 api_base: String::new(),
106 }
107 }
108}
109
110impl LlmSection {
111 pub fn resolved_model(&self) -> &str {
113 if self.model.is_empty() {
114 self.provider.default_model()
115 } else {
116 &self.model
117 }
118 }
119
120 pub fn resolved_api_base(&self) -> &str {
122 if self.api_base.is_empty() {
123 self.provider.default_api_base()
124 } else {
125 &self.api_base
126 }
127 }
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct PipelineSection {
132 #[serde(default)]
134 pub docs_dir: Option<String>,
135 #[serde(default)]
137 pub auto_sync: Option<bool>,
138}
139
140fn default_provider() -> Provider {
141 Provider::Anthropic
142}
143
144pub fn config_path(base: &Path) -> std::path::PathBuf {
148 base.join(CONFIG_FILE)
149}
150
151pub fn load_from_dir(dir: &Path) -> Config {
154 load(dir)
155}
156
157pub fn load(base: &Path) -> Config {
160 let path = config_path(base);
161 if !path.exists() {
162 return Config::default();
163 }
164
165 let content = match fs::read_to_string(&path) {
166 Ok(c) => c,
167 Err(_) => return Config::default(),
168 };
169
170 match toml::from_str(&content) {
171 Ok(cfg) => validate(cfg),
172 Err(_) => Config::default(),
173 }
174}
175
176pub fn save(base: &Path, config: &Config) -> Result<(), String> {
178 let path = config_path(base);
179 let content = toml::to_string_pretty(config).map_err(|e| format!("serialize config: {e}"))?;
180 fs::write(&path, content).map_err(|e| format!("write {}: {e}", path.display()))
181}
182
183pub fn exists(base: &Path) -> bool {
185 config_path(base).exists()
186}
187
188fn validate(mut cfg: Config) -> Config {
189 if !(1..=50).contains(&cfg.ephemeral.max_entries) {
190 cfg.ephemeral.max_entries = DEFAULT_MAX_ENTRIES;
191 }
192 cfg
193}
194
195impl Config {
198 pub fn set_key(&mut self, key: &str, value: &str) -> Result<(), String> {
200 match key {
201 "llm.provider" | "provider" => {
202 let provider = Provider::from_str_loose(value)?;
203 self.llm.model = String::new();
205 self.llm.api_base = String::new();
206 self.llm.provider = provider;
207 Ok(())
208 }
209 "llm.model" | "model" => {
210 self.llm.model = value.to_string();
211 Ok(())
212 }
213 "llm.api_base" | "api_base" => {
214 self.llm.api_base = value.to_string();
215 Ok(())
216 }
217 "ephemeral.max_entries" => {
218 let n: usize = value
219 .parse()
220 .map_err(|_| format!("invalid number: {value}"))?;
221 if !(1..=50).contains(&n) {
222 return Err("max_entries must be between 1 and 50".into());
223 }
224 self.ephemeral.max_entries = n;
225 Ok(())
226 }
227 "pipeline.docs_dir" => {
228 let section = self.pipeline.get_or_insert(PipelineSection {
229 docs_dir: None,
230 auto_sync: None,
231 });
232 section.docs_dir = Some(value.to_string());
233 Ok(())
234 }
235 "pipeline.auto_sync" => {
236 let b: bool = value
237 .parse()
238 .map_err(|_| format!("invalid boolean: {value}"))?;
239 let section = self.pipeline.get_or_insert(PipelineSection {
240 docs_dir: None,
241 auto_sync: None,
242 });
243 section.auto_sync = Some(b);
244 Ok(())
245 }
246 other => Err(format!("unknown config key: {other}")),
247 }
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn default_config() {
257 let cfg = Config::default();
258 assert_eq!(cfg.ephemeral.max_entries, 5);
259 assert_eq!(cfg.llm.provider, Provider::Anthropic);
260 assert!(cfg.llm.model.is_empty());
261 }
262
263 #[test]
264 fn parse_ephemeral_only() {
265 let cfg: Config = toml::from_str("[ephemeral]\nmax_entries = 10\n").unwrap();
266 assert_eq!(cfg.ephemeral.max_entries, 10);
267 assert_eq!(cfg.llm.provider, Provider::Anthropic);
268 }
269
270 #[test]
271 fn parse_llm_section() {
272 let cfg: Config = toml::from_str(
273 "[llm]\nprovider = \"openai\"\nmodel = \"llama3.1\"\napi_base = \"http://myhost:11434/v1\"\n",
274 )
275 .unwrap();
276 assert_eq!(cfg.llm.provider, Provider::Openai);
277 assert_eq!(cfg.llm.model, "llama3.1");
278 assert_eq!(cfg.llm.api_base, "http://myhost:11434/v1");
279 }
280
281 #[test]
282 fn parse_claude_code_provider() {
283 let cfg: Config = toml::from_str("[llm]\nprovider = \"claude-code\"\n").unwrap();
284 assert_eq!(cfg.llm.provider, Provider::ClaudeCode);
285 }
286
287 #[test]
288 fn resolved_defaults() {
289 let llm = LlmSection::default();
290 assert_eq!(llm.resolved_model(), "claude-haiku-4-5-20251001");
291 assert_eq!(
292 llm.resolved_api_base(),
293 "https://api.anthropic.com/v1/messages"
294 );
295 }
296
297 #[test]
298 fn resolved_custom_overrides_default() {
299 let llm = LlmSection {
300 provider: Provider::Openai,
301 model: "mistral-7b".into(),
302 api_base: String::new(),
303 };
304 assert_eq!(llm.resolved_model(), "mistral-7b");
305 assert_eq!(llm.resolved_api_base(), "http://localhost:11434/v1");
306 }
307
308 #[test]
309 fn round_trip_toml() {
310 let cfg = Config {
311 ephemeral: EphemeralConfig { max_entries: 3 },
312 llm: LlmSection {
313 provider: Provider::Openai,
314 model: "llama3.2".into(),
315 api_base: "http://localhost:11434/v1".into(),
316 },
317 pipeline: None,
318 };
319 let s = toml::to_string_pretty(&cfg).unwrap();
320 let parsed: Config = toml::from_str(&s).unwrap();
321 assert_eq!(parsed.ephemeral.max_entries, 3);
322 assert_eq!(parsed.llm.provider, Provider::Openai);
323 assert_eq!(parsed.llm.model, "llama3.2");
324 }
325
326 #[test]
327 fn set_key_provider() {
328 let mut cfg = Config::default();
329 cfg.set_key("llm.provider", "ollama").unwrap();
330 assert_eq!(cfg.llm.provider, Provider::Openai);
331 assert!(cfg.llm.model.is_empty());
332 }
333
334 #[test]
335 fn set_key_model() {
336 let mut cfg = Config::default();
337 cfg.set_key("llm.model", "claude-sonnet-4-6").unwrap();
338 assert_eq!(cfg.llm.model, "claude-sonnet-4-6");
339 }
340
341 #[test]
342 fn set_key_unknown_fails() {
343 let mut cfg = Config::default();
344 assert!(cfg.set_key("nonexistent.key", "value").is_err());
345 }
346
347 #[test]
348 fn provider_from_str_loose() {
349 assert_eq!(
350 Provider::from_str_loose("ollama").unwrap(),
351 Provider::Openai
352 );
353 assert_eq!(
354 Provider::from_str_loose("claude").unwrap(),
355 Provider::Anthropic
356 );
357 assert_eq!(
358 Provider::from_str_loose("claude-code").unwrap(),
359 Provider::ClaudeCode
360 );
361 assert!(Provider::from_str_loose("unknown").is_err());
362 }
363
364 #[test]
365 fn save_and_load() {
366 let tmp = tempfile::tempdir().unwrap();
367 let cfg = Config {
368 ephemeral: EphemeralConfig { max_entries: 7 },
369 llm: LlmSection {
370 provider: Provider::ClaudeCode,
371 model: String::new(),
372 api_base: String::new(),
373 },
374 pipeline: None,
375 };
376 save(tmp.path(), &cfg).unwrap();
377 let loaded = load(tmp.path());
378 assert_eq!(loaded.ephemeral.max_entries, 7);
379 assert_eq!(loaded.llm.provider, Provider::ClaudeCode);
380 }
381
382 #[test]
383 fn load_nonexistent_file() {
384 let tmp = tempfile::tempdir().unwrap();
385 let cfg = load(tmp.path());
386 assert_eq!(cfg.ephemeral.max_entries, 5);
387 }
388
389 #[test]
390 fn validate_out_of_range() {
391 let cfg = validate(Config {
392 ephemeral: EphemeralConfig { max_entries: 100 },
393 llm: LlmSection::default(),
394 pipeline: None,
395 });
396 assert_eq!(cfg.ephemeral.max_entries, 5);
397 }
398}