1use crate::grouper::llm::LlmBackend;
2use serde::Deserialize;
3use std::path::PathBuf;
4
5#[derive(Debug, Clone)]
7pub struct Config {
8 pub preferred_ai_cli: Option<AiCli>,
9 pub claude_model: String,
10 pub copilot_model: String,
11}
12
13#[derive(Debug, Clone, Copy, PartialEq, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum AiCli {
16 Claude,
17 Copilot,
18}
19
20#[derive(Debug, Default, Deserialize)]
22#[serde(default)]
23struct RawConfig {
24 #[serde(rename = "preferred-ai-cli")]
25 preferred_ai_cli: Option<AiCli>,
26 claude: CliConfig,
27 copilot: CliConfig,
28}
29
30#[derive(Debug, Default, Deserialize)]
31#[serde(default)]
32struct CliConfig {
33 model: Option<String>,
34}
35
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
39enum ModelTier {
40 Fast, Balanced, Power, }
44
45impl Config {
46 pub fn default_config() -> Self {
47 Self {
48 preferred_ai_cli: None,
49 claude_model: "sonnet".to_string(),
50 copilot_model: "sonnet".to_string(),
51 }
52 }
53
54 pub fn model_for_backend(&self, backend: LlmBackend) -> &str {
56 match backend {
57 LlmBackend::Claude => &self.claude_model,
58 LlmBackend::Copilot => &self.copilot_model,
59 }
60 }
61
62 pub fn detect_backend(&self) -> Option<LlmBackend> {
64 let claude_ok = which::which("claude").is_ok();
65 let copilot_ok = which::which("copilot").is_ok();
66
67 match self.preferred_ai_cli {
68 Some(AiCli::Claude) => {
69 if claude_ok {
70 Some(LlmBackend::Claude)
71 } else if copilot_ok {
72 Some(LlmBackend::Copilot)
73 } else {
74 None
75 }
76 }
77 Some(AiCli::Copilot) => {
78 if copilot_ok {
79 Some(LlmBackend::Copilot)
80 } else if claude_ok {
81 Some(LlmBackend::Claude)
82 } else {
83 None
84 }
85 }
86 None => {
87 if claude_ok {
89 Some(LlmBackend::Claude)
90 } else if copilot_ok {
91 Some(LlmBackend::Copilot)
92 } else {
93 None
94 }
95 }
96 }
97 }
98}
99
100fn config_path() -> Option<PathBuf> {
103 let home = dirs::home_dir()?;
106 Some(home.join(".config").join("semantic-diff.json"))
107}
108
109const DEFAULT_CONFIG: &str = r#"{
111 // Which AI CLI to prefer: "claude" or "copilot"
112 // Falls back to the other if the preferred one is not installed.
113 // If unset, defaults to: claude > copilot
114 // "preferred-ai-cli": "claude",
115
116 // Claude CLI settings
117 "claude": {
118 // Model to use: "sonnet", "opus", "haiku"
119 // Cross-backend models are mapped automatically:
120 // gemini-flash -> haiku, gemini-pro -> sonnet
121 "model": "sonnet"
122 },
123
124 // Copilot CLI settings
125 "copilot": {
126 // Model to use: "sonnet", "opus", "haiku", "gemini-flash", "gemini-pro"
127 "model": "sonnet"
128 }
129}
130"#;
131
132pub fn load() -> Config {
135 let path = match config_path() {
136 Some(p) => p,
137 None => {
138 tracing::warn!("Could not determine home directory, using default config");
139 return Config::default_config();
140 }
141 };
142
143 if !path.exists() {
145 if let Some(parent) = path.parent() {
146 let _ = std::fs::create_dir_all(parent);
147 }
148 let _ = std::fs::write(&path, DEFAULT_CONFIG);
149 tracing::info!("Created default config at {}", path.display());
150 }
151
152 let content = match std::fs::read_to_string(&path) {
154 Ok(c) => c,
155 Err(e) => {
156 tracing::warn!("Failed to read config {}: {}", path.display(), e);
157 return Config::default_config();
158 }
159 };
160
161 let stripped = strip_json_comments(&content);
162 let raw: RawConfig = match serde_json::from_str(&stripped) {
163 Ok(r) => r,
164 Err(e) => {
165 tracing::warn!("Failed to parse config {}: {}", path.display(), e);
166 return Config::default_config();
167 }
168 };
169
170 Config {
171 preferred_ai_cli: raw.preferred_ai_cli,
172 claude_model: resolve_model_for_claude(raw.claude.model.as_deref()),
173 copilot_model: resolve_model_for_copilot(raw.copilot.model.as_deref()),
174 }
175}
176
177fn resolve_model_for_claude(model: Option<&str>) -> String {
179 let tier = model.map(model_tier).unwrap_or(ModelTier::Balanced);
180 match tier {
181 ModelTier::Fast => "haiku",
182 ModelTier::Balanced => "sonnet",
183 ModelTier::Power => "opus",
184 }
185 .to_string()
186}
187
188fn resolve_model_for_copilot(model: Option<&str>) -> String {
191 match model {
192 Some(m) => {
193 let tier = model_tier(m);
194 match m {
196 "sonnet" | "opus" | "haiku" | "gemini-flash" | "gemini-pro" => m.to_string(),
197 _ => match tier {
199 ModelTier::Fast => "gemini-flash",
200 ModelTier::Balanced => "sonnet",
201 ModelTier::Power => "opus",
202 }
203 .to_string(),
204 }
205 }
206 None => "sonnet".to_string(),
207 }
208}
209
210fn model_tier(name: &str) -> ModelTier {
212 let n = name.to_lowercase();
213 if n.contains("flash") || n.contains("haiku") || n == "gpt-4o-mini" || n.ends_with("-mini") {
214 ModelTier::Fast
215 } else if n.contains("opus") {
216 ModelTier::Power
217 } else {
218 ModelTier::Balanced
220 }
221}
222
223fn strip_json_comments(input: &str) -> String {
225 let mut out = String::with_capacity(input.len());
226 let mut chars = input.chars().peekable();
227 let mut in_string = false;
228
229 while let Some(c) = chars.next() {
230 if in_string {
231 out.push(c);
232 if c == '\\' {
233 if let Some(next) = chars.next() {
235 out.push(next);
236 }
237 } else if c == '"' {
238 in_string = false;
239 }
240 continue;
241 }
242
243 match c {
244 '"' => {
245 in_string = true;
246 out.push(c);
247 }
248 '/' => match chars.peek() {
249 Some('/') => {
250 for rest in chars.by_ref() {
252 if rest == '\n' {
253 out.push('\n');
254 break;
255 }
256 }
257 }
258 Some('*') => {
259 chars.next(); let mut prev = ' ';
262 for rest in chars.by_ref() {
263 if prev == '*' && rest == '/' {
264 break;
265 }
266 if rest == '\n' {
267 out.push('\n');
268 }
269 prev = rest;
270 }
271 }
272 _ => out.push(c),
273 },
274 _ => out.push(c),
275 }
276 }
277 out
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_strip_line_comments() {
286 let input = r#"{
287 // this is a comment
288 "key": "value"
289}"#;
290 let stripped = strip_json_comments(input);
291 let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap();
292 assert_eq!(parsed["key"], "value");
293 }
294
295 #[test]
296 fn test_strip_block_comments() {
297 let input = r#"{ /* block */ "key": "value" }"#;
298 let stripped = strip_json_comments(input);
299 let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap();
300 assert_eq!(parsed["key"], "value");
301 }
302
303 #[test]
304 fn test_preserves_strings_with_slashes() {
305 let input = r#"{ "url": "https://example.com" }"#;
306 let stripped = strip_json_comments(input);
307 let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap();
308 assert_eq!(parsed["url"], "https://example.com");
309 }
310
311 #[test]
312 fn test_commented_out_keys_stripped() {
313 let input = r#"{
314 // "preferred-ai-cli": "claude",
315 "claude": { "model": "opus" }
316}"#;
317 let stripped = strip_json_comments(input);
318 let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap();
319 assert!(parsed.get("preferred-ai-cli").is_none());
320 assert_eq!(parsed["claude"]["model"], "opus");
321 }
322
323 #[test]
324 fn test_model_tier_mapping() {
325 assert_eq!(model_tier("haiku"), ModelTier::Fast);
326 assert_eq!(model_tier("gemini-flash"), ModelTier::Fast);
327 assert_eq!(model_tier("gpt-4o-mini"), ModelTier::Fast);
328 assert_eq!(model_tier("sonnet"), ModelTier::Balanced);
329 assert_eq!(model_tier("gemini-pro"), ModelTier::Balanced);
330 assert_eq!(model_tier("opus"), ModelTier::Power);
331 }
332
333 #[test]
334 fn test_resolve_claude_model() {
335 assert_eq!(resolve_model_for_claude(Some("gemini-flash")), "haiku");
336 assert_eq!(resolve_model_for_claude(Some("sonnet")), "sonnet");
337 assert_eq!(resolve_model_for_claude(Some("opus")), "opus");
338 assert_eq!(resolve_model_for_claude(Some("gemini-pro")), "sonnet");
339 assert_eq!(resolve_model_for_claude(None), "sonnet");
340 }
341
342 #[test]
343 fn test_resolve_copilot_model() {
344 assert_eq!(resolve_model_for_copilot(Some("gemini-flash")), "gemini-flash");
345 assert_eq!(resolve_model_for_copilot(Some("sonnet")), "sonnet");
346 assert_eq!(resolve_model_for_copilot(Some("haiku")), "haiku");
347 assert_eq!(resolve_model_for_copilot(None), "sonnet");
348 }
349
350 #[test]
351 fn test_default_config_parses() {
352 let stripped = strip_json_comments(DEFAULT_CONFIG);
353 let raw: RawConfig = serde_json::from_str(&stripped).unwrap();
354 assert!(raw.preferred_ai_cli.is_none());
355 assert_eq!(raw.claude.model.as_deref(), Some("sonnet"));
356 assert_eq!(raw.copilot.model.as_deref(), Some("sonnet"));
357 }
358
359 #[test]
360 fn test_config_path_returns_option_not_cwd() {
361 let path = config_path();
363 match path {
364 Some(p) => {
365 let path_str = p.to_string_lossy();
366 assert!(
367 !path_str.starts_with("./"),
368 "config_path should not fall back to cwd, got: {}",
369 path_str
370 );
371 assert!(
372 path_str.contains(".config/semantic-diff.json"),
373 "config_path should end with .config/semantic-diff.json, got: {}",
374 path_str
375 );
376 }
377 None => {
378 }
380 }
381 }
382
383 #[test]
384 fn test_config_path_no_dot_fallback() {
385 let path = config_path();
387 if let Some(p) = path {
388 assert_ne!(
389 p.components().next().map(|c| c.as_os_str().to_string_lossy().to_string()),
390 Some(".".to_string()),
391 "config_path must not use '.' as base directory"
392 );
393 }
394 }
395}