Skip to main content

semantic_diff/
config.rs

1use crate::grouper::llm::LlmBackend;
2use crate::theme::ThemeMode;
3use serde::Deserialize;
4use std::path::PathBuf;
5
6/// User configuration loaded from ~/.config/semantic-diff.json (JSONC supported).
7#[derive(Debug, Clone)]
8pub struct Config {
9    pub preferred_ai_cli: Option<AiCli>,
10    pub claude_model: String,
11    pub copilot_model: String,
12    pub theme_mode: ThemeMode,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Deserialize)]
16#[serde(rename_all = "lowercase")]
17pub enum AiCli {
18    Claude,
19    Copilot,
20}
21
22/// Raw JSON-serializable config (matches the file format).
23#[derive(Debug, Default, Deserialize)]
24#[serde(default)]
25struct RawConfig {
26    #[serde(rename = "preferred-ai-cli")]
27    preferred_ai_cli: Option<AiCli>,
28    claude: CliConfig,
29    copilot: CliConfig,
30    theme: Option<String>,
31}
32
33#[derive(Debug, Default, Deserialize)]
34#[serde(default)]
35struct CliConfig {
36    model: Option<String>,
37}
38
39
40/// Model tier for intelligent cross-backend mapping.
41#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
42enum ModelTier {
43    Fast,     // haiku, gemini-flash
44    Balanced, // sonnet, gemini-pro
45    Power,    // opus
46}
47
48impl Config {
49    pub fn default_config() -> Self {
50        Self {
51            preferred_ai_cli: None,
52            claude_model: "sonnet".to_string(),
53            copilot_model: "sonnet".to_string(),
54            theme_mode: ThemeMode::Auto,
55        }
56    }
57
58    /// Resolve the model string to pass to the given backend's CLI.
59    pub fn model_for_backend(&self, backend: LlmBackend) -> &str {
60        match backend {
61            LlmBackend::Claude => &self.claude_model,
62            LlmBackend::Copilot => &self.copilot_model,
63        }
64    }
65
66    /// Detect the best available backend, respecting the user's preference.
67    pub fn detect_backend(&self) -> Option<LlmBackend> {
68        let claude_ok = which::which("claude").is_ok();
69        let copilot_ok = which::which("copilot").is_ok();
70
71        match self.preferred_ai_cli {
72            Some(AiCli::Claude) => {
73                if claude_ok {
74                    Some(LlmBackend::Claude)
75                } else if copilot_ok {
76                    Some(LlmBackend::Copilot)
77                } else {
78                    None
79                }
80            }
81            Some(AiCli::Copilot) => {
82                if copilot_ok {
83                    Some(LlmBackend::Copilot)
84                } else if claude_ok {
85                    Some(LlmBackend::Claude)
86                } else {
87                    None
88                }
89            }
90            None => {
91                // Default: prefer claude, fallback copilot
92                if claude_ok {
93                    Some(LlmBackend::Claude)
94                } else if copilot_ok {
95                    Some(LlmBackend::Copilot)
96                } else {
97                    None
98                }
99            }
100        }
101    }
102}
103
104/// Config file path: ~/.config/semantic-diff.json
105/// Returns None if home directory cannot be determined (refuses to fall back to cwd).
106fn config_path() -> Option<PathBuf> {
107    // Explicitly refuse to use cwd as home directory fallback.
108    // This prevents a malicious repo from injecting config via .config/semantic-diff.json
109    let home = dirs::home_dir()?;
110    Some(home.join(".config").join("semantic-diff.json"))
111}
112
113/// Default config file content with comments explaining each option.
114const DEFAULT_CONFIG: &str = r#"{
115  // Which AI CLI to prefer: "claude" or "copilot"
116  // Falls back to the other if the preferred one is not installed.
117  // If unset, defaults to: claude > copilot
118  // "preferred-ai-cli": "claude",
119
120  // Claude CLI settings
121  "claude": {
122    // Model to use: "sonnet", "opus", "haiku"
123    // Cross-backend models are mapped automatically:
124    //   gemini-flash -> haiku, gemini-pro -> sonnet
125    "model": "sonnet"
126  },
127
128  // Copilot CLI settings
129  "copilot": {
130    // Model to use: "sonnet", "opus", "haiku", "gemini-flash", "gemini-pro"
131    "model": "sonnet"
132  }
133
134  // Theme: "dark", "light", or "auto" (detects from terminal)
135  // "theme": "auto"
136}
137"#;
138
139/// Load config from disk. Creates a default commented config if none exists.
140/// Returns default config if home directory cannot be determined.
141pub fn load() -> Config {
142    let path = match config_path() {
143        Some(p) => p,
144        None => {
145            tracing::warn!("Could not determine home directory, using default config");
146            return Config::default_config();
147        }
148    };
149
150    // Create default config if it doesn't exist
151    if !path.exists() {
152        if let Some(parent) = path.parent() {
153            let _ = std::fs::create_dir_all(parent);
154        }
155        let _ = std::fs::write(&path, DEFAULT_CONFIG);
156        tracing::info!("Created default config at {}", path.display());
157    }
158
159    // Read and parse
160    let content = match std::fs::read_to_string(&path) {
161        Ok(c) => c,
162        Err(e) => {
163            tracing::warn!("Failed to read config {}: {}", path.display(), e);
164            return Config::default_config();
165        }
166    };
167
168    let stripped = strip_json_comments(&content);
169    let raw: RawConfig = match serde_json::from_str(&stripped) {
170        Ok(r) => r,
171        Err(e) => {
172            tracing::warn!("Failed to parse config {}: {}", path.display(), e);
173            return Config::default_config();
174        }
175    };
176
177    Config {
178        preferred_ai_cli: raw.preferred_ai_cli,
179        claude_model: resolve_model_for_claude(raw.claude.model.as_deref()),
180        copilot_model: resolve_model_for_copilot(raw.copilot.model.as_deref()),
181        theme_mode: match raw.theme.as_deref() {
182            Some("light") => ThemeMode::Light,
183            Some("dark") => ThemeMode::Dark,
184            _ => ThemeMode::Auto,
185        },
186    }
187}
188
189/// Map any model name to the closest Claude CLI model.
190fn resolve_model_for_claude(model: Option<&str>) -> String {
191    let tier = model.map(model_tier).unwrap_or(ModelTier::Balanced);
192    match tier {
193        ModelTier::Fast => "haiku",
194        ModelTier::Balanced => "sonnet",
195        ModelTier::Power => "opus",
196    }
197    .to_string()
198}
199
200/// Map any model name to the closest Copilot CLI model.
201/// Copilot passes model names through directly, but we normalize known aliases.
202fn resolve_model_for_copilot(model: Option<&str>) -> String {
203    match model {
204        Some(m) => {
205            let tier = model_tier(m);
206            // If it's already a recognized copilot model, pass through
207            match m {
208                "sonnet" | "opus" | "haiku" | "gemini-flash" | "gemini-pro" => m.to_string(),
209                // Otherwise map by tier
210                _ => match tier {
211                    ModelTier::Fast => "gemini-flash",
212                    ModelTier::Balanced => "sonnet",
213                    ModelTier::Power => "opus",
214                }
215                .to_string(),
216            }
217        }
218        None => "sonnet".to_string(),
219    }
220}
221
222/// Classify a model name into a performance tier.
223fn model_tier(name: &str) -> ModelTier {
224    let n = name.to_lowercase();
225    if n.contains("flash") || n.contains("haiku") || n == "gpt-4o-mini" || n.ends_with("-mini") {
226        ModelTier::Fast
227    } else if n.contains("opus") {
228        ModelTier::Power
229    } else {
230        // sonnet, gemini-pro, gpt-4o, etc. → balanced
231        ModelTier::Balanced
232    }
233}
234
235/// Strip // and /* */ comments from JSON text (simple JSONC support).
236fn strip_json_comments(input: &str) -> String {
237    let mut out = String::with_capacity(input.len());
238    let mut chars = input.chars().peekable();
239    let mut in_string = false;
240
241    while let Some(c) = chars.next() {
242        if in_string {
243            out.push(c);
244            if c == '\\' {
245                // Push escaped char as-is
246                if let Some(next) = chars.next() {
247                    out.push(next);
248                }
249            } else if c == '"' {
250                in_string = false;
251            }
252            continue;
253        }
254
255        match c {
256            '"' => {
257                in_string = true;
258                out.push(c);
259            }
260            '/' => match chars.peek() {
261                Some('/') => {
262                    // Line comment — skip to end of line
263                    for rest in chars.by_ref() {
264                        if rest == '\n' {
265                            out.push('\n');
266                            break;
267                        }
268                    }
269                }
270                Some('*') => {
271                    // Block comment — skip to */
272                    chars.next(); // consume *
273                    let mut prev = ' ';
274                    for rest in chars.by_ref() {
275                        if prev == '*' && rest == '/' {
276                            break;
277                        }
278                        if rest == '\n' {
279                            out.push('\n');
280                        }
281                        prev = rest;
282                    }
283                }
284                _ => out.push(c),
285            },
286            _ => out.push(c),
287        }
288    }
289    out
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn test_strip_line_comments() {
298        let input = r#"{
299  // this is a comment
300  "key": "value"
301}"#;
302        let stripped = strip_json_comments(input);
303        let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap();
304        assert_eq!(parsed["key"], "value");
305    }
306
307    #[test]
308    fn test_strip_block_comments() {
309        let input = r#"{ /* block */ "key": "value" }"#;
310        let stripped = strip_json_comments(input);
311        let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap();
312        assert_eq!(parsed["key"], "value");
313    }
314
315    #[test]
316    fn test_preserves_strings_with_slashes() {
317        let input = r#"{ "url": "https://example.com" }"#;
318        let stripped = strip_json_comments(input);
319        let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap();
320        assert_eq!(parsed["url"], "https://example.com");
321    }
322
323    #[test]
324    fn test_commented_out_keys_stripped() {
325        let input = r#"{
326  // "preferred-ai-cli": "claude",
327  "claude": { "model": "opus" }
328}"#;
329        let stripped = strip_json_comments(input);
330        let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap();
331        assert!(parsed.get("preferred-ai-cli").is_none());
332        assert_eq!(parsed["claude"]["model"], "opus");
333    }
334
335    #[test]
336    fn test_model_tier_mapping() {
337        assert_eq!(model_tier("haiku"), ModelTier::Fast);
338        assert_eq!(model_tier("gemini-flash"), ModelTier::Fast);
339        assert_eq!(model_tier("gpt-4o-mini"), ModelTier::Fast);
340        assert_eq!(model_tier("sonnet"), ModelTier::Balanced);
341        assert_eq!(model_tier("gemini-pro"), ModelTier::Balanced);
342        assert_eq!(model_tier("opus"), ModelTier::Power);
343    }
344
345    #[test]
346    fn test_resolve_claude_model() {
347        assert_eq!(resolve_model_for_claude(Some("gemini-flash")), "haiku");
348        assert_eq!(resolve_model_for_claude(Some("sonnet")), "sonnet");
349        assert_eq!(resolve_model_for_claude(Some("opus")), "opus");
350        assert_eq!(resolve_model_for_claude(Some("gemini-pro")), "sonnet");
351        assert_eq!(resolve_model_for_claude(None), "sonnet");
352    }
353
354    #[test]
355    fn test_resolve_copilot_model() {
356        assert_eq!(resolve_model_for_copilot(Some("gemini-flash")), "gemini-flash");
357        assert_eq!(resolve_model_for_copilot(Some("sonnet")), "sonnet");
358        assert_eq!(resolve_model_for_copilot(Some("haiku")), "haiku");
359        assert_eq!(resolve_model_for_copilot(None), "sonnet");
360    }
361
362    #[test]
363    fn test_default_config_parses() {
364        let stripped = strip_json_comments(DEFAULT_CONFIG);
365        let raw: RawConfig = serde_json::from_str(&stripped).unwrap();
366        assert!(raw.preferred_ai_cli.is_none());
367        assert_eq!(raw.claude.model.as_deref(), Some("sonnet"));
368        assert_eq!(raw.copilot.model.as_deref(), Some("sonnet"));
369    }
370
371    #[test]
372    fn test_config_path_returns_option_not_cwd() {
373        // config_path() should return Some with a path under home, never "."
374        let path = config_path();
375        match path {
376            Some(p) => {
377                let path_str = p.to_string_lossy();
378                assert!(
379                    !path_str.starts_with("./"),
380                    "config_path should not fall back to cwd, got: {}",
381                    path_str
382                );
383                assert!(
384                    path_str.contains(".config/semantic-diff.json"),
385                    "config_path should end with .config/semantic-diff.json, got: {}",
386                    path_str
387                );
388            }
389            None => {
390                // None is acceptable if HOME is not set
391            }
392        }
393    }
394
395    #[test]
396    fn test_config_path_no_dot_fallback() {
397        // Verify config_path never returns a path starting with "."
398        let path = config_path();
399        if let Some(p) = path {
400            assert_ne!(
401                p.components().next().map(|c| c.as_os_str().to_string_lossy().to_string()),
402                Some(".".to_string()),
403                "config_path must not use '.' as base directory"
404            );
405        }
406    }
407}