Skip to main content

semantic_diff/
config.rs

1use crate::grouper::llm::LlmBackend;
2use serde::Deserialize;
3use std::path::PathBuf;
4
5/// User configuration loaded from ~/.config/semantic-diff.json (JSONC supported).
6#[derive(Debug, Clone)]
7pub struct Config {
8    pub preferred_ai_cli: Option<AiCli>,
9    pub claude_model: String,
10    pub copilot_model: String,
11}
12
13#[derive(Debug, Clone, Copy, PartialEq, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum AiCli {
16    Claude,
17    Copilot,
18}
19
20/// Raw JSON-serializable config (matches the file format).
21#[derive(Debug, Default, Deserialize)]
22#[serde(default)]
23struct RawConfig {
24    #[serde(rename = "preferred-ai-cli")]
25    preferred_ai_cli: Option<AiCli>,
26    claude: CliConfig,
27    copilot: CliConfig,
28}
29
30#[derive(Debug, Default, Deserialize)]
31#[serde(default)]
32struct CliConfig {
33    model: Option<String>,
34}
35
36
37/// Model tier for intelligent cross-backend mapping.
38#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
39enum ModelTier {
40    Fast,     // haiku, gemini-flash
41    Balanced, // sonnet, gemini-pro
42    Power,    // opus
43}
44
45impl Config {
46    pub fn default_config() -> Self {
47        Self {
48            preferred_ai_cli: None,
49            claude_model: "sonnet".to_string(),
50            copilot_model: "sonnet".to_string(),
51        }
52    }
53
54    /// Resolve the model string to pass to the given backend's CLI.
55    pub fn model_for_backend(&self, backend: LlmBackend) -> &str {
56        match backend {
57            LlmBackend::Claude => &self.claude_model,
58            LlmBackend::Copilot => &self.copilot_model,
59        }
60    }
61
62    /// Detect the best available backend, respecting the user's preference.
63    pub fn detect_backend(&self) -> Option<LlmBackend> {
64        let claude_ok = which::which("claude").is_ok();
65        let copilot_ok = which::which("copilot").is_ok();
66
67        match self.preferred_ai_cli {
68            Some(AiCli::Claude) => {
69                if claude_ok {
70                    Some(LlmBackend::Claude)
71                } else if copilot_ok {
72                    Some(LlmBackend::Copilot)
73                } else {
74                    None
75                }
76            }
77            Some(AiCli::Copilot) => {
78                if copilot_ok {
79                    Some(LlmBackend::Copilot)
80                } else if claude_ok {
81                    Some(LlmBackend::Claude)
82                } else {
83                    None
84                }
85            }
86            None => {
87                // Default: prefer claude, fallback copilot
88                if claude_ok {
89                    Some(LlmBackend::Claude)
90                } else if copilot_ok {
91                    Some(LlmBackend::Copilot)
92                } else {
93                    None
94                }
95            }
96        }
97    }
98}
99
100/// Config file path: ~/.config/semantic-diff.json
101/// Returns None if home directory cannot be determined (refuses to fall back to cwd).
102fn config_path() -> Option<PathBuf> {
103    // Explicitly refuse to use cwd as home directory fallback.
104    // This prevents a malicious repo from injecting config via .config/semantic-diff.json
105    let home = dirs::home_dir()?;
106    Some(home.join(".config").join("semantic-diff.json"))
107}
108
109/// Default config file content with comments explaining each option.
110const DEFAULT_CONFIG: &str = r#"{
111  // Which AI CLI to prefer: "claude" or "copilot"
112  // Falls back to the other if the preferred one is not installed.
113  // If unset, defaults to: claude > copilot
114  // "preferred-ai-cli": "claude",
115
116  // Claude CLI settings
117  "claude": {
118    // Model to use: "sonnet", "opus", "haiku"
119    // Cross-backend models are mapped automatically:
120    //   gemini-flash -> haiku, gemini-pro -> sonnet
121    "model": "sonnet"
122  },
123
124  // Copilot CLI settings
125  "copilot": {
126    // Model to use: "sonnet", "opus", "haiku", "gemini-flash", "gemini-pro"
127    "model": "sonnet"
128  }
129}
130"#;
131
132/// Load config from disk. Creates a default commented config if none exists.
133/// Returns default config if home directory cannot be determined.
134pub fn load() -> Config {
135    let path = match config_path() {
136        Some(p) => p,
137        None => {
138            tracing::warn!("Could not determine home directory, using default config");
139            return Config::default_config();
140        }
141    };
142
143    // Create default config if it doesn't exist
144    if !path.exists() {
145        if let Some(parent) = path.parent() {
146            let _ = std::fs::create_dir_all(parent);
147        }
148        let _ = std::fs::write(&path, DEFAULT_CONFIG);
149        tracing::info!("Created default config at {}", path.display());
150    }
151
152    // Read and parse
153    let content = match std::fs::read_to_string(&path) {
154        Ok(c) => c,
155        Err(e) => {
156            tracing::warn!("Failed to read config {}: {}", path.display(), e);
157            return Config::default_config();
158        }
159    };
160
161    let stripped = strip_json_comments(&content);
162    let raw: RawConfig = match serde_json::from_str(&stripped) {
163        Ok(r) => r,
164        Err(e) => {
165            tracing::warn!("Failed to parse config {}: {}", path.display(), e);
166            return Config::default_config();
167        }
168    };
169
170    Config {
171        preferred_ai_cli: raw.preferred_ai_cli,
172        claude_model: resolve_model_for_claude(raw.claude.model.as_deref()),
173        copilot_model: resolve_model_for_copilot(raw.copilot.model.as_deref()),
174    }
175}
176
177/// Map any model name to the closest Claude CLI model.
178fn resolve_model_for_claude(model: Option<&str>) -> String {
179    let tier = model.map(model_tier).unwrap_or(ModelTier::Balanced);
180    match tier {
181        ModelTier::Fast => "haiku",
182        ModelTier::Balanced => "sonnet",
183        ModelTier::Power => "opus",
184    }
185    .to_string()
186}
187
188/// Map any model name to the closest Copilot CLI model.
189/// Copilot passes model names through directly, but we normalize known aliases.
190fn resolve_model_for_copilot(model: Option<&str>) -> String {
191    match model {
192        Some(m) => {
193            let tier = model_tier(m);
194            // If it's already a recognized copilot model, pass through
195            match m {
196                "sonnet" | "opus" | "haiku" | "gemini-flash" | "gemini-pro" => m.to_string(),
197                // Otherwise map by tier
198                _ => match tier {
199                    ModelTier::Fast => "gemini-flash",
200                    ModelTier::Balanced => "sonnet",
201                    ModelTier::Power => "opus",
202                }
203                .to_string(),
204            }
205        }
206        None => "sonnet".to_string(),
207    }
208}
209
210/// Classify a model name into a performance tier.
211fn model_tier(name: &str) -> ModelTier {
212    let n = name.to_lowercase();
213    if n.contains("flash") || n.contains("haiku") || n == "gpt-4o-mini" || n.ends_with("-mini") {
214        ModelTier::Fast
215    } else if n.contains("opus") {
216        ModelTier::Power
217    } else {
218        // sonnet, gemini-pro, gpt-4o, etc. → balanced
219        ModelTier::Balanced
220    }
221}
222
223/// Strip // and /* */ comments from JSON text (simple JSONC support).
224fn strip_json_comments(input: &str) -> String {
225    let mut out = String::with_capacity(input.len());
226    let mut chars = input.chars().peekable();
227    let mut in_string = false;
228
229    while let Some(c) = chars.next() {
230        if in_string {
231            out.push(c);
232            if c == '\\' {
233                // Push escaped char as-is
234                if let Some(next) = chars.next() {
235                    out.push(next);
236                }
237            } else if c == '"' {
238                in_string = false;
239            }
240            continue;
241        }
242
243        match c {
244            '"' => {
245                in_string = true;
246                out.push(c);
247            }
248            '/' => match chars.peek() {
249                Some('/') => {
250                    // Line comment — skip to end of line
251                    for rest in chars.by_ref() {
252                        if rest == '\n' {
253                            out.push('\n');
254                            break;
255                        }
256                    }
257                }
258                Some('*') => {
259                    // Block comment — skip to */
260                    chars.next(); // consume *
261                    let mut prev = ' ';
262                    for rest in chars.by_ref() {
263                        if prev == '*' && rest == '/' {
264                            break;
265                        }
266                        if rest == '\n' {
267                            out.push('\n');
268                        }
269                        prev = rest;
270                    }
271                }
272                _ => out.push(c),
273            },
274            _ => out.push(c),
275        }
276    }
277    out
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn test_strip_line_comments() {
286        let input = r#"{
287  // this is a comment
288  "key": "value"
289}"#;
290        let stripped = strip_json_comments(input);
291        let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap();
292        assert_eq!(parsed["key"], "value");
293    }
294
295    #[test]
296    fn test_strip_block_comments() {
297        let input = r#"{ /* block */ "key": "value" }"#;
298        let stripped = strip_json_comments(input);
299        let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap();
300        assert_eq!(parsed["key"], "value");
301    }
302
303    #[test]
304    fn test_preserves_strings_with_slashes() {
305        let input = r#"{ "url": "https://example.com" }"#;
306        let stripped = strip_json_comments(input);
307        let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap();
308        assert_eq!(parsed["url"], "https://example.com");
309    }
310
311    #[test]
312    fn test_commented_out_keys_stripped() {
313        let input = r#"{
314  // "preferred-ai-cli": "claude",
315  "claude": { "model": "opus" }
316}"#;
317        let stripped = strip_json_comments(input);
318        let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap();
319        assert!(parsed.get("preferred-ai-cli").is_none());
320        assert_eq!(parsed["claude"]["model"], "opus");
321    }
322
323    #[test]
324    fn test_model_tier_mapping() {
325        assert_eq!(model_tier("haiku"), ModelTier::Fast);
326        assert_eq!(model_tier("gemini-flash"), ModelTier::Fast);
327        assert_eq!(model_tier("gpt-4o-mini"), ModelTier::Fast);
328        assert_eq!(model_tier("sonnet"), ModelTier::Balanced);
329        assert_eq!(model_tier("gemini-pro"), ModelTier::Balanced);
330        assert_eq!(model_tier("opus"), ModelTier::Power);
331    }
332
333    #[test]
334    fn test_resolve_claude_model() {
335        assert_eq!(resolve_model_for_claude(Some("gemini-flash")), "haiku");
336        assert_eq!(resolve_model_for_claude(Some("sonnet")), "sonnet");
337        assert_eq!(resolve_model_for_claude(Some("opus")), "opus");
338        assert_eq!(resolve_model_for_claude(Some("gemini-pro")), "sonnet");
339        assert_eq!(resolve_model_for_claude(None), "sonnet");
340    }
341
342    #[test]
343    fn test_resolve_copilot_model() {
344        assert_eq!(resolve_model_for_copilot(Some("gemini-flash")), "gemini-flash");
345        assert_eq!(resolve_model_for_copilot(Some("sonnet")), "sonnet");
346        assert_eq!(resolve_model_for_copilot(Some("haiku")), "haiku");
347        assert_eq!(resolve_model_for_copilot(None), "sonnet");
348    }
349
350    #[test]
351    fn test_default_config_parses() {
352        let stripped = strip_json_comments(DEFAULT_CONFIG);
353        let raw: RawConfig = serde_json::from_str(&stripped).unwrap();
354        assert!(raw.preferred_ai_cli.is_none());
355        assert_eq!(raw.claude.model.as_deref(), Some("sonnet"));
356        assert_eq!(raw.copilot.model.as_deref(), Some("sonnet"));
357    }
358
359    #[test]
360    fn test_config_path_returns_option_not_cwd() {
361        // config_path() should return Some with a path under home, never "."
362        let path = config_path();
363        match path {
364            Some(p) => {
365                let path_str = p.to_string_lossy();
366                assert!(
367                    !path_str.starts_with("./"),
368                    "config_path should not fall back to cwd, got: {}",
369                    path_str
370                );
371                assert!(
372                    path_str.contains(".config/semantic-diff.json"),
373                    "config_path should end with .config/semantic-diff.json, got: {}",
374                    path_str
375                );
376            }
377            None => {
378                // None is acceptable if HOME is not set
379            }
380        }
381    }
382
383    #[test]
384    fn test_config_path_no_dot_fallback() {
385        // Verify config_path never returns a path starting with "."
386        let path = config_path();
387        if let Some(p) = path {
388            assert_ne!(
389                p.components().next().map(|c| c.as_os_str().to_string_lossy().to_string()),
390                Some(".".to_string()),
391                "config_path must not use '.' as base directory"
392            );
393        }
394    }
395}