1use crate::grouper::llm::LlmBackend;
2use crate::theme::ThemeMode;
3use serde::Deserialize;
4use std::path::PathBuf;
5
6#[derive(Debug, Clone)]
8pub struct Config {
9 pub preferred_ai_cli: Option<AiCli>,
10 pub claude_model: String,
11 pub copilot_model: String,
12 pub theme_mode: ThemeMode,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Deserialize)]
16#[serde(rename_all = "lowercase")]
17pub enum AiCli {
18 Claude,
19 Copilot,
20}
21
22#[derive(Debug, Default, Deserialize)]
24#[serde(default)]
25struct RawConfig {
26 #[serde(rename = "preferred-ai-cli")]
27 preferred_ai_cli: Option<AiCli>,
28 claude: CliConfig,
29 copilot: CliConfig,
30 theme: Option<String>,
31}
32
33#[derive(Debug, Default, Deserialize)]
34#[serde(default)]
35struct CliConfig {
36 model: Option<String>,
37}
38
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
42enum ModelTier {
43 Fast, Balanced, Power, }
47
48impl Config {
49 pub fn default_config() -> Self {
50 Self {
51 preferred_ai_cli: None,
52 claude_model: "sonnet".to_string(),
53 copilot_model: "sonnet".to_string(),
54 theme_mode: ThemeMode::Auto,
55 }
56 }
57
58 pub fn model_for_backend(&self, backend: LlmBackend) -> &str {
60 match backend {
61 LlmBackend::Claude => &self.claude_model,
62 LlmBackend::Copilot => &self.copilot_model,
63 }
64 }
65
66 pub fn detect_backend(&self) -> Option<LlmBackend> {
68 let claude_ok = which::which("claude").is_ok();
69 let copilot_ok = which::which("copilot").is_ok();
70
71 match self.preferred_ai_cli {
72 Some(AiCli::Claude) => {
73 if claude_ok {
74 Some(LlmBackend::Claude)
75 } else if copilot_ok {
76 Some(LlmBackend::Copilot)
77 } else {
78 None
79 }
80 }
81 Some(AiCli::Copilot) => {
82 if copilot_ok {
83 Some(LlmBackend::Copilot)
84 } else if claude_ok {
85 Some(LlmBackend::Claude)
86 } else {
87 None
88 }
89 }
90 None => {
91 if claude_ok {
93 Some(LlmBackend::Claude)
94 } else if copilot_ok {
95 Some(LlmBackend::Copilot)
96 } else {
97 None
98 }
99 }
100 }
101 }
102}
103
104fn config_path() -> Option<PathBuf> {
107 let home = dirs::home_dir()?;
110 Some(home.join(".config").join("semantic-diff.json"))
111}
112
113const DEFAULT_CONFIG: &str = r#"{
115 // Which AI CLI to prefer: "claude" or "copilot"
116 // Falls back to the other if the preferred one is not installed.
117 // If unset, defaults to: claude > copilot
118 // "preferred-ai-cli": "claude",
119
120 // Claude CLI settings
121 "claude": {
122 // Model to use: "sonnet", "opus", "haiku"
123 // Cross-backend models are mapped automatically:
124 // gemini-flash -> haiku, gemini-pro -> sonnet
125 "model": "sonnet"
126 },
127
128 // Copilot CLI settings
129 "copilot": {
130 // Model to use: "sonnet", "opus", "haiku", "gemini-flash", "gemini-pro"
131 "model": "sonnet"
132 }
133
134 // Theme: "dark", "light", or "auto" (detects from terminal)
135 // "theme": "auto"
136}
137"#;
138
139pub fn load() -> Config {
142 let path = match config_path() {
143 Some(p) => p,
144 None => {
145 tracing::warn!("Could not determine home directory, using default config");
146 return Config::default_config();
147 }
148 };
149
150 if !path.exists() {
152 if let Some(parent) = path.parent() {
153 let _ = std::fs::create_dir_all(parent);
154 }
155 let _ = std::fs::write(&path, DEFAULT_CONFIG);
156 tracing::info!("Created default config at {}", path.display());
157 }
158
159 let content = match std::fs::read_to_string(&path) {
161 Ok(c) => c,
162 Err(e) => {
163 tracing::warn!("Failed to read config {}: {}", path.display(), e);
164 return Config::default_config();
165 }
166 };
167
168 let stripped = strip_json_comments(&content);
169 let raw: RawConfig = match serde_json::from_str(&stripped) {
170 Ok(r) => r,
171 Err(e) => {
172 tracing::warn!("Failed to parse config {}: {}", path.display(), e);
173 return Config::default_config();
174 }
175 };
176
177 Config {
178 preferred_ai_cli: raw.preferred_ai_cli,
179 claude_model: resolve_model_for_claude(raw.claude.model.as_deref()),
180 copilot_model: resolve_model_for_copilot(raw.copilot.model.as_deref()),
181 theme_mode: match raw.theme.as_deref() {
182 Some("light") => ThemeMode::Light,
183 Some("dark") => ThemeMode::Dark,
184 _ => ThemeMode::Auto,
185 },
186 }
187}
188
189fn resolve_model_for_claude(model: Option<&str>) -> String {
191 let tier = model.map(model_tier).unwrap_or(ModelTier::Balanced);
192 match tier {
193 ModelTier::Fast => "haiku",
194 ModelTier::Balanced => "sonnet",
195 ModelTier::Power => "opus",
196 }
197 .to_string()
198}
199
200fn resolve_model_for_copilot(model: Option<&str>) -> String {
203 match model {
204 Some(m) => {
205 let tier = model_tier(m);
206 match m {
208 "sonnet" | "opus" | "haiku" | "gemini-flash" | "gemini-pro" => m.to_string(),
209 _ => match tier {
211 ModelTier::Fast => "gemini-flash",
212 ModelTier::Balanced => "sonnet",
213 ModelTier::Power => "opus",
214 }
215 .to_string(),
216 }
217 }
218 None => "sonnet".to_string(),
219 }
220}
221
222fn model_tier(name: &str) -> ModelTier {
224 let n = name.to_lowercase();
225 if n.contains("flash") || n.contains("haiku") || n == "gpt-4o-mini" || n.ends_with("-mini") {
226 ModelTier::Fast
227 } else if n.contains("opus") {
228 ModelTier::Power
229 } else {
230 ModelTier::Balanced
232 }
233}
234
235fn strip_json_comments(input: &str) -> String {
237 let mut out = String::with_capacity(input.len());
238 let mut chars = input.chars().peekable();
239 let mut in_string = false;
240
241 while let Some(c) = chars.next() {
242 if in_string {
243 out.push(c);
244 if c == '\\' {
245 if let Some(next) = chars.next() {
247 out.push(next);
248 }
249 } else if c == '"' {
250 in_string = false;
251 }
252 continue;
253 }
254
255 match c {
256 '"' => {
257 in_string = true;
258 out.push(c);
259 }
260 '/' => match chars.peek() {
261 Some('/') => {
262 for rest in chars.by_ref() {
264 if rest == '\n' {
265 out.push('\n');
266 break;
267 }
268 }
269 }
270 Some('*') => {
271 chars.next(); let mut prev = ' ';
274 for rest in chars.by_ref() {
275 if prev == '*' && rest == '/' {
276 break;
277 }
278 if rest == '\n' {
279 out.push('\n');
280 }
281 prev = rest;
282 }
283 }
284 _ => out.push(c),
285 },
286 _ => out.push(c),
287 }
288 }
289 out
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn test_strip_line_comments() {
298 let input = r#"{
299 // this is a comment
300 "key": "value"
301}"#;
302 let stripped = strip_json_comments(input);
303 let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap();
304 assert_eq!(parsed["key"], "value");
305 }
306
307 #[test]
308 fn test_strip_block_comments() {
309 let input = r#"{ /* block */ "key": "value" }"#;
310 let stripped = strip_json_comments(input);
311 let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap();
312 assert_eq!(parsed["key"], "value");
313 }
314
315 #[test]
316 fn test_preserves_strings_with_slashes() {
317 let input = r#"{ "url": "https://example.com" }"#;
318 let stripped = strip_json_comments(input);
319 let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap();
320 assert_eq!(parsed["url"], "https://example.com");
321 }
322
323 #[test]
324 fn test_commented_out_keys_stripped() {
325 let input = r#"{
326 // "preferred-ai-cli": "claude",
327 "claude": { "model": "opus" }
328}"#;
329 let stripped = strip_json_comments(input);
330 let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap();
331 assert!(parsed.get("preferred-ai-cli").is_none());
332 assert_eq!(parsed["claude"]["model"], "opus");
333 }
334
335 #[test]
336 fn test_model_tier_mapping() {
337 assert_eq!(model_tier("haiku"), ModelTier::Fast);
338 assert_eq!(model_tier("gemini-flash"), ModelTier::Fast);
339 assert_eq!(model_tier("gpt-4o-mini"), ModelTier::Fast);
340 assert_eq!(model_tier("sonnet"), ModelTier::Balanced);
341 assert_eq!(model_tier("gemini-pro"), ModelTier::Balanced);
342 assert_eq!(model_tier("opus"), ModelTier::Power);
343 }
344
345 #[test]
346 fn test_resolve_claude_model() {
347 assert_eq!(resolve_model_for_claude(Some("gemini-flash")), "haiku");
348 assert_eq!(resolve_model_for_claude(Some("sonnet")), "sonnet");
349 assert_eq!(resolve_model_for_claude(Some("opus")), "opus");
350 assert_eq!(resolve_model_for_claude(Some("gemini-pro")), "sonnet");
351 assert_eq!(resolve_model_for_claude(None), "sonnet");
352 }
353
354 #[test]
355 fn test_resolve_copilot_model() {
356 assert_eq!(resolve_model_for_copilot(Some("gemini-flash")), "gemini-flash");
357 assert_eq!(resolve_model_for_copilot(Some("sonnet")), "sonnet");
358 assert_eq!(resolve_model_for_copilot(Some("haiku")), "haiku");
359 assert_eq!(resolve_model_for_copilot(None), "sonnet");
360 }
361
362 #[test]
363 fn test_default_config_parses() {
364 let stripped = strip_json_comments(DEFAULT_CONFIG);
365 let raw: RawConfig = serde_json::from_str(&stripped).unwrap();
366 assert!(raw.preferred_ai_cli.is_none());
367 assert_eq!(raw.claude.model.as_deref(), Some("sonnet"));
368 assert_eq!(raw.copilot.model.as_deref(), Some("sonnet"));
369 }
370
371 #[test]
372 fn test_config_path_returns_option_not_cwd() {
373 let path = config_path();
375 match path {
376 Some(p) => {
377 let path_str = p.to_string_lossy();
378 assert!(
379 !path_str.starts_with("./"),
380 "config_path should not fall back to cwd, got: {}",
381 path_str
382 );
383 assert!(
384 path_str.contains(".config/semantic-diff.json"),
385 "config_path should end with .config/semantic-diff.json, got: {}",
386 path_str
387 );
388 }
389 None => {
390 }
392 }
393 }
394
395 #[test]
396 fn test_config_path_no_dot_fallback() {
397 let path = config_path();
399 if let Some(p) = path {
400 assert_ne!(
401 p.components().next().map(|c| c.as_os_str().to_string_lossy().to_string()),
402 Some(".".to_string()),
403 "config_path must not use '.' as base directory"
404 );
405 }
406 }
407}