1use anyhow::{bail, Context, Result};
2use colored::Colorize;
3use indicatif::{ProgressBar, ProgressStyle};
4use serde_json::Value;
5use std::time::Duration;
6
7use crate::config::AppConfig;
8use crate::interpolation::interpolate;
9
10#[derive(Debug, Clone, Copy, PartialEq)]
11enum RequestFormat {
12 Gemini,
13 OpenAiCompat,
14 Anthropic,
15 LmStudio,
16}
17
18struct ProviderDef {
19 api_url: &'static str,
20 api_headers: &'static str,
21 default_model: &'static str,
22 format: RequestFormat,
23 response_path: &'static str,
24}
25
26fn get_provider(name: &str) -> Option<ProviderDef> {
28 match name {
29 "gemini" => Some(ProviderDef {
30 api_url: "https://generativelanguage.googleapis.com/v1beta/models/$ACR_MODEL:generateContent?key=$ACR_API_KEY",
31 api_headers: "",
32 default_model: "gemini-2.0-flash",
33 format: RequestFormat::Gemini,
34 response_path: "candidates.0.content.parts.0.text",
35 }),
36 "openai" => Some(ProviderDef {
37 api_url: "https://api.openai.com/v1/chat/completions",
38 api_headers: "Authorization: Bearer $ACR_API_KEY",
39 default_model: "gpt-4o-mini",
40 format: RequestFormat::OpenAiCompat,
41 response_path: "choices.0.message.content",
42 }),
43 "anthropic" => Some(ProviderDef {
44 api_url: "https://api.anthropic.com/v1/messages",
45 api_headers: "x-api-key: $ACR_API_KEY, anthropic-version: 2023-06-01",
46 default_model: "claude-sonnet-4-20250514",
47 format: RequestFormat::Anthropic,
48 response_path: "content.0.text",
49 }),
50 "groq" => Some(ProviderDef {
51 api_url: "https://api.groq.com/openai/v1/chat/completions",
52 api_headers: "Authorization: Bearer $ACR_API_KEY",
53 default_model: "llama-3.3-70b-versatile",
54 format: RequestFormat::OpenAiCompat,
55 response_path: "choices.0.message.content",
56 }),
57 "grok" => Some(ProviderDef {
58 api_url: "https://api.x.ai/v1/chat/completions",
59 api_headers: "Authorization: Bearer $ACR_API_KEY",
60 default_model: "grok-3",
61 format: RequestFormat::OpenAiCompat,
62 response_path: "choices.0.message.content",
63 }),
64 "deepseek" => Some(ProviderDef {
65 api_url: "https://api.deepseek.com/v1/chat/completions",
66 api_headers: "Authorization: Bearer $ACR_API_KEY",
67 default_model: "deepseek-chat",
68 format: RequestFormat::OpenAiCompat,
69 response_path: "choices.0.message.content",
70 }),
71 "openrouter" => Some(ProviderDef {
72 api_url: "https://openrouter.ai/api/v1/chat/completions",
73 api_headers: "Authorization: Bearer $ACR_API_KEY",
74 default_model: "openai/gpt-4o-mini",
75 format: RequestFormat::OpenAiCompat,
76 response_path: "choices.0.message.content",
77 }),
78 "mistral" => Some(ProviderDef {
79 api_url: "https://api.mistral.ai/v1/chat/completions",
80 api_headers: "Authorization: Bearer $ACR_API_KEY",
81 default_model: "mistral-small-latest",
82 format: RequestFormat::OpenAiCompat,
83 response_path: "choices.0.message.content",
84 }),
85 "together" => Some(ProviderDef {
86 api_url: "https://api.together.xyz/v1/chat/completions",
87 api_headers: "Authorization: Bearer $ACR_API_KEY",
88 default_model: "meta-llama/Llama-3.3-70B-Instruct-Turbo",
89 format: RequestFormat::OpenAiCompat,
90 response_path: "choices.0.message.content",
91 }),
92 "fireworks" => Some(ProviderDef {
93 api_url: "https://api.fireworks.ai/inference/v1/chat/completions",
94 api_headers: "Authorization: Bearer $ACR_API_KEY",
95 default_model: "accounts/fireworks/models/llama-v3p3-70b-instruct",
96 format: RequestFormat::OpenAiCompat,
97 response_path: "choices.0.message.content",
98 }),
99 "perplexity" => Some(ProviderDef {
100 api_url: "https://api.perplexity.ai/chat/completions",
101 api_headers: "Authorization: Bearer $ACR_API_KEY",
102 default_model: "sonar",
103 format: RequestFormat::OpenAiCompat,
104 response_path: "choices.0.message.content",
105 }),
106 "lm_studio" => Some(ProviderDef {
107 api_url: "http://localhost:1234/api/v1/chat",
108 api_headers: "Content-Type: application/json",
109 default_model: "qwen/qwen3.5-35b-a3b",
110 format: RequestFormat::LmStudio,
111 response_path: "output",
112 }),
113 _ => None,
114 }
115}
116
117pub fn default_model_for(provider: &str) -> &'static str {
119 get_provider(provider).map_or("", |p| p.default_model)
120}
121
122pub enum LlmCallError {
123 HttpError { code: u16, body: String },
124 TransportError(String),
125 Other(anyhow::Error),
126}
127
128impl std::fmt::Display for LlmCallError {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 match self {
131 LlmCallError::HttpError { code, body } => {
132 write!(f, "API returned HTTP {code}: {body}")
133 }
134 LlmCallError::TransportError(msg) => write!(f, "Network error: {msg}"),
135 LlmCallError::Other(e) => write!(f, "{e}"),
136 }
137 }
138}
139
140fn call_llm_inner(
141 cfg: &AppConfig,
142 system_prompt: &str,
143 diff: &str,
144) -> Result<String, LlmCallError> {
145 let (url, headers_raw, format, response_path) =
146 resolve_provider(cfg).map_err(LlmCallError::Other)?;
147
148 let url = interpolate(&url, cfg);
149 let headers_raw = interpolate(&headers_raw, cfg);
150
151 let body = build_request_body(format, &cfg.model, system_prompt, diff);
152 let headers = parse_headers(&headers_raw);
153
154 let spinner = ProgressBar::new_spinner();
155 spinner.set_style(
156 ProgressStyle::default_spinner()
157 .template("{spinner:.cyan} {msg} {elapsed}")
158 .unwrap(),
159 );
160 spinner.set_message("Generating commit message...");
161 spinner.enable_steady_tick(Duration::from_millis(80));
162
163 let mut req = ureq::post(&url);
164 for (key, val) in &headers {
165 req = req.set(key, val);
166 }
167 req = req.set("Content-Type", "application/json");
168
169 let response = req.send_json(&body);
170
171 spinner.finish_and_clear();
172
173 let response = match response {
174 Ok(resp) => resp,
175 Err(ureq::Error::Status(code, resp)) => {
176 let body = resp.into_string().unwrap_or_default();
177 return Err(LlmCallError::HttpError { code, body });
178 }
179 Err(ureq::Error::Transport(t)) => {
180 return Err(LlmCallError::TransportError(t.to_string()));
181 }
182 };
183
184 let json: Value = response.into_json().map_err(|e| {
185 LlmCallError::Other(anyhow::anyhow!("Failed to parse API response as JSON: {e}"))
186 })?;
187
188 let message = extract_message(&json, format, &response_path).map_err(|e| {
189 LlmCallError::Other(anyhow::anyhow!(
190 "Failed to extract message from response at path '{}'. Response:\n{}\nError: {}",
191 response_path,
192 serde_json::to_string_pretty(&json).unwrap_or_default(),
193 e
194 ))
195 })?;
196
197 Ok(message)
198}
199
200pub fn call_llm_with_fallback(
202 cfg: &AppConfig,
203 system_prompt: &str,
204 diff: &str,
205) -> Result<(String, Option<String>)> {
206 match call_llm_inner(cfg, system_prompt, diff) {
207 Ok(msg) => Ok((msg, None)),
208 Err(LlmCallError::TransportError(msg)) => {
209 anyhow::bail!("Network error: {msg}");
210 }
211 Err(LlmCallError::HttpError { code, body }) => {
212 if !cfg.fallback_enabled {
213 anyhow::bail!("API returned HTTP {code}: {body}");
214 }
215
216 let presets_file = match crate::preset::load_presets() {
217 Ok(f) => f,
218 Err(_) => anyhow::bail!("API returned HTTP {code}: {body}"),
219 };
220
221 if presets_file.fallback.order.is_empty() {
222 anyhow::bail!("API returned HTTP {code}: {body}");
223 }
224
225 let current_fields = crate::preset::fields_from_config(cfg);
226 let mut errors = vec![format!("Primary (HTTP {code})")];
227
228 for &preset_id in &presets_file.fallback.order {
229 let preset = match presets_file.presets.iter().find(|p| p.id == preset_id) {
230 Some(p) => p,
231 None => continue,
232 };
233
234 if preset.fields.provider == current_fields.provider
236 && preset.fields.model == current_fields.model
237 && preset.fields.api_key == current_fields.api_key
238 && preset.fields.api_url == current_fields.api_url
239 {
240 continue;
241 }
242
243 eprintln!(
244 "{} Primary failed (HTTP {}), trying: {}...",
245 "fallback:".yellow().bold(),
246 code,
247 preset.name
248 );
249
250 let mut temp_cfg = cfg.clone();
251 crate::preset::apply_preset_to_config(&mut temp_cfg, preset);
252
253 match call_llm_inner(&temp_cfg, system_prompt, diff) {
254 Ok(msg) => return Ok((msg, Some(preset.name.clone()))),
255 Err(LlmCallError::HttpError { code: fc, .. }) => {
256 errors.push(format!("{} (HTTP {fc})", preset.name));
257 continue;
258 }
259 Err(LlmCallError::TransportError(msg)) => {
260 anyhow::bail!("Network error during fallback to '{}': {msg}", preset.name);
261 }
262 Err(LlmCallError::Other(e)) => {
263 errors.push(format!("{} ({})", preset.name, e));
264 continue;
265 }
266 }
267 }
268
269 anyhow::bail!("All LLM providers failed: {}", errors.join(", "));
270 }
271 Err(LlmCallError::Other(e)) => {
272 anyhow::bail!("{e}");
273 }
274 }
275}
276
277pub fn call_llm(cfg: &AppConfig, system_prompt: &str, diff: &str) -> Result<String> {
279 let (msg, _) = call_llm_with_fallback(cfg, system_prompt, diff)?;
280 Ok(msg)
281}
282
283fn resolve_provider(cfg: &AppConfig) -> Result<(String, String, RequestFormat, String)> {
284 if let Some(def) = get_provider(&cfg.provider) {
285 let url = if cfg.api_url.is_empty() {
286 def.api_url.to_string()
287 } else {
288 cfg.api_url.clone()
289 };
290 let headers = if cfg.api_headers.is_empty() {
291 def.api_headers.to_string()
292 } else {
293 cfg.api_headers.clone()
294 };
295 Ok((url, headers, def.format, def.response_path.to_string()))
296 } else {
297 if cfg.api_url.is_empty() {
299 bail!(
300 "Unknown provider '{}'. Set {} for custom providers.",
301 cfg.provider.yellow(),
302 "ACR_API_URL".yellow()
303 );
304 }
305 Ok((
306 cfg.api_url.clone(),
307 cfg.api_headers.clone(),
308 RequestFormat::OpenAiCompat,
309 "choices.0.message.content".to_string(),
310 ))
311 }
312}
313
314fn build_request_body(
315 format: RequestFormat,
316 model: &str,
317 system_prompt: &str,
318 diff: &str,
319) -> Value {
320 match format {
321 RequestFormat::Gemini => {
322 serde_json::json!({
323 "system_instruction": {
324 "parts": [{ "text": system_prompt }]
325 },
326 "contents": [{
327 "role": "user",
328 "parts": [{ "text": diff }]
329 }],
330 "generationConfig": {
331 "temperature": 0
332 }
333 })
334 }
335 RequestFormat::OpenAiCompat => {
336 serde_json::json!({
337 "model": model,
338 "messages": [
339 { "role": "system", "content": system_prompt },
340 { "role": "user", "content": diff }
341 ],
342 "max_tokens": 512,
343 "temperature": 0
344 })
345 }
346 RequestFormat::Anthropic => {
347 serde_json::json!({
348 "model": model,
349 "system": system_prompt,
350 "messages": [
351 { "role": "user", "content": diff }
352 ],
353 "max_tokens": 512
354 })
355 }
356 RequestFormat::LmStudio => {
357 serde_json::json!({
358 "model": model,
359 "input": diff
360 })
361 }
362 }
363}
364
365fn parse_headers(raw: &str) -> Vec<(String, String)> {
367 if raw.trim().is_empty() {
368 return Vec::new();
369 }
370 raw.split(',')
371 .filter_map(|pair| {
372 let pair = pair.trim();
373 pair.split_once(':')
374 .map(|(k, v)| (k.trim().to_string(), v.trim().to_string()))
375 })
376 .collect()
377}
378
379fn extract_by_path(value: &Value, path: &str) -> Result<String> {
381 let mut current = value;
382 for segment in path.split('.') {
383 current = if let Ok(index) = segment.parse::<usize>() {
384 current
385 .get(index)
386 .with_context(|| format!("Array index {index} out of bounds"))?
387 } else {
388 current
389 .get(segment)
390 .with_context(|| format!("Key '{segment}' not found"))?
391 };
392 }
393 current
394 .as_str()
395 .map(|s| s.to_string())
396 .with_context(|| "Expected string value at path end".to_string())
397}
398
399fn extract_message(value: &Value, format: RequestFormat, response_path: &str) -> Result<String> {
400 match format {
401 RequestFormat::LmStudio => {
402 let output = value
403 .get(response_path)
404 .and_then(Value::as_array)
405 .with_context(|| format!("Key '{response_path}' not found or is not an array"))?;
406
407 let message = output
408 .iter()
409 .find(|item| item.get("type").and_then(Value::as_str) == Some("message"))
410 .with_context(|| "No output item with type 'message' found".to_string())?;
411
412 message
413 .get("content")
414 .and_then(Value::as_str)
415 .map(str::to_string)
416 .with_context(|| "Expected string 'content' in message output item".to_string())
417 }
418 _ => extract_by_path(value, response_path),
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 #[test]
427 fn test_parse_headers_empty() {
428 assert!(parse_headers("").is_empty());
429 assert!(parse_headers(" ").is_empty());
430 }
431
432 #[test]
433 fn test_parse_headers_single() {
434 let headers = parse_headers("Authorization: Bearer abc123");
435 assert_eq!(headers.len(), 1);
436 assert_eq!(headers[0].0, "Authorization");
437 assert_eq!(headers[0].1, "Bearer abc123");
438 }
439
440 #[test]
441 fn test_parse_headers_multiple() {
442 let headers = parse_headers("X-Api-Key: key123, Content-Type: application/json");
443 assert_eq!(headers.len(), 2);
444 assert_eq!(headers[0].0, "X-Api-Key");
445 assert_eq!(headers[0].1, "key123");
446 assert_eq!(headers[1].0, "Content-Type");
447 assert_eq!(headers[1].1, "application/json");
448 }
449
450 #[test]
451 fn test_parse_headers_trims_whitespace() {
452 let headers = parse_headers(" Key : Value ");
453 assert_eq!(headers.len(), 1);
454 assert_eq!(headers[0].0, "Key");
455 assert_eq!(headers[0].1, "Value");
456 }
457
458 #[test]
459 fn test_parse_headers_skips_invalid() {
460 let headers = parse_headers("Valid: Header, InvalidNoColon, Another: One");
461 assert_eq!(headers.len(), 2);
462 assert_eq!(headers[0].0, "Valid");
463 assert_eq!(headers[1].0, "Another");
464 }
465
466 #[test]
467 fn test_extract_by_path_simple() {
468 let json = serde_json::json!({"message": "hello"});
469 let result = extract_by_path(&json, "message").unwrap();
470 assert_eq!(result, "hello");
471 }
472
473 #[test]
474 fn test_extract_by_path_nested() {
475 let json = serde_json::json!({"content": {"text": "nested"}});
476 let result = extract_by_path(&json, "content.text").unwrap();
477 assert_eq!(result, "nested");
478 }
479
480 #[test]
481 fn test_extract_by_path_array_index() {
482 let json = serde_json::json!({"items": ["first", "second"]});
483 let result = extract_by_path(&json, "items.0").unwrap();
484 assert_eq!(result, "first");
485 }
486
487 #[test]
488 fn test_extract_by_path_complex() {
489 let json = serde_json::json!({
490 "choices": [{"message": {"content": "generated"}}]
491 });
492 let result = extract_by_path(&json, "choices.0.message.content").unwrap();
493 assert_eq!(result, "generated");
494 }
495
496 #[test]
497 fn test_extract_by_path_gemini_format() {
498 let json = serde_json::json!({
499 "candidates": [{"content": {"parts": [{"text": "gemini response"}]}}]
500 });
501 let result = extract_by_path(&json, "candidates.0.content.parts.0.text").unwrap();
502 assert_eq!(result, "gemini response");
503 }
504
505 #[test]
506 fn test_extract_by_path_anthropic_format() {
507 let json = serde_json::json!({
508 "content": [{"text": "anthropic response"}]
509 });
510 let result = extract_by_path(&json, "content.0.text").unwrap();
511 assert_eq!(result, "anthropic response");
512 }
513
514 #[test]
515 fn test_extract_by_path_key_not_found() {
516 let json = serde_json::json!({"foo": "bar"});
517 let result = extract_by_path(&json, "missing");
518 assert!(result.is_err());
519 assert!(result.unwrap_err().to_string().contains("not found"));
520 }
521
522 #[test]
523 fn test_extract_by_path_index_out_of_bounds() {
524 let json = serde_json::json!({"items": ["only"]});
525 let result = extract_by_path(&json, "items.5");
526 assert!(result.is_err());
527 assert!(result.unwrap_err().to_string().contains("out of bounds"));
528 }
529
530 #[test]
531 fn test_extract_by_path_not_string() {
532 let json = serde_json::json!({"number": 42});
533 let result = extract_by_path(&json, "number");
534 assert!(result.is_err());
535 assert!(result.unwrap_err().to_string().contains("Expected string"));
536 }
537
538 #[test]
539 fn test_build_request_body_openai_compat() {
540 let body = build_request_body(
541 RequestFormat::OpenAiCompat,
542 "gpt-4o",
543 "system prompt",
544 "user diff",
545 );
546 assert_eq!(body["model"], "gpt-4o");
547 assert_eq!(body["messages"][0]["role"], "system");
548 assert_eq!(body["messages"][0]["content"], "system prompt");
549 assert_eq!(body["messages"][1]["role"], "user");
550 assert_eq!(body["messages"][1]["content"], "user diff");
551 assert_eq!(body["max_tokens"], 512);
552 assert_eq!(body["temperature"], 0);
553 }
554
555 #[test]
556 fn test_build_request_body_gemini() {
557 let body = build_request_body(
558 RequestFormat::Gemini,
559 "gemini-pro",
560 "system prompt",
561 "user diff",
562 );
563 assert_eq!(
564 body["system_instruction"]["parts"][0]["text"],
565 "system prompt"
566 );
567 assert_eq!(body["contents"][0]["role"], "user");
568 assert_eq!(body["contents"][0]["parts"][0]["text"], "user diff");
569 assert_eq!(body["generationConfig"]["temperature"], 0);
570 }
571
572 #[test]
573 fn test_build_request_body_anthropic() {
574 let body = build_request_body(
575 RequestFormat::Anthropic,
576 "claude-3-opus",
577 "system prompt",
578 "user diff",
579 );
580 assert_eq!(body["model"], "claude-3-opus");
581 assert_eq!(body["system"], "system prompt");
582 assert_eq!(body["messages"][0]["role"], "user");
583 assert_eq!(body["messages"][0]["content"], "user diff");
584 assert_eq!(body["max_tokens"], 512);
585 }
586
587 #[test]
588 fn test_build_request_body_lm_studio() {
589 let body = build_request_body(
590 RequestFormat::LmStudio,
591 "qwen/qwen3.5-35b-a3b",
592 "system prompt",
593 "user diff",
594 );
595 assert_eq!(body["model"], "qwen/qwen3.5-35b-a3b");
596 assert_eq!(body["input"], "user diff");
597 assert!(body.get("messages").is_none());
598 }
599
600 #[test]
601 fn test_get_provider_known() {
602 assert!(get_provider("gemini").is_some());
603 assert!(get_provider("openai").is_some());
604 assert!(get_provider("anthropic").is_some());
605 assert!(get_provider("groq").is_some());
606 assert!(get_provider("grok").is_some());
607 assert!(get_provider("deepseek").is_some());
608 assert!(get_provider("openrouter").is_some());
609 assert!(get_provider("mistral").is_some());
610 assert!(get_provider("together").is_some());
611 assert!(get_provider("fireworks").is_some());
612 assert!(get_provider("perplexity").is_some());
613 assert!(get_provider("lm_studio").is_some());
614 }
615
616 #[test]
617 fn test_get_provider_unknown() {
618 assert!(get_provider("unknown").is_none());
619 assert!(get_provider("custom").is_none());
620 }
621
622 #[test]
623 fn test_get_provider_gemini_format() {
624 let provider = get_provider("gemini").unwrap();
625 assert_eq!(provider.format, RequestFormat::Gemini);
626 assert!(provider
627 .api_url
628 .contains("generativelanguage.googleapis.com"));
629 assert_eq!(provider.default_model, "gemini-2.0-flash");
630 }
631
632 #[test]
633 fn test_get_provider_anthropic_format() {
634 let provider = get_provider("anthropic").unwrap();
635 assert_eq!(provider.format, RequestFormat::Anthropic);
636 assert!(provider.api_url.contains("anthropic.com"));
637 assert!(provider.api_headers.contains("anthropic-version"));
638 }
639
640 #[test]
641 fn test_get_provider_openai_compat() {
642 for name in &[
643 "openai",
644 "groq",
645 "grok",
646 "deepseek",
647 "openrouter",
648 "mistral",
649 "together",
650 "fireworks",
651 "perplexity",
652 ] {
653 let provider = get_provider(name).unwrap();
654 assert_eq!(
655 provider.format,
656 RequestFormat::OpenAiCompat,
657 "Provider {name} should use OpenAiCompat format"
658 );
659 }
660 }
661
662 #[test]
663 fn test_get_provider_lm_studio_format() {
664 let provider = get_provider("lm_studio").unwrap();
665 assert_eq!(provider.format, RequestFormat::LmStudio);
666 assert_eq!(provider.api_url, "http://localhost:1234/api/v1/chat");
667 assert_eq!(provider.api_headers, "Content-Type: application/json");
668 assert_eq!(provider.default_model, "qwen/qwen3.5-35b-a3b");
669 }
670
671 #[test]
672 fn test_default_model_for_known() {
673 assert_eq!(default_model_for("groq"), "llama-3.3-70b-versatile");
674 assert_eq!(default_model_for("openai"), "gpt-4o-mini");
675 assert_eq!(default_model_for("anthropic"), "claude-sonnet-4-20250514");
676 assert_eq!(default_model_for("lm_studio"), "qwen/qwen3.5-35b-a3b");
677 }
678
679 #[test]
680 fn test_default_model_for_unknown() {
681 assert_eq!(default_model_for("custom"), "");
682 assert_eq!(default_model_for("unknown"), "");
683 }
684
685 #[test]
686 fn test_resolve_provider_known() {
687 let cfg = AppConfig {
688 provider: "groq".into(),
689 api_key: "test-key".into(),
690 ..Default::default()
691 };
692 let (url, headers, format, path) = resolve_provider(&cfg).unwrap();
693 assert!(url.contains("groq.com"));
694 assert!(headers.contains("Bearer"));
695 assert_eq!(format, RequestFormat::OpenAiCompat);
696 assert_eq!(path, "choices.0.message.content");
697 }
698
699 #[test]
700 fn test_resolve_provider_known_with_override() {
701 let cfg = AppConfig {
702 provider: "groq".into(),
703 api_url: "https://custom.url/v1".into(),
704 api_headers: "X-Custom: value".into(),
705 ..Default::default()
706 };
707 let (url, headers, _, _) = resolve_provider(&cfg).unwrap();
708 assert_eq!(url, "https://custom.url/v1");
709 assert_eq!(headers, "X-Custom: value");
710 }
711
712 #[test]
713 fn test_resolve_provider_custom_requires_url() {
714 let cfg = AppConfig {
715 provider: "custom-provider".into(),
716 api_url: "".into(),
717 ..Default::default()
718 };
719 let result = resolve_provider(&cfg);
720 assert!(result.is_err());
721 assert!(result.unwrap_err().to_string().contains("Unknown provider"));
722 }
723
724 #[test]
725 fn test_resolve_provider_custom_with_url() {
726 let cfg = AppConfig {
727 provider: "custom-provider".into(),
728 api_url: "https://my-custom-api.com/v1".into(),
729 api_headers: "Authorization: custom".into(),
730 ..Default::default()
731 };
732 let (url, headers, format, path) = resolve_provider(&cfg).unwrap();
733 assert_eq!(url, "https://my-custom-api.com/v1");
734 assert_eq!(headers, "Authorization: custom");
735 assert_eq!(format, RequestFormat::OpenAiCompat);
736 assert_eq!(path, "choices.0.message.content");
737 }
738
739 #[test]
740 fn test_llm_call_error_display_http() {
741 let err = LlmCallError::HttpError {
742 code: 401,
743 body: "Unauthorized".into(),
744 };
745 let display = format!("{err}");
746 assert!(display.contains("HTTP 401"));
747 assert!(display.contains("Unauthorized"));
748 }
749
750 #[test]
751 fn test_llm_call_error_display_transport() {
752 let err = LlmCallError::TransportError("connection refused".into());
753 let display = format!("{err}");
754 assert!(display.contains("Network error"));
755 assert!(display.contains("connection refused"));
756 }
757
758 #[test]
759 fn test_llm_call_error_display_other() {
760 let err = LlmCallError::Other(anyhow::anyhow!("custom error"));
761 let display = format!("{err}");
762 assert!(display.contains("custom error"));
763 }
764
765 #[test]
766 fn test_request_format_equality() {
767 assert_eq!(RequestFormat::Gemini, RequestFormat::Gemini);
768 assert_eq!(RequestFormat::OpenAiCompat, RequestFormat::OpenAiCompat);
769 assert_eq!(RequestFormat::Anthropic, RequestFormat::Anthropic);
770 assert_eq!(RequestFormat::LmStudio, RequestFormat::LmStudio);
771 assert_ne!(RequestFormat::Gemini, RequestFormat::OpenAiCompat);
772 }
773
774 #[test]
775 fn test_extract_message_lm_studio_message_item() {
776 let json = serde_json::json!({
777 "output": [
778 { "type": "reasoning", "content": "thinking" },
779 { "type": "message", "content": "feat: lm studio response" }
780 ]
781 });
782 let result = extract_message(&json, RequestFormat::LmStudio, "output").unwrap();
783 assert_eq!(result, "feat: lm studio response");
784 }
785
786 #[test]
787 fn test_extract_message_lm_studio_missing_message_item() {
788 let json = serde_json::json!({
789 "output": [
790 { "type": "reasoning", "content": "thinking only" }
791 ]
792 });
793 let result = extract_message(&json, RequestFormat::LmStudio, "output");
794 assert!(result.is_err());
795 assert!(result.unwrap_err().to_string().contains("type 'message'"));
796 }
797}