Skip to main content

auto_commit_rs/
provider.rs

1use anyhow::{bail, Context, Result};
2use colored::Colorize;
3use indicatif::{ProgressBar, ProgressStyle};
4use serde_json::Value;
5use std::time::Duration;
6
7use crate::config::AppConfig;
8use crate::interpolation::interpolate;
9
10#[derive(Debug, Clone, Copy, PartialEq)]
11enum RequestFormat {
12    Gemini,
13    OpenAiCompat,
14    Anthropic,
15}
16
17struct ProviderDef {
18    api_url: &'static str,
19    api_headers: &'static str,
20    default_model: &'static str,
21    format: RequestFormat,
22    response_path: &'static str,
23}
24
25/// Built-in provider definitions
26fn get_provider(name: &str) -> Option<ProviderDef> {
27    match name {
28        "gemini" => Some(ProviderDef {
29            api_url: "https://generativelanguage.googleapis.com/v1beta/models/$ACR_MODEL:generateContent?key=$ACR_API_KEY",
30            api_headers: "",
31            default_model: "gemini-2.0-flash",
32            format: RequestFormat::Gemini,
33            response_path: "candidates.0.content.parts.0.text",
34        }),
35        "openai" => Some(ProviderDef {
36            api_url: "https://api.openai.com/v1/chat/completions",
37            api_headers: "Authorization: Bearer $ACR_API_KEY",
38            default_model: "gpt-4o-mini",
39            format: RequestFormat::OpenAiCompat,
40            response_path: "choices.0.message.content",
41        }),
42        "anthropic" => Some(ProviderDef {
43            api_url: "https://api.anthropic.com/v1/messages",
44            api_headers: "x-api-key: $ACR_API_KEY, anthropic-version: 2023-06-01",
45            default_model: "claude-sonnet-4-20250514",
46            format: RequestFormat::Anthropic,
47            response_path: "content.0.text",
48        }),
49        "groq" => Some(ProviderDef {
50            api_url: "https://api.groq.com/openai/v1/chat/completions",
51            api_headers: "Authorization: Bearer $ACR_API_KEY",
52            default_model: "llama-3.3-70b-versatile",
53            format: RequestFormat::OpenAiCompat,
54            response_path: "choices.0.message.content",
55        }),
56        _ => None,
57    }
58}
59
60/// Get the default model for a built-in provider, or empty string for unknown providers.
61pub fn default_model_for(provider: &str) -> &'static str {
62    get_provider(provider).map_or("", |p| p.default_model)
63}
64
65/// Call the LLM API and return the generated commit message
66pub fn call_llm(cfg: &AppConfig, system_prompt: &str, diff: &str) -> Result<String> {
67    let (url, headers_raw, format, response_path) = resolve_provider(cfg)?;
68
69    let url = interpolate(&url, cfg);
70    let headers_raw = interpolate(&headers_raw, cfg);
71
72    let body = build_request_body(format, &cfg.model, system_prompt, diff);
73
74    let headers = parse_headers(&headers_raw);
75
76    // Spinner
77    let spinner = ProgressBar::new_spinner();
78    spinner.set_style(
79        ProgressStyle::default_spinner()
80            .template("{spinner:.cyan} {msg} {elapsed}")
81            .unwrap(),
82    );
83    spinner.set_message("Generating commit message...");
84    spinner.enable_steady_tick(Duration::from_millis(80));
85
86    // HTTP request
87    let mut req = ureq::post(&url);
88    for (key, val) in &headers {
89        req = req.set(key, val);
90    }
91    req = req.set("Content-Type", "application/json");
92
93    let response = req.send_json(&body);
94
95    spinner.finish_and_clear();
96
97    let response = response.map_err(|e| match e {
98        ureq::Error::Status(code, resp) => {
99            let body = resp.into_string().unwrap_or_default();
100            anyhow::anyhow!("API returned HTTP {code}: {body}")
101        }
102        ureq::Error::Transport(t) => {
103            anyhow::anyhow!("Network error: {t}")
104        }
105    })?;
106
107    let json: Value = response
108        .into_json()
109        .context("Failed to parse API response as JSON")?;
110
111    let message = extract_by_path(&json, &response_path).with_context(|| {
112        format!(
113            "Failed to extract message from response at path '{}'. Response:\n{}",
114            response_path,
115            serde_json::to_string_pretty(&json).unwrap_or_default()
116        )
117    })?;
118
119    Ok(message)
120}
121
122fn resolve_provider(cfg: &AppConfig) -> Result<(String, String, RequestFormat, String)> {
123    if let Some(def) = get_provider(&cfg.provider) {
124        let url = if cfg.api_url.is_empty() {
125            def.api_url.to_string()
126        } else {
127            cfg.api_url.clone()
128        };
129        let headers = if cfg.api_headers.is_empty() {
130            def.api_headers.to_string()
131        } else {
132            cfg.api_headers.clone()
133        };
134        Ok((url, headers, def.format, def.response_path.to_string()))
135    } else {
136        // Custom provider: require API URL, default to OpenAI-compatible format
137        if cfg.api_url.is_empty() {
138            bail!(
139                "Unknown provider '{}'. Set {} for custom providers.",
140                cfg.provider.yellow(),
141                "ACR_API_URL".yellow()
142            );
143        }
144        Ok((
145            cfg.api_url.clone(),
146            cfg.api_headers.clone(),
147            RequestFormat::OpenAiCompat,
148            "choices.0.message.content".to_string(),
149        ))
150    }
151}
152
153fn build_request_body(
154    format: RequestFormat,
155    model: &str,
156    system_prompt: &str,
157    diff: &str,
158) -> Value {
159    match format {
160        RequestFormat::Gemini => {
161            serde_json::json!({
162                "system_instruction": {
163                    "parts": [{ "text": system_prompt }]
164                },
165                "contents": [{
166                    "role": "user",
167                    "parts": [{ "text": diff }]
168                }],
169                "generationConfig": {
170                    "temperature": 0
171                }
172            })
173        }
174        RequestFormat::OpenAiCompat => {
175            serde_json::json!({
176                "model": model,
177                "messages": [
178                    { "role": "system", "content": system_prompt },
179                    { "role": "user", "content": diff }
180                ],
181                "max_tokens": 512,
182                "temperature": 0
183            })
184        }
185        RequestFormat::Anthropic => {
186            serde_json::json!({
187                "model": model,
188                "system": system_prompt,
189                "messages": [
190                    { "role": "user", "content": diff }
191                ],
192                "max_tokens": 512
193            })
194        }
195    }
196}
197
198/// Parse "Key: Value, Key2: Value2" header string into pairs
199fn parse_headers(raw: &str) -> Vec<(String, String)> {
200    if raw.trim().is_empty() {
201        return Vec::new();
202    }
203    raw.split(',')
204        .filter_map(|pair| {
205            let pair = pair.trim();
206            pair.split_once(':')
207                .map(|(k, v)| (k.trim().to_string(), v.trim().to_string()))
208        })
209        .collect()
210}
211
212/// Walk a JSON value by a dot-separated path like "candidates.0.content.parts.0.text"
213fn extract_by_path(value: &Value, path: &str) -> Result<String> {
214    let mut current = value;
215    for segment in path.split('.') {
216        current = if let Ok(index) = segment.parse::<usize>() {
217            current
218                .get(index)
219                .with_context(|| format!("Array index {index} out of bounds"))?
220        } else {
221            current
222                .get(segment)
223                .with_context(|| format!("Key '{segment}' not found"))?
224        };
225    }
226    current
227        .as_str()
228        .map(|s| s.to_string())
229        .with_context(|| "Expected string value at path end".to_string())
230}