1use std::pin::Pin;
8
9use futures::Stream;
10use serde::{Deserialize, Serialize};
11use thiserror::Error;
12
13#[derive(Debug, Error)]
17pub enum LlmError {
18 #[error("HTTP request failed: {0}")]
19 Http(#[from] reqwest::Error),
20
21 #[error("API error: {status} - {message}")]
22 Api { status: u16, message: String },
23
24 #[error("Stream error: {0}")]
25 Stream(String),
26
27 #[error("Invalid response format: {0}")]
28 InvalidFormat(String),
29
30 #[error("Provider not available: {0}")]
31 ProviderUnavailable(String),
32
33 #[error("Rate limited")]
34 RateLimited,
35
36 #[error("Timeout")]
37 Timeout,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct Message {
45 pub role: Role,
46 pub content: String,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
51#[serde(rename_all = "lowercase")]
52pub enum Role {
53 System,
54 User,
55 Assistant,
56}
57
58#[derive(Debug, Clone)]
60pub struct ResponseChunk {
61 pub content: String,
62 pub is_done: bool,
63}
64
65#[derive(Debug, Clone)]
67pub struct Response {
68 pub content: String,
69 pub usage: Option<Usage>,
70}
71
72#[derive(Debug, Clone)]
74pub struct Usage {
75 pub prompt_tokens: u32,
76 pub completion_tokens: u32,
77 pub total_tokens: u32,
78}
79
80#[async_trait::async_trait]
84pub trait LlmProvider: Send + Sync {
85 async fn generate(&self, messages: &[Message]) -> Result<Response, LlmError>;
87
88 async fn generate_stream(
90 &self,
91 messages: &[Message],
92 ) -> Result<Pin<Box<dyn Stream<Item = Result<ResponseChunk, LlmError>> + Send>>, LlmError>;
93
94 async fn health_check(&self) -> bool;
96
97 fn name(&self) -> &str;
99}
100
101#[derive(Serialize)]
105struct OllamaRequest {
106 model: String,
107 messages: Vec<OllamaMessage>,
108 stream: bool,
109 options: Option<OllamaOptions>,
110}
111
112#[derive(Serialize, Deserialize)]
113struct OllamaMessage {
114 role: String,
115 content: String,
116}
117
118#[derive(Serialize)]
119#[serde(rename_all = "camelCase")]
120struct OllamaOptions {
121 temperature: f64,
122 num_predict: i32,
123}
124
125#[derive(Deserialize)]
127struct OllamaResponse {
128 message: Option<OllamaMessage>,
129 done: bool,
130 #[serde(default)]
131 prompt_eval_count: Option<u32>,
132 #[serde(default)]
133 eval_count: Option<u32>,
134}
135
136pub struct OllamaProvider {
138 client: reqwest::Client,
139 base_url: String,
140 model: String,
141 temperature: f64,
142 max_tokens: i32,
143}
144
145impl OllamaProvider {
146 pub fn new(
148 base_url: &str,
149 model: &str,
150 temperature: f64,
151 max_tokens: i32,
152 ) -> Result<Self, LlmError> {
153 let client = reqwest::Client::builder()
155 .timeout(std::time::Duration::from_secs(300))
156 .build()
157 .map_err(|e| {
158 LlmError::ProviderUnavailable(format!("Failed to create HTTP client: {e}"))
159 })?;
160
161 Ok(Self {
162 client,
163 base_url: base_url.trim_end_matches('/').to_string(),
164 model: model.to_string(),
165 temperature,
166 max_tokens,
167 })
168 }
169
170 pub fn default_config() -> Self {
172 Self::new("http://localhost:11434", "qwen2.5-coder:7b", 0.7, 4096)
173 .expect("Failed to initialise default Ollama HTTP client")
174 }
175
176 fn convert_messages(messages: &[Message]) -> Vec<OllamaMessage> {
177 messages
178 .iter()
179 .map(|m| OllamaMessage {
180 role: match m.role {
181 Role::System => "system".to_string(),
182 Role::User => "user".to_string(),
183 Role::Assistant => "assistant".to_string(),
184 },
185 content: m.content.clone(),
186 })
187 .collect()
188 }
189}
190
191#[async_trait::async_trait]
192impl LlmProvider for OllamaProvider {
193 async fn generate(&self, messages: &[Message]) -> Result<Response, LlmError> {
194 let url = format!("{}/api/chat", self.base_url);
195 let request = OllamaRequest {
196 model: self.model.clone(),
197 messages: Self::convert_messages(messages),
198 stream: false,
199 options: Some(OllamaOptions {
200 temperature: self.temperature,
201 num_predict: self.max_tokens,
202 }),
203 };
204
205 let resp = self.client.post(&url).json(&request).send().await?;
206
207 if !resp.status().is_success() {
208 let status = resp.status();
209 let body = resp.text().await.unwrap_or_default();
210 return Err(LlmError::Api {
211 status: status.as_u16(),
212 message: body,
213 });
214 }
215
216 let data: OllamaResponse = resp.json().await?;
217
218 let content = data.message.map(|m| m.content).unwrap_or_default();
219
220 Ok(Response {
221 content,
222 usage: Some(Usage {
223 prompt_tokens: data.prompt_eval_count.unwrap_or(0),
224 completion_tokens: data.eval_count.unwrap_or(0),
225 total_tokens: data.prompt_eval_count.unwrap_or(0) + data.eval_count.unwrap_or(0),
226 }),
227 })
228 }
229
230 async fn generate_stream(
231 &self,
232 messages: &[Message],
233 ) -> Result<Pin<Box<dyn Stream<Item = Result<ResponseChunk, LlmError>> + Send>>, LlmError> {
234 use futures::stream::try_unfold;
235
236 let url = format!("{}/api/chat", self.base_url);
237 let request = OllamaRequest {
238 model: self.model.clone(),
239 messages: Self::convert_messages(messages),
240 stream: true,
241 options: Some(OllamaOptions {
242 temperature: self.temperature,
243 num_predict: self.max_tokens,
244 }),
245 };
246
247 let resp = self.client.post(&url).json(&request).send().await?;
248
249 if !resp.status().is_success() {
250 let status = resp.status();
251 let body = resp.text().await.unwrap_or_default();
252 return Err(LlmError::Api {
253 status: status.as_u16(),
254 message: body,
255 });
256 }
257
258 let byte_stream = resp.bytes_stream();
259
260 let stream = try_unfold(
262 (Box::pin(byte_stream), String::new()),
263 |(mut byte_stream, mut buf)| async move {
264 use futures::TryStreamExt;
265
266 loop {
267 if let Some(newline_pos) = buf.find('\n') {
269 let line: String = buf[..newline_pos].to_string();
270 buf = buf[newline_pos + 1..].to_string();
271
272 let line = line.trim();
273 if line.is_empty() {
274 continue;
275 }
276
277 match serde_json::from_str::<OllamaResponse>(line) {
278 Ok(data) => {
279 let content = data.message.map(|m| m.content).unwrap_or_default();
280 let chunk = ResponseChunk {
281 content,
282 is_done: data.done,
283 };
284 if data.done {
285 return Ok(Some((chunk, (byte_stream, buf))));
286 }
287 return Ok(Some((chunk, (byte_stream, buf))));
288 }
289 Err(e) => {
290 return Err(LlmError::InvalidFormat(format!(
291 "Failed to parse streaming response: {e}"
292 )));
293 }
294 }
295 }
296
297 match byte_stream.try_next().await {
299 Ok(Some(bytes)) => {
300 buf.push_str(&String::from_utf8_lossy(&bytes));
301 }
302 Ok(None) => {
303 let remaining = buf.trim();
305 if !remaining.is_empty() {
306 if let Ok(data) = serde_json::from_str::<OllamaResponse>(remaining)
307 {
308 let content =
309 data.message.map(|m| m.content).unwrap_or_default();
310 return Ok(Some((
311 ResponseChunk {
312 content,
313 is_done: true,
314 },
315 (byte_stream, String::new()),
316 )));
317 }
318 }
319 return Ok(None);
320 }
321 Err(e) => return Err(LlmError::Http(e)),
322 }
323 }
324 },
325 );
326
327 Ok(Box::pin(stream))
328 }
329
330 async fn health_check(&self) -> bool {
331 let url = format!("{}/api/tags", self.base_url);
332 match self.client.get(&url).send().await {
333 Ok(resp) => resp.status().is_success(),
334 Err(_) => false,
335 }
336 }
337
338 fn name(&self) -> &str {
339 "ollama"
340 }
341}
342
343#[derive(Serialize)]
347struct OpenAiRequest {
348 model: String,
349 messages: Vec<OpenAiMessage>,
350 temperature: f64,
351 max_tokens: Option<i32>,
352 stream: bool,
353}
354
355#[derive(Serialize, Deserialize)]
356struct OpenAiMessage {
357 role: String,
358 content: String,
359}
360
361#[derive(Deserialize)]
363struct OpenAiResponse {
364 choices: Vec<OpenAiChoice>,
365 usage: Option<OpenAiUsage>,
366}
367
368#[derive(Deserialize)]
369struct OpenAiChoice {
370 message: OpenAiMessage,
371 #[allow(dead_code)]
372 finish_reason: Option<String>,
373}
374
375#[derive(Deserialize)]
377struct OpenAiStreamResponse {
378 choices: Vec<OpenAiStreamChoice>,
379}
380
381#[derive(Deserialize)]
382struct OpenAiStreamChoice {
383 delta: OpenAiDelta,
384 finish_reason: Option<String>,
385}
386
387#[derive(Deserialize)]
388struct OpenAiDelta {
389 #[serde(default)]
390 content: Option<String>,
391}
392
393#[derive(Deserialize)]
394struct OpenAiUsage {
395 prompt_tokens: u32,
396 completion_tokens: u32,
397 total_tokens: u32,
398}
399
400pub struct OpenAiProvider {
402 client: reqwest::Client,
403 base_url: String,
404 api_key: Option<String>,
405 model: String,
406 temperature: f64,
407 max_tokens: Option<i32>,
408}
409
410impl OpenAiProvider {
411 pub fn new(
413 base_url: &str,
414 api_key: Option<&str>,
415 model: &str,
416 temperature: f64,
417 max_tokens: Option<i32>,
418 ) -> Result<Self, LlmError> {
419 let client = reqwest::Client::builder()
420 .timeout(std::time::Duration::from_secs(300))
421 .build()
422 .map_err(|e| {
423 LlmError::ProviderUnavailable(format!("Failed to create HTTP client: {e}"))
424 })?;
425
426 Ok(Self {
427 client,
428 base_url: base_url.trim_end_matches('/').to_string(),
429 api_key: api_key.map(|s| s.to_string()),
430 model: model.to_string(),
431 temperature,
432 max_tokens,
433 })
434 }
435
436 pub fn openai(api_key: &str, model: &str) -> Self {
438 Self::new(
439 "https://api.openai.com/v1",
440 Some(api_key),
441 model,
442 0.7,
443 Some(4096),
444 )
445 .expect("Failed to initialise OpenAI HTTP client")
446 }
447
448 pub fn openrouter(api_key: &str, model: &str) -> Self {
450 Self::new(
451 "https://openrouter.ai/api/v1",
452 Some(api_key),
453 model,
454 0.7,
455 Some(4096),
456 )
457 .expect("Failed to initialise OpenRouter HTTP client")
458 }
459
460 fn convert_messages(messages: &[Message]) -> Vec<OpenAiMessage> {
461 messages
462 .iter()
463 .map(|m| OpenAiMessage {
464 role: match m.role {
465 Role::System => "system".to_string(),
466 Role::User => "user".to_string(),
467 Role::Assistant => "assistant".to_string(),
468 },
469 content: m.content.clone(),
470 })
471 .collect()
472 }
473
474 fn build_request(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
475 let mut builder = builder;
476 if let Some(key) = &self.api_key {
477 builder = builder.header("Authorization", format!("Bearer {}", key));
478 }
479 builder
480 }
481}
482
483#[async_trait::async_trait]
484impl LlmProvider for OpenAiProvider {
485 async fn generate(&self, messages: &[Message]) -> Result<Response, LlmError> {
486 let url = format!("{}/chat/completions", self.base_url);
487 let request = OpenAiRequest {
488 model: self.model.clone(),
489 messages: Self::convert_messages(messages),
490 temperature: self.temperature,
491 max_tokens: self.max_tokens,
492 stream: false,
493 };
494
495 let resp = self
496 .build_request(self.client.post(&url))
497 .json(&request)
498 .send()
499 .await?;
500
501 if !resp.status().is_success() {
502 let status = resp.status();
503 let body = resp.text().await.unwrap_or_default();
504 return Err(LlmError::Api {
505 status: status.as_u16(),
506 message: body,
507 });
508 }
509
510 let data: OpenAiResponse = resp.json().await?;
511 let content = data
512 .choices
513 .first()
514 .map(|c| c.message.content.clone())
515 .unwrap_or_default();
516
517 Ok(Response {
518 content,
519 usage: data.usage.map(|u| Usage {
520 prompt_tokens: u.prompt_tokens,
521 completion_tokens: u.completion_tokens,
522 total_tokens: u.total_tokens,
523 }),
524 })
525 }
526
527 async fn generate_stream(
528 &self,
529 messages: &[Message],
530 ) -> Result<Pin<Box<dyn Stream<Item = Result<ResponseChunk, LlmError>> + Send>>, LlmError> {
531 use futures::stream::try_unfold;
532
533 let url = format!("{}/chat/completions", self.base_url);
534 let request = OpenAiRequest {
535 model: self.model.clone(),
536 messages: Self::convert_messages(messages),
537 temperature: self.temperature,
538 max_tokens: self.max_tokens,
539 stream: true,
540 };
541
542 let resp = self
543 .build_request(self.client.post(&url))
544 .json(&request)
545 .send()
546 .await?;
547
548 if !resp.status().is_success() {
549 let status = resp.status();
550 let body = resp.text().await.unwrap_or_default();
551 return Err(LlmError::Api {
552 status: status.as_u16(),
553 message: body,
554 });
555 }
556
557 let byte_stream = resp.bytes_stream();
558
559 let stream = try_unfold(
561 (Box::pin(byte_stream), String::new()),
562 |(mut byte_stream, mut buf)| async move {
563 use futures::TryStreamExt;
564
565 loop {
566 if let Some(newline_pos) = buf.find('\n') {
568 let line: String = buf[..newline_pos].to_string();
569 buf = buf[newline_pos + 1..].to_string();
570
571 let line = line.trim();
572 if line.is_empty() {
573 continue;
574 }
575
576 if let Some(data) = line.strip_prefix("data: ") {
578 let data = data.trim();
579 if data == "[DONE]" {
580 return Ok(None);
581 }
582
583 match serde_json::from_str::<OpenAiStreamResponse>(data) {
584 Ok(resp) => {
585 if let Some(choice) = resp.choices.first() {
586 let content =
587 choice.delta.content.clone().unwrap_or_default();
588 let is_done = choice.finish_reason.is_some();
589 let chunk = ResponseChunk { content, is_done };
590 return Ok(Some((chunk, (byte_stream, buf))));
591 }
592 continue;
594 }
595 Err(e) => {
596 return Err(LlmError::InvalidFormat(format!(
597 "Failed to parse streaming response: {e}"
598 )));
599 }
600 }
601 }
602 continue;
604 }
605
606 match byte_stream.try_next().await {
608 Ok(Some(bytes)) => {
609 buf.push_str(&String::from_utf8_lossy(&bytes));
610 }
611 Ok(None) => return Ok(None),
612 Err(e) => return Err(LlmError::Http(e)),
613 }
614 }
615 },
616 );
617
618 Ok(Box::pin(stream))
619 }
620
621 async fn health_check(&self) -> bool {
622 let url = format!("{}/models", self.base_url);
623 match self.build_request(self.client.get(&url)).send().await {
624 Ok(resp) => resp.status().is_success(),
625 Err(_) => false,
626 }
627 }
628
629 fn name(&self) -> &str {
630 "openai"
631 }
632}
633
634#[derive(Debug, Clone)]
638pub struct ProviderConfig {
639 pub provider: String,
640 pub base_url: String,
641 pub api_key: Option<String>,
642 pub model: String,
643 pub temperature: f64,
644 pub max_tokens: i32,
645}
646
647impl Default for ProviderConfig {
648 fn default() -> Self {
649 Self {
650 provider: "ollama".to_string(),
651 base_url: "http://localhost:11434".to_string(),
652 api_key: None,
653 model: "qwen2.5-coder:7b".to_string(),
654 temperature: 0.7,
655 max_tokens: 4096,
656 }
657 }
658}
659
660pub fn create_provider(config: &ProviderConfig) -> Box<dyn LlmProvider> {
662 match config.provider.as_str() {
663 "ollama" => Box::new(
664 OllamaProvider::new(
665 &config.base_url,
666 &config.model,
667 config.temperature,
668 config.max_tokens,
669 )
670 .unwrap_or_else(|e| {
671 tracing::error!(error = %e, "Failed to create Ollama provider, falling back to default");
672 OllamaProvider::default_config()
673 }),
674 ),
675 "openai" => Box::new(
676 OpenAiProvider::new(
677 &config.base_url,
678 config.api_key.as_deref(),
679 &config.model,
680 config.temperature,
681 Some(config.max_tokens),
682 )
683 .expect("Failed to initialise OpenAI HTTP client"),
685 ),
686 _ => Box::new(OllamaProvider::default_config()),
687 }
688}
689
690#[cfg(test)]
693mod tests {
694 use super::*;
695
696 #[test]
697 fn test_provider_config_default() {
698 let config = ProviderConfig::default();
699 assert_eq!(config.provider, "ollama");
700 assert_eq!(config.model, "qwen2.5-coder:7b");
701 }
702
703 #[test]
704 fn test_ollama_provider_creation() {
705 let provider = OllamaProvider::new("http://localhost:11434", "llama3:8b", 0.5, 2048)
706 .expect("OllamaProvider::new should not fail in test");
707 assert_eq!(provider.name(), "ollama");
708 }
709
710 #[test]
711 fn test_openai_provider_creation() {
712 let provider = OpenAiProvider::openai("test-key", "gpt-4");
713 assert_eq!(provider.name(), "openai");
714 }
715
716 #[test]
717 fn test_openrouter_provider_creation() {
718 let provider = OpenAiProvider::openrouter("test-key", "anthropic/claude-3-opus");
719 assert_eq!(provider.name(), "openai");
720 }
721}