1use anyhow::{anyhow, Context, Result};
11use serde::Deserialize;
12use std::io::BufRead;
13use std::process::{Command, Stdio};
14use ureq::Agent;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize)]
18#[serde(rename_all = "lowercase")]
19pub enum ProviderType {
20 #[default]
21 Claude,
22 Ollama,
23 Openai,
24}
25
26#[derive(Debug, Clone, Default, Deserialize)]
28pub struct ProviderConfig {
29 #[serde(default)]
30 pub ollama: Option<OllamaConfig>,
31 #[serde(default)]
32 pub openai: Option<OpenaiConfig>,
33}
34
35#[derive(Debug, Clone, Deserialize)]
36pub struct OllamaConfig {
37 #[serde(default = "default_ollama_endpoint")]
38 pub endpoint: String,
39 #[serde(default = "default_max_retries")]
41 pub max_retries: u32,
42 #[serde(default = "default_retry_delay_ms")]
44 pub retry_delay_ms: u64,
45}
46
47fn default_ollama_endpoint() -> String {
48 "http://localhost:11434/v1".to_string()
49}
50
51fn default_max_retries() -> u32 {
52 3
53}
54
55fn default_retry_delay_ms() -> u64 {
56 1000 }
58
59#[derive(Debug, Clone, Deserialize)]
60pub struct OpenaiConfig {
61 #[serde(default = "default_openai_endpoint")]
62 pub endpoint: String,
63 #[serde(default = "default_max_retries")]
65 pub max_retries: u32,
66 #[serde(default = "default_retry_delay_ms")]
68 pub retry_delay_ms: u64,
69}
70
71fn default_openai_endpoint() -> String {
72 "https://api.openai.com/v1".to_string()
73}
74
75pub trait ModelProvider {
77 fn invoke(
78 &self,
79 message: &str,
80 model: &str,
81 callback: &mut dyn FnMut(&str) -> Result<()>,
82 ) -> Result<String>;
83
84 fn name(&self) -> &'static str;
86}
87
88pub struct ClaudeCliProvider;
90
91impl ModelProvider for ClaudeCliProvider {
92 fn invoke(
93 &self,
94 message: &str,
95 model: &str,
96 callback: &mut dyn FnMut(&str) -> Result<()>,
97 ) -> Result<String> {
98 let mut cmd = Command::new("claude");
99 cmd.arg("--print")
100 .arg("--output-format")
101 .arg("stream-json")
102 .arg("--verbose")
103 .arg("--model")
104 .arg(model)
105 .arg("--dangerously-skip-permissions")
106 .arg(message)
107 .stdout(Stdio::piped())
108 .stderr(Stdio::piped());
109
110 let mut child = cmd
111 .spawn()
112 .context("Failed to invoke claude CLI. Is it installed and in PATH?")?;
113
114 let mut captured_output = String::new();
115 if let Some(stdout) = child.stdout.take() {
116 let reader = std::io::BufReader::new(stdout);
117 for line in reader.lines().map_while(Result::ok) {
118 for text in extract_text_from_stream_json(&line) {
119 for text_line in text.lines() {
120 callback(text_line)?;
121 captured_output.push_str(text_line);
122 captured_output.push('\n');
123 }
124 }
125 }
126 }
127
128 let status = child.wait()?;
129 if !status.success() {
130 anyhow::bail!("Agent exited with status: {}", status);
131 }
132
133 Ok(captured_output)
134 }
135
136 fn name(&self) -> &'static str {
137 "claude"
138 }
139}
140
141pub struct OllamaProvider {
143 pub endpoint: String,
144 pub max_retries: u32,
145 pub retry_delay_ms: u64,
146}
147
148impl ModelProvider for OllamaProvider {
149 fn invoke(
150 &self,
151 message: &str,
152 model: &str,
153 callback: &mut dyn FnMut(&str) -> Result<()>,
154 ) -> Result<String> {
155 if !self.endpoint.starts_with("http://") && !self.endpoint.starts_with("https://") {
157 return Err(anyhow!("Invalid endpoint URL: {}", self.endpoint));
158 }
159
160 crate::agent::run_agent_with_retries(
161 &self.endpoint,
162 model,
163 "",
164 message,
165 callback,
166 self.max_retries,
167 self.retry_delay_ms,
168 )
169 .map_err(|e| {
170 let err_str = e.to_string();
171 if err_str.contains("Connection") || err_str.contains("connect") {
172 anyhow!("Failed to connect to Ollama at {}\n\nOllama does not appear to be running. To fix:\n\n 1. Install Ollama: https://ollama.ai/download\n 2. Start Ollama: ollama serve\n 3. Pull a model: ollama pull {}\n\nOr switch to Claude CLI by removing 'provider: ollama' from .chant/config.md", self.endpoint, model)
173 } else {
174 e
175 }
176 })
177 }
178
179 fn name(&self) -> &'static str {
180 "ollama"
181 }
182}
183
184pub struct OpenaiProvider {
186 pub endpoint: String,
187 pub api_key: Option<String>,
188 pub max_retries: u32,
189 pub retry_delay_ms: u64,
190}
191
192impl ModelProvider for OpenaiProvider {
193 fn invoke(
194 &self,
195 message: &str,
196 model: &str,
197 callback: &mut dyn FnMut(&str) -> Result<()>,
198 ) -> Result<String> {
199 let api_key = self
200 .api_key
201 .clone()
202 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
203 .ok_or_else(|| anyhow!("OPENAI_API_KEY environment variable not set"))?;
204
205 let url = format!("{}/chat/completions", self.endpoint);
206
207 if !self.endpoint.starts_with("http://") && !self.endpoint.starts_with("https://") {
209 return Err(anyhow!("Invalid endpoint URL: {}", self.endpoint));
210 }
211
212 let request_body = serde_json::json!({
213 "model": model,
214 "messages": [
215 {
216 "role": "user",
217 "content": message
218 }
219 ],
220 "stream": true,
221 });
222
223 let mut attempt = 0;
225 loop {
226 attempt += 1;
227
228 let agent = Agent::new();
230 let response = agent
231 .post(&url)
232 .set("Content-Type", "application/json")
233 .set("Authorization", &format!("Bearer {}", api_key))
234 .send_json(&request_body)
235 .map_err(|e| anyhow!("HTTP request failed: {}", e))?;
236
237 let status = response.status();
238
239 if status == 401 {
241 return Err(anyhow!(
242 "Authentication failed. Check OPENAI_API_KEY env var"
243 ));
244 }
245
246 let is_retryable =
248 status == 429 || status == 500 || status == 502 || status == 503 || status == 504;
249
250 if status == 200 {
251 return self.process_response(response, callback);
253 } else if is_retryable && attempt <= self.max_retries {
254 let delay_ms = self.calculate_backoff(attempt);
256 callback(&format!(
257 "[Retry {}] HTTP {} - waiting {}ms before retry",
258 attempt, status, delay_ms
259 ))?;
260 std::thread::sleep(std::time::Duration::from_millis(delay_ms));
261 continue;
262 } else {
263 return Err(anyhow!(
265 "HTTP {}: {} (after {} attempt{})",
266 status,
267 response.status_text(),
268 attempt,
269 if attempt == 1 { "" } else { "s" }
270 ));
271 }
272 }
273 }
274
275 fn name(&self) -> &'static str {
276 "openai"
277 }
278}
279
280impl OpenaiProvider {
281 fn calculate_backoff(&self, attempt: u32) -> u64 {
283 let base_delay = self.retry_delay_ms;
284 let exponential = 2u64.saturating_pow(attempt - 1);
285 let delay = base_delay.saturating_mul(exponential);
286 let jitter = (delay / 10).saturating_mul(
288 ((attempt as u64).wrapping_mul(7)) % 21 / 10, );
290 if attempt.is_multiple_of(2) {
291 delay.saturating_add(jitter)
292 } else {
293 delay.saturating_sub(jitter)
294 }
295 }
296
297 fn process_response(
299 &self,
300 response: ureq::Response,
301 callback: &mut dyn FnMut(&str) -> Result<()>,
302 ) -> Result<String> {
303 let reader = std::io::BufReader::new(response.into_reader());
304 let mut captured_output = String::new();
305 let mut line_buffer = String::new();
306
307 for line in reader.lines().map_while(Result::ok) {
308 if let Some(json_str) = line.strip_prefix("data: ") {
309 if json_str == "[DONE]" {
310 break;
311 }
312
313 if let Ok(json) = serde_json::from_str::<serde_json::Value>(json_str) {
314 if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
315 for choice in choices {
316 if let Some(delta) = choice.get("delta") {
317 if let Some(content) = delta.get("content").and_then(|c| c.as_str())
318 {
319 line_buffer.push_str(content);
320
321 while let Some(newline_pos) = line_buffer.find('\n') {
323 let complete_line = &line_buffer[..newline_pos];
324 callback(complete_line)?;
325 captured_output.push_str(complete_line);
326 captured_output.push('\n');
327 line_buffer = line_buffer[newline_pos + 1..].to_string();
328 }
329 }
330 }
331 }
332 }
333 }
334 }
335 }
336
337 if !line_buffer.is_empty() {
339 callback(&line_buffer)?;
340 captured_output.push_str(&line_buffer);
341 captured_output.push('\n');
342 }
343
344 if captured_output.is_empty() {
345 return Err(anyhow!("Empty response from OpenAI API"));
346 }
347
348 Ok(captured_output)
349 }
350}
351
352fn extract_text_from_stream_json(line: &str) -> Vec<String> {
354 let mut texts = Vec::new();
355
356 if let Ok(json) = serde_json::from_str::<serde_json::Value>(line) {
357 if let Some("assistant") = json.get("type").and_then(|t| t.as_str()) {
358 if let Some(content) = json
359 .get("message")
360 .and_then(|m| m.get("content"))
361 .and_then(|c| c.as_array())
362 {
363 for item in content {
364 if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
365 texts.push(text.to_string());
366 }
367 }
368 }
369 }
370 }
371
372 texts
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378
379 #[test]
380 fn test_default_ollama_endpoint() {
381 assert_eq!(
382 default_ollama_endpoint(),
383 "http://localhost:11434/v1".to_string()
384 );
385 }
386
387 #[test]
388 fn test_default_openai_endpoint() {
389 assert_eq!(
390 default_openai_endpoint(),
391 "https://api.openai.com/v1".to_string()
392 );
393 }
394
395 #[test]
396 fn test_claude_provider_name() {
397 let provider = ClaudeCliProvider;
398 assert_eq!(provider.name(), "claude");
399 }
400
401 #[test]
402 fn test_ollama_provider_name() {
403 let provider = OllamaProvider {
404 endpoint: "http://localhost:11434/v1".to_string(),
405 max_retries: 3,
406 retry_delay_ms: 1000,
407 };
408 assert_eq!(provider.name(), "ollama");
409 }
410
411 #[test]
412 fn test_openai_provider_name() {
413 let provider = OpenaiProvider {
414 endpoint: "https://api.openai.com/v1".to_string(),
415 api_key: None,
416 max_retries: 3,
417 retry_delay_ms: 1000,
418 };
419 assert_eq!(provider.name(), "openai");
420 }
421
422 #[test]
423 fn test_provider_type_default() {
424 let provider_type: ProviderType = Default::default();
425 assert_eq!(provider_type, ProviderType::Claude);
426 }
427}