1use anyhow::{bail, Context, Result};
2use colored::Colorize;
3use indicatif::{ProgressBar, ProgressStyle};
4use serde_json::Value;
5use std::time::Duration;
6
7use crate::config::AppConfig;
8use crate::interpolation::interpolate;
9
10#[derive(Debug, Clone, Copy, PartialEq)]
11enum RequestFormat {
12 Gemini,
13 OpenAiCompat,
14 Anthropic,
15}
16
17struct ProviderDef {
18 api_url: &'static str,
19 api_headers: &'static str,
20 default_model: &'static str,
21 format: RequestFormat,
22 response_path: &'static str,
23}
24
25fn get_provider(name: &str) -> Option<ProviderDef> {
27 match name {
28 "gemini" => Some(ProviderDef {
29 api_url: "https://generativelanguage.googleapis.com/v1beta/models/$ACR_MODEL:generateContent?key=$ACR_API_KEY",
30 api_headers: "",
31 default_model: "gemini-2.0-flash",
32 format: RequestFormat::Gemini,
33 response_path: "candidates.0.content.parts.0.text",
34 }),
35 "openai" => Some(ProviderDef {
36 api_url: "https://api.openai.com/v1/chat/completions",
37 api_headers: "Authorization: Bearer $ACR_API_KEY",
38 default_model: "gpt-4o-mini",
39 format: RequestFormat::OpenAiCompat,
40 response_path: "choices.0.message.content",
41 }),
42 "anthropic" => Some(ProviderDef {
43 api_url: "https://api.anthropic.com/v1/messages",
44 api_headers: "x-api-key: $ACR_API_KEY, anthropic-version: 2023-06-01",
45 default_model: "claude-sonnet-4-20250514",
46 format: RequestFormat::Anthropic,
47 response_path: "content.0.text",
48 }),
49 "groq" => Some(ProviderDef {
50 api_url: "https://api.groq.com/openai/v1/chat/completions",
51 api_headers: "Authorization: Bearer $ACR_API_KEY",
52 default_model: "llama-3.3-70b-versatile",
53 format: RequestFormat::OpenAiCompat,
54 response_path: "choices.0.message.content",
55 }),
56 _ => None,
57 }
58}
59
60pub fn default_model_for(provider: &str) -> &'static str {
62 get_provider(provider).map_or("", |p| p.default_model)
63}
64
65pub fn call_llm(cfg: &AppConfig, system_prompt: &str, diff: &str) -> Result<String> {
67 let (url, headers_raw, format, response_path) = resolve_provider(cfg)?;
68
69 let url = interpolate(&url, cfg);
70 let headers_raw = interpolate(&headers_raw, cfg);
71
72 let body = build_request_body(format, &cfg.model, system_prompt, diff);
73
74 let headers = parse_headers(&headers_raw);
75
76 let spinner = ProgressBar::new_spinner();
78 spinner.set_style(
79 ProgressStyle::default_spinner()
80 .template("{spinner:.cyan} {msg} {elapsed}")
81 .unwrap(),
82 );
83 spinner.set_message("Generating commit message...");
84 spinner.enable_steady_tick(Duration::from_millis(80));
85
86 let mut req = ureq::post(&url);
88 for (key, val) in &headers {
89 req = req.set(key, val);
90 }
91 req = req.set("Content-Type", "application/json");
92
93 let response = req.send_json(&body);
94
95 spinner.finish_and_clear();
96
97 let response = response.map_err(|e| match e {
98 ureq::Error::Status(code, resp) => {
99 let body = resp.into_string().unwrap_or_default();
100 anyhow::anyhow!("API returned HTTP {code}: {body}")
101 }
102 ureq::Error::Transport(t) => {
103 anyhow::anyhow!("Network error: {t}")
104 }
105 })?;
106
107 let json: Value = response
108 .into_json()
109 .context("Failed to parse API response as JSON")?;
110
111 let message = extract_by_path(&json, &response_path).with_context(|| {
112 format!(
113 "Failed to extract message from response at path '{}'. Response:\n{}",
114 response_path,
115 serde_json::to_string_pretty(&json).unwrap_or_default()
116 )
117 })?;
118
119 Ok(message)
120}
121
122fn resolve_provider(cfg: &AppConfig) -> Result<(String, String, RequestFormat, String)> {
123 if let Some(def) = get_provider(&cfg.provider) {
124 let url = if cfg.api_url.is_empty() {
125 def.api_url.to_string()
126 } else {
127 cfg.api_url.clone()
128 };
129 let headers = if cfg.api_headers.is_empty() {
130 def.api_headers.to_string()
131 } else {
132 cfg.api_headers.clone()
133 };
134 Ok((url, headers, def.format, def.response_path.to_string()))
135 } else {
136 if cfg.api_url.is_empty() {
138 bail!(
139 "Unknown provider '{}'. Set {} for custom providers.",
140 cfg.provider.yellow(),
141 "ACR_API_URL".yellow()
142 );
143 }
144 Ok((
145 cfg.api_url.clone(),
146 cfg.api_headers.clone(),
147 RequestFormat::OpenAiCompat,
148 "choices.0.message.content".to_string(),
149 ))
150 }
151}
152
153fn build_request_body(
154 format: RequestFormat,
155 model: &str,
156 system_prompt: &str,
157 diff: &str,
158) -> Value {
159 match format {
160 RequestFormat::Gemini => {
161 serde_json::json!({
162 "system_instruction": {
163 "parts": [{ "text": system_prompt }]
164 },
165 "contents": [{
166 "role": "user",
167 "parts": [{ "text": diff }]
168 }],
169 "generationConfig": {
170 "temperature": 0
171 }
172 })
173 }
174 RequestFormat::OpenAiCompat => {
175 serde_json::json!({
176 "model": model,
177 "messages": [
178 { "role": "system", "content": system_prompt },
179 { "role": "user", "content": diff }
180 ],
181 "max_tokens": 512,
182 "temperature": 0
183 })
184 }
185 RequestFormat::Anthropic => {
186 serde_json::json!({
187 "model": model,
188 "system": system_prompt,
189 "messages": [
190 { "role": "user", "content": diff }
191 ],
192 "max_tokens": 512
193 })
194 }
195 }
196}
197
198fn parse_headers(raw: &str) -> Vec<(String, String)> {
200 if raw.trim().is_empty() {
201 return Vec::new();
202 }
203 raw.split(',')
204 .filter_map(|pair| {
205 let pair = pair.trim();
206 pair.split_once(':')
207 .map(|(k, v)| (k.trim().to_string(), v.trim().to_string()))
208 })
209 .collect()
210}
211
212fn extract_by_path(value: &Value, path: &str) -> Result<String> {
214 let mut current = value;
215 for segment in path.split('.') {
216 current = if let Ok(index) = segment.parse::<usize>() {
217 current
218 .get(index)
219 .with_context(|| format!("Array index {index} out of bounds"))?
220 } else {
221 current
222 .get(segment)
223 .with_context(|| format!("Key '{segment}' not found"))?
224 };
225 }
226 current
227 .as_str()
228 .map(|s| s.to_string())
229 .with_context(|| "Expected string value at path end".to_string())
230}