Skip to main content

chant/
provider.rs

1//! Model provider abstraction for invoking AI agents.
2//!
3//! Supports multiple providers (Claude, Ollama, OpenAI).
4//!
5//! # Doc Audit
6//! - audited: 2026-01-25
7//! - docs: architecture/invoke.md
8//! - ignore: false
9
10use anyhow::{anyhow, Context, Result};
11use serde::Deserialize;
12use std::io::BufRead;
13use std::process::{Command, Stdio};
14use ureq::Agent;
15
16/// Model provider type
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize)]
18#[serde(rename_all = "lowercase")]
19pub enum ProviderType {
20    #[default]
21    Claude,
22    Ollama,
23    Openai,
24    Kirocli,
25}
26
27/// Provider configuration
28#[derive(Debug, Clone, Default, Deserialize)]
29pub struct ProviderConfig {
30    #[serde(default)]
31    pub ollama: Option<OllamaConfig>,
32    #[serde(default)]
33    pub openai: Option<OpenaiConfig>,
34}
35
36#[derive(Debug, Clone, Deserialize)]
37pub struct OllamaConfig {
38    #[serde(default = "default_ollama_endpoint")]
39    pub endpoint: String,
40    /// Maximum number of retry attempts for throttled requests
41    #[serde(default = "default_max_retries")]
42    pub max_retries: u32,
43    /// Initial delay in milliseconds before first retry
44    #[serde(default = "default_retry_delay_ms")]
45    pub retry_delay_ms: u64,
46}
47
48fn default_ollama_endpoint() -> String {
49    "http://localhost:11434/v1".to_string()
50}
51
52fn default_max_retries() -> u32 {
53    3
54}
55
56fn default_retry_delay_ms() -> u64 {
57    1000 // 1 second
58}
59
60#[derive(Debug, Clone, Deserialize)]
61pub struct OpenaiConfig {
62    #[serde(default = "default_openai_endpoint")]
63    pub endpoint: String,
64    /// Maximum number of retry attempts for throttled requests
65    #[serde(default = "default_max_retries")]
66    pub max_retries: u32,
67    /// Initial delay in milliseconds before first retry
68    #[serde(default = "default_retry_delay_ms")]
69    pub retry_delay_ms: u64,
70}
71
72fn default_openai_endpoint() -> String {
73    "https://api.openai.com/v1".to_string()
74}
75
76/// Trait for model providers
77pub trait ModelProvider {
78    fn invoke(
79        &self,
80        message: &str,
81        model: &str,
82        callback: &mut dyn FnMut(&str) -> Result<()>,
83    ) -> Result<String>;
84
85    /// Returns the provider name. Part of the trait API, used in tests.
86    fn name(&self) -> &'static str;
87}
88
89/// Kiro CLI provider (MCP-only)
90pub struct KiroCliProvider;
91
92impl ModelProvider for KiroCliProvider {
93    fn invoke(
94        &self,
95        message: &str,
96        model: &str,
97        callback: &mut dyn FnMut(&str) -> Result<()>,
98    ) -> Result<String> {
99        let mut cmd = Command::new("kiro-cli-chat");
100        cmd.arg("chat")
101            .arg("--no-interactive")
102            .arg("--trust-all-tools")
103            .arg("--model")
104            .arg(model)
105            .arg(message)
106            .stdout(Stdio::piped())
107            .stderr(Stdio::piped());
108
109        let mut child = cmd
110            .spawn()
111            .context("Failed to invoke kiro-cli-chat. Is it installed and in PATH?")?;
112
113        // Capture stderr in a separate thread
114        let stderr_handle = child.stderr.take().map(|stderr| {
115            std::thread::spawn(move || {
116                let reader = std::io::BufReader::new(stderr);
117                let mut stderr_output = String::new();
118                for line in reader.lines().map_while(Result::ok) {
119                    stderr_output.push_str(&line);
120                    stderr_output.push('\n');
121                }
122                stderr_output
123            })
124        });
125
126        let mut captured_output = String::new();
127        if let Some(stdout) = child.stdout.take() {
128            let reader = std::io::BufReader::new(stdout);
129            for line in reader.lines().map_while(Result::ok) {
130                callback(&line)?;
131                captured_output.push_str(&line);
132                captured_output.push('\n');
133            }
134        }
135
136        let status = child.wait()?;
137
138        // Collect stderr
139        let stderr_output = stderr_handle
140            .and_then(|h| h.join().ok())
141            .unwrap_or_default();
142
143        if !status.success() {
144            if !stderr_output.is_empty() {
145                anyhow::bail!(
146                    "kiro-cli-chat exited with status: {}\nStderr: {}",
147                    status,
148                    stderr_output.trim()
149                );
150            }
151            anyhow::bail!("kiro-cli-chat exited with status: {}", status);
152        }
153
154        Ok(captured_output)
155    }
156
157    fn name(&self) -> &'static str {
158        "kirocli"
159    }
160}
161
162/// Claude CLI provider (existing behavior)
163pub struct ClaudeCliProvider;
164
165impl ModelProvider for ClaudeCliProvider {
166    fn invoke(
167        &self,
168        message: &str,
169        model: &str,
170        callback: &mut dyn FnMut(&str) -> Result<()>,
171    ) -> Result<String> {
172        let mut cmd = Command::new("claude");
173        cmd.arg("--print")
174            .arg("--output-format")
175            .arg("stream-json")
176            .arg("--verbose")
177            .arg("--model")
178            .arg(model)
179            .arg("--dangerously-skip-permissions")
180            .arg(message)
181            .stdout(Stdio::piped())
182            .stderr(Stdio::piped());
183
184        let mut child = cmd
185            .spawn()
186            .context("Failed to invoke claude CLI. Is it installed and in PATH?")?;
187
188        let mut captured_output = String::new();
189        if let Some(stdout) = child.stdout.take() {
190            let reader = std::io::BufReader::new(stdout);
191            for line in reader.lines().map_while(Result::ok) {
192                for text in extract_text_from_stream_json(&line) {
193                    for text_line in text.lines() {
194                        callback(text_line)?;
195                        captured_output.push_str(text_line);
196                        captured_output.push('\n');
197                    }
198                }
199            }
200        }
201
202        let status = child.wait()?;
203        if !status.success() {
204            anyhow::bail!("Agent exited with status: {}", status);
205        }
206
207        Ok(captured_output)
208    }
209
210    fn name(&self) -> &'static str {
211        "claude"
212    }
213}
214
215/// Ollama provider (OpenAI-compatible API with agent runtime)
216pub struct OllamaProvider {
217    pub endpoint: String,
218    pub max_retries: u32,
219    pub retry_delay_ms: u64,
220}
221
222impl ModelProvider for OllamaProvider {
223    fn invoke(
224        &self,
225        message: &str,
226        model: &str,
227        callback: &mut dyn FnMut(&str) -> Result<()>,
228    ) -> Result<String> {
229        // Validate endpoint URL
230        if !self.endpoint.starts_with("http://") && !self.endpoint.starts_with("https://") {
231            return Err(anyhow!("Invalid endpoint URL: {}", self.endpoint));
232        }
233
234        crate::agent::run_agent_with_retries(
235            &self.endpoint,
236            model,
237            "",
238            message,
239            callback,
240            self.max_retries,
241            self.retry_delay_ms,
242        )
243        .map_err(|e| {
244            let err_str = e.to_string();
245            if err_str.contains("Connection") || err_str.contains("connect") {
246                anyhow!("Failed to connect to Ollama at {}\n\nOllama does not appear to be running. To fix:\n\n  1. Install Ollama: https://ollama.ai/download\n  2. Start Ollama: ollama serve\n  3. Pull a model: ollama pull {}\n\nOr switch to Claude CLI by removing 'provider: ollama' from .chant/config.md", self.endpoint, model)
247            } else {
248                e
249            }
250        })
251    }
252
253    fn name(&self) -> &'static str {
254        "ollama"
255    }
256}
257
258/// OpenAI provider
259pub struct OpenaiProvider {
260    pub endpoint: String,
261    pub api_key: Option<String>,
262    pub max_retries: u32,
263    pub retry_delay_ms: u64,
264}
265
266impl ModelProvider for OpenaiProvider {
267    fn invoke(
268        &self,
269        message: &str,
270        model: &str,
271        callback: &mut dyn FnMut(&str) -> Result<()>,
272    ) -> Result<String> {
273        let api_key = self
274            .api_key
275            .clone()
276            .or_else(|| std::env::var("OPENAI_API_KEY").ok())
277            .ok_or_else(|| anyhow!("OPENAI_API_KEY environment variable not set"))?;
278
279        let url = format!("{}/chat/completions", self.endpoint);
280
281        // Validate endpoint URL
282        if !self.endpoint.starts_with("http://") && !self.endpoint.starts_with("https://") {
283            return Err(anyhow!("Invalid endpoint URL: {}", self.endpoint));
284        }
285
286        let request_body = serde_json::json!({
287            "model": model,
288            "messages": [
289                {
290                    "role": "user",
291                    "content": message
292                }
293            ],
294            "stream": true,
295        });
296
297        // Retry loop with exponential backoff
298        let mut attempt = 0;
299        loop {
300            attempt += 1;
301
302            // Create HTTP agent and send request
303            let agent = Agent::new();
304            let response = agent
305                .post(&url)
306                .set("Content-Type", "application/json")
307                .set("Authorization", &format!("Bearer {}", api_key))
308                .send_json(&request_body)
309                .map_err(|e| anyhow!("HTTP request failed: {}", e))?;
310
311            let status = response.status();
312
313            // Check response status
314            if status == 401 {
315                return Err(anyhow!(
316                    "Authentication failed. Check OPENAI_API_KEY env var"
317                ));
318            }
319
320            // Check for throttle/error conditions (429 or 400+ errors)
321            let is_retryable =
322                status == 429 || status == 500 || status == 502 || status == 503 || status == 504;
323
324            if status == 200 {
325                // Success - process response
326                return self.process_response(response, callback);
327            } else if is_retryable && attempt <= self.max_retries {
328                // Retryable error - wait and retry
329                let delay_ms = self.calculate_backoff(attempt);
330                callback(&format!(
331                    "[Retry {}] HTTP {} - waiting {}ms before retry",
332                    attempt, status, delay_ms
333                ))?;
334                std::thread::sleep(std::time::Duration::from_millis(delay_ms));
335                continue;
336            } else {
337                // Non-retryable error or max retries exceeded
338                return Err(anyhow!(
339                    "HTTP {}: {} (after {} attempt{})",
340                    status,
341                    response.status_text(),
342                    attempt,
343                    if attempt == 1 { "" } else { "s" }
344                ));
345            }
346        }
347    }
348
349    fn name(&self) -> &'static str {
350        "openai"
351    }
352}
353
354impl OpenaiProvider {
355    /// Calculate exponential backoff delay with jitter
356    fn calculate_backoff(&self, attempt: u32) -> u64 {
357        let base_delay = self.retry_delay_ms;
358        let exponential = 2u64.saturating_pow(attempt - 1);
359        let delay = base_delay.saturating_mul(exponential);
360        // Add jitter: ±10% of delay to avoid thundering herd
361        let jitter = (delay / 10).saturating_mul(
362            ((attempt as u64).wrapping_mul(7)) % 21 / 10, // Deterministic pseudo-random jitter
363        );
364        if attempt.is_multiple_of(2) {
365            delay.saturating_add(jitter)
366        } else {
367            delay.saturating_sub(jitter)
368        }
369    }
370
371    /// Process successful API response
372    fn process_response(
373        &self,
374        response: ureq::Response,
375        callback: &mut dyn FnMut(&str) -> Result<()>,
376    ) -> Result<String> {
377        let reader = std::io::BufReader::new(response.into_reader());
378        let mut captured_output = String::new();
379        let mut line_buffer = String::new();
380
381        for line in reader.lines().map_while(Result::ok) {
382            if let Some(json_str) = line.strip_prefix("data: ") {
383                if json_str == "[DONE]" {
384                    break;
385                }
386
387                if let Ok(json) = serde_json::from_str::<serde_json::Value>(json_str) {
388                    if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
389                        for choice in choices {
390                            if let Some(delta) = choice.get("delta") {
391                                if let Some(content) = delta.get("content").and_then(|c| c.as_str())
392                                {
393                                    line_buffer.push_str(content);
394
395                                    // Only callback when we have complete lines
396                                    while let Some(newline_pos) = line_buffer.find('\n') {
397                                        let complete_line = &line_buffer[..newline_pos];
398                                        callback(complete_line)?;
399                                        captured_output.push_str(complete_line);
400                                        captured_output.push('\n');
401                                        line_buffer = line_buffer[newline_pos + 1..].to_string();
402                                    }
403                                }
404                            }
405                        }
406                    }
407                }
408            }
409        }
410
411        // Flush any remaining buffered content
412        if !line_buffer.is_empty() {
413            callback(&line_buffer)?;
414            captured_output.push_str(&line_buffer);
415            captured_output.push('\n');
416        }
417
418        if captured_output.is_empty() {
419            return Err(anyhow!("Empty response from OpenAI API"));
420        }
421
422        Ok(captured_output)
423    }
424}
425
426/// Helper function to extract text from Claude CLI stream-json format
427fn extract_text_from_stream_json(line: &str) -> Vec<String> {
428    let mut texts = Vec::new();
429
430    if let Ok(json) = serde_json::from_str::<serde_json::Value>(line) {
431        if let Some("assistant") = json.get("type").and_then(|t| t.as_str()) {
432            if let Some(content) = json
433                .get("message")
434                .and_then(|m| m.get("content"))
435                .and_then(|c| c.as_array())
436            {
437                for item in content {
438                    if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
439                        texts.push(text.to_string());
440                    }
441                }
442            }
443        }
444    }
445
446    texts
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452
453    #[test]
454    fn test_default_ollama_endpoint() {
455        assert_eq!(
456            default_ollama_endpoint(),
457            "http://localhost:11434/v1".to_string()
458        );
459    }
460
461    #[test]
462    fn test_default_openai_endpoint() {
463        assert_eq!(
464            default_openai_endpoint(),
465            "https://api.openai.com/v1".to_string()
466        );
467    }
468
469    #[test]
470    fn test_claude_provider_name() {
471        let provider = ClaudeCliProvider;
472        assert_eq!(provider.name(), "claude");
473    }
474
475    #[test]
476    fn test_ollama_provider_name() {
477        let provider = OllamaProvider {
478            endpoint: "http://localhost:11434/v1".to_string(),
479            max_retries: 3,
480            retry_delay_ms: 1000,
481        };
482        assert_eq!(provider.name(), "ollama");
483    }
484
485    #[test]
486    fn test_openai_provider_name() {
487        let provider = OpenaiProvider {
488            endpoint: "https://api.openai.com/v1".to_string(),
489            api_key: None,
490            max_retries: 3,
491            retry_delay_ms: 1000,
492        };
493        assert_eq!(provider.name(), "openai");
494    }
495
496    #[test]
497    fn test_provider_type_default() {
498        let provider_type: ProviderType = Default::default();
499        assert_eq!(provider_type, ProviderType::Claude);
500    }
501
502    #[test]
503    fn test_kirocli_provider_name() {
504        let provider = KiroCliProvider;
505        assert_eq!(provider.name(), "kirocli");
506    }
507}