1use anyhow::{bail, Context, Result};
2use colored::Colorize;
3use indicatif::{ProgressBar, ProgressStyle};
4use serde_json::Value;
5use std::time::Duration;
6
7use crate::config::AppConfig;
8use crate::interpolation::interpolate;
9
10#[derive(Debug, Clone, Copy, PartialEq)]
11enum RequestFormat {
12 Gemini,
13 OpenAiCompat,
14 Anthropic,
15}
16
17struct ProviderDef {
18 api_url: &'static str,
19 api_headers: &'static str,
20 default_model: &'static str,
21 format: RequestFormat,
22 response_path: &'static str,
23}
24
25fn get_provider(name: &str) -> Option<ProviderDef> {
27 match name {
28 "gemini" => Some(ProviderDef {
29 api_url: "https://generativelanguage.googleapis.com/v1beta/models/$ACR_MODEL:generateContent?key=$ACR_API_KEY",
30 api_headers: "",
31 default_model: "gemini-2.0-flash",
32 format: RequestFormat::Gemini,
33 response_path: "candidates.0.content.parts.0.text",
34 }),
35 "openai" => Some(ProviderDef {
36 api_url: "https://api.openai.com/v1/chat/completions",
37 api_headers: "Authorization: Bearer $ACR_API_KEY",
38 default_model: "gpt-4o-mini",
39 format: RequestFormat::OpenAiCompat,
40 response_path: "choices.0.message.content",
41 }),
42 "anthropic" => Some(ProviderDef {
43 api_url: "https://api.anthropic.com/v1/messages",
44 api_headers: "x-api-key: $ACR_API_KEY, anthropic-version: 2023-06-01",
45 default_model: "claude-sonnet-4-20250514",
46 format: RequestFormat::Anthropic,
47 response_path: "content.0.text",
48 }),
49 "groq" => Some(ProviderDef {
50 api_url: "https://api.groq.com/openai/v1/chat/completions",
51 api_headers: "Authorization: Bearer $ACR_API_KEY",
52 default_model: "llama-3.3-70b-versatile",
53 format: RequestFormat::OpenAiCompat,
54 response_path: "choices.0.message.content",
55 }),
56 "grok" => Some(ProviderDef {
57 api_url: "https://api.x.ai/v1/chat/completions",
58 api_headers: "Authorization: Bearer $ACR_API_KEY",
59 default_model: "grok-3",
60 format: RequestFormat::OpenAiCompat,
61 response_path: "choices.0.message.content",
62 }),
63 "deepseek" => Some(ProviderDef {
64 api_url: "https://api.deepseek.com/v1/chat/completions",
65 api_headers: "Authorization: Bearer $ACR_API_KEY",
66 default_model: "deepseek-chat",
67 format: RequestFormat::OpenAiCompat,
68 response_path: "choices.0.message.content",
69 }),
70 "openrouter" => Some(ProviderDef {
71 api_url: "https://openrouter.ai/api/v1/chat/completions",
72 api_headers: "Authorization: Bearer $ACR_API_KEY",
73 default_model: "openai/gpt-4o-mini",
74 format: RequestFormat::OpenAiCompat,
75 response_path: "choices.0.message.content",
76 }),
77 "mistral" => Some(ProviderDef {
78 api_url: "https://api.mistral.ai/v1/chat/completions",
79 api_headers: "Authorization: Bearer $ACR_API_KEY",
80 default_model: "mistral-small-latest",
81 format: RequestFormat::OpenAiCompat,
82 response_path: "choices.0.message.content",
83 }),
84 "together" => Some(ProviderDef {
85 api_url: "https://api.together.xyz/v1/chat/completions",
86 api_headers: "Authorization: Bearer $ACR_API_KEY",
87 default_model: "meta-llama/Llama-3.3-70B-Instruct-Turbo",
88 format: RequestFormat::OpenAiCompat,
89 response_path: "choices.0.message.content",
90 }),
91 "fireworks" => Some(ProviderDef {
92 api_url: "https://api.fireworks.ai/inference/v1/chat/completions",
93 api_headers: "Authorization: Bearer $ACR_API_KEY",
94 default_model: "accounts/fireworks/models/llama-v3p3-70b-instruct",
95 format: RequestFormat::OpenAiCompat,
96 response_path: "choices.0.message.content",
97 }),
98 "perplexity" => Some(ProviderDef {
99 api_url: "https://api.perplexity.ai/chat/completions",
100 api_headers: "Authorization: Bearer $ACR_API_KEY",
101 default_model: "sonar",
102 format: RequestFormat::OpenAiCompat,
103 response_path: "choices.0.message.content",
104 }),
105 _ => None,
106 }
107}
108
109pub fn default_model_for(provider: &str) -> &'static str {
111 get_provider(provider).map_or("", |p| p.default_model)
112}
113
114pub enum LlmCallError {
115 HttpError { code: u16, body: String },
116 TransportError(String),
117 Other(anyhow::Error),
118}
119
120impl std::fmt::Display for LlmCallError {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 match self {
123 LlmCallError::HttpError { code, body } => {
124 write!(f, "API returned HTTP {code}: {body}")
125 }
126 LlmCallError::TransportError(msg) => write!(f, "Network error: {msg}"),
127 LlmCallError::Other(e) => write!(f, "{e}"),
128 }
129 }
130}
131
132fn call_llm_inner(cfg: &AppConfig, system_prompt: &str, diff: &str) -> Result<String, LlmCallError> {
133 let (url, headers_raw, format, response_path) =
134 resolve_provider(cfg).map_err(LlmCallError::Other)?;
135
136 let url = interpolate(&url, cfg);
137 let headers_raw = interpolate(&headers_raw, cfg);
138
139 let body = build_request_body(format, &cfg.model, system_prompt, diff);
140 let headers = parse_headers(&headers_raw);
141
142 let spinner = ProgressBar::new_spinner();
143 spinner.set_style(
144 ProgressStyle::default_spinner()
145 .template("{spinner:.cyan} {msg} {elapsed}")
146 .unwrap(),
147 );
148 spinner.set_message("Generating commit message...");
149 spinner.enable_steady_tick(Duration::from_millis(80));
150
151 let mut req = ureq::post(&url);
152 for (key, val) in &headers {
153 req = req.set(key, val);
154 }
155 req = req.set("Content-Type", "application/json");
156
157 let response = req.send_json(&body);
158
159 spinner.finish_and_clear();
160
161 let response = match response {
162 Ok(resp) => resp,
163 Err(ureq::Error::Status(code, resp)) => {
164 let body = resp.into_string().unwrap_or_default();
165 return Err(LlmCallError::HttpError { code, body });
166 }
167 Err(ureq::Error::Transport(t)) => {
168 return Err(LlmCallError::TransportError(t.to_string()));
169 }
170 };
171
172 let json: Value = response
173 .into_json()
174 .map_err(|e| LlmCallError::Other(anyhow::anyhow!("Failed to parse API response as JSON: {e}")))?;
175
176 let message = extract_by_path(&json, &response_path).map_err(|e| {
177 LlmCallError::Other(anyhow::anyhow!(
178 "Failed to extract message from response at path '{}'. Response:\n{}\nError: {}",
179 response_path,
180 serde_json::to_string_pretty(&json).unwrap_or_default(),
181 e
182 ))
183 })?;
184
185 Ok(message)
186}
187
188pub fn call_llm_with_fallback(
190 cfg: &AppConfig,
191 system_prompt: &str,
192 diff: &str,
193) -> Result<(String, Option<String>)> {
194 match call_llm_inner(cfg, system_prompt, diff) {
195 Ok(msg) => Ok((msg, None)),
196 Err(LlmCallError::TransportError(msg)) => {
197 anyhow::bail!("Network error: {msg}");
198 }
199 Err(LlmCallError::HttpError { code, body }) => {
200 if !cfg.fallback_enabled {
201 anyhow::bail!("API returned HTTP {code}: {body}");
202 }
203
204 let presets_file = match crate::preset::load_presets() {
205 Ok(f) => f,
206 Err(_) => anyhow::bail!("API returned HTTP {code}: {body}"),
207 };
208
209 if presets_file.fallback.order.is_empty() {
210 anyhow::bail!("API returned HTTP {code}: {body}");
211 }
212
213 let current_fields = crate::preset::fields_from_config(cfg);
214 let mut errors = vec![format!("Primary (HTTP {code})")];
215
216 for &preset_id in &presets_file.fallback.order {
217 let preset = match presets_file.presets.iter().find(|p| p.id == preset_id) {
218 Some(p) => p,
219 None => continue,
220 };
221
222 if preset.fields.provider == current_fields.provider
224 && preset.fields.model == current_fields.model
225 && preset.fields.api_key == current_fields.api_key
226 && preset.fields.api_url == current_fields.api_url
227 {
228 continue;
229 }
230
231 eprintln!(
232 "{} Primary failed (HTTP {}), trying: {}...",
233 "fallback:".yellow().bold(),
234 code,
235 preset.name
236 );
237
238 let mut temp_cfg = cfg.clone();
239 crate::preset::apply_preset_to_config(&mut temp_cfg, preset);
240
241 match call_llm_inner(&temp_cfg, system_prompt, diff) {
242 Ok(msg) => return Ok((msg, Some(preset.name.clone()))),
243 Err(LlmCallError::HttpError { code: fc, .. }) => {
244 errors.push(format!("{} (HTTP {fc})", preset.name));
245 continue;
246 }
247 Err(LlmCallError::TransportError(msg)) => {
248 anyhow::bail!("Network error during fallback to '{}': {msg}", preset.name);
249 }
250 Err(LlmCallError::Other(e)) => {
251 errors.push(format!("{} ({})", preset.name, e));
252 continue;
253 }
254 }
255 }
256
257 anyhow::bail!(
258 "All LLM providers failed: {}",
259 errors.join(", ")
260 );
261 }
262 Err(LlmCallError::Other(e)) => {
263 anyhow::bail!("{e}");
264 }
265 }
266}
267
268pub fn call_llm(cfg: &AppConfig, system_prompt: &str, diff: &str) -> Result<String> {
270 let (msg, _) = call_llm_with_fallback(cfg, system_prompt, diff)?;
271 Ok(msg)
272}
273
274fn resolve_provider(cfg: &AppConfig) -> Result<(String, String, RequestFormat, String)> {
275 if let Some(def) = get_provider(&cfg.provider) {
276 let url = if cfg.api_url.is_empty() {
277 def.api_url.to_string()
278 } else {
279 cfg.api_url.clone()
280 };
281 let headers = if cfg.api_headers.is_empty() {
282 def.api_headers.to_string()
283 } else {
284 cfg.api_headers.clone()
285 };
286 Ok((url, headers, def.format, def.response_path.to_string()))
287 } else {
288 if cfg.api_url.is_empty() {
290 bail!(
291 "Unknown provider '{}'. Set {} for custom providers.",
292 cfg.provider.yellow(),
293 "ACR_API_URL".yellow()
294 );
295 }
296 Ok((
297 cfg.api_url.clone(),
298 cfg.api_headers.clone(),
299 RequestFormat::OpenAiCompat,
300 "choices.0.message.content".to_string(),
301 ))
302 }
303}
304
305fn build_request_body(
306 format: RequestFormat,
307 model: &str,
308 system_prompt: &str,
309 diff: &str,
310) -> Value {
311 match format {
312 RequestFormat::Gemini => {
313 serde_json::json!({
314 "system_instruction": {
315 "parts": [{ "text": system_prompt }]
316 },
317 "contents": [{
318 "role": "user",
319 "parts": [{ "text": diff }]
320 }],
321 "generationConfig": {
322 "temperature": 0
323 }
324 })
325 }
326 RequestFormat::OpenAiCompat => {
327 serde_json::json!({
328 "model": model,
329 "messages": [
330 { "role": "system", "content": system_prompt },
331 { "role": "user", "content": diff }
332 ],
333 "max_tokens": 512,
334 "temperature": 0
335 })
336 }
337 RequestFormat::Anthropic => {
338 serde_json::json!({
339 "model": model,
340 "system": system_prompt,
341 "messages": [
342 { "role": "user", "content": diff }
343 ],
344 "max_tokens": 512
345 })
346 }
347 }
348}
349
350fn parse_headers(raw: &str) -> Vec<(String, String)> {
352 if raw.trim().is_empty() {
353 return Vec::new();
354 }
355 raw.split(',')
356 .filter_map(|pair| {
357 let pair = pair.trim();
358 pair.split_once(':')
359 .map(|(k, v)| (k.trim().to_string(), v.trim().to_string()))
360 })
361 .collect()
362}
363
364fn extract_by_path(value: &Value, path: &str) -> Result<String> {
366 let mut current = value;
367 for segment in path.split('.') {
368 current = if let Ok(index) = segment.parse::<usize>() {
369 current
370 .get(index)
371 .with_context(|| format!("Array index {index} out of bounds"))?
372 } else {
373 current
374 .get(segment)
375 .with_context(|| format!("Key '{segment}' not found"))?
376 };
377 }
378 current
379 .as_str()
380 .map(|s| s.to_string())
381 .with_context(|| "Expected string value at path end".to_string())
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387
388 #[test]
389 fn test_parse_headers_empty() {
390 assert!(parse_headers("").is_empty());
391 assert!(parse_headers(" ").is_empty());
392 }
393
394 #[test]
395 fn test_parse_headers_single() {
396 let headers = parse_headers("Authorization: Bearer abc123");
397 assert_eq!(headers.len(), 1);
398 assert_eq!(headers[0].0, "Authorization");
399 assert_eq!(headers[0].1, "Bearer abc123");
400 }
401
402 #[test]
403 fn test_parse_headers_multiple() {
404 let headers = parse_headers("X-Api-Key: key123, Content-Type: application/json");
405 assert_eq!(headers.len(), 2);
406 assert_eq!(headers[0].0, "X-Api-Key");
407 assert_eq!(headers[0].1, "key123");
408 assert_eq!(headers[1].0, "Content-Type");
409 assert_eq!(headers[1].1, "application/json");
410 }
411
412 #[test]
413 fn test_parse_headers_trims_whitespace() {
414 let headers = parse_headers(" Key : Value ");
415 assert_eq!(headers.len(), 1);
416 assert_eq!(headers[0].0, "Key");
417 assert_eq!(headers[0].1, "Value");
418 }
419
420 #[test]
421 fn test_parse_headers_skips_invalid() {
422 let headers = parse_headers("Valid: Header, InvalidNoColon, Another: One");
423 assert_eq!(headers.len(), 2);
424 assert_eq!(headers[0].0, "Valid");
425 assert_eq!(headers[1].0, "Another");
426 }
427
428 #[test]
429 fn test_extract_by_path_simple() {
430 let json = serde_json::json!({"message": "hello"});
431 let result = extract_by_path(&json, "message").unwrap();
432 assert_eq!(result, "hello");
433 }
434
435 #[test]
436 fn test_extract_by_path_nested() {
437 let json = serde_json::json!({"content": {"text": "nested"}});
438 let result = extract_by_path(&json, "content.text").unwrap();
439 assert_eq!(result, "nested");
440 }
441
442 #[test]
443 fn test_extract_by_path_array_index() {
444 let json = serde_json::json!({"items": ["first", "second"]});
445 let result = extract_by_path(&json, "items.0").unwrap();
446 assert_eq!(result, "first");
447 }
448
449 #[test]
450 fn test_extract_by_path_complex() {
451 let json = serde_json::json!({
452 "choices": [{"message": {"content": "generated"}}]
453 });
454 let result = extract_by_path(&json, "choices.0.message.content").unwrap();
455 assert_eq!(result, "generated");
456 }
457
458 #[test]
459 fn test_extract_by_path_gemini_format() {
460 let json = serde_json::json!({
461 "candidates": [{"content": {"parts": [{"text": "gemini response"}]}}]
462 });
463 let result = extract_by_path(&json, "candidates.0.content.parts.0.text").unwrap();
464 assert_eq!(result, "gemini response");
465 }
466
467 #[test]
468 fn test_extract_by_path_anthropic_format() {
469 let json = serde_json::json!({
470 "content": [{"text": "anthropic response"}]
471 });
472 let result = extract_by_path(&json, "content.0.text").unwrap();
473 assert_eq!(result, "anthropic response");
474 }
475
476 #[test]
477 fn test_extract_by_path_key_not_found() {
478 let json = serde_json::json!({"foo": "bar"});
479 let result = extract_by_path(&json, "missing");
480 assert!(result.is_err());
481 assert!(result.unwrap_err().to_string().contains("not found"));
482 }
483
484 #[test]
485 fn test_extract_by_path_index_out_of_bounds() {
486 let json = serde_json::json!({"items": ["only"]});
487 let result = extract_by_path(&json, "items.5");
488 assert!(result.is_err());
489 assert!(result.unwrap_err().to_string().contains("out of bounds"));
490 }
491
492 #[test]
493 fn test_extract_by_path_not_string() {
494 let json = serde_json::json!({"number": 42});
495 let result = extract_by_path(&json, "number");
496 assert!(result.is_err());
497 assert!(result.unwrap_err().to_string().contains("Expected string"));
498 }
499
500 #[test]
501 fn test_build_request_body_openai_compat() {
502 let body = build_request_body(
503 RequestFormat::OpenAiCompat,
504 "gpt-4o",
505 "system prompt",
506 "user diff",
507 );
508 assert_eq!(body["model"], "gpt-4o");
509 assert_eq!(body["messages"][0]["role"], "system");
510 assert_eq!(body["messages"][0]["content"], "system prompt");
511 assert_eq!(body["messages"][1]["role"], "user");
512 assert_eq!(body["messages"][1]["content"], "user diff");
513 assert_eq!(body["max_tokens"], 512);
514 assert_eq!(body["temperature"], 0);
515 }
516
517 #[test]
518 fn test_build_request_body_gemini() {
519 let body = build_request_body(
520 RequestFormat::Gemini,
521 "gemini-pro",
522 "system prompt",
523 "user diff",
524 );
525 assert_eq!(body["system_instruction"]["parts"][0]["text"], "system prompt");
526 assert_eq!(body["contents"][0]["role"], "user");
527 assert_eq!(body["contents"][0]["parts"][0]["text"], "user diff");
528 assert_eq!(body["generationConfig"]["temperature"], 0);
529 }
530
531 #[test]
532 fn test_build_request_body_anthropic() {
533 let body = build_request_body(
534 RequestFormat::Anthropic,
535 "claude-3-opus",
536 "system prompt",
537 "user diff",
538 );
539 assert_eq!(body["model"], "claude-3-opus");
540 assert_eq!(body["system"], "system prompt");
541 assert_eq!(body["messages"][0]["role"], "user");
542 assert_eq!(body["messages"][0]["content"], "user diff");
543 assert_eq!(body["max_tokens"], 512);
544 }
545
546 #[test]
547 fn test_get_provider_known() {
548 assert!(get_provider("gemini").is_some());
549 assert!(get_provider("openai").is_some());
550 assert!(get_provider("anthropic").is_some());
551 assert!(get_provider("groq").is_some());
552 assert!(get_provider("grok").is_some());
553 assert!(get_provider("deepseek").is_some());
554 assert!(get_provider("openrouter").is_some());
555 assert!(get_provider("mistral").is_some());
556 assert!(get_provider("together").is_some());
557 assert!(get_provider("fireworks").is_some());
558 assert!(get_provider("perplexity").is_some());
559 }
560
561 #[test]
562 fn test_get_provider_unknown() {
563 assert!(get_provider("unknown").is_none());
564 assert!(get_provider("custom").is_none());
565 }
566
567 #[test]
568 fn test_get_provider_gemini_format() {
569 let provider = get_provider("gemini").unwrap();
570 assert_eq!(provider.format, RequestFormat::Gemini);
571 assert!(provider.api_url.contains("generativelanguage.googleapis.com"));
572 assert_eq!(provider.default_model, "gemini-2.0-flash");
573 }
574
575 #[test]
576 fn test_get_provider_anthropic_format() {
577 let provider = get_provider("anthropic").unwrap();
578 assert_eq!(provider.format, RequestFormat::Anthropic);
579 assert!(provider.api_url.contains("anthropic.com"));
580 assert!(provider.api_headers.contains("anthropic-version"));
581 }
582
583 #[test]
584 fn test_get_provider_openai_compat() {
585 for name in &["openai", "groq", "grok", "deepseek", "openrouter", "mistral", "together", "fireworks", "perplexity"] {
586 let provider = get_provider(name).unwrap();
587 assert_eq!(provider.format, RequestFormat::OpenAiCompat, "Provider {name} should use OpenAiCompat format");
588 }
589 }
590
591 #[test]
592 fn test_default_model_for_known() {
593 assert_eq!(default_model_for("groq"), "llama-3.3-70b-versatile");
594 assert_eq!(default_model_for("openai"), "gpt-4o-mini");
595 assert_eq!(default_model_for("anthropic"), "claude-sonnet-4-20250514");
596 }
597
598 #[test]
599 fn test_default_model_for_unknown() {
600 assert_eq!(default_model_for("custom"), "");
601 assert_eq!(default_model_for("unknown"), "");
602 }
603
604 #[test]
605 fn test_resolve_provider_known() {
606 let cfg = AppConfig {
607 provider: "groq".into(),
608 api_key: "test-key".into(),
609 ..Default::default()
610 };
611 let (url, headers, format, path) = resolve_provider(&cfg).unwrap();
612 assert!(url.contains("groq.com"));
613 assert!(headers.contains("Bearer"));
614 assert_eq!(format, RequestFormat::OpenAiCompat);
615 assert_eq!(path, "choices.0.message.content");
616 }
617
618 #[test]
619 fn test_resolve_provider_known_with_override() {
620 let cfg = AppConfig {
621 provider: "groq".into(),
622 api_url: "https://custom.url/v1".into(),
623 api_headers: "X-Custom: value".into(),
624 ..Default::default()
625 };
626 let (url, headers, _, _) = resolve_provider(&cfg).unwrap();
627 assert_eq!(url, "https://custom.url/v1");
628 assert_eq!(headers, "X-Custom: value");
629 }
630
631 #[test]
632 fn test_resolve_provider_custom_requires_url() {
633 let cfg = AppConfig {
634 provider: "custom-provider".into(),
635 api_url: "".into(),
636 ..Default::default()
637 };
638 let result = resolve_provider(&cfg);
639 assert!(result.is_err());
640 assert!(result.unwrap_err().to_string().contains("Unknown provider"));
641 }
642
643 #[test]
644 fn test_resolve_provider_custom_with_url() {
645 let cfg = AppConfig {
646 provider: "custom-provider".into(),
647 api_url: "https://my-custom-api.com/v1".into(),
648 api_headers: "Authorization: custom".into(),
649 ..Default::default()
650 };
651 let (url, headers, format, path) = resolve_provider(&cfg).unwrap();
652 assert_eq!(url, "https://my-custom-api.com/v1");
653 assert_eq!(headers, "Authorization: custom");
654 assert_eq!(format, RequestFormat::OpenAiCompat);
655 assert_eq!(path, "choices.0.message.content");
656 }
657
658 #[test]
659 fn test_llm_call_error_display_http() {
660 let err = LlmCallError::HttpError {
661 code: 401,
662 body: "Unauthorized".into(),
663 };
664 let display = format!("{err}");
665 assert!(display.contains("HTTP 401"));
666 assert!(display.contains("Unauthorized"));
667 }
668
669 #[test]
670 fn test_llm_call_error_display_transport() {
671 let err = LlmCallError::TransportError("connection refused".into());
672 let display = format!("{err}");
673 assert!(display.contains("Network error"));
674 assert!(display.contains("connection refused"));
675 }
676
677 #[test]
678 fn test_llm_call_error_display_other() {
679 let err = LlmCallError::Other(anyhow::anyhow!("custom error"));
680 let display = format!("{err}");
681 assert!(display.contains("custom error"));
682 }
683
684 #[test]
685 fn test_request_format_equality() {
686 assert_eq!(RequestFormat::Gemini, RequestFormat::Gemini);
687 assert_eq!(RequestFormat::OpenAiCompat, RequestFormat::OpenAiCompat);
688 assert_eq!(RequestFormat::Anthropic, RequestFormat::Anthropic);
689 assert_ne!(RequestFormat::Gemini, RequestFormat::OpenAiCompat);
690 }
691}