1use futures::Stream;
7use reqwest::{Client, Response};
8use serde::{Deserialize, Serialize};
9use std::pin::Pin;
10use std::sync::Arc;
11use std::time::Duration;
12use thiserror::Error;
13use tokio::time::sleep;
14use tracing::instrument;
15
16#[derive(Debug, Clone)]
18#[allow(dead_code)]
19pub struct StreamChunk {
20 pub content: String,
21 #[allow(dead_code)]
22 pub finish_reason: Option<String>,
23}
24
25#[allow(dead_code)]
27pub type StreamResult = Pin<Box<dyn Stream<Item = Result<StreamChunk, LLMError>> + Send>>;
28
29use crate::config::{LLMConfig, LLMProvider};
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36#[non_exhaustive]
37pub enum OpenAICompatibleProvider {
38 LiteLLM,
39 OpenAI,
40 OpenRouter,
41 Generic,
43 Azure,
45}
46
47impl OpenAICompatibleProvider {
48 pub fn default_endpoint(&self) -> &'static str {
50 match self {
51 OpenAICompatibleProvider::LiteLLM => "http://localhost:4000",
52 OpenAICompatibleProvider::OpenAI => "https://api.openai.com",
53 OpenAICompatibleProvider::OpenRouter => "https://openrouter.ai",
54 OpenAICompatibleProvider::Generic => "http://localhost:8000",
55 OpenAICompatibleProvider::Azure => "https://YOUR_RESOURCE.openai.azure.com",
56 }
57 }
58
59 pub fn name(&self) -> &'static str {
61 match self {
62 OpenAICompatibleProvider::LiteLLM => "litellm",
63 OpenAICompatibleProvider::OpenAI => "openai",
64 OpenAICompatibleProvider::OpenRouter => "openrouter",
65 OpenAICompatibleProvider::Generic => "openai-compatible",
66 OpenAICompatibleProvider::Azure => "azure",
67 }
68 }
69
70 #[allow(dead_code)]
72 pub fn requires_custom_headers(&self) -> bool {
73 matches!(
74 self,
75 OpenAICompatibleProvider::OpenRouter | OpenAICompatibleProvider::Azure
76 )
77 }
78}
79
80#[derive(Error, Debug)]
86#[non_exhaustive]
87pub enum LLMError {
88 #[error("Request failed: {0}")]
89 RequestFailed(String),
90
91 #[error("Invalid response: {0}")]
92 InvalidResponse(String),
93
94 #[error("Rate limit exceeded")]
95 RateLimited,
96
97 #[error("Authentication failed")]
98 AuthFailed,
99
100 #[error("Provider not supported: {0}")]
101 #[allow(dead_code)]
102 ProviderNotSupported(String),
103
104 #[error("Token budget exceeded")]
105 TokenBudgetExceeded,
106
107 #[error("All providers failed after retries")]
108 AllProvidersFailed,
109
110 #[error("Circuit breaker open for provider: {0}")]
111 CircuitBreakerOpen(String),
112}
113
114#[derive(Debug, Clone)]
116pub struct RetryConfig {
117 pub max_retries: u32,
119 pub base_delay_ms: u64,
121 pub max_delay_ms: u64,
123 pub jitter: f64,
125}
126
127impl Default for RetryConfig {
128 fn default() -> Self {
129 Self {
130 max_retries: 3,
131 base_delay_ms: 100,
132 max_delay_ms: 10000,
133 jitter: 0.5,
134 }
135 }
136}
137
138impl RetryConfig {
139 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
141 use rand::Rng;
142 let exp = 2u64.pow(attempt);
143 let base = self.base_delay_ms * exp;
144 let capped = base.min(self.max_delay_ms);
145 let jitter_range = (capped as f64) * self.jitter;
146 let jitter = rand::thread_rng().gen_range(-jitter_range..=jitter_range) as u64;
147 let delay = capped.saturating_add(jitter).max(self.base_delay_ms);
148 Duration::from_millis(delay)
149 }
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157#[non_exhaustive]
158pub enum CircuitState {
159 Closed,
160 Open,
161 HalfOpen,
162}
163
164#[derive(Debug)]
166pub struct CircuitBreaker {
167 pub state: CircuitState,
168 pub failure_count: u32,
169 pub last_failure_time: Option<std::time::Instant>,
170 pub open_duration: Duration,
171}
172
173impl CircuitBreaker {
174 pub fn new(open_duration_secs: u64) -> Self {
175 Self {
176 state: CircuitState::Closed,
177 failure_count: 0,
178 last_failure_time: None,
179 open_duration: Duration::from_secs(open_duration_secs),
180 }
181 }
182
183 pub fn record_success(&mut self) {
184 self.failure_count = 0;
185 self.state = CircuitState::Closed;
186 }
187
188 pub fn record_failure(&mut self) {
189 self.failure_count += 1;
190 self.last_failure_time = Some(std::time::Instant::now());
191 if self.failure_count >= 5 {
192 self.state = CircuitState::Open;
193 }
194 }
195
196 pub fn can_execute(&mut self) -> bool {
197 match self.state {
198 CircuitState::Closed => true,
199 CircuitState::Open => {
200 if let Some(last) = self.last_failure_time {
201 if last.elapsed() >= self.open_duration {
202 self.state = CircuitState::HalfOpen;
203 return true;
204 }
205 }
206 false
207 }
208 CircuitState::HalfOpen => true,
209 }
210 }
211}
212
213#[derive(Debug, Clone)]
215pub struct TokenBudget {
216 pub max_tokens: u32,
218 pub used_tokens: u32,
220 pub cost_per_1k: f64,
222}
223
224#[allow(dead_code)]
225impl TokenBudget {
226 pub fn new(max_tokens: u32, cost_per_1k: f64) -> Self {
227 Self {
228 max_tokens,
229 used_tokens: 0,
230 cost_per_1k,
231 }
232 }
233
234 pub fn remaining(&self) -> u32 {
235 self.max_tokens.saturating_sub(self.used_tokens)
236 }
237
238 pub fn can_spend(&self, tokens: u32) -> bool {
239 self.remaining() >= tokens
240 }
241
242 pub fn record_usage(&mut self, tokens: u32) {
243 self.used_tokens = self.used_tokens.saturating_add(tokens);
244 }
245
246 pub fn estimated_cost(&self) -> f64 {
247 (self.used_tokens as f64 / 1000.0) * self.cost_per_1k
248 }
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
256#[serde(untagged)]
257#[non_exhaustive]
258pub enum ContentPart {
259 Text { text: String },
261 ImageUrl {
264 #[serde(rename = "image_url")]
265 image_url: ImageUrlContent,
266 },
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct ImageUrlContent {
272 pub url: String,
273}
274
275impl ContentPart {
276 pub fn text(text: impl Into<String>) -> Self {
278 Self::Text { text: text.into() }
279 }
280
281 pub fn image(data_uri: impl Into<String>) -> Self {
283 Self::ImageUrl {
284 image_url: ImageUrlContent {
285 url: data_uri.into(),
286 },
287 }
288 }
289
290 pub fn to_openai_value(&self) -> serde_json::Value {
292 match self {
293 ContentPart::Text { text } => serde_json::json!({
294 "type": "text",
295 "text": text
296 }),
297 ContentPart::ImageUrl { image_url } => serde_json::json!({
298 "type": "image_url",
299 "image_url": {
300 "url": image_url.url
301 }
302 }),
303 }
304 }
305
306 pub fn to_anthropic_value(&self) -> serde_json::Value {
308 match self {
309 ContentPart::Text { text } => serde_json::json!({
310 "type": "text",
311 "text": text
312 }),
313 ContentPart::ImageUrl { image_url } => {
314 let url = &image_url.url;
316 if let Some(rest) = url.strip_prefix("data:") {
317 if let Some((media_type, data)) = rest.split_once(";base64,") {
318 return serde_json::json!({
319 "type": "image",
320 "source": {
321 "type": "base64",
322 "media_type": media_type,
323 "data": data
324 }
325 });
326 }
327 }
328 serde_json::json!({
330 "type": "text",
331 "text": format!("[Image: {}]", url.chars().take(100).collect::<String>())
332 })
333 }
334 }
335 }
336}
337
338pub fn load_image(path: &std::path::Path) -> Result<String, crate::error::RavenClawsError> {
345 let data = std::fs::read(path).map_err(crate::error::RavenClawsError::IO)?;
346
347 let mime = match path
348 .extension()
349 .and_then(|e| e.to_str())
350 .map(|e| e.to_lowercase())
351 .as_deref()
352 {
353 Some("png") => "image/png",
354 Some("jpg") | Some("jpeg") => "image/jpeg",
355 Some("gif") => "image/gif",
356 Some("webp") => "image/webp",
357 _ => {
358 return Err(crate::error::RavenClawsError::CommandExecution(format!(
359 "Unsupported image format: '{}'. Supported: png, jpg, jpeg, gif, webp",
360 path.display()
361 )));
362 }
363 };
364
365 let encoded = base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &data);
366 Ok(format!("data:{};base64,{}", mime, encoded))
367}
368
369#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct ChatMessage {
371 pub role: String,
372 pub content: String,
373 #[serde(skip_serializing_if = "Option::is_none")]
376 pub content_parts: Option<Vec<ContentPart>>,
377}
378
379impl ChatMessage {
380 pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
382 Self {
383 role: role.into(),
384 content: content.into(),
385 content_parts: None,
386 }
387 }
388
389 pub fn with_images(
391 role: impl Into<String>,
392 text: impl Into<String>,
393 image_data_uris: Vec<String>,
394 ) -> Self {
395 let text = text.into();
396 let mut parts = Vec::with_capacity(1 + image_data_uris.len());
397 parts.push(ContentPart::text(&text));
398 for uri in image_data_uris {
399 parts.push(ContentPart::image(uri));
400 }
401 Self {
402 role: role.into(),
403 content: text.clone(),
404 content_parts: Some(parts),
405 }
406 }
407
408 pub fn to_openai_message(&self) -> serde_json::Value {
411 match &self.content_parts {
412 Some(parts) => {
413 let content_array: Vec<serde_json::Value> =
414 parts.iter().map(|p| p.to_openai_value()).collect();
415 serde_json::json!({
416 "role": self.role,
417 "content": content_array
418 })
419 }
420 None => {
421 serde_json::json!({
422 "role": self.role,
423 "content": self.content
424 })
425 }
426 }
427 }
428
429 pub fn to_anthropic_message(&self) -> serde_json::Value {
431 match &self.content_parts {
432 Some(parts) => {
433 let content_array: Vec<serde_json::Value> =
434 parts.iter().map(|p| p.to_anthropic_value()).collect();
435 serde_json::json!({
436 "role": self.role,
437 "content": content_array
438 })
439 }
440 None => {
441 serde_json::json!({
442 "role": self.role,
443 "content": self.content
444 })
445 }
446 }
447 }
448
449 pub fn ollama_images(&self) -> Option<Vec<String>> {
452 let parts = self.content_parts.as_ref()?;
453 let images: Vec<String> = parts
454 .iter()
455 .filter_map(|p| match p {
456 ContentPart::ImageUrl { image_url } => {
457 let url = &image_url.url;
458 url.strip_prefix("data:")
459 .and_then(|rest| rest.split_once(";base64,").map(|x| x.1))
460 .map(|s| s.to_string())
461 }
462 _ => None,
463 })
464 .collect();
465 if images.is_empty() {
466 None
467 } else {
468 Some(images)
469 }
470 }
471}
472
473#[derive(Debug, Clone, Serialize)]
474pub struct ChatRequest {
475 pub model: String,
476 #[serde(serialize_with = "serialize_messages_openai")]
477 pub messages: Vec<ChatMessage>,
478 #[serde(skip_serializing_if = "Option::is_none")]
479 pub temperature: Option<f32>,
480 #[serde(skip_serializing_if = "Option::is_none")]
481 pub max_tokens: Option<u32>,
482 #[serde(skip_serializing_if = "Option::is_none")]
483 pub stream: Option<bool>,
484 #[serde(skip_serializing_if = "Option::is_none")]
485 pub tools: Option<Vec<serde_json::Value>>,
486 #[serde(skip_serializing_if = "Option::is_none")]
487 pub tool_choice: Option<String>,
488}
489
490fn serialize_messages_openai<S>(messages: &[ChatMessage], serializer: S) -> Result<S::Ok, S::Error>
492where
493 S: serde::Serializer,
494{
495 use serde::ser::SerializeSeq;
496 let mut seq = serializer.serialize_seq(Some(messages.len()))?;
497 for msg in messages {
498 let value = msg.to_openai_message();
499 seq.serialize_element(&value)?;
500 }
501 seq.end()
502}
503
504#[derive(Debug, Clone, Deserialize)]
505pub struct ChatResponse {
506 #[allow(dead_code)]
507 pub id: String,
508 #[allow(dead_code)]
509 pub object: String,
510 #[allow(dead_code)]
511 pub created: u64,
512 #[allow(dead_code)]
513 pub model: String,
514 pub choices: Vec<Choice>,
515 #[allow(dead_code)]
516 pub usage: Option<Usage>,
517}
518
519#[derive(Debug, Clone, Deserialize)]
520pub struct ToolCallResponse {
521 #[allow(dead_code)]
522 pub id: String,
523 #[allow(dead_code)]
524 #[serde(rename = "type")]
525 pub call_type: String,
526 pub function: FunctionCall,
527}
528
529#[derive(Debug, Clone, Deserialize)]
530pub struct FunctionCall {
531 pub name: String,
532 pub arguments: String,
533}
534
535#[derive(Debug, Clone, Deserialize)]
536pub struct Choice {
537 #[allow(dead_code)]
538 pub index: u32,
539 pub message: ChatMessage,
540 #[allow(dead_code)]
541 pub finish_reason: Option<String>,
542 #[serde(default, skip_serializing_if = "Option::is_none")]
543 pub tool_calls: Option<Vec<ToolCallResponse>>,
544}
545
546#[derive(Debug, Clone, Deserialize)]
547pub struct Usage {
548 #[allow(dead_code)]
549 pub prompt_tokens: u32,
550 #[allow(dead_code)]
551 pub completion_tokens: u32,
552 #[allow(dead_code)]
553 pub total_tokens: u32,
554}
555
556#[async_trait::async_trait]
558pub trait LLMProviderTrait: Send + Sync {
559 async fn chat(&self, messages: Vec<ChatMessage>) -> Result<ChatResponse, LLMError>;
560 #[allow(dead_code)]
561 async fn chat_stream(&self, messages: Vec<ChatMessage>) -> Result<StreamResult, LLMError> {
562 let response = self.chat(messages).await?;
564 let content = response
565 .choices
566 .first()
567 .map(|c| c.message.content.clone())
568 .unwrap_or_default();
569 let finish_reason = response
570 .choices
571 .first()
572 .and_then(|c| c.finish_reason.clone());
573
574 let stream = futures::stream::once(async move {
575 Ok(StreamChunk {
576 content,
577 finish_reason,
578 })
579 });
580 Ok(Box::pin(stream))
581 }
582 fn provider_name(&self) -> &str;
583 fn model(&self) -> &str;
584}
585
586async fn handle_openai_response(response: Response) -> Result<ChatResponse, LLMError> {
588 let status = response.status();
589
590 if status.is_success() {
591 response
592 .json::<ChatResponse>()
593 .await
594 .map_err(|e| LLMError::InvalidResponse(e.to_string()))
595 } else if status == reqwest::StatusCode::UNAUTHORIZED {
596 Err(LLMError::AuthFailed)
597 } else if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
598 Err(LLMError::RateLimited)
599 } else {
600 let body = response
601 .text()
602 .await
603 .unwrap_or_else(|_| "Unknown error".to_string());
604 Err(LLMError::RequestFailed(format!("{}: {}", status, body)))
605 }
606}
607
608pub struct OpenAICompatibleClient {
611 client: Client,
612 config: LLMConfig,
613 provider: OpenAICompatibleProvider,
614 retry_config: RetryConfig,
615 circuit_breaker: std::sync::Mutex<CircuitBreaker>,
616}
617
618impl OpenAICompatibleClient {
619 pub fn new(config: &LLMConfig, provider: OpenAICompatibleProvider) -> Result<Self, LLMError> {
620 let client = Client::builder()
621 .timeout(std::time::Duration::from_secs(config.timeout_secs))
622 .build()
623 .map_err(|e| LLMError::RequestFailed(format!("Failed to create HTTP client: {}", e)))?;
624
625 let retry_config = RetryConfig {
626 max_retries: config.retry_max,
627 base_delay_ms: config.retry_base_delay_ms,
628 max_delay_ms: config.retry_max_delay_ms,
629 jitter: 0.5,
630 };
631
632 Ok(Self {
633 client,
634 config: config.clone(),
635 provider,
636 retry_config,
637 circuit_breaker: std::sync::Mutex::new(CircuitBreaker::new(30)),
638 })
639 }
640
641 async fn send_request_with_retry(
643 &self,
644 request: ChatRequest,
645 ) -> Result<ChatResponse, LLMError> {
646 let mut last_error = None;
647
648 for attempt in 0..=self.retry_config.max_retries {
649 {
651 let mut cb = self.circuit_breaker.lock().map_err(|_| {
652 LLMError::RequestFailed("Circuit breaker lock poisoned".to_string())
653 })?;
654 if !cb.can_execute() {
655 return Err(LLMError::CircuitBreakerOpen(
656 self.provider.name().to_string(),
657 ));
658 }
659 }
660
661 let result = self.send_request_inner(request.clone()).await;
662
663 match result {
664 Ok(response) => {
665 {
667 let mut cb = self.circuit_breaker.lock().map_err(|_| {
668 LLMError::RequestFailed("Circuit breaker lock poisoned".to_string())
669 })?;
670 cb.record_success();
671 }
672 return Ok(response);
673 }
674 Err(e) => {
675 {
677 let mut cb = self.circuit_breaker.lock().map_err(|_| {
678 LLMError::RequestFailed("Circuit breaker lock poisoned".to_string())
679 })?;
680 cb.record_failure();
681 }
682
683 last_error = Some(e);
684
685 if matches!(last_error, Some(LLMError::AuthFailed)) {
687 return Err(last_error.unwrap());
688 }
689
690 if attempt < self.retry_config.max_retries {
692 let delay = self.retry_config.delay_for_attempt(attempt);
693 sleep(delay).await;
694 }
695 }
696 }
697 }
698
699 Err(last_error.unwrap_or(LLMError::AllProvidersFailed))
700 }
701
702 async fn send_request_inner(&self, request: ChatRequest) -> Result<ChatResponse, LLMError> {
704 let req = self.apply_headers(self.client.post(self.endpoint()).json(&request));
705
706 let response = req
707 .send()
708 .await
709 .map_err(|e| LLMError::RequestFailed(e.to_string()))?;
710
711 handle_openai_response(response).await
712 }
713
714 fn build_request(&self, messages: Vec<ChatMessage>) -> ChatRequest {
715 ChatRequest {
716 model: self.config.model.clone(),
717 messages,
718 temperature: Some(0.7),
719 max_tokens: Some(2048),
720 stream: None,
721 tools: None,
722 tool_choice: None,
723 }
724 }
725
726 fn endpoint(&self) -> String {
727 let base = if self.config.endpoint.is_empty() {
728 self.provider.default_endpoint()
729 } else {
730 &self.config.endpoint
731 };
732 let mut url = format!("{}/v1/chat/completions", base.trim_end_matches('/'));
733 if self.provider == OpenAICompatibleProvider::Azure {
734 if !url.contains("api-version") {
737 url = format!("{}?api-version=2024-02-15-preview", url);
738 }
739 }
740 url
741 }
742
743 fn apply_headers(&self, mut req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
744 if let Some(ref key) = self.config.api_key {
745 if self.provider == OpenAICompatibleProvider::Azure {
746 req = req.header("api-key", key);
748 } else {
749 req = req.header("Authorization", format!("Bearer {}", key));
750 }
751 }
752
753 if self.provider == OpenAICompatibleProvider::OpenRouter {
755 req = req
756 .header("HTTP-Referer", "https://github.com/egkristi/RavenClaws")
757 .header("X-Title", "RavenClaws");
758 }
759
760 req
761 }
762
763 #[allow(dead_code)]
764 async fn send_request(&self, request: ChatRequest) -> Result<ChatResponse, LLMError> {
765 let req = self.apply_headers(self.client.post(self.endpoint()).json(&request));
766
767 let response = req
768 .send()
769 .await
770 .map_err(|e| LLMError::RequestFailed(e.to_string()))?;
771
772 handle_openai_response(response).await
773 }
774}
775
776#[async_trait::async_trait]
777impl LLMProviderTrait for OpenAICompatibleClient {
778 #[instrument(skip(self, messages), fields(provider = self.provider_name(), model = self.model()))]
779 async fn chat(&self, messages: Vec<ChatMessage>) -> Result<ChatResponse, LLMError> {
780 let request = self.build_request(messages);
781 self.send_request_with_retry(request).await
782 }
783
784 async fn chat_stream(&self, messages: Vec<ChatMessage>) -> Result<StreamResult, LLMError> {
785 let request = ChatRequest {
786 model: self.config.model.clone(),
787 messages,
788 temperature: Some(0.7),
789 max_tokens: Some(2048),
790 stream: Some(true),
791 tools: None,
792 tool_choice: None,
793 };
794
795 let req = self.apply_headers(self.client.post(self.endpoint()).json(&request));
796
797 let response = req
798 .send()
799 .await
800 .map_err(|e| LLMError::RequestFailed(e.to_string()))?;
801
802 let status = response.status();
803 if !status.is_success() {
804 if status == reqwest::StatusCode::UNAUTHORIZED {
805 return Err(LLMError::AuthFailed);
806 } else if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
807 return Err(LLMError::RateLimited);
808 } else {
809 let body = response
810 .text()
811 .await
812 .unwrap_or_else(|_| "Unknown error".to_string());
813 return Err(LLMError::RequestFailed(format!("{}: {}", status, body)));
814 }
815 }
816
817 use futures::StreamExt;
819 let stream = response
820 .bytes_stream()
821 .filter_map(|chunk_result| async move {
822 match chunk_result {
823 Err(e) => Some(Err(LLMError::RequestFailed(e.to_string()))),
824 Ok(bytes) => {
825 let text = String::from_utf8_lossy(&bytes);
826 let mut content = String::new();
827 let mut finish_reason = None;
828
829 for line in text.lines() {
830 if let Some(data) = line.strip_prefix("data: ") {
831 if data == "[DONE]" {
832 finish_reason = Some("stop".to_string());
833 continue;
834 }
835 if let Ok(sse_chunk) =
836 serde_json::from_str::<serde_json::Value>(data)
837 {
838 if let Some(choice) =
839 sse_chunk["choices"].as_array().and_then(|c| c.first())
840 {
841 if let Some(delta) = choice["delta"].as_object() {
842 if let Some(c) = delta["content"].as_str() {
843 content.push_str(c);
844 }
845 }
846 if let Some(reason) = choice["finish_reason"].as_str() {
847 if reason != "null" {
848 finish_reason = Some(reason.to_string());
849 }
850 }
851 }
852 }
853 }
854 }
855
856 if content.is_empty() && finish_reason.is_none() {
857 None
858 } else {
859 Some(Ok(StreamChunk {
860 content,
861 finish_reason,
862 }))
863 }
864 }
865 }
866 });
867
868 Ok(Box::pin(stream))
869 }
870
871 fn provider_name(&self) -> &str {
872 self.provider.name()
873 }
874
875 fn model(&self) -> &str {
876 &self.config.model
877 }
878}
879
880pub struct OllamaClient {
882 client: Client,
883 config: LLMConfig,
884}
885
886impl OllamaClient {
887 pub fn new(config: &LLMConfig) -> Result<Self, LLMError> {
888 let client = Client::builder()
889 .timeout(std::time::Duration::from_secs(config.timeout_secs))
890 .build()
891 .map_err(|e| LLMError::RequestFailed(format!("Failed to create HTTP client: {}", e)))?;
892
893 Ok(Self {
894 client,
895 config: config.clone(),
896 })
897 }
898}
899
900#[async_trait::async_trait]
901impl LLMProviderTrait for OllamaClient {
902 #[instrument(skip(self, messages), fields(provider = self.provider_name(), model = self.model()))]
903 async fn chat(&self, messages: Vec<ChatMessage>) -> Result<ChatResponse, LLMError> {
904 #[derive(Serialize)]
908 struct OllamaRequest {
909 model: String,
910 #[serde(serialize_with = "serialize_messages_ollama")]
911 messages: Vec<ChatMessage>,
912 stream: bool,
913 }
914
915 fn serialize_messages_ollama<S>(
917 messages: &[ChatMessage],
918 serializer: S,
919 ) -> Result<S::Ok, S::Error>
920 where
921 S: serde::Serializer,
922 {
923 use serde::ser::SerializeSeq;
924 let mut seq = serializer.serialize_seq(Some(messages.len()))?;
925 for msg in messages {
926 let value = if let Some(images) = msg.ollama_images() {
927 serde_json::json!({
928 "role": msg.role,
929 "content": msg.content,
930 "images": images
931 })
932 } else {
933 serde_json::json!({
934 "role": msg.role,
935 "content": msg.content
936 })
937 };
938 seq.serialize_element(&value)?;
939 }
940 seq.end()
941 }
942
943 let request = OllamaRequest {
944 model: self.config.model.clone(),
945 messages,
946 stream: false,
947 };
948
949 let response = self
950 .client
951 .post(format!(
952 "{}/api/chat",
953 self.config.endpoint.trim_end_matches('/')
954 ))
955 .json(&request)
956 .send()
957 .await
958 .map_err(|e| LLMError::RequestFailed(e.to_string()))?;
959
960 let status = response.status();
961
962 if status.is_success() {
963 #[derive(Deserialize)]
965 struct OllamaResponse {
966 model: String,
967 message: ChatMessage,
968 done: bool,
969 }
970
971 let ollama_resp = response
972 .json::<OllamaResponse>()
973 .await
974 .map_err(|e| LLMError::InvalidResponse(e.to_string()))?;
975
976 Ok(ChatResponse {
977 id: format!("ollama-{}", uuid::Uuid::new_v4()),
978 object: "chat.completion".to_string(),
979 created: std::time::SystemTime::now()
980 .duration_since(std::time::UNIX_EPOCH)
981 .unwrap()
982 .as_secs(),
983 model: ollama_resp.model,
984 choices: vec![Choice {
985 index: 0,
986 message: ollama_resp.message,
987 finish_reason: if ollama_resp.done {
988 Some("stop".to_string())
989 } else {
990 None
991 },
992 tool_calls: None,
993 }],
994 usage: None, })
996 } else if status == reqwest::StatusCode::UNAUTHORIZED {
997 Err(LLMError::AuthFailed)
998 } else {
999 let body = response
1000 .text()
1001 .await
1002 .unwrap_or_else(|_| "Unknown error".to_string());
1003 Err(LLMError::RequestFailed(format!("{}: {}", status, body)))
1004 }
1005 }
1006
1007 fn provider_name(&self) -> &str {
1008 "ollama"
1009 }
1010
1011 fn model(&self) -> &str {
1012 &self.config.model
1013 }
1014}
1015
1016pub struct AnthropicClient {
1018 client: Client,
1019 config: LLMConfig,
1020}
1021
1022impl AnthropicClient {
1023 pub fn new(config: &LLMConfig) -> Result<Self, LLMError> {
1024 let client = Client::builder()
1025 .timeout(std::time::Duration::from_secs(config.timeout_secs))
1026 .build()
1027 .map_err(|e| LLMError::RequestFailed(format!("Failed to create HTTP client: {}", e)))?;
1028
1029 Ok(Self {
1030 client,
1031 config: config.clone(),
1032 })
1033 }
1034}
1035
1036#[async_trait::async_trait]
1037impl LLMProviderTrait for AnthropicClient {
1038 #[instrument(skip(self, messages), fields(provider = self.provider_name(), model = self.model()))]
1039 async fn chat(&self, messages: Vec<ChatMessage>) -> Result<ChatResponse, LLMError> {
1040 #[derive(Serialize)]
1042 struct AnthropicRequest {
1043 model: String,
1044 max_tokens: u32,
1045 #[serde(serialize_with = "serialize_anthropic_messages")]
1046 messages: Vec<ChatMessage>,
1047 #[serde(skip_serializing_if = "Option::is_none")]
1048 system: Option<String>,
1049 #[serde(skip_serializing_if = "Option::is_none")]
1050 temperature: Option<f32>,
1051 }
1052
1053 fn serialize_anthropic_messages<S>(
1055 messages: &[ChatMessage],
1056 serializer: S,
1057 ) -> Result<S::Ok, S::Error>
1058 where
1059 S: serde::Serializer,
1060 {
1061 use serde::ser::SerializeSeq;
1062 let mut seq = serializer.serialize_seq(Some(messages.len()))?;
1063 for msg in messages {
1064 let value = msg.to_anthropic_message();
1065 seq.serialize_element(&value)?;
1066 }
1067 seq.end()
1068 }
1069
1070 let system = messages
1072 .iter()
1073 .find(|m| m.role == "system")
1074 .map(|m| m.content.clone());
1075
1076 let anthropic_messages: Vec<ChatMessage> = messages
1077 .into_iter()
1078 .filter(|m| m.role != "system")
1079 .collect();
1080
1081 let request = AnthropicRequest {
1082 model: self.config.model.clone(),
1083 max_tokens: 2048,
1084 messages: anthropic_messages,
1085 system,
1086 temperature: Some(0.7),
1087 };
1088
1089 let api_key = self
1090 .config
1091 .api_key
1092 .clone()
1093 .ok_or_else(|| LLMError::AuthFailed)?;
1094
1095 let response = self
1096 .client
1097 .post("https://api.anthropic.com/v1/messages")
1098 .header("x-api-key", api_key)
1099 .header("anthropic-version", "2023-06-01")
1100 .header("content-type", "application/json")
1101 .json(&request)
1102 .send()
1103 .await
1104 .map_err(|e| LLMError::RequestFailed(e.to_string()))?;
1105
1106 let status = response.status();
1107
1108 if status.is_success() {
1109 #[derive(Deserialize)]
1111 #[allow(dead_code)]
1112 struct AnthropicResponse {
1113 id: String,
1114 #[serde(rename = "type")]
1115 response_type: String,
1116 role: String,
1117 content: Vec<AnthropicContentBlock>,
1118 model: String,
1119 stop_reason: Option<String>,
1120 #[serde(default)]
1121 usage: Option<AnthropicUsage>,
1122 }
1123
1124 #[derive(Deserialize)]
1125 #[serde(tag = "type", rename_all = "lowercase")]
1126 enum AnthropicContentBlock {
1127 Text {
1128 text: String,
1129 },
1130 ToolUse {
1131 id: String,
1132 name: String,
1133 input: serde_json::Value,
1134 },
1135 }
1136
1137 #[derive(Deserialize)]
1138 struct AnthropicUsage {
1139 input_tokens: u32,
1140 output_tokens: u32,
1141 }
1142
1143 let anthropic_resp = response
1144 .json::<AnthropicResponse>()
1145 .await
1146 .map_err(|e| LLMError::InvalidResponse(e.to_string()))?;
1147
1148 let mut content = String::new();
1150 let mut tool_calls = None;
1151
1152 for block in anthropic_resp.content {
1153 match block {
1154 AnthropicContentBlock::Text { text } => {
1155 content.push_str(&text);
1156 }
1157 AnthropicContentBlock::ToolUse { id, name, input } => {
1158 if tool_calls.is_none() {
1159 tool_calls = Some(Vec::new());
1160 }
1161 if let Some(ref mut calls) = tool_calls {
1162 calls.push(ToolCallResponse {
1163 id,
1164 call_type: "function".to_string(),
1165 function: FunctionCall {
1166 name,
1167 arguments: input.to_string(),
1168 },
1169 });
1170 }
1171 }
1172 }
1173 }
1174
1175 Ok(ChatResponse {
1176 id: anthropic_resp.id,
1177 object: "chat.completion".to_string(),
1178 created: std::time::SystemTime::now()
1179 .duration_since(std::time::UNIX_EPOCH)
1180 .unwrap()
1181 .as_secs(),
1182 model: anthropic_resp.model,
1183 choices: vec![Choice {
1184 index: 0,
1185 message: ChatMessage {
1186 role: "assistant".to_string(),
1187 content,
1188 content_parts: None,
1189 },
1190 finish_reason: anthropic_resp.stop_reason,
1191 tool_calls,
1192 }],
1193 usage: anthropic_resp.usage.map(|u| Usage {
1194 prompt_tokens: u.input_tokens,
1195 completion_tokens: u.output_tokens,
1196 total_tokens: u.input_tokens + u.output_tokens,
1197 }),
1198 })
1199 } else if status == reqwest::StatusCode::UNAUTHORIZED {
1200 Err(LLMError::AuthFailed)
1201 } else if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
1202 Err(LLMError::RateLimited)
1203 } else {
1204 let body = response
1205 .text()
1206 .await
1207 .unwrap_or_else(|_| "Unknown error".to_string());
1208 Err(LLMError::RequestFailed(format!("{}: {}", status, body)))
1209 }
1210 }
1211
1212 fn provider_name(&self) -> &str {
1213 "anthropic"
1214 }
1215
1216 fn model(&self) -> &str {
1217 &self.config.model
1218 }
1219}
1220
1221pub fn create_client(config: &LLMConfig) -> Result<Arc<dyn LLMProviderTrait>, LLMError> {
1223 match config.provider {
1224 LLMProvider::LiteLLM => {
1225 let unified = OpenAICompatibleClient::new(config, OpenAICompatibleProvider::LiteLLM)?;
1226 Ok(Arc::new(unified))
1227 }
1228 LLMProvider::OpenRouter => {
1229 let unified =
1230 OpenAICompatibleClient::new(config, OpenAICompatibleProvider::OpenRouter)?;
1231 Ok(Arc::new(unified))
1232 }
1233 LLMProvider::Ollama => Ok(Arc::new(OllamaClient::new(config)?)),
1234 LLMProvider::OpenAI => {
1235 let unified = OpenAICompatibleClient::new(config, OpenAICompatibleProvider::OpenAI)?;
1236 Ok(Arc::new(unified))
1237 }
1238 LLMProvider::Anthropic => Ok(Arc::new(AnthropicClient::new(config)?)),
1239 LLMProvider::OpenAICompatible => {
1240 let unified = OpenAICompatibleClient::new(config, OpenAICompatibleProvider::Generic)?;
1241 Ok(Arc::new(unified))
1242 }
1243 LLMProvider::Azure => {
1244 let unified = OpenAICompatibleClient::new(config, OpenAICompatibleProvider::Azure)?;
1245 Ok(Arc::new(unified))
1246 }
1247 }
1248}
1249
1250#[derive(Clone)]
1252pub struct MultiModelManager {
1253 clients: Vec<Arc<dyn LLMProviderTrait>>,
1254}
1255
1256impl MultiModelManager {
1257 pub fn new(configs: Vec<LLMConfig>) -> Result<Self, LLMError> {
1258 let clients: Result<Vec<_>, _> = configs.iter().map(create_client).collect();
1259 Ok(Self { clients: clients? })
1260 }
1261
1262 pub fn get_client(&self, index: usize) -> Option<&Arc<dyn LLMProviderTrait>> {
1263 self.clients.get(index)
1264 }
1265
1266 pub fn client_count(&self) -> usize {
1267 self.clients.len()
1268 }
1269
1270 pub fn next_client(&self, last_index: usize) -> Option<&Arc<dyn LLMProviderTrait>> {
1272 if self.clients.is_empty() {
1273 return None;
1274 }
1275 let next = (last_index + 1) % self.clients.len();
1276 Some(&self.clients[next])
1277 }
1278}
1279
1280#[derive(Debug)]
1282pub struct ProviderFallbackChain {
1283 pub configs: Vec<LLMConfig>,
1285 token_budget: Option<TokenBudget>,
1286}
1287
1288impl ProviderFallbackChain {
1289 pub fn new(configs: Vec<LLMConfig>) -> Self {
1290 Self {
1291 configs,
1292 token_budget: None,
1293 }
1294 }
1295
1296 pub fn with_token_budget(mut self, budget: TokenBudget) -> Self {
1297 self.token_budget = Some(budget);
1298 self
1299 }
1300
1301 #[instrument(skip(self, messages))]
1303 pub async fn chat_with_fallback(
1304 &mut self,
1305 messages: Vec<ChatMessage>,
1306 ) -> Result<ChatResponse, LLMError> {
1307 let mut last_error = None;
1308
1309 for (i, config) in self.configs.iter().enumerate() {
1310 let client = match create_client(config) {
1311 Ok(c) => c,
1312 Err(e) => {
1313 tracing::warn!(
1314 "Failed to create client for provider {:?}: {}",
1315 config.provider,
1316 e
1317 );
1318 last_error = Some(e);
1319 continue;
1320 }
1321 };
1322
1323 if let Some(ref budget) = self.token_budget {
1325 if !budget.can_spend(500) {
1327 return Err(LLMError::TokenBudgetExceeded);
1328 }
1329 }
1330
1331 match client.chat(messages.clone()).await {
1332 Ok(response) => {
1333 if let Some(ref mut budget) = self.token_budget {
1335 if let Some(usage) = &response.usage {
1336 budget.record_usage(usage.total_tokens);
1337 }
1338 }
1339 return Ok(response);
1340 }
1341 Err(e) => {
1342 tracing::warn!("Provider {} failed: {}", i, e);
1343 last_error = Some(e);
1344 }
1346 }
1347 }
1348
1349 Err(last_error.unwrap_or(LLMError::AllProvidersFailed))
1350 }
1351
1352 #[allow(dead_code)]
1354 pub fn provider_names(&self) -> Vec<String> {
1355 self.configs
1356 .iter()
1357 .map(|c| format!("{:?}", c.provider))
1358 .collect()
1359 }
1360}
1361
1362#[cfg(test)]
1363mod tests {
1364 use super::*;
1365 use mockito::Server;
1366
1367 fn make_chat_messages() -> Vec<ChatMessage> {
1370 vec![
1371 ChatMessage::new("system", "You are helpful."),
1372 ChatMessage::new("user", "Hello!"),
1373 ]
1374 }
1375
1376 fn sample_chat_response_json(model: &str) -> String {
1377 format!(
1378 r#"{{
1379 "id": "chat-123",
1380 "object": "chat.completion",
1381 "created": 1717000000,
1382 "model": "{}",
1383 "choices": [
1384 {{
1385 "index": 0,
1386 "message": {{
1387 "role": "assistant",
1388 "content": "Hi there!"
1389 }},
1390 "finish_reason": "stop"
1391 }}
1392 ],
1393 "usage": {{
1394 "prompt_tokens": 10,
1395 "completion_tokens": 5,
1396 "total_tokens": 15
1397 }}
1398 }}"#,
1399 model
1400 )
1401 }
1402
1403 fn sample_ollama_response_json(model: &str) -> String {
1404 format!(
1405 r#"{{
1406 "model": "{}",
1407 "message": {{
1408 "role": "assistant",
1409 "content": "Hi there!"
1410 }},
1411 "done": true
1412 }}"#,
1413 model
1414 )
1415 }
1416
1417 fn with_mockito<F, Fut>(f: F)
1421 where
1422 F: FnOnce(mockito::ServerGuard) -> Fut,
1423 Fut: std::future::Future<Output = ()>,
1424 {
1425 let server = Server::new();
1426 let rt = tokio::runtime::Runtime::new().unwrap();
1427 rt.block_on(f(server));
1428 }
1429
1430 #[test]
1433 fn test_openai_compatible_provider_defaults() {
1434 assert_eq!(
1435 OpenAICompatibleProvider::LiteLLM.default_endpoint(),
1436 "http://localhost:4000"
1437 );
1438 assert_eq!(
1439 OpenAICompatibleProvider::OpenAI.default_endpoint(),
1440 "https://api.openai.com"
1441 );
1442 assert_eq!(
1443 OpenAICompatibleProvider::OpenRouter.default_endpoint(),
1444 "https://openrouter.ai"
1445 );
1446 }
1447
1448 #[test]
1449 fn test_openai_compatible_provider_names() {
1450 assert_eq!(OpenAICompatibleProvider::LiteLLM.name(), "litellm");
1451 assert_eq!(OpenAICompatibleProvider::OpenAI.name(), "openai");
1452 assert_eq!(OpenAICompatibleProvider::OpenRouter.name(), "openrouter");
1453 }
1454
1455 #[test]
1456 fn test_openai_compatible_requires_custom_headers() {
1457 assert!(!OpenAICompatibleProvider::LiteLLM.requires_custom_headers());
1458 assert!(OpenAICompatibleProvider::OpenRouter.requires_custom_headers());
1459 assert!(!OpenAICompatibleProvider::OpenAI.requires_custom_headers());
1460 }
1461
1462 #[test]
1463 fn test_openai_compatible_client_new() {
1464 let config = LLMConfig {
1465 provider: LLMProvider::LiteLLM,
1466 endpoint: "http://localhost:4000".to_string(),
1467 model: "gpt-4o-mini".to_string(),
1468 api_key: Some("test-key".to_string()),
1469 timeout_secs: 30,
1470 system_prompt: crate::config::default_system_prompt(),
1471 token_budget: None,
1472 retry_max: 3,
1473 retry_base_delay_ms: 100,
1474 retry_max_delay_ms: 10000,
1475 };
1476
1477 let client = OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM);
1478 assert!(client.is_ok());
1479 assert_eq!(client.unwrap().provider_name(), "litellm");
1480 }
1481
1482 #[test]
1483 fn test_openai_compatible_client_endpoint() {
1484 let config = LLMConfig {
1486 provider: LLMProvider::OpenAI,
1487 endpoint: "https://custom.api.example.com".to_string(),
1488 model: "gpt-4o".to_string(),
1489 api_key: Some("test-key".to_string()),
1490 timeout_secs: 30,
1491 system_prompt: crate::config::default_system_prompt(),
1492 token_budget: None,
1493 retry_max: 3,
1494 retry_base_delay_ms: 100,
1495 retry_max_delay_ms: 10000,
1496 };
1497
1498 let client =
1499 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
1500 assert_eq!(client.provider_name(), "openai");
1502 }
1503
1504 #[test]
1505 fn test_openai_compatible_client_chat_success() {
1506 with_mockito(|mut server| async move {
1507 let mock = server
1508 .mock("POST", "/v1/chat/completions")
1509 .with_status(200)
1510 .with_header("content-type", "application/json")
1511 .with_body(sample_chat_response_json("gpt-4o-mini"))
1512 .create();
1513
1514 let config = LLMConfig {
1515 provider: LLMProvider::LiteLLM,
1516 endpoint: server.url(),
1517 model: "gpt-4o-mini".to_string(),
1518 api_key: Some("test-key".to_string()),
1519 timeout_secs: 30,
1520 system_prompt: crate::config::default_system_prompt(),
1521 token_budget: None,
1522 retry_max: 3,
1523 retry_base_delay_ms: 100,
1524 retry_max_delay_ms: 10000,
1525 };
1526
1527 let client =
1528 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1529 let response = client.chat(make_chat_messages()).await.unwrap();
1530
1531 assert_eq!(response.model, "gpt-4o-mini");
1532 assert_eq!(response.choices[0].message.content, "Hi there!");
1533 mock.assert();
1534 });
1535 }
1536
1537 #[test]
1538 fn test_openai_compatible_client_auth_failure() {
1539 with_mockito(|mut server| async move {
1540 let mock = server
1541 .mock("POST", "/v1/chat/completions")
1542 .with_status(401)
1543 .with_body(r#"{"error": "Unauthorized"}"#)
1544 .create();
1545
1546 let config = LLMConfig {
1547 provider: LLMProvider::LiteLLM,
1548 endpoint: server.url(),
1549 model: "gpt-4o-mini".to_string(),
1550 api_key: Some("bad-key".to_string()),
1551 timeout_secs: 30,
1552 system_prompt: crate::config::default_system_prompt(),
1553 token_budget: None,
1554 retry_max: 3,
1555 retry_base_delay_ms: 100,
1556 retry_max_delay_ms: 10000,
1557 };
1558
1559 let client =
1560 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1561 let err = client.chat(make_chat_messages()).await.unwrap_err();
1562
1563 assert!(matches!(err, LLMError::AuthFailed));
1564 mock.assert();
1565 });
1566 }
1567
1568 #[test]
1569 fn test_openai_compatible_client_rate_limit() {
1570 with_mockito(|mut server| async move {
1571 let mock = server
1572 .mock("POST", "/v1/chat/completions")
1573 .with_status(429)
1574 .with_body(r#"{"error": "Rate limited"}"#)
1575 .create();
1576
1577 let config = LLMConfig {
1578 provider: LLMProvider::LiteLLM,
1579 endpoint: server.url(),
1580 model: "gpt-4o-mini".to_string(),
1581 api_key: Some("test-key".to_string()),
1582 timeout_secs: 30,
1583 system_prompt: crate::config::default_system_prompt(),
1584 token_budget: None,
1585 retry_max: 0, retry_base_delay_ms: 100,
1587 retry_max_delay_ms: 10000,
1588 };
1589
1590 let client =
1591 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1592 let err = client.chat(make_chat_messages()).await.unwrap_err();
1593
1594 assert!(matches!(err, LLMError::RateLimited));
1595 mock.assert();
1596 });
1597 }
1598
1599 #[test]
1600 fn test_openrouter_client_uses_custom_headers() {
1601 with_mockito(|mut server| async move {
1602 let mock = server
1603 .mock("POST", "/v1/chat/completions")
1604 .match_header("HTTP-Referer", "https://github.com/egkristi/RavenClaws")
1605 .match_header("X-Title", "RavenClaws")
1606 .with_status(200)
1607 .with_body(sample_chat_response_json("claude-sonnet-4"))
1608 .create();
1609
1610 let config = LLMConfig {
1611 provider: LLMProvider::OpenRouter,
1612 endpoint: server.url(),
1613 model: "claude-sonnet-4".to_string(),
1614 api_key: Some("or-key".to_string()),
1615 timeout_secs: 30,
1616 system_prompt: crate::config::default_system_prompt(),
1617 token_budget: None,
1618 retry_max: 3,
1619 retry_base_delay_ms: 100,
1620 retry_max_delay_ms: 10000,
1621 };
1622
1623 let client =
1624 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
1625 let _ = client.chat(make_chat_messages()).await.unwrap();
1626 mock.assert();
1627 });
1628 }
1629
1630 #[test]
1633 fn test_anthropic_client_new() {
1634 let config = LLMConfig {
1635 provider: LLMProvider::Anthropic,
1636 endpoint: String::new(),
1637 model: "claude-sonnet-4-20250514".to_string(),
1638 api_key: Some("sk-ant-test".to_string()),
1639 timeout_secs: 30,
1640 system_prompt: crate::config::default_system_prompt(),
1641 token_budget: None,
1642 retry_max: 3,
1643 retry_base_delay_ms: 100,
1644 retry_max_delay_ms: 10000,
1645 };
1646
1647 let client = AnthropicClient::new(&config);
1648 assert!(client.is_ok());
1649 }
1650
1651 #[test]
1652 fn test_anthropic_client_provider_name() {
1653 let config = LLMConfig {
1654 provider: LLMProvider::Anthropic,
1655 endpoint: String::new(),
1656 model: "claude-sonnet-4-20250514".to_string(),
1657 api_key: Some("sk-ant-test".to_string()),
1658 timeout_secs: 30,
1659 system_prompt: crate::config::default_system_prompt(),
1660 token_budget: None,
1661 retry_max: 3,
1662 retry_base_delay_ms: 100,
1663 retry_max_delay_ms: 10000,
1664 };
1665
1666 let client = AnthropicClient::new(&config).unwrap();
1667 assert_eq!(client.provider_name(), "anthropic");
1668 }
1669
1670 #[test]
1671 fn test_anthropic_client_model() {
1672 let config = LLMConfig {
1673 provider: LLMProvider::Anthropic,
1674 endpoint: String::new(),
1675 model: "claude-opus-4-20250514".to_string(),
1676 api_key: Some("sk-ant-test".to_string()),
1677 timeout_secs: 30,
1678 system_prompt: crate::config::default_system_prompt(),
1679 token_budget: None,
1680 retry_max: 3,
1681 retry_base_delay_ms: 100,
1682 retry_max_delay_ms: 10000,
1683 };
1684
1685 let client = AnthropicClient::new(&config).unwrap();
1686 assert_eq!(client.model(), "claude-opus-4-20250514");
1687 }
1688
1689 #[test]
1690 fn test_create_client_anthropic() {
1691 let config = LLMConfig {
1692 provider: LLMProvider::Anthropic,
1693 endpoint: String::new(),
1694 model: "claude-sonnet-4-20250514".to_string(),
1695 api_key: Some("sk-ant-test".to_string()),
1696 timeout_secs: 30,
1697 system_prompt: crate::config::default_system_prompt(),
1698 token_budget: None,
1699 retry_max: 3,
1700 retry_base_delay_ms: 100,
1701 retry_max_delay_ms: 10000,
1702 };
1703
1704 let client = create_client(&config);
1705 assert!(client.is_ok());
1706 assert_eq!(client.unwrap().provider_name(), "anthropic");
1707 }
1708
1709 #[test]
1712 fn test_retry_config_delay_calculation() {
1713 let config = RetryConfig {
1714 max_retries: 3,
1715 base_delay_ms: 100,
1716 max_delay_ms: 10000,
1717 jitter: 0.0, };
1719
1720 assert_eq!(config.delay_for_attempt(0).as_millis(), 100);
1722 assert_eq!(config.delay_for_attempt(1).as_millis(), 200);
1723 assert_eq!(config.delay_for_attempt(2).as_millis(), 400);
1724 }
1725
1726 #[test]
1727 fn test_retry_config_max_delay_cap() {
1728 let config = RetryConfig {
1729 max_retries: 10,
1730 base_delay_ms: 100,
1731 max_delay_ms: 1000,
1732 jitter: 0.0,
1733 };
1734
1735 assert!(config.delay_for_attempt(10).as_millis() <= 1000);
1737 }
1738
1739 #[test]
1740 fn test_circuit_breaker_state_transitions() {
1741 let mut cb = CircuitBreaker::new(30);
1742
1743 assert_eq!(cb.state, CircuitState::Closed);
1745 assert!(cb.can_execute());
1746
1747 for _ in 0..5 {
1749 cb.record_failure();
1750 }
1751 assert_eq!(cb.state, CircuitState::Open);
1752 assert!(!cb.can_execute());
1753 }
1754
1755 #[test]
1756 fn test_circuit_breaker_success_resets() {
1757 let mut cb = CircuitBreaker::new(30);
1758
1759 for _ in 0..3 {
1761 cb.record_failure();
1762 }
1763 assert_eq!(cb.failure_count, 3);
1764
1765 cb.record_success();
1767 assert_eq!(cb.failure_count, 0);
1768 assert_eq!(cb.state, CircuitState::Closed);
1769 }
1770
1771 #[test]
1772 fn test_token_budget_tracking() {
1773 let mut budget = TokenBudget::new(1000, 0.002); assert_eq!(budget.remaining(), 1000);
1776 assert!(budget.can_spend(500));
1777
1778 budget.record_usage(300);
1779 assert_eq!(budget.remaining(), 700);
1780 assert!(budget.can_spend(500));
1781
1782 budget.record_usage(500);
1783 assert_eq!(budget.remaining(), 200);
1784 assert!(!budget.can_spend(500));
1785
1786 assert!((budget.estimated_cost() - 0.0016).abs() < 0.0001);
1788 }
1789
1790 #[test]
1791 fn test_provider_fallback_chain_creation() {
1792 let configs = vec![
1793 LLMConfig {
1794 provider: LLMProvider::LiteLLM,
1795 endpoint: "http://localhost:4000".to_string(),
1796 model: "gpt-4o".to_string(),
1797 api_key: Some("key1".to_string()),
1798 timeout_secs: 30,
1799 system_prompt: crate::config::default_system_prompt(),
1800 token_budget: None,
1801 retry_max: 3,
1802 retry_base_delay_ms: 100,
1803 retry_max_delay_ms: 10000,
1804 },
1805 LLMConfig {
1806 provider: LLMProvider::Ollama,
1807 endpoint: "http://localhost:11434".to_string(),
1808 model: "llama3.1".to_string(),
1809 api_key: None,
1810 timeout_secs: 30,
1811 system_prompt: crate::config::default_system_prompt(),
1812 token_budget: None,
1813 retry_max: 3,
1814 retry_base_delay_ms: 100,
1815 retry_max_delay_ms: 10000,
1816 },
1817 ];
1818
1819 let chain = ProviderFallbackChain::new(configs);
1820 assert_eq!(chain.provider_names(), vec!["LiteLLM", "Ollama"]);
1821 }
1822
1823 #[test]
1826 fn test_litellm_chat_auth_failure() {
1827 with_mockito(|mut server| async move {
1828 let mock = server
1829 .mock("POST", "/v1/chat/completions")
1830 .with_status(401)
1831 .with_header("content-type", "application/json")
1832 .with_body(r#"{"error": "Unauthorized"}"#)
1833 .create();
1834
1835 let config = LLMConfig {
1836 provider: LLMProvider::LiteLLM,
1837 endpoint: server.url(),
1838 model: "gpt-4o-mini".to_string(),
1839 api_key: Some("bad-key".to_string()),
1840 timeout_secs: 30,
1841 system_prompt: crate::config::default_system_prompt(),
1842 token_budget: None,
1843 retry_max: 3,
1844 retry_base_delay_ms: 100,
1845 retry_max_delay_ms: 10000,
1846 };
1847
1848 let client =
1849 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1850 let err = client.chat(make_chat_messages()).await.unwrap_err();
1851
1852 assert!(matches!(err, LLMError::AuthFailed));
1853 mock.assert();
1854 });
1855 }
1856
1857 #[test]
1858 fn test_litellm_chat_rate_limit() {
1859 with_mockito(|mut server| async move {
1860 let mock = server
1861 .mock("POST", "/v1/chat/completions")
1862 .with_status(429)
1863 .with_header("content-type", "application/json")
1864 .with_body(r#"{"error": "Rate limit exceeded"}"#)
1865 .create();
1866
1867 let config = LLMConfig {
1868 provider: LLMProvider::LiteLLM,
1869 endpoint: server.url(),
1870 model: "gpt-4o-mini".to_string(),
1871 api_key: Some("test-key".to_string()),
1872 timeout_secs: 30,
1873 system_prompt: crate::config::default_system_prompt(),
1874 token_budget: None,
1875 retry_max: 0,
1876 retry_base_delay_ms: 100,
1877 retry_max_delay_ms: 10000,
1878 };
1879
1880 let client =
1881 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1882 let err = client.chat(make_chat_messages()).await.unwrap_err();
1883
1884 assert!(matches!(err, LLMError::RateLimited));
1885 mock.assert();
1886 });
1887 }
1888
1889 #[test]
1890 fn test_litellm_chat_server_error() {
1891 with_mockito(|mut server| async move {
1892 let mock = server
1893 .mock("POST", "/v1/chat/completions")
1894 .with_status(500)
1895 .with_header("content-type", "application/json")
1896 .with_body(r#"{"error": "Internal server error"}"#)
1897 .create();
1898
1899 let config = LLMConfig {
1900 provider: LLMProvider::LiteLLM,
1901 endpoint: server.url(),
1902 model: "gpt-4o-mini".to_string(),
1903 api_key: Some("test-key".to_string()),
1904 timeout_secs: 30,
1905 system_prompt: crate::config::default_system_prompt(),
1906 token_budget: None,
1907 retry_max: 0,
1908 retry_base_delay_ms: 100,
1909 retry_max_delay_ms: 10000,
1910 };
1911
1912 let client =
1913 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1914 let err = client.chat(make_chat_messages()).await.unwrap_err();
1915
1916 assert!(matches!(err, LLMError::RequestFailed(_)));
1917 assert!(format!("{}", err).contains("500"));
1918 mock.assert();
1919 });
1920 }
1921
1922 #[test]
1923 fn test_litellm_chat_invalid_json() {
1924 with_mockito(|mut server| async move {
1925 let mock = server
1926 .mock("POST", "/v1/chat/completions")
1927 .with_status(200)
1928 .with_header("content-type", "application/json")
1929 .with_body("not-json")
1930 .create();
1931
1932 let config = LLMConfig {
1933 provider: LLMProvider::LiteLLM,
1934 endpoint: server.url(),
1935 model: "gpt-4o-mini".to_string(),
1936 api_key: Some("test-key".to_string()),
1937 timeout_secs: 30,
1938 system_prompt: crate::config::default_system_prompt(),
1939 token_budget: None,
1940 retry_max: 0,
1941 retry_base_delay_ms: 100,
1942 retry_max_delay_ms: 10000,
1943 };
1944
1945 let client =
1946 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1947 let err = client.chat(make_chat_messages()).await.unwrap_err();
1948
1949 assert!(matches!(err, LLMError::InvalidResponse(_)));
1950 mock.assert();
1951 });
1952 }
1953
1954 #[test]
1957 fn test_openrouter_chat_success() {
1958 with_mockito(|mut server| async move {
1959 let mock = server
1960 .mock("POST", "/v1/chat/completions")
1961 .with_status(200)
1962 .with_header("content-type", "application/json")
1963 .with_body(sample_chat_response_json(
1964 "anthropic/claude-sonnet-4-20250514",
1965 ))
1966 .create();
1967
1968 let config = LLMConfig {
1969 provider: LLMProvider::OpenRouter,
1970 endpoint: server.url(),
1971 model: "anthropic/claude-sonnet-4-20250514".to_string(),
1972 api_key: Some("or-key".to_string()),
1973 timeout_secs: 30,
1974 system_prompt: crate::config::default_system_prompt(),
1975 token_budget: None,
1976 retry_max: 3,
1977 retry_base_delay_ms: 100,
1978 retry_max_delay_ms: 10000,
1979 };
1980
1981 let client =
1982 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
1983 let response = client.chat(make_chat_messages()).await.unwrap();
1984
1985 assert_eq!(response.model, "anthropic/claude-sonnet-4-20250514");
1986 assert_eq!(response.choices[0].message.content, "Hi there!");
1987 mock.assert();
1988 });
1989 }
1990
1991 #[test]
1992 fn test_openrouter_chat_auth_failure() {
1993 with_mockito(|mut server| async move {
1994 let mock = server
1995 .mock("POST", "/v1/chat/completions")
1996 .with_status(401)
1997 .with_header("content-type", "application/json")
1998 .with_body(r#"{"error": "Unauthorized"}"#)
1999 .create();
2000
2001 let config = LLMConfig {
2002 provider: LLMProvider::OpenRouter,
2003 endpoint: server.url(),
2004 model: "anthropic/claude-sonnet-4-20250514".to_string(),
2005 api_key: Some("bad-key".to_string()),
2006 timeout_secs: 30,
2007 system_prompt: crate::config::default_system_prompt(),
2008 token_budget: None,
2009 retry_max: 3,
2010 retry_base_delay_ms: 100,
2011 retry_max_delay_ms: 10000,
2012 };
2013
2014 let client =
2015 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
2016 let err = client.chat(make_chat_messages()).await.unwrap_err();
2017
2018 assert!(matches!(err, LLMError::AuthFailed));
2019 mock.assert();
2020 });
2021 }
2022
2023 #[test]
2024 fn test_openrouter_chat_rate_limit() {
2025 with_mockito(|mut server| async move {
2026 let mock = server
2027 .mock("POST", "/v1/chat/completions")
2028 .with_status(429)
2029 .with_header("content-type", "application/json")
2030 .with_body(r#"{"error": "Rate limited"}"#)
2031 .create();
2032
2033 let config = LLMConfig {
2034 provider: LLMProvider::OpenRouter,
2035 endpoint: server.url(),
2036 model: "anthropic/claude-sonnet-4-20250514".to_string(),
2037 api_key: Some("or-key".to_string()),
2038 timeout_secs: 30,
2039 system_prompt: crate::config::default_system_prompt(),
2040 token_budget: None,
2041 retry_max: 0, retry_base_delay_ms: 100,
2043 retry_max_delay_ms: 10000,
2044 };
2045
2046 let client =
2047 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
2048 let err = client.chat(make_chat_messages()).await.unwrap_err();
2049
2050 assert!(matches!(err, LLMError::RateLimited));
2051 mock.assert();
2052 });
2053 }
2054
2055 #[test]
2056 fn test_openrouter_chat_server_error() {
2057 with_mockito(|mut server| async move {
2058 let mock = server
2059 .mock("POST", "/v1/chat/completions")
2060 .with_status(500)
2061 .with_header("content-type", "application/json")
2062 .with_body(r#"{"error": "Internal error"}"#)
2063 .create();
2064
2065 let config = LLMConfig {
2066 provider: LLMProvider::OpenRouter,
2067 endpoint: server.url(),
2068 model: "anthropic/claude-sonnet-4-20250514".to_string(),
2069 api_key: Some("or-key".to_string()),
2070 timeout_secs: 30,
2071 system_prompt: crate::config::default_system_prompt(),
2072 token_budget: None,
2073 retry_max: 0, retry_base_delay_ms: 100,
2075 retry_max_delay_ms: 10000,
2076 };
2077
2078 let client =
2079 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
2080 let err = client.chat(make_chat_messages()).await.unwrap_err();
2081
2082 assert!(matches!(err, LLMError::RequestFailed(_)));
2083 assert!(format!("{}", err).contains("500"));
2084 mock.assert();
2085 });
2086 }
2087
2088 #[test]
2089 fn test_openrouter_chat_invalid_json() {
2090 with_mockito(|mut server| async move {
2091 let mock = server
2092 .mock("POST", "/v1/chat/completions")
2093 .with_status(200)
2094 .with_header("content-type", "application/json")
2095 .with_body("not-json")
2096 .create();
2097
2098 let config = LLMConfig {
2099 provider: LLMProvider::OpenRouter,
2100 endpoint: server.url(),
2101 model: "anthropic/claude-sonnet-4-20250514".to_string(),
2102 api_key: Some("or-key".to_string()),
2103 timeout_secs: 30,
2104 system_prompt: crate::config::default_system_prompt(),
2105 token_budget: None,
2106 retry_max: 0, retry_base_delay_ms: 100,
2108 retry_max_delay_ms: 10000,
2109 };
2110
2111 let client =
2112 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
2113 let err = client.chat(make_chat_messages()).await.unwrap_err();
2114
2115 assert!(matches!(err, LLMError::InvalidResponse(_)));
2116 mock.assert();
2117 });
2118 }
2119
2120 #[test]
2123 fn test_openai_chat_success() {
2124 with_mockito(|mut server| async move {
2125 let mock = server
2126 .mock("POST", "/v1/chat/completions")
2127 .with_status(200)
2128 .with_header("content-type", "application/json")
2129 .with_body(sample_chat_response_json("gpt-4o"))
2130 .create();
2131
2132 let config = LLMConfig {
2133 provider: LLMProvider::OpenAI,
2134 endpoint: server.url(),
2135 model: "gpt-4o".to_string(),
2136 api_key: Some("sk-test".to_string()),
2137 timeout_secs: 60,
2138 system_prompt: crate::config::default_system_prompt(),
2139 token_budget: None,
2140 retry_max: 3,
2141 retry_base_delay_ms: 100,
2142 retry_max_delay_ms: 10000,
2143 };
2144
2145 let client =
2146 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
2147 let response = client.chat(make_chat_messages()).await.unwrap();
2148
2149 assert_eq!(response.model, "gpt-4o");
2150 assert_eq!(response.choices[0].message.content, "Hi there!");
2151 mock.assert();
2152 });
2153 }
2154
2155 #[test]
2156 fn test_openai_chat_auth_failure() {
2157 with_mockito(|mut server| async move {
2158 let mock = server
2159 .mock("POST", "/v1/chat/completions")
2160 .with_status(401)
2161 .with_header("content-type", "application/json")
2162 .with_body(r#"{"error": "Unauthorized"}"#)
2163 .create();
2164
2165 let config = LLMConfig {
2166 provider: LLMProvider::OpenAI,
2167 endpoint: server.url(),
2168 model: "gpt-4o".to_string(),
2169 api_key: Some("bad-key".to_string()),
2170 timeout_secs: 30,
2171 system_prompt: crate::config::default_system_prompt(),
2172 token_budget: None,
2173 retry_max: 3,
2174 retry_base_delay_ms: 100,
2175 retry_max_delay_ms: 10000,
2176 };
2177
2178 let client =
2179 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
2180 let err = client.chat(make_chat_messages()).await.unwrap_err();
2181
2182 assert!(matches!(err, LLMError::AuthFailed));
2183 mock.assert();
2184 });
2185 }
2186
2187 #[test]
2188 fn test_openai_chat_rate_limit() {
2189 with_mockito(|mut server| async move {
2190 let mock = server
2191 .mock("POST", "/v1/chat/completions")
2192 .with_status(429)
2193 .with_header("content-type", "application/json")
2194 .with_body(r#"{"error": "Rate limited"}"#)
2195 .create();
2196
2197 let config = LLMConfig {
2198 provider: LLMProvider::OpenAI,
2199 endpoint: server.url(),
2200 model: "gpt-4o".to_string(),
2201 api_key: Some("sk-test".to_string()),
2202 timeout_secs: 30,
2203 system_prompt: crate::config::default_system_prompt(),
2204 token_budget: None,
2205 retry_max: 0, retry_base_delay_ms: 100,
2207 retry_max_delay_ms: 10000,
2208 };
2209
2210 let client =
2211 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
2212 let err = client.chat(make_chat_messages()).await.unwrap_err();
2213
2214 assert!(matches!(err, LLMError::RateLimited));
2215 mock.assert();
2216 });
2217 }
2218
2219 #[test]
2220 fn test_openai_chat_server_error() {
2221 with_mockito(|mut server| async move {
2222 let mock = server
2223 .mock("POST", "/v1/chat/completions")
2224 .with_status(500)
2225 .with_header("content-type", "application/json")
2226 .with_body(r#"{"error": "Internal error"}"#)
2227 .create();
2228
2229 let config = LLMConfig {
2230 provider: LLMProvider::OpenAI,
2231 endpoint: server.url(),
2232 model: "gpt-4o".to_string(),
2233 api_key: Some("sk-test".to_string()),
2234 timeout_secs: 30,
2235 system_prompt: crate::config::default_system_prompt(),
2236 token_budget: None,
2237 retry_max: 0, retry_base_delay_ms: 100,
2239 retry_max_delay_ms: 10000,
2240 };
2241
2242 let client =
2243 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
2244 let err = client.chat(make_chat_messages()).await.unwrap_err();
2245
2246 assert!(matches!(err, LLMError::RequestFailed(_)));
2247 assert!(format!("{}", err).contains("500"));
2248 mock.assert();
2249 });
2250 }
2251
2252 #[test]
2253 fn test_openai_chat_invalid_json() {
2254 with_mockito(|mut server| async move {
2255 let mock = server
2256 .mock("POST", "/v1/chat/completions")
2257 .with_status(200)
2258 .with_header("content-type", "application/json")
2259 .with_body("not-json")
2260 .create();
2261
2262 let config = LLMConfig {
2263 provider: LLMProvider::OpenAI,
2264 endpoint: server.url(),
2265 model: "gpt-4o".to_string(),
2266 api_key: Some("sk-test".to_string()),
2267 timeout_secs: 30,
2268 system_prompt: crate::config::default_system_prompt(),
2269 token_budget: None,
2270 retry_max: 0, retry_base_delay_ms: 100,
2272 retry_max_delay_ms: 10000,
2273 };
2274
2275 let client =
2276 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
2277 let err = client.chat(make_chat_messages()).await.unwrap_err();
2278
2279 assert!(matches!(err, LLMError::InvalidResponse(_)));
2280 mock.assert();
2281 });
2282 }
2283
2284 #[test]
2287 fn test_ollama_chat_success() {
2288 with_mockito(|mut server| async move {
2289 let mock = server
2290 .mock("POST", "/api/chat")
2291 .with_status(200)
2292 .with_header("content-type", "application/json")
2293 .with_body(sample_ollama_response_json("llama3.1"))
2294 .create();
2295
2296 let config = LLMConfig {
2297 provider: LLMProvider::Ollama,
2298 endpoint: server.url(),
2299 model: "llama3.1".to_string(),
2300 api_key: None,
2301 timeout_secs: 30,
2302 system_prompt: crate::config::default_system_prompt(),
2303 token_budget: None,
2304 retry_max: 3,
2305 retry_base_delay_ms: 100,
2306 retry_max_delay_ms: 10000,
2307 };
2308
2309 let client = OllamaClient::new(&config).unwrap();
2310 let response = client.chat(make_chat_messages()).await.unwrap();
2311
2312 assert_eq!(response.model, "llama3.1");
2313 assert_eq!(response.choices[0].message.content, "Hi there!");
2314 assert_eq!(response.choices[0].finish_reason, Some("stop".to_string()));
2315 mock.assert();
2316 });
2317 }
2318
2319 #[test]
2320 fn test_ollama_chat_server_error() {
2321 with_mockito(|mut server| async move {
2322 let mock = server
2323 .mock("POST", "/api/chat")
2324 .with_status(500)
2325 .with_header("content-type", "application/json")
2326 .with_body(r#"{"error": "Model not loaded"}"#)
2327 .create();
2328
2329 let config = LLMConfig {
2330 provider: LLMProvider::Ollama,
2331 endpoint: server.url(),
2332 model: "llama3.1".to_string(),
2333 api_key: None,
2334 timeout_secs: 30,
2335 system_prompt: crate::config::default_system_prompt(),
2336 token_budget: None,
2337 retry_max: 3,
2338 retry_base_delay_ms: 100,
2339 retry_max_delay_ms: 10000,
2340 };
2341
2342 let client = OllamaClient::new(&config).unwrap();
2343 let err = client.chat(make_chat_messages()).await.unwrap_err();
2344
2345 assert!(matches!(err, LLMError::RequestFailed(_)));
2346 mock.assert();
2347 });
2348 }
2349
2350 #[test]
2351 fn test_ollama_chat_invalid_json() {
2352 with_mockito(|mut server| async move {
2353 let mock = server
2354 .mock("POST", "/api/chat")
2355 .with_status(200)
2356 .with_header("content-type", "application/json")
2357 .with_body("not-json")
2358 .create();
2359
2360 let config = LLMConfig {
2361 provider: LLMProvider::Ollama,
2362 endpoint: server.url(),
2363 model: "llama3.1".to_string(),
2364 api_key: None,
2365 timeout_secs: 30,
2366 system_prompt: crate::config::default_system_prompt(),
2367 token_budget: None,
2368 retry_max: 3,
2369 retry_base_delay_ms: 100,
2370 retry_max_delay_ms: 10000,
2371 };
2372
2373 let client = OllamaClient::new(&config).unwrap();
2374 let err = client.chat(make_chat_messages()).await.unwrap_err();
2375
2376 assert!(matches!(err, LLMError::InvalidResponse(_)));
2377 mock.assert();
2378 });
2379 }
2380
2381 #[test]
2382 fn test_ollama_chat_auth_failure() {
2383 with_mockito(|mut server| async move {
2384 let mock = server
2385 .mock("POST", "/api/chat")
2386 .with_status(401)
2387 .with_header("content-type", "application/json")
2388 .with_body(r#"{"error": "Unauthorized"}"#)
2389 .create();
2390
2391 let config = LLMConfig {
2392 provider: LLMProvider::Ollama,
2393 endpoint: server.url(),
2394 model: "llama3.1".to_string(),
2395 api_key: Some("bad-key".to_string()),
2396 timeout_secs: 30,
2397 system_prompt: crate::config::default_system_prompt(),
2398 token_budget: None,
2399 retry_max: 3,
2400 retry_base_delay_ms: 100,
2401 retry_max_delay_ms: 10000,
2402 };
2403
2404 let client = OllamaClient::new(&config).unwrap();
2405 let err = client.chat(make_chat_messages()).await.unwrap_err();
2406
2407 assert!(matches!(err, LLMError::AuthFailed));
2408 mock.assert();
2409 });
2410 }
2411
2412 #[test]
2415 fn test_create_client_factory_litellm() {
2416 let config = LLMConfig {
2417 provider: LLMProvider::LiteLLM,
2418 endpoint: "http://localhost:4000".to_string(),
2419 model: "gpt-4o-mini".to_string(),
2420 api_key: Some("test".to_string()),
2421 timeout_secs: 30,
2422 system_prompt: crate::config::default_system_prompt(),
2423 token_budget: None,
2424 retry_max: 3,
2425 retry_base_delay_ms: 100,
2426 retry_max_delay_ms: 10000,
2427 };
2428
2429 let client = create_client(&config).unwrap();
2430 assert_eq!(client.provider_name(), "litellm");
2431 assert_eq!(client.model(), "gpt-4o-mini");
2432 }
2433
2434 #[test]
2435 fn test_ollama_client_creation() {
2436 let config = LLMConfig {
2437 provider: LLMProvider::Ollama,
2438 endpoint: "http://localhost:11434".to_string(),
2439 model: "llama3.1".to_string(),
2440 api_key: None,
2441 timeout_secs: 30,
2442 system_prompt: crate::config::default_system_prompt(),
2443 token_budget: None,
2444 retry_max: 3,
2445 retry_base_delay_ms: 100,
2446 retry_max_delay_ms: 10000,
2447 };
2448
2449 let client = OllamaClient::new(&config).unwrap();
2450 assert_eq!(client.provider_name(), "ollama");
2451 assert_eq!(client.model(), "llama3.1");
2452 }
2453
2454 #[test]
2455 fn test_openai_client_creation() {
2456 let config = LLMConfig {
2457 provider: LLMProvider::OpenAI,
2458 endpoint: String::new(),
2459 model: "gpt-4o".to_string(),
2460 api_key: Some("sk-test".to_string()),
2461 timeout_secs: 60,
2462 system_prompt: crate::config::default_system_prompt(),
2463 token_budget: None,
2464 retry_max: 3,
2465 retry_base_delay_ms: 100,
2466 retry_max_delay_ms: 10000,
2467 };
2468
2469 let client =
2470 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
2471 assert_eq!(client.provider_name(), "openai");
2472 assert_eq!(client.model(), "gpt-4o");
2473 }
2474
2475 #[test]
2476 fn test_openrouter_client_creation() {
2477 let config = LLMConfig {
2478 provider: LLMProvider::OpenRouter,
2479 endpoint: String::new(),
2480 model: "anthropic/claude-sonnet-4-20250514".to_string(),
2481 api_key: Some("sk-test".to_string()),
2482 timeout_secs: 30,
2483 system_prompt: crate::config::default_system_prompt(),
2484 token_budget: None,
2485 retry_max: 3,
2486 retry_base_delay_ms: 100,
2487 retry_max_delay_ms: 10000,
2488 };
2489
2490 let client =
2491 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
2492 assert_eq!(client.provider_name(), "openrouter");
2493 assert_eq!(client.model(), "anthropic/claude-sonnet-4-20250514");
2494 }
2495
2496 #[test]
2497 fn test_multi_model_manager_empty() {
2498 let manager = MultiModelManager::new(vec![]).unwrap();
2499 assert_eq!(manager.client_count(), 0);
2500 assert!(manager.get_client(0).is_none());
2501 }
2502
2503 #[test]
2504 fn test_multi_model_manager_single() {
2505 let config = LLMConfig {
2506 provider: LLMProvider::LiteLLM,
2507 endpoint: "http://localhost:4000".to_string(),
2508 model: "gpt-4o-mini".to_string(),
2509 api_key: Some("test".to_string()),
2510 timeout_secs: 30,
2511 system_prompt: crate::config::default_system_prompt(),
2512 token_budget: None,
2513 retry_max: 3,
2514 retry_base_delay_ms: 100,
2515 retry_max_delay_ms: 10000,
2516 };
2517
2518 let manager = MultiModelManager::new(vec![config]).unwrap();
2519 assert_eq!(manager.client_count(), 1);
2520 assert!(manager.get_client(0).is_some());
2521 assert_eq!(manager.get_client(0).unwrap().provider_name(), "litellm");
2522 }
2523
2524 #[test]
2525 fn test_multi_model_manager_multiple() {
2526 let configs = vec![
2527 LLMConfig {
2528 provider: LLMProvider::LiteLLM,
2529 endpoint: "http://localhost:4000".to_string(),
2530 model: "gpt-4o-mini".to_string(),
2531 api_key: Some("test".to_string()),
2532 timeout_secs: 30,
2533 system_prompt: crate::config::default_system_prompt(),
2534 token_budget: None,
2535 retry_max: 3,
2536 retry_base_delay_ms: 100,
2537 retry_max_delay_ms: 10000,
2538 },
2539 LLMConfig {
2540 provider: LLMProvider::Ollama,
2541 endpoint: "http://localhost:11434".to_string(),
2542 model: "llama3.1".to_string(),
2543 api_key: None,
2544 timeout_secs: 60,
2545 system_prompt: crate::config::default_system_prompt(),
2546 token_budget: None,
2547 retry_max: 3,
2548 retry_base_delay_ms: 100,
2549 retry_max_delay_ms: 10000,
2550 },
2551 ];
2552
2553 let manager = MultiModelManager::new(configs).unwrap();
2554 assert_eq!(manager.client_count(), 2);
2555 assert_eq!(manager.get_client(0).unwrap().provider_name(), "litellm");
2556 assert_eq!(manager.get_client(1).unwrap().provider_name(), "ollama");
2557 }
2558
2559 #[test]
2560 fn test_multi_model_next_client_round_robin() {
2561 let configs = vec![
2562 LLMConfig {
2563 provider: LLMProvider::LiteLLM,
2564 endpoint: "http://localhost:4000".to_string(),
2565 model: "gpt-4o-mini".to_string(),
2566 api_key: Some("test".to_string()),
2567 timeout_secs: 30,
2568 system_prompt: crate::config::default_system_prompt(),
2569 token_budget: None,
2570 retry_max: 3,
2571 retry_base_delay_ms: 100,
2572 retry_max_delay_ms: 10000,
2573 },
2574 LLMConfig {
2575 provider: LLMProvider::Ollama,
2576 endpoint: "http://localhost:11434".to_string(),
2577 model: "llama3.1".to_string(),
2578 api_key: None,
2579 timeout_secs: 60,
2580 system_prompt: crate::config::default_system_prompt(),
2581 token_budget: None,
2582 retry_max: 3,
2583 retry_base_delay_ms: 100,
2584 retry_max_delay_ms: 10000,
2585 },
2586 ];
2587
2588 let manager = MultiModelManager::new(configs).unwrap();
2589 let next = manager.next_client(0).unwrap();
2591 assert_eq!(next.provider_name(), "ollama");
2592 let next = manager.next_client(1).unwrap();
2594 assert_eq!(next.provider_name(), "litellm");
2595 }
2596
2597 #[test]
2598 fn test_chat_request_serialization() {
2599 let request = ChatRequest {
2600 model: "gpt-4o-mini".to_string(),
2601 messages: vec![
2602 ChatMessage::new("system", "You are a helpful assistant."),
2603 ChatMessage::new("user", "Hello!"),
2604 ],
2605 temperature: Some(0.7),
2606 max_tokens: Some(2048),
2607 stream: None,
2608 tools: None,
2609 tool_choice: None,
2610 };
2611
2612 let json = serde_json::to_string(&request).unwrap();
2613 assert!(json.contains("gpt-4o-mini"));
2614 assert!(json.contains("system"));
2615 assert!(json.contains("user"));
2616 assert!(json.contains("Hello!"));
2617 assert!(json.contains("0.7"));
2618 assert!(!json.contains("stream"));
2620 }
2621
2622 #[test]
2623 fn test_chat_response_deserialization() {
2624 let json = r#"{
2625 "id": "chat-123",
2626 "object": "chat.completion",
2627 "created": 1717000000,
2628 "model": "gpt-4o-mini",
2629 "choices": [
2630 {
2631 "index": 0,
2632 "message": {
2633 "role": "assistant",
2634 "content": "Hello! How can I help you?"
2635 },
2636 "finish_reason": "stop"
2637 }
2638 ],
2639 "usage": {
2640 "prompt_tokens": 10,
2641 "completion_tokens": 20,
2642 "total_tokens": 30
2643 }
2644 }"#;
2645
2646 let response: ChatResponse = serde_json::from_str(json).unwrap();
2647 assert_eq!(response.id, "chat-123");
2648 assert_eq!(response.model, "gpt-4o-mini");
2649 assert_eq!(response.choices.len(), 1);
2650 assert_eq!(response.choices[0].message.role, "assistant");
2651 assert_eq!(
2652 response.choices[0].message.content,
2653 "Hello! How can I help you?"
2654 );
2655 assert_eq!(response.usage.unwrap().total_tokens, 30);
2656 }
2657
2658 #[test]
2659 fn test_multi_model_manager_new_invalid_config() {
2660 let configs = vec![LLMConfig {
2662 provider: LLMProvider::LiteLLM,
2663 endpoint: String::new(), model: "gpt-4o-mini".to_string(),
2665 api_key: None,
2666 timeout_secs: 30,
2667 system_prompt: crate::config::default_system_prompt(),
2668 token_budget: None,
2669 retry_max: 3,
2670 retry_base_delay_ms: 100,
2671 retry_max_delay_ms: 10000,
2672 }];
2673
2674 let result = MultiModelManager::new(configs);
2675 assert!(result.is_ok());
2678 let manager = result.unwrap();
2679 assert_eq!(manager.client_count(), 1);
2680 }
2681
2682 #[test]
2683 fn test_create_client_all_providers() {
2684 let test_cases = vec![
2685 (LLMProvider::LiteLLM, "litellm"),
2686 (LLMProvider::OpenRouter, "openrouter"),
2687 (LLMProvider::Ollama, "ollama"),
2688 (LLMProvider::OpenAI, "openai"),
2689 ];
2690
2691 for (provider, expected_name) in test_cases {
2692 let config = LLMConfig {
2693 provider,
2694 endpoint: "http://localhost:4000".to_string(),
2695 model: "test-model".to_string(),
2696 api_key: Some("test-key".to_string()),
2697 timeout_secs: 30,
2698 system_prompt: crate::config::default_system_prompt(),
2699 token_budget: None,
2700 retry_max: 3,
2701 retry_base_delay_ms: 100,
2702 retry_max_delay_ms: 10000,
2703 };
2704
2705 let client = create_client(&config).unwrap();
2706 assert_eq!(client.provider_name(), expected_name);
2707 assert_eq!(client.model(), "test-model");
2708 }
2709 }
2710
2711 #[test]
2712 fn test_llm_error_display() {
2713 let err = LLMError::RequestFailed("timeout".to_string());
2714 assert_eq!(format!("{}", err), "Request failed: timeout");
2715
2716 let err = LLMError::AuthFailed;
2717 assert_eq!(format!("{}", err), "Authentication failed");
2718
2719 let err = LLMError::RateLimited;
2720 assert_eq!(format!("{}", err), "Rate limit exceeded");
2721
2722 let err = LLMError::InvalidResponse("bad json".to_string());
2723 assert_eq!(format!("{}", err), "Invalid response: bad json");
2724
2725 let err = LLMError::ProviderNotSupported("custom".to_string());
2726 assert_eq!(format!("{}", err), "Provider not supported: custom");
2727 }
2728
2729 #[test]
2730 fn test_llm_error_is_debug() {
2731 let err = LLMError::RequestFailed("test".to_string());
2732 let debug = format!("{:?}", err);
2733 assert!(debug.contains("RequestFailed"));
2734 }
2735
2736 #[test]
2737 fn test_llm_error_is_send() {
2738 fn check_send<T: Send>() {}
2739 check_send::<LLMError>();
2740 }
2741
2742 #[test]
2743 fn test_llm_error_is_sync() {
2744 fn check_sync<T: Sync>() {}
2745 check_sync::<LLMError>();
2746 }
2747
2748 #[test]
2749 fn test_litellm_client_new_with_invalid_timeout() {
2750 let config = LLMConfig {
2752 provider: LLMProvider::LiteLLM,
2753 endpoint: "http://localhost:4000".to_string(),
2754 model: "gpt-4o-mini".to_string(),
2755 api_key: Some("test".to_string()),
2756 timeout_secs: u64::MAX,
2757 system_prompt: crate::config::default_system_prompt(),
2758 token_budget: None,
2759 retry_max: 3,
2760 retry_base_delay_ms: 100,
2761 retry_max_delay_ms: 10000,
2762 };
2763
2764 let result = OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM);
2765 assert!(result.is_ok());
2766 }
2767
2768 #[test]
2769 fn test_openai_client_with_custom_endpoint() {
2770 let config = LLMConfig {
2771 provider: LLMProvider::OpenAI,
2772 endpoint: "https://custom.openai.example.com".to_string(),
2773 model: "gpt-4o".to_string(),
2774 api_key: Some("sk-test".to_string()),
2775 timeout_secs: 30,
2776 system_prompt: crate::config::default_system_prompt(),
2777 token_budget: None,
2778 retry_max: 3,
2779 retry_base_delay_ms: 100,
2780 retry_max_delay_ms: 10000,
2781 };
2782
2783 let client =
2784 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
2785 assert_eq!(client.provider_name(), "openai");
2786 assert_eq!(client.model(), "gpt-4o");
2787 }
2788
2789 #[test]
2790 fn test_openrouter_client_with_custom_endpoint() {
2791 let config = LLMConfig {
2792 provider: LLMProvider::OpenRouter,
2793 endpoint: "https://custom.openrouter.example.com".to_string(),
2794 model: "anthropic/claude-sonnet-4-20250514".to_string(),
2795 api_key: Some("or-key".to_string()),
2796 timeout_secs: 30,
2797 system_prompt: crate::config::default_system_prompt(),
2798 token_budget: None,
2799 retry_max: 3,
2800 retry_base_delay_ms: 100,
2801 retry_max_delay_ms: 10000,
2802 };
2803
2804 let client =
2805 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
2806 assert_eq!(client.provider_name(), "openrouter");
2807 assert_eq!(client.model(), "anthropic/claude-sonnet-4-20250514");
2808 }
2809
2810 #[test]
2811 fn test_ollama_client_with_auth() {
2812 let config = LLMConfig {
2814 provider: LLMProvider::Ollama,
2815 endpoint: "http://localhost:11434".to_string(),
2816 model: "llama3.1".to_string(),
2817 api_key: Some("some-key".to_string()),
2818 timeout_secs: 30,
2819 system_prompt: crate::config::default_system_prompt(),
2820 token_budget: None,
2821 retry_max: 3,
2822 retry_base_delay_ms: 100,
2823 retry_max_delay_ms: 10000,
2824 };
2825
2826 let client = OllamaClient::new(&config).unwrap();
2827 assert_eq!(client.provider_name(), "ollama");
2828 assert_eq!(client.model(), "llama3.1");
2829 }
2830
2831 #[test]
2832 fn test_chat_request_no_temperature() {
2833 let request = ChatRequest {
2834 model: "gpt-4o-mini".to_string(),
2835 messages: vec![ChatMessage::new("user", "Hello!")],
2836 temperature: None,
2837 max_tokens: None,
2838 stream: None,
2839 tools: None,
2840 tool_choice: None,
2841 };
2842
2843 let json = serde_json::to_string(&request).unwrap();
2844 assert!(json.contains("gpt-4o-mini"));
2845 assert!(json.contains("Hello!"));
2846 assert!(!json.contains("temperature"));
2848 assert!(!json.contains("max_tokens"));
2849 assert!(!json.contains("stream"));
2850 }
2851
2852 #[test]
2853 fn test_chat_response_deserialization_no_usage() {
2854 let json = r#"{
2855 "id": "chat-456",
2856 "object": "chat.completion",
2857 "created": 1717000001,
2858 "model": "gpt-4o",
2859 "choices": [
2860 {
2861 "index": 0,
2862 "message": {
2863 "role": "assistant",
2864 "content": "Sure!"
2865 },
2866 "finish_reason": "stop"
2867 }
2868 ]
2869 }"#;
2870
2871 let response: ChatResponse = serde_json::from_str(json).unwrap();
2872 assert_eq!(response.id, "chat-456");
2873 assert_eq!(response.model, "gpt-4o");
2874 assert_eq!(response.choices.len(), 1);
2875 assert!(response.usage.is_none());
2876 }
2877
2878 #[test]
2879 fn test_chat_response_deserialization_multiple_choices() {
2880 let json = r#"{
2881 "id": "chat-789",
2882 "object": "chat.completion",
2883 "created": 1717000002,
2884 "model": "gpt-4o",
2885 "choices": [
2886 {
2887 "index": 0,
2888 "message": {
2889 "role": "assistant",
2890 "content": "First choice"
2891 },
2892 "finish_reason": "stop"
2893 },
2894 {
2895 "index": 1,
2896 "message": {
2897 "role": "assistant",
2898 "content": "Second choice"
2899 },
2900 "finish_reason": "stop"
2901 }
2902 ]
2903 }"#;
2904
2905 let response: ChatResponse = serde_json::from_str(json).unwrap();
2906 assert_eq!(response.choices.len(), 2);
2907 assert_eq!(response.choices[0].message.content, "First choice");
2908 assert_eq!(response.choices[1].message.content, "Second choice");
2909 }
2910
2911 #[test]
2912 fn test_llm_error_into_boxed() {
2913 let err = LLMError::AuthFailed;
2914 let boxed: Box<dyn std::error::Error> = Box::new(err);
2915 assert!(format!("{}", boxed).contains("Authentication failed"));
2916 }
2917
2918 #[test]
2919 fn test_llm_error_into_string() {
2920 let err = LLMError::RateLimited;
2921 let msg: String = err.to_string();
2922 assert_eq!(msg, "Rate limit exceeded");
2923 }
2924
2925 #[test]
2926 fn test_create_client_with_empty_api_key() {
2927 let config = LLMConfig {
2929 provider: LLMProvider::Ollama,
2930 endpoint: "http://localhost:11434".to_string(),
2931 model: "llama3.1".to_string(),
2932 api_key: None,
2933 timeout_secs: 30,
2934 system_prompt: crate::config::default_system_prompt(),
2935 token_budget: None,
2936 retry_max: 3,
2937 retry_base_delay_ms: 100,
2938 retry_max_delay_ms: 10000,
2939 };
2940
2941 let client = create_client(&config).unwrap();
2942 assert_eq!(client.provider_name(), "ollama");
2943 }
2944
2945 #[test]
2946 fn test_multi_model_manager_get_client_out_of_bounds() {
2947 let manager = MultiModelManager::new(vec![]).unwrap();
2948 assert!(manager.get_client(0).is_none());
2949 assert!(manager.get_client(100).is_none());
2950 assert!(manager.get_client(usize::MAX).is_none());
2951 }
2952
2953 #[test]
2954 fn test_multi_model_next_client_empty() {
2955 let manager = MultiModelManager::new(vec![]).unwrap();
2956 assert!(manager.next_client(0).is_none());
2957 }
2958
2959 #[test]
2960 fn test_multi_model_next_client_single() {
2961 let config = LLMConfig {
2962 provider: LLMProvider::LiteLLM,
2963 endpoint: "http://localhost:4000".to_string(),
2964 model: "gpt-4o-mini".to_string(),
2965 api_key: Some("test".to_string()),
2966 timeout_secs: 30,
2967 system_prompt: crate::config::default_system_prompt(),
2968 token_budget: None,
2969 retry_max: 3,
2970 retry_base_delay_ms: 100,
2971 retry_max_delay_ms: 10000,
2972 };
2973
2974 let manager = MultiModelManager::new(vec![config]).unwrap();
2975 let next = manager.next_client(0).unwrap();
2977 assert_eq!(next.provider_name(), "litellm");
2978 }
2979}