1use anyhow::{anyhow, Context, Result};
11use serde::Deserialize;
12use std::io::BufRead;
13use std::process::{Command, Stdio};
14use ureq::Agent;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize)]
18#[serde(rename_all = "lowercase")]
19pub enum ProviderType {
20 #[default]
21 Claude,
22 Ollama,
23 Openai,
24 Kirocli,
25}
26
27#[derive(Debug, Clone, Default, Deserialize)]
29pub struct ProviderConfig {
30 #[serde(default)]
31 pub ollama: Option<OllamaConfig>,
32 #[serde(default)]
33 pub openai: Option<OpenaiConfig>,
34}
35
36#[derive(Debug, Clone, Deserialize)]
37pub struct OllamaConfig {
38 #[serde(default = "default_ollama_endpoint")]
39 pub endpoint: String,
40 #[serde(default = "default_max_retries")]
42 pub max_retries: u32,
43 #[serde(default = "default_retry_delay_ms")]
45 pub retry_delay_ms: u64,
46}
47
48fn default_ollama_endpoint() -> String {
49 "http://localhost:11434/v1".to_string()
50}
51
52fn default_max_retries() -> u32 {
53 3
54}
55
56fn default_retry_delay_ms() -> u64 {
57 1000 }
59
60#[derive(Debug, Clone, Deserialize)]
61pub struct OpenaiConfig {
62 #[serde(default = "default_openai_endpoint")]
63 pub endpoint: String,
64 #[serde(default = "default_max_retries")]
66 pub max_retries: u32,
67 #[serde(default = "default_retry_delay_ms")]
69 pub retry_delay_ms: u64,
70}
71
72fn default_openai_endpoint() -> String {
73 "https://api.openai.com/v1".to_string()
74}
75
76pub trait ModelProvider {
78 fn invoke(
79 &self,
80 message: &str,
81 model: &str,
82 callback: &mut dyn FnMut(&str) -> Result<()>,
83 ) -> Result<String>;
84
85 fn name(&self) -> &'static str;
87}
88
89pub struct KiroCliProvider;
91
92impl ModelProvider for KiroCliProvider {
93 fn invoke(
94 &self,
95 message: &str,
96 model: &str,
97 callback: &mut dyn FnMut(&str) -> Result<()>,
98 ) -> Result<String> {
99 let mut cmd = Command::new("kiro-cli-chat");
100 cmd.arg("chat")
101 .arg("--no-interactive")
102 .arg("--trust-all-tools")
103 .arg("--model")
104 .arg(model)
105 .arg(message)
106 .stdout(Stdio::piped())
107 .stderr(Stdio::piped());
108
109 let mut child = cmd
110 .spawn()
111 .context("Failed to invoke kiro-cli-chat. Is it installed and in PATH?")?;
112
113 let stderr_handle = child.stderr.take().map(|stderr| {
115 std::thread::spawn(move || {
116 let reader = std::io::BufReader::new(stderr);
117 let mut stderr_output = String::new();
118 for line in reader.lines().map_while(Result::ok) {
119 stderr_output.push_str(&line);
120 stderr_output.push('\n');
121 }
122 stderr_output
123 })
124 });
125
126 let mut captured_output = String::new();
127 if let Some(stdout) = child.stdout.take() {
128 let reader = std::io::BufReader::new(stdout);
129 for line in reader.lines().map_while(Result::ok) {
130 callback(&line)?;
131 captured_output.push_str(&line);
132 captured_output.push('\n');
133 }
134 }
135
136 let status = child.wait()?;
137
138 let stderr_output = stderr_handle
140 .and_then(|h| h.join().ok())
141 .unwrap_or_default();
142
143 if !status.success() {
144 if !stderr_output.is_empty() {
145 anyhow::bail!(
146 "kiro-cli-chat exited with status: {}\nStderr: {}",
147 status,
148 stderr_output.trim()
149 );
150 }
151 anyhow::bail!("kiro-cli-chat exited with status: {}", status);
152 }
153
154 Ok(captured_output)
155 }
156
157 fn name(&self) -> &'static str {
158 "kirocli"
159 }
160}
161
162pub struct ClaudeCliProvider;
164
165impl ModelProvider for ClaudeCliProvider {
166 fn invoke(
167 &self,
168 message: &str,
169 model: &str,
170 callback: &mut dyn FnMut(&str) -> Result<()>,
171 ) -> Result<String> {
172 let mut cmd = Command::new("claude");
173 cmd.arg("--print")
174 .arg("--output-format")
175 .arg("stream-json")
176 .arg("--verbose")
177 .arg("--model")
178 .arg(model)
179 .arg("--dangerously-skip-permissions")
180 .arg(message)
181 .stdout(Stdio::piped())
182 .stderr(Stdio::piped());
183
184 let mut child = cmd
185 .spawn()
186 .context("Failed to invoke claude CLI. Is it installed and in PATH?")?;
187
188 let mut captured_output = String::new();
189 if let Some(stdout) = child.stdout.take() {
190 let reader = std::io::BufReader::new(stdout);
191 for line in reader.lines().map_while(Result::ok) {
192 for text in extract_text_from_stream_json(&line) {
193 for text_line in text.lines() {
194 callback(text_line)?;
195 captured_output.push_str(text_line);
196 captured_output.push('\n');
197 }
198 }
199 }
200 }
201
202 let status = child.wait()?;
203 if !status.success() {
204 anyhow::bail!("Agent exited with status: {}", status);
205 }
206
207 Ok(captured_output)
208 }
209
210 fn name(&self) -> &'static str {
211 "claude"
212 }
213}
214
215pub struct OllamaProvider {
217 pub endpoint: String,
218 pub max_retries: u32,
219 pub retry_delay_ms: u64,
220}
221
222impl ModelProvider for OllamaProvider {
223 fn invoke(
224 &self,
225 message: &str,
226 model: &str,
227 callback: &mut dyn FnMut(&str) -> Result<()>,
228 ) -> Result<String> {
229 if !self.endpoint.starts_with("http://") && !self.endpoint.starts_with("https://") {
231 return Err(anyhow!("Invalid endpoint URL: {}", self.endpoint));
232 }
233
234 crate::agent::run_agent_with_retries(
235 &self.endpoint,
236 model,
237 "",
238 message,
239 callback,
240 self.max_retries,
241 self.retry_delay_ms,
242 )
243 .map_err(|e| {
244 let err_str = e.to_string();
245 if err_str.contains("Connection") || err_str.contains("connect") {
246 anyhow!("Failed to connect to Ollama at {}\n\nOllama does not appear to be running. To fix:\n\n 1. Install Ollama: https://ollama.ai/download\n 2. Start Ollama: ollama serve\n 3. Pull a model: ollama pull {}\n\nOr switch to Claude CLI by removing 'provider: ollama' from .chant/config.md", self.endpoint, model)
247 } else {
248 e
249 }
250 })
251 }
252
253 fn name(&self) -> &'static str {
254 "ollama"
255 }
256}
257
258pub struct OpenaiProvider {
260 pub endpoint: String,
261 pub api_key: Option<String>,
262 pub max_retries: u32,
263 pub retry_delay_ms: u64,
264}
265
266impl ModelProvider for OpenaiProvider {
267 fn invoke(
268 &self,
269 message: &str,
270 model: &str,
271 callback: &mut dyn FnMut(&str) -> Result<()>,
272 ) -> Result<String> {
273 let api_key = self
274 .api_key
275 .clone()
276 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
277 .ok_or_else(|| anyhow!("OPENAI_API_KEY environment variable not set"))?;
278
279 let url = format!("{}/chat/completions", self.endpoint);
280
281 if !self.endpoint.starts_with("http://") && !self.endpoint.starts_with("https://") {
283 return Err(anyhow!("Invalid endpoint URL: {}", self.endpoint));
284 }
285
286 let request_body = serde_json::json!({
287 "model": model,
288 "messages": [
289 {
290 "role": "user",
291 "content": message
292 }
293 ],
294 "stream": true,
295 });
296
297 let mut attempt = 0;
299 loop {
300 attempt += 1;
301
302 let agent = Agent::new();
304 let response = agent
305 .post(&url)
306 .set("Content-Type", "application/json")
307 .set("Authorization", &format!("Bearer {}", api_key))
308 .send_json(&request_body)
309 .map_err(|e| anyhow!("HTTP request failed: {}", e))?;
310
311 let status = response.status();
312
313 if status == 401 {
315 return Err(anyhow!(
316 "Authentication failed. Check OPENAI_API_KEY env var"
317 ));
318 }
319
320 let is_retryable =
322 status == 429 || status == 500 || status == 502 || status == 503 || status == 504;
323
324 if status == 200 {
325 return self.process_response(response, callback);
327 } else if is_retryable && attempt <= self.max_retries {
328 let delay_ms = self.calculate_backoff(attempt);
330 callback(&format!(
331 "[Retry {}] HTTP {} - waiting {}ms before retry",
332 attempt, status, delay_ms
333 ))?;
334 std::thread::sleep(std::time::Duration::from_millis(delay_ms));
335 continue;
336 } else {
337 return Err(anyhow!(
339 "HTTP {}: {} (after {} attempt{})",
340 status,
341 response.status_text(),
342 attempt,
343 if attempt == 1 { "" } else { "s" }
344 ));
345 }
346 }
347 }
348
349 fn name(&self) -> &'static str {
350 "openai"
351 }
352}
353
354impl OpenaiProvider {
355 fn calculate_backoff(&self, attempt: u32) -> u64 {
357 let base_delay = self.retry_delay_ms;
358 let exponential = 2u64.saturating_pow(attempt - 1);
359 let delay = base_delay.saturating_mul(exponential);
360 let jitter = (delay / 10).saturating_mul(
362 ((attempt as u64).wrapping_mul(7)) % 21 / 10, );
364 if attempt.is_multiple_of(2) {
365 delay.saturating_add(jitter)
366 } else {
367 delay.saturating_sub(jitter)
368 }
369 }
370
371 fn process_response(
373 &self,
374 response: ureq::Response,
375 callback: &mut dyn FnMut(&str) -> Result<()>,
376 ) -> Result<String> {
377 let reader = std::io::BufReader::new(response.into_reader());
378 let mut captured_output = String::new();
379 let mut line_buffer = String::new();
380
381 for line in reader.lines().map_while(Result::ok) {
382 if let Some(json_str) = line.strip_prefix("data: ") {
383 if json_str == "[DONE]" {
384 break;
385 }
386
387 if let Ok(json) = serde_json::from_str::<serde_json::Value>(json_str) {
388 if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
389 for choice in choices {
390 if let Some(delta) = choice.get("delta") {
391 if let Some(content) = delta.get("content").and_then(|c| c.as_str())
392 {
393 line_buffer.push_str(content);
394
395 while let Some(newline_pos) = line_buffer.find('\n') {
397 let complete_line = &line_buffer[..newline_pos];
398 callback(complete_line)?;
399 captured_output.push_str(complete_line);
400 captured_output.push('\n');
401 line_buffer = line_buffer[newline_pos + 1..].to_string();
402 }
403 }
404 }
405 }
406 }
407 }
408 }
409 }
410
411 if !line_buffer.is_empty() {
413 callback(&line_buffer)?;
414 captured_output.push_str(&line_buffer);
415 captured_output.push('\n');
416 }
417
418 if captured_output.is_empty() {
419 return Err(anyhow!("Empty response from OpenAI API"));
420 }
421
422 Ok(captured_output)
423 }
424}
425
426fn extract_text_from_stream_json(line: &str) -> Vec<String> {
428 let mut texts = Vec::new();
429
430 if let Ok(json) = serde_json::from_str::<serde_json::Value>(line) {
431 if let Some("assistant") = json.get("type").and_then(|t| t.as_str()) {
432 if let Some(content) = json
433 .get("message")
434 .and_then(|m| m.get("content"))
435 .and_then(|c| c.as_array())
436 {
437 for item in content {
438 if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
439 texts.push(text.to_string());
440 }
441 }
442 }
443 }
444 }
445
446 texts
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452
453 #[test]
454 fn test_default_ollama_endpoint() {
455 assert_eq!(
456 default_ollama_endpoint(),
457 "http://localhost:11434/v1".to_string()
458 );
459 }
460
461 #[test]
462 fn test_default_openai_endpoint() {
463 assert_eq!(
464 default_openai_endpoint(),
465 "https://api.openai.com/v1".to_string()
466 );
467 }
468
469 #[test]
470 fn test_claude_provider_name() {
471 let provider = ClaudeCliProvider;
472 assert_eq!(provider.name(), "claude");
473 }
474
475 #[test]
476 fn test_ollama_provider_name() {
477 let provider = OllamaProvider {
478 endpoint: "http://localhost:11434/v1".to_string(),
479 max_retries: 3,
480 retry_delay_ms: 1000,
481 };
482 assert_eq!(provider.name(), "ollama");
483 }
484
485 #[test]
486 fn test_openai_provider_name() {
487 let provider = OpenaiProvider {
488 endpoint: "https://api.openai.com/v1".to_string(),
489 api_key: None,
490 max_retries: 3,
491 retry_delay_ms: 1000,
492 };
493 assert_eq!(provider.name(), "openai");
494 }
495
496 #[test]
497 fn test_provider_type_default() {
498 let provider_type: ProviderType = Default::default();
499 assert_eq!(provider_type, ProviderType::Claude);
500 }
501
502 #[test]
503 fn test_kirocli_provider_name() {
504 let provider = KiroCliProvider;
505 assert_eq!(provider.name(), "kirocli");
506 }
507}