Skip to main content

chant/
provider.rs

1//! Model provider abstraction for invoking AI agents.
2//!
3//! Supports multiple providers (Claude, Ollama, OpenAI).
4//!
5//! # Doc Audit
6//! - audited: 2026-01-25
7//! - docs: architecture/invoke.md
8//! - ignore: false
9
10use anyhow::{anyhow, Context, Result};
11use serde::Deserialize;
12use std::io::BufRead;
13use std::process::{Command, Stdio};
14use ureq::Agent;
15
16/// Model provider type
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize)]
18#[serde(rename_all = "lowercase")]
19pub enum ProviderType {
20    #[default]
21    Claude,
22    Ollama,
23    Openai,
24}
25
26/// Provider configuration
27#[derive(Debug, Clone, Default, Deserialize)]
28pub struct ProviderConfig {
29    #[serde(default)]
30    pub ollama: Option<OllamaConfig>,
31    #[serde(default)]
32    pub openai: Option<OpenaiConfig>,
33}
34
35#[derive(Debug, Clone, Deserialize)]
36pub struct OllamaConfig {
37    #[serde(default = "default_ollama_endpoint")]
38    pub endpoint: String,
39    /// Maximum number of retry attempts for throttled requests
40    #[serde(default = "default_max_retries")]
41    pub max_retries: u32,
42    /// Initial delay in milliseconds before first retry
43    #[serde(default = "default_retry_delay_ms")]
44    pub retry_delay_ms: u64,
45}
46
47fn default_ollama_endpoint() -> String {
48    "http://localhost:11434/v1".to_string()
49}
50
51fn default_max_retries() -> u32 {
52    3
53}
54
55fn default_retry_delay_ms() -> u64 {
56    1000 // 1 second
57}
58
59#[derive(Debug, Clone, Deserialize)]
60pub struct OpenaiConfig {
61    #[serde(default = "default_openai_endpoint")]
62    pub endpoint: String,
63    /// Maximum number of retry attempts for throttled requests
64    #[serde(default = "default_max_retries")]
65    pub max_retries: u32,
66    /// Initial delay in milliseconds before first retry
67    #[serde(default = "default_retry_delay_ms")]
68    pub retry_delay_ms: u64,
69}
70
71fn default_openai_endpoint() -> String {
72    "https://api.openai.com/v1".to_string()
73}
74
75/// Trait for model providers
76pub trait ModelProvider {
77    fn invoke(
78        &self,
79        message: &str,
80        model: &str,
81        callback: &mut dyn FnMut(&str) -> Result<()>,
82    ) -> Result<String>;
83
84    /// Returns the provider name. Part of the trait API, used in tests.
85    fn name(&self) -> &'static str;
86}
87
88/// Claude CLI provider (existing behavior)
89pub struct ClaudeCliProvider;
90
91impl ModelProvider for ClaudeCliProvider {
92    fn invoke(
93        &self,
94        message: &str,
95        model: &str,
96        callback: &mut dyn FnMut(&str) -> Result<()>,
97    ) -> Result<String> {
98        let mut cmd = Command::new("claude");
99        cmd.arg("--print")
100            .arg("--output-format")
101            .arg("stream-json")
102            .arg("--verbose")
103            .arg("--model")
104            .arg(model)
105            .arg("--dangerously-skip-permissions")
106            .arg(message)
107            .stdout(Stdio::piped())
108            .stderr(Stdio::piped());
109
110        let mut child = cmd
111            .spawn()
112            .context("Failed to invoke claude CLI. Is it installed and in PATH?")?;
113
114        let mut captured_output = String::new();
115        if let Some(stdout) = child.stdout.take() {
116            let reader = std::io::BufReader::new(stdout);
117            for line in reader.lines().map_while(Result::ok) {
118                for text in extract_text_from_stream_json(&line) {
119                    for text_line in text.lines() {
120                        callback(text_line)?;
121                        captured_output.push_str(text_line);
122                        captured_output.push('\n');
123                    }
124                }
125            }
126        }
127
128        let status = child.wait()?;
129        if !status.success() {
130            anyhow::bail!("Agent exited with status: {}", status);
131        }
132
133        Ok(captured_output)
134    }
135
136    fn name(&self) -> &'static str {
137        "claude"
138    }
139}
140
141/// Ollama provider (OpenAI-compatible API with agent runtime)
142pub struct OllamaProvider {
143    pub endpoint: String,
144    pub max_retries: u32,
145    pub retry_delay_ms: u64,
146}
147
148impl ModelProvider for OllamaProvider {
149    fn invoke(
150        &self,
151        message: &str,
152        model: &str,
153        callback: &mut dyn FnMut(&str) -> Result<()>,
154    ) -> Result<String> {
155        // Validate endpoint URL
156        if !self.endpoint.starts_with("http://") && !self.endpoint.starts_with("https://") {
157            return Err(anyhow!("Invalid endpoint URL: {}", self.endpoint));
158        }
159
160        crate::agent::run_agent_with_retries(
161            &self.endpoint,
162            model,
163            "",
164            message,
165            callback,
166            self.max_retries,
167            self.retry_delay_ms,
168        )
169        .map_err(|e| {
170            let err_str = e.to_string();
171            if err_str.contains("Connection") || err_str.contains("connect") {
172                anyhow!("Failed to connect to Ollama at {}\n\nOllama does not appear to be running. To fix:\n\n  1. Install Ollama: https://ollama.ai/download\n  2. Start Ollama: ollama serve\n  3. Pull a model: ollama pull {}\n\nOr switch to Claude CLI by removing 'provider: ollama' from .chant/config.md", self.endpoint, model)
173            } else {
174                e
175            }
176        })
177    }
178
179    fn name(&self) -> &'static str {
180        "ollama"
181    }
182}
183
184/// OpenAI provider
185pub struct OpenaiProvider {
186    pub endpoint: String,
187    pub api_key: Option<String>,
188    pub max_retries: u32,
189    pub retry_delay_ms: u64,
190}
191
192impl ModelProvider for OpenaiProvider {
193    fn invoke(
194        &self,
195        message: &str,
196        model: &str,
197        callback: &mut dyn FnMut(&str) -> Result<()>,
198    ) -> Result<String> {
199        let api_key = self
200            .api_key
201            .clone()
202            .or_else(|| std::env::var("OPENAI_API_KEY").ok())
203            .ok_or_else(|| anyhow!("OPENAI_API_KEY environment variable not set"))?;
204
205        let url = format!("{}/chat/completions", self.endpoint);
206
207        // Validate endpoint URL
208        if !self.endpoint.starts_with("http://") && !self.endpoint.starts_with("https://") {
209            return Err(anyhow!("Invalid endpoint URL: {}", self.endpoint));
210        }
211
212        let request_body = serde_json::json!({
213            "model": model,
214            "messages": [
215                {
216                    "role": "user",
217                    "content": message
218                }
219            ],
220            "stream": true,
221        });
222
223        // Retry loop with exponential backoff
224        let mut attempt = 0;
225        loop {
226            attempt += 1;
227
228            // Create HTTP agent and send request
229            let agent = Agent::new();
230            let response = agent
231                .post(&url)
232                .set("Content-Type", "application/json")
233                .set("Authorization", &format!("Bearer {}", api_key))
234                .send_json(&request_body)
235                .map_err(|e| anyhow!("HTTP request failed: {}", e))?;
236
237            let status = response.status();
238
239            // Check response status
240            if status == 401 {
241                return Err(anyhow!(
242                    "Authentication failed. Check OPENAI_API_KEY env var"
243                ));
244            }
245
246            // Check for throttle/error conditions (429 or 400+ errors)
247            let is_retryable =
248                status == 429 || status == 500 || status == 502 || status == 503 || status == 504;
249
250            if status == 200 {
251                // Success - process response
252                return self.process_response(response, callback);
253            } else if is_retryable && attempt <= self.max_retries {
254                // Retryable error - wait and retry
255                let delay_ms = self.calculate_backoff(attempt);
256                callback(&format!(
257                    "[Retry {}] HTTP {} - waiting {}ms before retry",
258                    attempt, status, delay_ms
259                ))?;
260                std::thread::sleep(std::time::Duration::from_millis(delay_ms));
261                continue;
262            } else {
263                // Non-retryable error or max retries exceeded
264                return Err(anyhow!(
265                    "HTTP {}: {} (after {} attempt{})",
266                    status,
267                    response.status_text(),
268                    attempt,
269                    if attempt == 1 { "" } else { "s" }
270                ));
271            }
272        }
273    }
274
275    fn name(&self) -> &'static str {
276        "openai"
277    }
278}
279
280impl OpenaiProvider {
281    /// Calculate exponential backoff delay with jitter
282    fn calculate_backoff(&self, attempt: u32) -> u64 {
283        let base_delay = self.retry_delay_ms;
284        let exponential = 2u64.saturating_pow(attempt - 1);
285        let delay = base_delay.saturating_mul(exponential);
286        // Add jitter: ±10% of delay to avoid thundering herd
287        let jitter = (delay / 10).saturating_mul(
288            ((attempt as u64).wrapping_mul(7)) % 21 / 10, // Deterministic pseudo-random jitter
289        );
290        if attempt.is_multiple_of(2) {
291            delay.saturating_add(jitter)
292        } else {
293            delay.saturating_sub(jitter)
294        }
295    }
296
297    /// Process successful API response
298    fn process_response(
299        &self,
300        response: ureq::Response,
301        callback: &mut dyn FnMut(&str) -> Result<()>,
302    ) -> Result<String> {
303        let reader = std::io::BufReader::new(response.into_reader());
304        let mut captured_output = String::new();
305        let mut line_buffer = String::new();
306
307        for line in reader.lines().map_while(Result::ok) {
308            if let Some(json_str) = line.strip_prefix("data: ") {
309                if json_str == "[DONE]" {
310                    break;
311                }
312
313                if let Ok(json) = serde_json::from_str::<serde_json::Value>(json_str) {
314                    if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
315                        for choice in choices {
316                            if let Some(delta) = choice.get("delta") {
317                                if let Some(content) = delta.get("content").and_then(|c| c.as_str())
318                                {
319                                    line_buffer.push_str(content);
320
321                                    // Only callback when we have complete lines
322                                    while let Some(newline_pos) = line_buffer.find('\n') {
323                                        let complete_line = &line_buffer[..newline_pos];
324                                        callback(complete_line)?;
325                                        captured_output.push_str(complete_line);
326                                        captured_output.push('\n');
327                                        line_buffer = line_buffer[newline_pos + 1..].to_string();
328                                    }
329                                }
330                            }
331                        }
332                    }
333                }
334            }
335        }
336
337        // Flush any remaining buffered content
338        if !line_buffer.is_empty() {
339            callback(&line_buffer)?;
340            captured_output.push_str(&line_buffer);
341            captured_output.push('\n');
342        }
343
344        if captured_output.is_empty() {
345            return Err(anyhow!("Empty response from OpenAI API"));
346        }
347
348        Ok(captured_output)
349    }
350}
351
352/// Helper function to extract text from Claude CLI stream-json format
353fn extract_text_from_stream_json(line: &str) -> Vec<String> {
354    let mut texts = Vec::new();
355
356    if let Ok(json) = serde_json::from_str::<serde_json::Value>(line) {
357        if let Some("assistant") = json.get("type").and_then(|t| t.as_str()) {
358            if let Some(content) = json
359                .get("message")
360                .and_then(|m| m.get("content"))
361                .and_then(|c| c.as_array())
362            {
363                for item in content {
364                    if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
365                        texts.push(text.to_string());
366                    }
367                }
368            }
369        }
370    }
371
372    texts
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    #[test]
380    fn test_default_ollama_endpoint() {
381        assert_eq!(
382            default_ollama_endpoint(),
383            "http://localhost:11434/v1".to_string()
384        );
385    }
386
387    #[test]
388    fn test_default_openai_endpoint() {
389        assert_eq!(
390            default_openai_endpoint(),
391            "https://api.openai.com/v1".to_string()
392        );
393    }
394
395    #[test]
396    fn test_claude_provider_name() {
397        let provider = ClaudeCliProvider;
398        assert_eq!(provider.name(), "claude");
399    }
400
401    #[test]
402    fn test_ollama_provider_name() {
403        let provider = OllamaProvider {
404            endpoint: "http://localhost:11434/v1".to_string(),
405            max_retries: 3,
406            retry_delay_ms: 1000,
407        };
408        assert_eq!(provider.name(), "ollama");
409    }
410
411    #[test]
412    fn test_openai_provider_name() {
413        let provider = OpenaiProvider {
414            endpoint: "https://api.openai.com/v1".to_string(),
415            api_key: None,
416            max_retries: 3,
417            retry_delay_ms: 1000,
418        };
419        assert_eq!(provider.name(), "openai");
420    }
421
422    #[test]
423    fn test_provider_type_default() {
424        let provider_type: ProviderType = Default::default();
425        assert_eq!(provider_type, ProviderType::Claude);
426    }
427}