Skip to main content

auto_commit_rs/
provider.rs

1use anyhow::{bail, Context, Result};
2use colored::Colorize;
3use indicatif::{ProgressBar, ProgressStyle};
4use serde_json::Value;
5use std::time::Duration;
6
7use crate::config::AppConfig;
8use crate::interpolation::interpolate;
9
10#[derive(Debug, Clone, Copy, PartialEq)]
11enum RequestFormat {
12    Gemini,
13    OpenAiCompat,
14    Anthropic,
15    LmStudio,
16}
17
18struct ProviderDef {
19    api_url: &'static str,
20    api_headers: &'static str,
21    default_model: &'static str,
22    format: RequestFormat,
23    response_path: &'static str,
24}
25
26/// Built-in provider definitions
27fn get_provider(name: &str) -> Option<ProviderDef> {
28    match name {
29        "gemini" => Some(ProviderDef {
30            api_url: "https://generativelanguage.googleapis.com/v1beta/models/$ACR_MODEL:generateContent?key=$ACR_API_KEY",
31            api_headers: "",
32            default_model: "gemini-2.0-flash",
33            format: RequestFormat::Gemini,
34            response_path: "candidates.0.content.parts.0.text",
35        }),
36        "openai" => Some(ProviderDef {
37            api_url: "https://api.openai.com/v1/chat/completions",
38            api_headers: "Authorization: Bearer $ACR_API_KEY",
39            default_model: "gpt-4o-mini",
40            format: RequestFormat::OpenAiCompat,
41            response_path: "choices.0.message.content",
42        }),
43        "anthropic" => Some(ProviderDef {
44            api_url: "https://api.anthropic.com/v1/messages",
45            api_headers: "x-api-key: $ACR_API_KEY, anthropic-version: 2023-06-01",
46            default_model: "claude-sonnet-4-20250514",
47            format: RequestFormat::Anthropic,
48            response_path: "content.0.text",
49        }),
50        "groq" => Some(ProviderDef {
51            api_url: "https://api.groq.com/openai/v1/chat/completions",
52            api_headers: "Authorization: Bearer $ACR_API_KEY",
53            default_model: "llama-3.3-70b-versatile",
54            format: RequestFormat::OpenAiCompat,
55            response_path: "choices.0.message.content",
56        }),
57        "grok" => Some(ProviderDef {
58            api_url: "https://api.x.ai/v1/chat/completions",
59            api_headers: "Authorization: Bearer $ACR_API_KEY",
60            default_model: "grok-3",
61            format: RequestFormat::OpenAiCompat,
62            response_path: "choices.0.message.content",
63        }),
64        "deepseek" => Some(ProviderDef {
65            api_url: "https://api.deepseek.com/v1/chat/completions",
66            api_headers: "Authorization: Bearer $ACR_API_KEY",
67            default_model: "deepseek-chat",
68            format: RequestFormat::OpenAiCompat,
69            response_path: "choices.0.message.content",
70        }),
71        "openrouter" => Some(ProviderDef {
72            api_url: "https://openrouter.ai/api/v1/chat/completions",
73            api_headers: "Authorization: Bearer $ACR_API_KEY",
74            default_model: "openai/gpt-4o-mini",
75            format: RequestFormat::OpenAiCompat,
76            response_path: "choices.0.message.content",
77        }),
78        "mistral" => Some(ProviderDef {
79            api_url: "https://api.mistral.ai/v1/chat/completions",
80            api_headers: "Authorization: Bearer $ACR_API_KEY",
81            default_model: "mistral-small-latest",
82            format: RequestFormat::OpenAiCompat,
83            response_path: "choices.0.message.content",
84        }),
85        "together" => Some(ProviderDef {
86            api_url: "https://api.together.xyz/v1/chat/completions",
87            api_headers: "Authorization: Bearer $ACR_API_KEY",
88            default_model: "meta-llama/Llama-3.3-70B-Instruct-Turbo",
89            format: RequestFormat::OpenAiCompat,
90            response_path: "choices.0.message.content",
91        }),
92        "fireworks" => Some(ProviderDef {
93            api_url: "https://api.fireworks.ai/inference/v1/chat/completions",
94            api_headers: "Authorization: Bearer $ACR_API_KEY",
95            default_model: "accounts/fireworks/models/llama-v3p3-70b-instruct",
96            format: RequestFormat::OpenAiCompat,
97            response_path: "choices.0.message.content",
98        }),
99        "perplexity" => Some(ProviderDef {
100            api_url: "https://api.perplexity.ai/chat/completions",
101            api_headers: "Authorization: Bearer $ACR_API_KEY",
102            default_model: "sonar",
103            format: RequestFormat::OpenAiCompat,
104            response_path: "choices.0.message.content",
105        }),
106        "lm_studio" => Some(ProviderDef {
107            api_url: "http://localhost:1234/api/v1/chat",
108            api_headers: "Content-Type: application/json",
109            default_model: "qwen/qwen3.5-35b-a3b",
110            format: RequestFormat::LmStudio,
111            response_path: "output",
112        }),
113        _ => None,
114    }
115}
116
117/// Get the default model for a built-in provider, or empty string for unknown providers.
118pub fn default_model_for(provider: &str) -> &'static str {
119    get_provider(provider).map_or("", |p| p.default_model)
120}
121
122pub enum LlmCallError {
123    HttpError { code: u16, body: String },
124    TransportError(String),
125    Other(anyhow::Error),
126}
127
128impl std::fmt::Display for LlmCallError {
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        match self {
131            LlmCallError::HttpError { code, body } => {
132                write!(f, "API returned HTTP {code}: {body}")
133            }
134            LlmCallError::TransportError(msg) => write!(f, "Network error: {msg}"),
135            LlmCallError::Other(e) => write!(f, "{e}"),
136        }
137    }
138}
139
140fn call_llm_inner(
141    cfg: &AppConfig,
142    system_prompt: &str,
143    diff: &str,
144) -> Result<String, LlmCallError> {
145    let (url, headers_raw, format, response_path) =
146        resolve_provider(cfg).map_err(LlmCallError::Other)?;
147
148    let url = interpolate(&url, cfg);
149    let headers_raw = interpolate(&headers_raw, cfg);
150
151    let body = build_request_body(format, &cfg.model, system_prompt, diff);
152    let headers = parse_headers(&headers_raw);
153
154    let spinner = ProgressBar::new_spinner();
155    spinner.set_style(
156        ProgressStyle::default_spinner()
157            .template("{spinner:.cyan} {msg} {elapsed}")
158            .unwrap(),
159    );
160    spinner.set_message("Generating commit message...");
161    spinner.enable_steady_tick(Duration::from_millis(80));
162
163    let mut req = ureq::post(&url);
164    for (key, val) in &headers {
165        req = req.set(key, val);
166    }
167    req = req.set("Content-Type", "application/json");
168
169    let response = req.send_json(&body);
170
171    spinner.finish_and_clear();
172
173    let response = match response {
174        Ok(resp) => resp,
175        Err(ureq::Error::Status(code, resp)) => {
176            let body = resp.into_string().unwrap_or_default();
177            return Err(LlmCallError::HttpError { code, body });
178        }
179        Err(ureq::Error::Transport(t)) => {
180            return Err(LlmCallError::TransportError(t.to_string()));
181        }
182    };
183
184    let json: Value = response.into_json().map_err(|e| {
185        LlmCallError::Other(anyhow::anyhow!("Failed to parse API response as JSON: {e}"))
186    })?;
187
188    let message = extract_message(&json, format, &response_path).map_err(|e| {
189        LlmCallError::Other(anyhow::anyhow!(
190            "Failed to extract message from response at path '{}'. Response:\n{}\nError: {}",
191            response_path,
192            serde_json::to_string_pretty(&json).unwrap_or_default(),
193            e
194        ))
195    })?;
196
197    Ok(message)
198}
199
200/// Call LLM with fallback support. Returns (message, fallback_preset_name_if_used).
201pub fn call_llm_with_fallback(
202    cfg: &AppConfig,
203    system_prompt: &str,
204    diff: &str,
205) -> Result<(String, Option<String>)> {
206    match call_llm_inner(cfg, system_prompt, diff) {
207        Ok(msg) => Ok((msg, None)),
208        Err(LlmCallError::TransportError(msg)) => {
209            anyhow::bail!("Network error: {msg}");
210        }
211        Err(LlmCallError::HttpError { code, body }) => {
212            if !cfg.fallback_enabled {
213                anyhow::bail!("API returned HTTP {code}: {body}");
214            }
215
216            let presets_file = match crate::preset::load_presets() {
217                Ok(f) => f,
218                Err(_) => anyhow::bail!("API returned HTTP {code}: {body}"),
219            };
220
221            if presets_file.fallback.order.is_empty() {
222                anyhow::bail!("API returned HTTP {code}: {body}");
223            }
224
225            let current_fields = crate::preset::fields_from_config(cfg);
226            let mut errors = vec![format!("Primary (HTTP {code})")];
227
228            for &preset_id in &presets_file.fallback.order {
229                let preset = match presets_file.presets.iter().find(|p| p.id == preset_id) {
230                    Some(p) => p,
231                    None => continue,
232                };
233
234                // Skip if this preset matches current config (dedup key comparison)
235                if preset.fields.provider == current_fields.provider
236                    && preset.fields.model == current_fields.model
237                    && preset.fields.api_key == current_fields.api_key
238                    && preset.fields.api_url == current_fields.api_url
239                {
240                    continue;
241                }
242
243                eprintln!(
244                    "{} Primary failed (HTTP {}), trying: {}...",
245                    "fallback:".yellow().bold(),
246                    code,
247                    preset.name
248                );
249
250                let mut temp_cfg = cfg.clone();
251                crate::preset::apply_preset_to_config(&mut temp_cfg, preset);
252
253                match call_llm_inner(&temp_cfg, system_prompt, diff) {
254                    Ok(msg) => return Ok((msg, Some(preset.name.clone()))),
255                    Err(LlmCallError::HttpError { code: fc, .. }) => {
256                        errors.push(format!("{} (HTTP {fc})", preset.name));
257                        continue;
258                    }
259                    Err(LlmCallError::TransportError(msg)) => {
260                        anyhow::bail!("Network error during fallback to '{}': {msg}", preset.name);
261                    }
262                    Err(LlmCallError::Other(e)) => {
263                        errors.push(format!("{} ({})", preset.name, e));
264                        continue;
265                    }
266                }
267            }
268
269            anyhow::bail!("All LLM providers failed: {}", errors.join(", "));
270        }
271        Err(LlmCallError::Other(e)) => {
272            anyhow::bail!("{e}");
273        }
274    }
275}
276
277/// Call the LLM API and return the generated commit message
278pub fn call_llm(cfg: &AppConfig, system_prompt: &str, diff: &str) -> Result<String> {
279    let (msg, _) = call_llm_with_fallback(cfg, system_prompt, diff)?;
280    Ok(msg)
281}
282
283fn resolve_provider(cfg: &AppConfig) -> Result<(String, String, RequestFormat, String)> {
284    if let Some(def) = get_provider(&cfg.provider) {
285        let url = if cfg.api_url.is_empty() {
286            def.api_url.to_string()
287        } else {
288            cfg.api_url.clone()
289        };
290        let headers = if cfg.api_headers.is_empty() {
291            def.api_headers.to_string()
292        } else {
293            cfg.api_headers.clone()
294        };
295        Ok((url, headers, def.format, def.response_path.to_string()))
296    } else {
297        // Custom provider: require API URL, default to OpenAI-compatible format
298        if cfg.api_url.is_empty() {
299            bail!(
300                "Unknown provider '{}'. Set {} for custom providers.",
301                cfg.provider.yellow(),
302                "ACR_API_URL".yellow()
303            );
304        }
305        Ok((
306            cfg.api_url.clone(),
307            cfg.api_headers.clone(),
308            RequestFormat::OpenAiCompat,
309            "choices.0.message.content".to_string(),
310        ))
311    }
312}
313
314fn build_request_body(
315    format: RequestFormat,
316    model: &str,
317    system_prompt: &str,
318    diff: &str,
319) -> Value {
320    match format {
321        RequestFormat::Gemini => {
322            serde_json::json!({
323                "system_instruction": {
324                    "parts": [{ "text": system_prompt }]
325                },
326                "contents": [{
327                    "role": "user",
328                    "parts": [{ "text": diff }]
329                }],
330                "generationConfig": {
331                    "temperature": 0
332                }
333            })
334        }
335        RequestFormat::OpenAiCompat => {
336            serde_json::json!({
337                "model": model,
338                "messages": [
339                    { "role": "system", "content": system_prompt },
340                    { "role": "user", "content": diff }
341                ],
342                "max_tokens": 512,
343                "temperature": 0
344            })
345        }
346        RequestFormat::Anthropic => {
347            serde_json::json!({
348                "model": model,
349                "system": system_prompt,
350                "messages": [
351                    { "role": "user", "content": diff }
352                ],
353                "max_tokens": 512
354            })
355        }
356        RequestFormat::LmStudio => {
357            serde_json::json!({
358                "model": model,
359                "input": diff
360            })
361        }
362    }
363}
364
365/// Parse "Key: Value, Key2: Value2" header string into pairs
366fn parse_headers(raw: &str) -> Vec<(String, String)> {
367    if raw.trim().is_empty() {
368        return Vec::new();
369    }
370    raw.split(',')
371        .filter_map(|pair| {
372            let pair = pair.trim();
373            pair.split_once(':')
374                .map(|(k, v)| (k.trim().to_string(), v.trim().to_string()))
375        })
376        .collect()
377}
378
379/// Walk a JSON value by a dot-separated path like "candidates.0.content.parts.0.text"
380fn extract_by_path(value: &Value, path: &str) -> Result<String> {
381    let mut current = value;
382    for segment in path.split('.') {
383        current = if let Ok(index) = segment.parse::<usize>() {
384            current
385                .get(index)
386                .with_context(|| format!("Array index {index} out of bounds"))?
387        } else {
388            current
389                .get(segment)
390                .with_context(|| format!("Key '{segment}' not found"))?
391        };
392    }
393    current
394        .as_str()
395        .map(|s| s.to_string())
396        .with_context(|| "Expected string value at path end".to_string())
397}
398
399fn extract_message(value: &Value, format: RequestFormat, response_path: &str) -> Result<String> {
400    match format {
401        RequestFormat::LmStudio => {
402            let output = value
403                .get(response_path)
404                .and_then(Value::as_array)
405                .with_context(|| format!("Key '{response_path}' not found or is not an array"))?;
406
407            let message = output
408                .iter()
409                .find(|item| item.get("type").and_then(Value::as_str) == Some("message"))
410                .with_context(|| "No output item with type 'message' found".to_string())?;
411
412            message
413                .get("content")
414                .and_then(Value::as_str)
415                .map(str::to_string)
416                .with_context(|| "Expected string 'content' in message output item".to_string())
417        }
418        _ => extract_by_path(value, response_path),
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425
426    #[test]
427    fn test_parse_headers_empty() {
428        assert!(parse_headers("").is_empty());
429        assert!(parse_headers("   ").is_empty());
430    }
431
432    #[test]
433    fn test_parse_headers_single() {
434        let headers = parse_headers("Authorization: Bearer abc123");
435        assert_eq!(headers.len(), 1);
436        assert_eq!(headers[0].0, "Authorization");
437        assert_eq!(headers[0].1, "Bearer abc123");
438    }
439
440    #[test]
441    fn test_parse_headers_multiple() {
442        let headers = parse_headers("X-Api-Key: key123, Content-Type: application/json");
443        assert_eq!(headers.len(), 2);
444        assert_eq!(headers[0].0, "X-Api-Key");
445        assert_eq!(headers[0].1, "key123");
446        assert_eq!(headers[1].0, "Content-Type");
447        assert_eq!(headers[1].1, "application/json");
448    }
449
450    #[test]
451    fn test_parse_headers_trims_whitespace() {
452        let headers = parse_headers("  Key  :  Value  ");
453        assert_eq!(headers.len(), 1);
454        assert_eq!(headers[0].0, "Key");
455        assert_eq!(headers[0].1, "Value");
456    }
457
458    #[test]
459    fn test_parse_headers_skips_invalid() {
460        let headers = parse_headers("Valid: Header, InvalidNoColon, Another: One");
461        assert_eq!(headers.len(), 2);
462        assert_eq!(headers[0].0, "Valid");
463        assert_eq!(headers[1].0, "Another");
464    }
465
466    #[test]
467    fn test_extract_by_path_simple() {
468        let json = serde_json::json!({"message": "hello"});
469        let result = extract_by_path(&json, "message").unwrap();
470        assert_eq!(result, "hello");
471    }
472
473    #[test]
474    fn test_extract_by_path_nested() {
475        let json = serde_json::json!({"content": {"text": "nested"}});
476        let result = extract_by_path(&json, "content.text").unwrap();
477        assert_eq!(result, "nested");
478    }
479
480    #[test]
481    fn test_extract_by_path_array_index() {
482        let json = serde_json::json!({"items": ["first", "second"]});
483        let result = extract_by_path(&json, "items.0").unwrap();
484        assert_eq!(result, "first");
485    }
486
487    #[test]
488    fn test_extract_by_path_complex() {
489        let json = serde_json::json!({
490            "choices": [{"message": {"content": "generated"}}]
491        });
492        let result = extract_by_path(&json, "choices.0.message.content").unwrap();
493        assert_eq!(result, "generated");
494    }
495
496    #[test]
497    fn test_extract_by_path_gemini_format() {
498        let json = serde_json::json!({
499            "candidates": [{"content": {"parts": [{"text": "gemini response"}]}}]
500        });
501        let result = extract_by_path(&json, "candidates.0.content.parts.0.text").unwrap();
502        assert_eq!(result, "gemini response");
503    }
504
505    #[test]
506    fn test_extract_by_path_anthropic_format() {
507        let json = serde_json::json!({
508            "content": [{"text": "anthropic response"}]
509        });
510        let result = extract_by_path(&json, "content.0.text").unwrap();
511        assert_eq!(result, "anthropic response");
512    }
513
514    #[test]
515    fn test_extract_by_path_key_not_found() {
516        let json = serde_json::json!({"foo": "bar"});
517        let result = extract_by_path(&json, "missing");
518        assert!(result.is_err());
519        assert!(result.unwrap_err().to_string().contains("not found"));
520    }
521
522    #[test]
523    fn test_extract_by_path_index_out_of_bounds() {
524        let json = serde_json::json!({"items": ["only"]});
525        let result = extract_by_path(&json, "items.5");
526        assert!(result.is_err());
527        assert!(result.unwrap_err().to_string().contains("out of bounds"));
528    }
529
530    #[test]
531    fn test_extract_by_path_not_string() {
532        let json = serde_json::json!({"number": 42});
533        let result = extract_by_path(&json, "number");
534        assert!(result.is_err());
535        assert!(result.unwrap_err().to_string().contains("Expected string"));
536    }
537
538    #[test]
539    fn test_build_request_body_openai_compat() {
540        let body = build_request_body(
541            RequestFormat::OpenAiCompat,
542            "gpt-4o",
543            "system prompt",
544            "user diff",
545        );
546        assert_eq!(body["model"], "gpt-4o");
547        assert_eq!(body["messages"][0]["role"], "system");
548        assert_eq!(body["messages"][0]["content"], "system prompt");
549        assert_eq!(body["messages"][1]["role"], "user");
550        assert_eq!(body["messages"][1]["content"], "user diff");
551        assert_eq!(body["max_tokens"], 512);
552        assert_eq!(body["temperature"], 0);
553    }
554
555    #[test]
556    fn test_build_request_body_gemini() {
557        let body = build_request_body(
558            RequestFormat::Gemini,
559            "gemini-pro",
560            "system prompt",
561            "user diff",
562        );
563        assert_eq!(
564            body["system_instruction"]["parts"][0]["text"],
565            "system prompt"
566        );
567        assert_eq!(body["contents"][0]["role"], "user");
568        assert_eq!(body["contents"][0]["parts"][0]["text"], "user diff");
569        assert_eq!(body["generationConfig"]["temperature"], 0);
570    }
571
572    #[test]
573    fn test_build_request_body_anthropic() {
574        let body = build_request_body(
575            RequestFormat::Anthropic,
576            "claude-3-opus",
577            "system prompt",
578            "user diff",
579        );
580        assert_eq!(body["model"], "claude-3-opus");
581        assert_eq!(body["system"], "system prompt");
582        assert_eq!(body["messages"][0]["role"], "user");
583        assert_eq!(body["messages"][0]["content"], "user diff");
584        assert_eq!(body["max_tokens"], 512);
585    }
586
587    #[test]
588    fn test_build_request_body_lm_studio() {
589        let body = build_request_body(
590            RequestFormat::LmStudio,
591            "qwen/qwen3.5-35b-a3b",
592            "system prompt",
593            "user diff",
594        );
595        assert_eq!(body["model"], "qwen/qwen3.5-35b-a3b");
596        assert_eq!(body["input"], "user diff");
597        assert!(body.get("messages").is_none());
598    }
599
600    #[test]
601    fn test_get_provider_known() {
602        assert!(get_provider("gemini").is_some());
603        assert!(get_provider("openai").is_some());
604        assert!(get_provider("anthropic").is_some());
605        assert!(get_provider("groq").is_some());
606        assert!(get_provider("grok").is_some());
607        assert!(get_provider("deepseek").is_some());
608        assert!(get_provider("openrouter").is_some());
609        assert!(get_provider("mistral").is_some());
610        assert!(get_provider("together").is_some());
611        assert!(get_provider("fireworks").is_some());
612        assert!(get_provider("perplexity").is_some());
613        assert!(get_provider("lm_studio").is_some());
614    }
615
616    #[test]
617    fn test_get_provider_unknown() {
618        assert!(get_provider("unknown").is_none());
619        assert!(get_provider("custom").is_none());
620    }
621
622    #[test]
623    fn test_get_provider_gemini_format() {
624        let provider = get_provider("gemini").unwrap();
625        assert_eq!(provider.format, RequestFormat::Gemini);
626        assert!(provider
627            .api_url
628            .contains("generativelanguage.googleapis.com"));
629        assert_eq!(provider.default_model, "gemini-2.0-flash");
630    }
631
632    #[test]
633    fn test_get_provider_anthropic_format() {
634        let provider = get_provider("anthropic").unwrap();
635        assert_eq!(provider.format, RequestFormat::Anthropic);
636        assert!(provider.api_url.contains("anthropic.com"));
637        assert!(provider.api_headers.contains("anthropic-version"));
638    }
639
640    #[test]
641    fn test_get_provider_openai_compat() {
642        for name in &[
643            "openai",
644            "groq",
645            "grok",
646            "deepseek",
647            "openrouter",
648            "mistral",
649            "together",
650            "fireworks",
651            "perplexity",
652        ] {
653            let provider = get_provider(name).unwrap();
654            assert_eq!(
655                provider.format,
656                RequestFormat::OpenAiCompat,
657                "Provider {name} should use OpenAiCompat format"
658            );
659        }
660    }
661
662    #[test]
663    fn test_get_provider_lm_studio_format() {
664        let provider = get_provider("lm_studio").unwrap();
665        assert_eq!(provider.format, RequestFormat::LmStudio);
666        assert_eq!(provider.api_url, "http://localhost:1234/api/v1/chat");
667        assert_eq!(provider.api_headers, "Content-Type: application/json");
668        assert_eq!(provider.default_model, "qwen/qwen3.5-35b-a3b");
669    }
670
671    #[test]
672    fn test_default_model_for_known() {
673        assert_eq!(default_model_for("groq"), "llama-3.3-70b-versatile");
674        assert_eq!(default_model_for("openai"), "gpt-4o-mini");
675        assert_eq!(default_model_for("anthropic"), "claude-sonnet-4-20250514");
676        assert_eq!(default_model_for("lm_studio"), "qwen/qwen3.5-35b-a3b");
677    }
678
679    #[test]
680    fn test_default_model_for_unknown() {
681        assert_eq!(default_model_for("custom"), "");
682        assert_eq!(default_model_for("unknown"), "");
683    }
684
685    #[test]
686    fn test_resolve_provider_known() {
687        let cfg = AppConfig {
688            provider: "groq".into(),
689            api_key: "test-key".into(),
690            ..Default::default()
691        };
692        let (url, headers, format, path) = resolve_provider(&cfg).unwrap();
693        assert!(url.contains("groq.com"));
694        assert!(headers.contains("Bearer"));
695        assert_eq!(format, RequestFormat::OpenAiCompat);
696        assert_eq!(path, "choices.0.message.content");
697    }
698
699    #[test]
700    fn test_resolve_provider_known_with_override() {
701        let cfg = AppConfig {
702            provider: "groq".into(),
703            api_url: "https://custom.url/v1".into(),
704            api_headers: "X-Custom: value".into(),
705            ..Default::default()
706        };
707        let (url, headers, _, _) = resolve_provider(&cfg).unwrap();
708        assert_eq!(url, "https://custom.url/v1");
709        assert_eq!(headers, "X-Custom: value");
710    }
711
712    #[test]
713    fn test_resolve_provider_custom_requires_url() {
714        let cfg = AppConfig {
715            provider: "custom-provider".into(),
716            api_url: "".into(),
717            ..Default::default()
718        };
719        let result = resolve_provider(&cfg);
720        assert!(result.is_err());
721        assert!(result.unwrap_err().to_string().contains("Unknown provider"));
722    }
723
724    #[test]
725    fn test_resolve_provider_custom_with_url() {
726        let cfg = AppConfig {
727            provider: "custom-provider".into(),
728            api_url: "https://my-custom-api.com/v1".into(),
729            api_headers: "Authorization: custom".into(),
730            ..Default::default()
731        };
732        let (url, headers, format, path) = resolve_provider(&cfg).unwrap();
733        assert_eq!(url, "https://my-custom-api.com/v1");
734        assert_eq!(headers, "Authorization: custom");
735        assert_eq!(format, RequestFormat::OpenAiCompat);
736        assert_eq!(path, "choices.0.message.content");
737    }
738
739    #[test]
740    fn test_llm_call_error_display_http() {
741        let err = LlmCallError::HttpError {
742            code: 401,
743            body: "Unauthorized".into(),
744        };
745        let display = format!("{err}");
746        assert!(display.contains("HTTP 401"));
747        assert!(display.contains("Unauthorized"));
748    }
749
750    #[test]
751    fn test_llm_call_error_display_transport() {
752        let err = LlmCallError::TransportError("connection refused".into());
753        let display = format!("{err}");
754        assert!(display.contains("Network error"));
755        assert!(display.contains("connection refused"));
756    }
757
758    #[test]
759    fn test_llm_call_error_display_other() {
760        let err = LlmCallError::Other(anyhow::anyhow!("custom error"));
761        let display = format!("{err}");
762        assert!(display.contains("custom error"));
763    }
764
765    #[test]
766    fn test_request_format_equality() {
767        assert_eq!(RequestFormat::Gemini, RequestFormat::Gemini);
768        assert_eq!(RequestFormat::OpenAiCompat, RequestFormat::OpenAiCompat);
769        assert_eq!(RequestFormat::Anthropic, RequestFormat::Anthropic);
770        assert_eq!(RequestFormat::LmStudio, RequestFormat::LmStudio);
771        assert_ne!(RequestFormat::Gemini, RequestFormat::OpenAiCompat);
772    }
773
774    #[test]
775    fn test_extract_message_lm_studio_message_item() {
776        let json = serde_json::json!({
777            "output": [
778                { "type": "reasoning", "content": "thinking" },
779                { "type": "message", "content": "feat: lm studio response" }
780            ]
781        });
782        let result = extract_message(&json, RequestFormat::LmStudio, "output").unwrap();
783        assert_eq!(result, "feat: lm studio response");
784    }
785
786    #[test]
787    fn test_extract_message_lm_studio_missing_message_item() {
788        let json = serde_json::json!({
789            "output": [
790                { "type": "reasoning", "content": "thinking only" }
791            ]
792        });
793        let result = extract_message(&json, RequestFormat::LmStudio, "output");
794        assert!(result.is_err());
795        assert!(result.unwrap_err().to_string().contains("type 'message'"));
796    }
797}