1use futures::Stream;
7use reqwest::{Client, Response};
8use serde::{Deserialize, Serialize};
9use std::pin::Pin;
10use std::sync::Arc;
11use std::time::Duration;
12use thiserror::Error;
13use tokio::time::sleep;
14use tracing::instrument;
15
16#[derive(Debug, Clone)]
18#[allow(dead_code)]
19pub struct StreamChunk {
20 pub content: String,
21 #[allow(dead_code)]
22 pub finish_reason: Option<String>,
23}
24
25#[allow(dead_code)]
27pub type StreamResult = Pin<Box<dyn Stream<Item = Result<StreamChunk, LLMError>> + Send>>;
28
29use crate::config::{LLMConfig, LLMProvider};
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36#[non_exhaustive]
37pub enum OpenAICompatibleProvider {
38 LiteLLM,
39 OpenAI,
40 OpenRouter,
41 Generic,
43 Azure,
45}
46
47impl OpenAICompatibleProvider {
48 pub fn default_endpoint(&self) -> &'static str {
50 match self {
51 OpenAICompatibleProvider::LiteLLM => "http://localhost:4000",
52 OpenAICompatibleProvider::OpenAI => "https://api.openai.com",
53 OpenAICompatibleProvider::OpenRouter => "https://openrouter.ai",
54 OpenAICompatibleProvider::Generic => "http://localhost:8000",
55 OpenAICompatibleProvider::Azure => "https://YOUR_RESOURCE.openai.azure.com",
56 }
57 }
58
59 pub fn name(&self) -> &'static str {
61 match self {
62 OpenAICompatibleProvider::LiteLLM => "litellm",
63 OpenAICompatibleProvider::OpenAI => "openai",
64 OpenAICompatibleProvider::OpenRouter => "openrouter",
65 OpenAICompatibleProvider::Generic => "openai-compatible",
66 OpenAICompatibleProvider::Azure => "azure",
67 }
68 }
69
70 #[allow(dead_code)]
72 pub fn requires_custom_headers(&self) -> bool {
73 matches!(
74 self,
75 OpenAICompatibleProvider::OpenRouter | OpenAICompatibleProvider::Azure
76 )
77 }
78}
79
80#[derive(Error, Debug)]
86#[non_exhaustive]
87pub enum LLMError {
88 #[error("Request failed: {0}")]
89 RequestFailed(String),
90
91 #[error("Invalid response: {0}")]
92 InvalidResponse(String),
93
94 #[error("Rate limit exceeded")]
95 RateLimited,
96
97 #[error("Authentication failed")]
98 AuthFailed,
99
100 #[error("Provider not supported: {0}")]
101 #[allow(dead_code)]
102 ProviderNotSupported(String),
103
104 #[error("Token budget exceeded")]
105 TokenBudgetExceeded,
106
107 #[error("All providers failed after retries")]
108 AllProvidersFailed,
109
110 #[error("Circuit breaker open for provider: {0}")]
111 CircuitBreakerOpen(String),
112}
113
114#[derive(Debug, Clone)]
116pub struct RetryConfig {
117 pub max_retries: u32,
119 pub base_delay_ms: u64,
121 pub max_delay_ms: u64,
123 pub jitter: f64,
125}
126
127impl Default for RetryConfig {
128 fn default() -> Self {
129 Self {
130 max_retries: 3,
131 base_delay_ms: 100,
132 max_delay_ms: 10000,
133 jitter: 0.5,
134 }
135 }
136}
137
138impl RetryConfig {
139 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
141 use rand::Rng;
142 let exp = 2u64.pow(attempt);
143 let base = self.base_delay_ms * exp;
144 let capped = base.min(self.max_delay_ms);
145 let jitter_range = (capped as f64) * self.jitter;
146 let jitter = rand::thread_rng().gen_range(-jitter_range..=jitter_range) as u64;
147 let delay = capped.saturating_add(jitter).max(self.base_delay_ms);
148 Duration::from_millis(delay)
149 }
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157#[non_exhaustive]
158pub enum CircuitState {
159 Closed,
160 Open,
161 HalfOpen,
162}
163
164#[derive(Debug)]
166pub struct CircuitBreaker {
167 pub state: CircuitState,
168 pub failure_count: u32,
169 pub last_failure_time: Option<std::time::Instant>,
170 pub open_duration: Duration,
171}
172
173impl CircuitBreaker {
174 pub fn new(open_duration_secs: u64) -> Self {
175 Self {
176 state: CircuitState::Closed,
177 failure_count: 0,
178 last_failure_time: None,
179 open_duration: Duration::from_secs(open_duration_secs),
180 }
181 }
182
183 pub fn record_success(&mut self) {
184 self.failure_count = 0;
185 self.state = CircuitState::Closed;
186 }
187
188 pub fn record_failure(&mut self) {
189 self.failure_count += 1;
190 self.last_failure_time = Some(std::time::Instant::now());
191 if self.failure_count >= 5 {
192 self.state = CircuitState::Open;
193 }
194 }
195
196 pub fn can_execute(&mut self) -> bool {
197 match self.state {
198 CircuitState::Closed => true,
199 CircuitState::Open => {
200 if let Some(last) = self.last_failure_time {
201 if last.elapsed() >= self.open_duration {
202 self.state = CircuitState::HalfOpen;
203 return true;
204 }
205 }
206 false
207 }
208 CircuitState::HalfOpen => true,
209 }
210 }
211}
212
213#[derive(Debug, Clone)]
215pub struct TokenBudget {
216 pub max_tokens: u32,
218 pub used_tokens: u32,
220 pub cost_per_1k: f64,
222}
223
224#[allow(dead_code)]
225impl TokenBudget {
226 pub fn new(max_tokens: u32, cost_per_1k: f64) -> Self {
227 Self {
228 max_tokens,
229 used_tokens: 0,
230 cost_per_1k,
231 }
232 }
233
234 pub fn remaining(&self) -> u32 {
235 self.max_tokens.saturating_sub(self.used_tokens)
236 }
237
238 pub fn can_spend(&self, tokens: u32) -> bool {
239 self.remaining() >= tokens
240 }
241
242 pub fn record_usage(&mut self, tokens: u32) {
243 self.used_tokens = self.used_tokens.saturating_add(tokens);
244 }
245
246 pub fn estimated_cost(&self) -> f64 {
247 (self.used_tokens as f64 / 1000.0) * self.cost_per_1k
248 }
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
252pub struct ChatMessage {
253 pub role: String,
254 pub content: String,
255}
256
257#[derive(Debug, Clone, Serialize)]
258pub struct ChatRequest {
259 pub model: String,
260 pub messages: Vec<ChatMessage>,
261 #[serde(skip_serializing_if = "Option::is_none")]
262 pub temperature: Option<f32>,
263 #[serde(skip_serializing_if = "Option::is_none")]
264 pub max_tokens: Option<u32>,
265 #[serde(skip_serializing_if = "Option::is_none")]
266 pub stream: Option<bool>,
267 #[serde(skip_serializing_if = "Option::is_none")]
268 pub tools: Option<Vec<serde_json::Value>>,
269 #[serde(skip_serializing_if = "Option::is_none")]
270 pub tool_choice: Option<String>,
271}
272
273#[derive(Debug, Clone, Deserialize)]
274pub struct ChatResponse {
275 #[allow(dead_code)]
276 pub id: String,
277 #[allow(dead_code)]
278 pub object: String,
279 #[allow(dead_code)]
280 pub created: u64,
281 #[allow(dead_code)]
282 pub model: String,
283 pub choices: Vec<Choice>,
284 #[allow(dead_code)]
285 pub usage: Option<Usage>,
286}
287
288#[derive(Debug, Clone, Deserialize)]
289pub struct ToolCallResponse {
290 #[allow(dead_code)]
291 pub id: String,
292 #[allow(dead_code)]
293 #[serde(rename = "type")]
294 pub call_type: String,
295 pub function: FunctionCall,
296}
297
298#[derive(Debug, Clone, Deserialize)]
299pub struct FunctionCall {
300 pub name: String,
301 pub arguments: String,
302}
303
304#[derive(Debug, Clone, Deserialize)]
305pub struct Choice {
306 #[allow(dead_code)]
307 pub index: u32,
308 pub message: ChatMessage,
309 #[allow(dead_code)]
310 pub finish_reason: Option<String>,
311 #[serde(default, skip_serializing_if = "Option::is_none")]
312 pub tool_calls: Option<Vec<ToolCallResponse>>,
313}
314
315#[derive(Debug, Clone, Deserialize)]
316pub struct Usage {
317 #[allow(dead_code)]
318 pub prompt_tokens: u32,
319 #[allow(dead_code)]
320 pub completion_tokens: u32,
321 #[allow(dead_code)]
322 pub total_tokens: u32,
323}
324
325#[async_trait::async_trait]
327pub trait LLMProviderTrait: Send + Sync {
328 async fn chat(&self, messages: Vec<ChatMessage>) -> Result<ChatResponse, LLMError>;
329 #[allow(dead_code)]
330 async fn chat_stream(&self, messages: Vec<ChatMessage>) -> Result<StreamResult, LLMError> {
331 let response = self.chat(messages).await?;
333 let content = response
334 .choices
335 .first()
336 .map(|c| c.message.content.clone())
337 .unwrap_or_default();
338 let finish_reason = response
339 .choices
340 .first()
341 .and_then(|c| c.finish_reason.clone());
342
343 let stream = futures::stream::once(async move {
344 Ok(StreamChunk {
345 content,
346 finish_reason,
347 })
348 });
349 Ok(Box::pin(stream))
350 }
351 fn provider_name(&self) -> &str;
352 fn model(&self) -> &str;
353}
354
355async fn handle_openai_response(response: Response) -> Result<ChatResponse, LLMError> {
357 let status = response.status();
358
359 if status.is_success() {
360 response
361 .json::<ChatResponse>()
362 .await
363 .map_err(|e| LLMError::InvalidResponse(e.to_string()))
364 } else if status == reqwest::StatusCode::UNAUTHORIZED {
365 Err(LLMError::AuthFailed)
366 } else if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
367 Err(LLMError::RateLimited)
368 } else {
369 let body = response
370 .text()
371 .await
372 .unwrap_or_else(|_| "Unknown error".to_string());
373 Err(LLMError::RequestFailed(format!("{}: {}", status, body)))
374 }
375}
376
377pub struct OpenAICompatibleClient {
380 client: Client,
381 config: LLMConfig,
382 provider: OpenAICompatibleProvider,
383 retry_config: RetryConfig,
384 circuit_breaker: std::sync::Mutex<CircuitBreaker>,
385}
386
387impl OpenAICompatibleClient {
388 pub fn new(config: &LLMConfig, provider: OpenAICompatibleProvider) -> Result<Self, LLMError> {
389 let client = Client::builder()
390 .timeout(std::time::Duration::from_secs(config.timeout_secs))
391 .build()
392 .map_err(|e| LLMError::RequestFailed(format!("Failed to create HTTP client: {}", e)))?;
393
394 let retry_config = RetryConfig {
395 max_retries: config.retry_max,
396 base_delay_ms: config.retry_base_delay_ms,
397 max_delay_ms: config.retry_max_delay_ms,
398 jitter: 0.5,
399 };
400
401 Ok(Self {
402 client,
403 config: config.clone(),
404 provider,
405 retry_config,
406 circuit_breaker: std::sync::Mutex::new(CircuitBreaker::new(30)),
407 })
408 }
409
410 async fn send_request_with_retry(
412 &self,
413 request: ChatRequest,
414 ) -> Result<ChatResponse, LLMError> {
415 let mut last_error = None;
416
417 for attempt in 0..=self.retry_config.max_retries {
418 {
420 let mut cb = self.circuit_breaker.lock().map_err(|_| {
421 LLMError::RequestFailed("Circuit breaker lock poisoned".to_string())
422 })?;
423 if !cb.can_execute() {
424 return Err(LLMError::CircuitBreakerOpen(
425 self.provider.name().to_string(),
426 ));
427 }
428 }
429
430 let result = self.send_request_inner(request.clone()).await;
431
432 match result {
433 Ok(response) => {
434 {
436 let mut cb = self.circuit_breaker.lock().map_err(|_| {
437 LLMError::RequestFailed("Circuit breaker lock poisoned".to_string())
438 })?;
439 cb.record_success();
440 }
441 return Ok(response);
442 }
443 Err(e) => {
444 {
446 let mut cb = self.circuit_breaker.lock().map_err(|_| {
447 LLMError::RequestFailed("Circuit breaker lock poisoned".to_string())
448 })?;
449 cb.record_failure();
450 }
451
452 last_error = Some(e);
453
454 if matches!(last_error, Some(LLMError::AuthFailed)) {
456 return Err(last_error.unwrap());
457 }
458
459 if attempt < self.retry_config.max_retries {
461 let delay = self.retry_config.delay_for_attempt(attempt);
462 sleep(delay).await;
463 }
464 }
465 }
466 }
467
468 Err(last_error.unwrap_or(LLMError::AllProvidersFailed))
469 }
470
471 async fn send_request_inner(&self, request: ChatRequest) -> Result<ChatResponse, LLMError> {
473 let req = self.apply_headers(self.client.post(self.endpoint()).json(&request));
474
475 let response = req
476 .send()
477 .await
478 .map_err(|e| LLMError::RequestFailed(e.to_string()))?;
479
480 handle_openai_response(response).await
481 }
482
483 fn build_request(&self, messages: Vec<ChatMessage>) -> ChatRequest {
484 ChatRequest {
485 model: self.config.model.clone(),
486 messages,
487 temperature: Some(0.7),
488 max_tokens: Some(2048),
489 stream: None,
490 tools: None,
491 tool_choice: None,
492 }
493 }
494
495 fn endpoint(&self) -> String {
496 let base = if self.config.endpoint.is_empty() {
497 self.provider.default_endpoint()
498 } else {
499 &self.config.endpoint
500 };
501 let mut url = format!("{}/v1/chat/completions", base.trim_end_matches('/'));
502 if self.provider == OpenAICompatibleProvider::Azure {
503 if !url.contains("api-version") {
506 url = format!("{}?api-version=2024-02-15-preview", url);
507 }
508 }
509 url
510 }
511
512 fn apply_headers(&self, mut req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
513 if let Some(ref key) = self.config.api_key {
514 if self.provider == OpenAICompatibleProvider::Azure {
515 req = req.header("api-key", key);
517 } else {
518 req = req.header("Authorization", format!("Bearer {}", key));
519 }
520 }
521
522 if self.provider == OpenAICompatibleProvider::OpenRouter {
524 req = req
525 .header("HTTP-Referer", "https://github.com/egkristi/RavenClaws")
526 .header("X-Title", "RavenClaws");
527 }
528
529 req
530 }
531
532 #[allow(dead_code)]
533 async fn send_request(&self, request: ChatRequest) -> Result<ChatResponse, LLMError> {
534 let req = self.apply_headers(self.client.post(self.endpoint()).json(&request));
535
536 let response = req
537 .send()
538 .await
539 .map_err(|e| LLMError::RequestFailed(e.to_string()))?;
540
541 handle_openai_response(response).await
542 }
543}
544
545#[async_trait::async_trait]
546impl LLMProviderTrait for OpenAICompatibleClient {
547 #[instrument(skip(self, messages), fields(provider = self.provider_name(), model = self.model()))]
548 async fn chat(&self, messages: Vec<ChatMessage>) -> Result<ChatResponse, LLMError> {
549 let request = self.build_request(messages);
550 self.send_request_with_retry(request).await
551 }
552
553 async fn chat_stream(&self, messages: Vec<ChatMessage>) -> Result<StreamResult, LLMError> {
554 let request = ChatRequest {
555 model: self.config.model.clone(),
556 messages,
557 temperature: Some(0.7),
558 max_tokens: Some(2048),
559 stream: Some(true),
560 tools: None,
561 tool_choice: None,
562 };
563
564 let req = self.apply_headers(self.client.post(self.endpoint()).json(&request));
565
566 let response = req
567 .send()
568 .await
569 .map_err(|e| LLMError::RequestFailed(e.to_string()))?;
570
571 let status = response.status();
572 if !status.is_success() {
573 if status == reqwest::StatusCode::UNAUTHORIZED {
574 return Err(LLMError::AuthFailed);
575 } else if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
576 return Err(LLMError::RateLimited);
577 } else {
578 let body = response
579 .text()
580 .await
581 .unwrap_or_else(|_| "Unknown error".to_string());
582 return Err(LLMError::RequestFailed(format!("{}: {}", status, body)));
583 }
584 }
585
586 use futures::StreamExt;
588 let stream = response
589 .bytes_stream()
590 .filter_map(|chunk_result| async move {
591 match chunk_result {
592 Err(e) => Some(Err(LLMError::RequestFailed(e.to_string()))),
593 Ok(bytes) => {
594 let text = String::from_utf8_lossy(&bytes);
595 let mut content = String::new();
596 let mut finish_reason = None;
597
598 for line in text.lines() {
599 if let Some(data) = line.strip_prefix("data: ") {
600 if data == "[DONE]" {
601 finish_reason = Some("stop".to_string());
602 continue;
603 }
604 if let Ok(sse_chunk) =
605 serde_json::from_str::<serde_json::Value>(data)
606 {
607 if let Some(choice) =
608 sse_chunk["choices"].as_array().and_then(|c| c.first())
609 {
610 if let Some(delta) = choice["delta"].as_object() {
611 if let Some(c) = delta["content"].as_str() {
612 content.push_str(c);
613 }
614 }
615 if let Some(reason) = choice["finish_reason"].as_str() {
616 if reason != "null" {
617 finish_reason = Some(reason.to_string());
618 }
619 }
620 }
621 }
622 }
623 }
624
625 if content.is_empty() && finish_reason.is_none() {
626 None
627 } else {
628 Some(Ok(StreamChunk {
629 content,
630 finish_reason,
631 }))
632 }
633 }
634 }
635 });
636
637 Ok(Box::pin(stream))
638 }
639
640 fn provider_name(&self) -> &str {
641 self.provider.name()
642 }
643
644 fn model(&self) -> &str {
645 &self.config.model
646 }
647}
648
649pub struct OllamaClient {
651 client: Client,
652 config: LLMConfig,
653}
654
655impl OllamaClient {
656 pub fn new(config: &LLMConfig) -> Result<Self, LLMError> {
657 let client = Client::builder()
658 .timeout(std::time::Duration::from_secs(config.timeout_secs))
659 .build()
660 .map_err(|e| LLMError::RequestFailed(format!("Failed to create HTTP client: {}", e)))?;
661
662 Ok(Self {
663 client,
664 config: config.clone(),
665 })
666 }
667}
668
669#[async_trait::async_trait]
670impl LLMProviderTrait for OllamaClient {
671 #[instrument(skip(self, messages), fields(provider = self.provider_name(), model = self.model()))]
672 async fn chat(&self, messages: Vec<ChatMessage>) -> Result<ChatResponse, LLMError> {
673 #[derive(Serialize)]
675 struct OllamaRequest {
676 model: String,
677 messages: Vec<ChatMessage>,
678 stream: bool,
679 }
680
681 let request = OllamaRequest {
682 model: self.config.model.clone(),
683 messages,
684 stream: false,
685 };
686
687 let response = self
688 .client
689 .post(format!(
690 "{}/api/chat",
691 self.config.endpoint.trim_end_matches('/')
692 ))
693 .json(&request)
694 .send()
695 .await
696 .map_err(|e| LLMError::RequestFailed(e.to_string()))?;
697
698 let status = response.status();
699
700 if status.is_success() {
701 #[derive(Deserialize)]
703 struct OllamaResponse {
704 model: String,
705 message: ChatMessage,
706 done: bool,
707 }
708
709 let ollama_resp = response
710 .json::<OllamaResponse>()
711 .await
712 .map_err(|e| LLMError::InvalidResponse(e.to_string()))?;
713
714 Ok(ChatResponse {
715 id: format!("ollama-{}", uuid::Uuid::new_v4()),
716 object: "chat.completion".to_string(),
717 created: std::time::SystemTime::now()
718 .duration_since(std::time::UNIX_EPOCH)
719 .unwrap()
720 .as_secs(),
721 model: ollama_resp.model,
722 choices: vec![Choice {
723 index: 0,
724 message: ollama_resp.message,
725 finish_reason: if ollama_resp.done {
726 Some("stop".to_string())
727 } else {
728 None
729 },
730 tool_calls: None,
731 }],
732 usage: None, })
734 } else if status == reqwest::StatusCode::UNAUTHORIZED {
735 Err(LLMError::AuthFailed)
736 } else {
737 let body = response
738 .text()
739 .await
740 .unwrap_or_else(|_| "Unknown error".to_string());
741 Err(LLMError::RequestFailed(format!("{}: {}", status, body)))
742 }
743 }
744
745 fn provider_name(&self) -> &str {
746 "ollama"
747 }
748
749 fn model(&self) -> &str {
750 &self.config.model
751 }
752}
753
754pub struct AnthropicClient {
756 client: Client,
757 config: LLMConfig,
758}
759
760impl AnthropicClient {
761 pub fn new(config: &LLMConfig) -> Result<Self, LLMError> {
762 let client = Client::builder()
763 .timeout(std::time::Duration::from_secs(config.timeout_secs))
764 .build()
765 .map_err(|e| LLMError::RequestFailed(format!("Failed to create HTTP client: {}", e)))?;
766
767 Ok(Self {
768 client,
769 config: config.clone(),
770 })
771 }
772}
773
774#[async_trait::async_trait]
775impl LLMProviderTrait for AnthropicClient {
776 #[instrument(skip(self, messages), fields(provider = self.provider_name(), model = self.model()))]
777 async fn chat(&self, messages: Vec<ChatMessage>) -> Result<ChatResponse, LLMError> {
778 #[derive(Serialize)]
780 struct AnthropicRequest {
781 model: String,
782 max_tokens: u32,
783 messages: Vec<AnthropicMessage>,
784 #[serde(skip_serializing_if = "Option::is_none")]
785 system: Option<String>,
786 #[serde(skip_serializing_if = "Option::is_none")]
787 temperature: Option<f32>,
788 }
789
790 #[derive(Serialize)]
791 struct AnthropicMessage {
792 role: String,
793 content: String,
794 }
795
796 let system = messages
798 .iter()
799 .find(|m| m.role == "system")
800 .map(|m| m.content.clone());
801
802 let anthropic_messages: Vec<AnthropicMessage> = messages
803 .into_iter()
804 .filter(|m| m.role != "system")
805 .map(|m| AnthropicMessage {
806 role: if m.role == "user" {
807 "user".to_string()
808 } else {
809 "assistant".to_string()
810 },
811 content: m.content,
812 })
813 .collect();
814
815 let request = AnthropicRequest {
816 model: self.config.model.clone(),
817 max_tokens: 2048,
818 messages: anthropic_messages,
819 system,
820 temperature: Some(0.7),
821 };
822
823 let api_key = self
824 .config
825 .api_key
826 .clone()
827 .ok_or_else(|| LLMError::AuthFailed)?;
828
829 let response = self
830 .client
831 .post("https://api.anthropic.com/v1/messages")
832 .header("x-api-key", api_key)
833 .header("anthropic-version", "2023-06-01")
834 .header("content-type", "application/json")
835 .json(&request)
836 .send()
837 .await
838 .map_err(|e| LLMError::RequestFailed(e.to_string()))?;
839
840 let status = response.status();
841
842 if status.is_success() {
843 #[derive(Deserialize)]
845 #[allow(dead_code)]
846 struct AnthropicResponse {
847 id: String,
848 #[serde(rename = "type")]
849 response_type: String,
850 role: String,
851 content: Vec<AnthropicContentBlock>,
852 model: String,
853 stop_reason: Option<String>,
854 #[serde(default)]
855 usage: Option<AnthropicUsage>,
856 }
857
858 #[derive(Deserialize)]
859 #[serde(tag = "type", rename_all = "lowercase")]
860 enum AnthropicContentBlock {
861 Text {
862 text: String,
863 },
864 ToolUse {
865 id: String,
866 name: String,
867 input: serde_json::Value,
868 },
869 }
870
871 #[derive(Deserialize)]
872 struct AnthropicUsage {
873 input_tokens: u32,
874 output_tokens: u32,
875 }
876
877 let anthropic_resp = response
878 .json::<AnthropicResponse>()
879 .await
880 .map_err(|e| LLMError::InvalidResponse(e.to_string()))?;
881
882 let mut content = String::new();
884 let mut tool_calls = None;
885
886 for block in anthropic_resp.content {
887 match block {
888 AnthropicContentBlock::Text { text } => {
889 content.push_str(&text);
890 }
891 AnthropicContentBlock::ToolUse { id, name, input } => {
892 if tool_calls.is_none() {
893 tool_calls = Some(Vec::new());
894 }
895 if let Some(ref mut calls) = tool_calls {
896 calls.push(ToolCallResponse {
897 id,
898 call_type: "function".to_string(),
899 function: FunctionCall {
900 name,
901 arguments: input.to_string(),
902 },
903 });
904 }
905 }
906 }
907 }
908
909 Ok(ChatResponse {
910 id: anthropic_resp.id,
911 object: "chat.completion".to_string(),
912 created: std::time::SystemTime::now()
913 .duration_since(std::time::UNIX_EPOCH)
914 .unwrap()
915 .as_secs(),
916 model: anthropic_resp.model,
917 choices: vec![Choice {
918 index: 0,
919 message: ChatMessage {
920 role: "assistant".to_string(),
921 content,
922 },
923 finish_reason: anthropic_resp.stop_reason,
924 tool_calls,
925 }],
926 usage: anthropic_resp.usage.map(|u| Usage {
927 prompt_tokens: u.input_tokens,
928 completion_tokens: u.output_tokens,
929 total_tokens: u.input_tokens + u.output_tokens,
930 }),
931 })
932 } else if status == reqwest::StatusCode::UNAUTHORIZED {
933 Err(LLMError::AuthFailed)
934 } else if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
935 Err(LLMError::RateLimited)
936 } else {
937 let body = response
938 .text()
939 .await
940 .unwrap_or_else(|_| "Unknown error".to_string());
941 Err(LLMError::RequestFailed(format!("{}: {}", status, body)))
942 }
943 }
944
945 fn provider_name(&self) -> &str {
946 "anthropic"
947 }
948
949 fn model(&self) -> &str {
950 &self.config.model
951 }
952}
953
954pub fn create_client(config: &LLMConfig) -> Result<Arc<dyn LLMProviderTrait>, LLMError> {
956 match config.provider {
957 LLMProvider::LiteLLM => {
958 let unified = OpenAICompatibleClient::new(config, OpenAICompatibleProvider::LiteLLM)?;
959 Ok(Arc::new(unified))
960 }
961 LLMProvider::OpenRouter => {
962 let unified =
963 OpenAICompatibleClient::new(config, OpenAICompatibleProvider::OpenRouter)?;
964 Ok(Arc::new(unified))
965 }
966 LLMProvider::Ollama => Ok(Arc::new(OllamaClient::new(config)?)),
967 LLMProvider::OpenAI => {
968 let unified = OpenAICompatibleClient::new(config, OpenAICompatibleProvider::OpenAI)?;
969 Ok(Arc::new(unified))
970 }
971 LLMProvider::Anthropic => Ok(Arc::new(AnthropicClient::new(config)?)),
972 LLMProvider::OpenAICompatible => {
973 let unified = OpenAICompatibleClient::new(config, OpenAICompatibleProvider::Generic)?;
974 Ok(Arc::new(unified))
975 }
976 LLMProvider::Azure => {
977 let unified = OpenAICompatibleClient::new(config, OpenAICompatibleProvider::Azure)?;
978 Ok(Arc::new(unified))
979 }
980 }
981}
982
983#[derive(Clone)]
985pub struct MultiModelManager {
986 clients: Vec<Arc<dyn LLMProviderTrait>>,
987}
988
989impl MultiModelManager {
990 pub fn new(configs: Vec<LLMConfig>) -> Result<Self, LLMError> {
991 let clients: Result<Vec<_>, _> = configs.iter().map(create_client).collect();
992 Ok(Self { clients: clients? })
993 }
994
995 pub fn get_client(&self, index: usize) -> Option<&Arc<dyn LLMProviderTrait>> {
996 self.clients.get(index)
997 }
998
999 pub fn client_count(&self) -> usize {
1000 self.clients.len()
1001 }
1002
1003 pub fn next_client(&self, last_index: usize) -> Option<&Arc<dyn LLMProviderTrait>> {
1005 if self.clients.is_empty() {
1006 return None;
1007 }
1008 let next = (last_index + 1) % self.clients.len();
1009 Some(&self.clients[next])
1010 }
1011}
1012
1013#[derive(Debug)]
1015pub struct ProviderFallbackChain {
1016 pub configs: Vec<LLMConfig>,
1018 token_budget: Option<TokenBudget>,
1019}
1020
1021impl ProviderFallbackChain {
1022 pub fn new(configs: Vec<LLMConfig>) -> Self {
1023 Self {
1024 configs,
1025 token_budget: None,
1026 }
1027 }
1028
1029 pub fn with_token_budget(mut self, budget: TokenBudget) -> Self {
1030 self.token_budget = Some(budget);
1031 self
1032 }
1033
1034 #[instrument(skip(self, messages))]
1036 pub async fn chat_with_fallback(
1037 &mut self,
1038 messages: Vec<ChatMessage>,
1039 ) -> Result<ChatResponse, LLMError> {
1040 let mut last_error = None;
1041
1042 for (i, config) in self.configs.iter().enumerate() {
1043 let client = match create_client(config) {
1044 Ok(c) => c,
1045 Err(e) => {
1046 tracing::warn!(
1047 "Failed to create client for provider {:?}: {}",
1048 config.provider,
1049 e
1050 );
1051 last_error = Some(e);
1052 continue;
1053 }
1054 };
1055
1056 if let Some(ref budget) = self.token_budget {
1058 if !budget.can_spend(500) {
1060 return Err(LLMError::TokenBudgetExceeded);
1061 }
1062 }
1063
1064 match client.chat(messages.clone()).await {
1065 Ok(response) => {
1066 if let Some(ref mut budget) = self.token_budget {
1068 if let Some(usage) = &response.usage {
1069 budget.record_usage(usage.total_tokens);
1070 }
1071 }
1072 return Ok(response);
1073 }
1074 Err(e) => {
1075 tracing::warn!("Provider {} failed: {}", i, e);
1076 last_error = Some(e);
1077 }
1079 }
1080 }
1081
1082 Err(last_error.unwrap_or(LLMError::AllProvidersFailed))
1083 }
1084
1085 #[allow(dead_code)]
1087 pub fn provider_names(&self) -> Vec<String> {
1088 self.configs
1089 .iter()
1090 .map(|c| format!("{:?}", c.provider))
1091 .collect()
1092 }
1093}
1094
1095#[cfg(test)]
1096mod tests {
1097 use super::*;
1098 use mockito::Server;
1099
1100 fn make_chat_messages() -> Vec<ChatMessage> {
1103 vec![
1104 ChatMessage {
1105 role: "system".to_string(),
1106 content: "You are helpful.".to_string(),
1107 },
1108 ChatMessage {
1109 role: "user".to_string(),
1110 content: "Hello!".to_string(),
1111 },
1112 ]
1113 }
1114
1115 fn sample_chat_response_json(model: &str) -> String {
1116 format!(
1117 r#"{{
1118 "id": "chat-123",
1119 "object": "chat.completion",
1120 "created": 1717000000,
1121 "model": "{}",
1122 "choices": [
1123 {{
1124 "index": 0,
1125 "message": {{
1126 "role": "assistant",
1127 "content": "Hi there!"
1128 }},
1129 "finish_reason": "stop"
1130 }}
1131 ],
1132 "usage": {{
1133 "prompt_tokens": 10,
1134 "completion_tokens": 5,
1135 "total_tokens": 15
1136 }}
1137 }}"#,
1138 model
1139 )
1140 }
1141
1142 fn sample_ollama_response_json(model: &str) -> String {
1143 format!(
1144 r#"{{
1145 "model": "{}",
1146 "message": {{
1147 "role": "assistant",
1148 "content": "Hi there!"
1149 }},
1150 "done": true
1151 }}"#,
1152 model
1153 )
1154 }
1155
1156 fn with_mockito<F, Fut>(f: F)
1160 where
1161 F: FnOnce(mockito::ServerGuard) -> Fut,
1162 Fut: std::future::Future<Output = ()>,
1163 {
1164 let server = Server::new();
1165 let rt = tokio::runtime::Runtime::new().unwrap();
1166 rt.block_on(f(server));
1167 }
1168
1169 #[test]
1172 fn test_openai_compatible_provider_defaults() {
1173 assert_eq!(
1174 OpenAICompatibleProvider::LiteLLM.default_endpoint(),
1175 "http://localhost:4000"
1176 );
1177 assert_eq!(
1178 OpenAICompatibleProvider::OpenAI.default_endpoint(),
1179 "https://api.openai.com"
1180 );
1181 assert_eq!(
1182 OpenAICompatibleProvider::OpenRouter.default_endpoint(),
1183 "https://openrouter.ai"
1184 );
1185 }
1186
1187 #[test]
1188 fn test_openai_compatible_provider_names() {
1189 assert_eq!(OpenAICompatibleProvider::LiteLLM.name(), "litellm");
1190 assert_eq!(OpenAICompatibleProvider::OpenAI.name(), "openai");
1191 assert_eq!(OpenAICompatibleProvider::OpenRouter.name(), "openrouter");
1192 }
1193
1194 #[test]
1195 fn test_openai_compatible_requires_custom_headers() {
1196 assert!(!OpenAICompatibleProvider::LiteLLM.requires_custom_headers());
1197 assert!(OpenAICompatibleProvider::OpenRouter.requires_custom_headers());
1198 assert!(!OpenAICompatibleProvider::OpenAI.requires_custom_headers());
1199 }
1200
1201 #[test]
1202 fn test_openai_compatible_client_new() {
1203 let config = LLMConfig {
1204 provider: LLMProvider::LiteLLM,
1205 endpoint: "http://localhost:4000".to_string(),
1206 model: "gpt-4o-mini".to_string(),
1207 api_key: Some("test-key".to_string()),
1208 timeout_secs: 30,
1209 system_prompt: crate::config::default_system_prompt(),
1210 token_budget: None,
1211 retry_max: 3,
1212 retry_base_delay_ms: 100,
1213 retry_max_delay_ms: 10000,
1214 };
1215
1216 let client = OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM);
1217 assert!(client.is_ok());
1218 assert_eq!(client.unwrap().provider_name(), "litellm");
1219 }
1220
1221 #[test]
1222 fn test_openai_compatible_client_endpoint() {
1223 let config = LLMConfig {
1225 provider: LLMProvider::OpenAI,
1226 endpoint: "https://custom.api.example.com".to_string(),
1227 model: "gpt-4o".to_string(),
1228 api_key: Some("test-key".to_string()),
1229 timeout_secs: 30,
1230 system_prompt: crate::config::default_system_prompt(),
1231 token_budget: None,
1232 retry_max: 3,
1233 retry_base_delay_ms: 100,
1234 retry_max_delay_ms: 10000,
1235 };
1236
1237 let client =
1238 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
1239 assert_eq!(client.provider_name(), "openai");
1241 }
1242
1243 #[test]
1244 fn test_openai_compatible_client_chat_success() {
1245 with_mockito(|mut server| async move {
1246 let mock = server
1247 .mock("POST", "/v1/chat/completions")
1248 .with_status(200)
1249 .with_header("content-type", "application/json")
1250 .with_body(sample_chat_response_json("gpt-4o-mini"))
1251 .create();
1252
1253 let config = LLMConfig {
1254 provider: LLMProvider::LiteLLM,
1255 endpoint: server.url(),
1256 model: "gpt-4o-mini".to_string(),
1257 api_key: Some("test-key".to_string()),
1258 timeout_secs: 30,
1259 system_prompt: crate::config::default_system_prompt(),
1260 token_budget: None,
1261 retry_max: 3,
1262 retry_base_delay_ms: 100,
1263 retry_max_delay_ms: 10000,
1264 };
1265
1266 let client =
1267 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1268 let response = client.chat(make_chat_messages()).await.unwrap();
1269
1270 assert_eq!(response.model, "gpt-4o-mini");
1271 assert_eq!(response.choices[0].message.content, "Hi there!");
1272 mock.assert();
1273 });
1274 }
1275
1276 #[test]
1277 fn test_openai_compatible_client_auth_failure() {
1278 with_mockito(|mut server| async move {
1279 let mock = server
1280 .mock("POST", "/v1/chat/completions")
1281 .with_status(401)
1282 .with_body(r#"{"error": "Unauthorized"}"#)
1283 .create();
1284
1285 let config = LLMConfig {
1286 provider: LLMProvider::LiteLLM,
1287 endpoint: server.url(),
1288 model: "gpt-4o-mini".to_string(),
1289 api_key: Some("bad-key".to_string()),
1290 timeout_secs: 30,
1291 system_prompt: crate::config::default_system_prompt(),
1292 token_budget: None,
1293 retry_max: 3,
1294 retry_base_delay_ms: 100,
1295 retry_max_delay_ms: 10000,
1296 };
1297
1298 let client =
1299 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1300 let err = client.chat(make_chat_messages()).await.unwrap_err();
1301
1302 assert!(matches!(err, LLMError::AuthFailed));
1303 mock.assert();
1304 });
1305 }
1306
1307 #[test]
1308 fn test_openai_compatible_client_rate_limit() {
1309 with_mockito(|mut server| async move {
1310 let mock = server
1311 .mock("POST", "/v1/chat/completions")
1312 .with_status(429)
1313 .with_body(r#"{"error": "Rate limited"}"#)
1314 .create();
1315
1316 let config = LLMConfig {
1317 provider: LLMProvider::LiteLLM,
1318 endpoint: server.url(),
1319 model: "gpt-4o-mini".to_string(),
1320 api_key: Some("test-key".to_string()),
1321 timeout_secs: 30,
1322 system_prompt: crate::config::default_system_prompt(),
1323 token_budget: None,
1324 retry_max: 0, retry_base_delay_ms: 100,
1326 retry_max_delay_ms: 10000,
1327 };
1328
1329 let client =
1330 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1331 let err = client.chat(make_chat_messages()).await.unwrap_err();
1332
1333 assert!(matches!(err, LLMError::RateLimited));
1334 mock.assert();
1335 });
1336 }
1337
1338 #[test]
1339 fn test_openrouter_client_uses_custom_headers() {
1340 with_mockito(|mut server| async move {
1341 let mock = server
1342 .mock("POST", "/v1/chat/completions")
1343 .match_header("HTTP-Referer", "https://github.com/egkristi/RavenClaws")
1344 .match_header("X-Title", "RavenClaws")
1345 .with_status(200)
1346 .with_body(sample_chat_response_json("claude-sonnet-4"))
1347 .create();
1348
1349 let config = LLMConfig {
1350 provider: LLMProvider::OpenRouter,
1351 endpoint: server.url(),
1352 model: "claude-sonnet-4".to_string(),
1353 api_key: Some("or-key".to_string()),
1354 timeout_secs: 30,
1355 system_prompt: crate::config::default_system_prompt(),
1356 token_budget: None,
1357 retry_max: 3,
1358 retry_base_delay_ms: 100,
1359 retry_max_delay_ms: 10000,
1360 };
1361
1362 let client =
1363 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
1364 let _ = client.chat(make_chat_messages()).await.unwrap();
1365 mock.assert();
1366 });
1367 }
1368
1369 #[test]
1372 fn test_anthropic_client_new() {
1373 let config = LLMConfig {
1374 provider: LLMProvider::Anthropic,
1375 endpoint: String::new(),
1376 model: "claude-sonnet-4-20250514".to_string(),
1377 api_key: Some("sk-ant-test".to_string()),
1378 timeout_secs: 30,
1379 system_prompt: crate::config::default_system_prompt(),
1380 token_budget: None,
1381 retry_max: 3,
1382 retry_base_delay_ms: 100,
1383 retry_max_delay_ms: 10000,
1384 };
1385
1386 let client = AnthropicClient::new(&config);
1387 assert!(client.is_ok());
1388 }
1389
1390 #[test]
1391 fn test_anthropic_client_provider_name() {
1392 let config = LLMConfig {
1393 provider: LLMProvider::Anthropic,
1394 endpoint: String::new(),
1395 model: "claude-sonnet-4-20250514".to_string(),
1396 api_key: Some("sk-ant-test".to_string()),
1397 timeout_secs: 30,
1398 system_prompt: crate::config::default_system_prompt(),
1399 token_budget: None,
1400 retry_max: 3,
1401 retry_base_delay_ms: 100,
1402 retry_max_delay_ms: 10000,
1403 };
1404
1405 let client = AnthropicClient::new(&config).unwrap();
1406 assert_eq!(client.provider_name(), "anthropic");
1407 }
1408
1409 #[test]
1410 fn test_anthropic_client_model() {
1411 let config = LLMConfig {
1412 provider: LLMProvider::Anthropic,
1413 endpoint: String::new(),
1414 model: "claude-opus-4-20250514".to_string(),
1415 api_key: Some("sk-ant-test".to_string()),
1416 timeout_secs: 30,
1417 system_prompt: crate::config::default_system_prompt(),
1418 token_budget: None,
1419 retry_max: 3,
1420 retry_base_delay_ms: 100,
1421 retry_max_delay_ms: 10000,
1422 };
1423
1424 let client = AnthropicClient::new(&config).unwrap();
1425 assert_eq!(client.model(), "claude-opus-4-20250514");
1426 }
1427
1428 #[test]
1429 fn test_create_client_anthropic() {
1430 let config = LLMConfig {
1431 provider: LLMProvider::Anthropic,
1432 endpoint: String::new(),
1433 model: "claude-sonnet-4-20250514".to_string(),
1434 api_key: Some("sk-ant-test".to_string()),
1435 timeout_secs: 30,
1436 system_prompt: crate::config::default_system_prompt(),
1437 token_budget: None,
1438 retry_max: 3,
1439 retry_base_delay_ms: 100,
1440 retry_max_delay_ms: 10000,
1441 };
1442
1443 let client = create_client(&config);
1444 assert!(client.is_ok());
1445 assert_eq!(client.unwrap().provider_name(), "anthropic");
1446 }
1447
1448 #[test]
1451 fn test_retry_config_delay_calculation() {
1452 let config = RetryConfig {
1453 max_retries: 3,
1454 base_delay_ms: 100,
1455 max_delay_ms: 10000,
1456 jitter: 0.0, };
1458
1459 assert_eq!(config.delay_for_attempt(0).as_millis(), 100);
1461 assert_eq!(config.delay_for_attempt(1).as_millis(), 200);
1462 assert_eq!(config.delay_for_attempt(2).as_millis(), 400);
1463 }
1464
1465 #[test]
1466 fn test_retry_config_max_delay_cap() {
1467 let config = RetryConfig {
1468 max_retries: 10,
1469 base_delay_ms: 100,
1470 max_delay_ms: 1000,
1471 jitter: 0.0,
1472 };
1473
1474 assert!(config.delay_for_attempt(10).as_millis() <= 1000);
1476 }
1477
1478 #[test]
1479 fn test_circuit_breaker_state_transitions() {
1480 let mut cb = CircuitBreaker::new(30);
1481
1482 assert_eq!(cb.state, CircuitState::Closed);
1484 assert!(cb.can_execute());
1485
1486 for _ in 0..5 {
1488 cb.record_failure();
1489 }
1490 assert_eq!(cb.state, CircuitState::Open);
1491 assert!(!cb.can_execute());
1492 }
1493
1494 #[test]
1495 fn test_circuit_breaker_success_resets() {
1496 let mut cb = CircuitBreaker::new(30);
1497
1498 for _ in 0..3 {
1500 cb.record_failure();
1501 }
1502 assert_eq!(cb.failure_count, 3);
1503
1504 cb.record_success();
1506 assert_eq!(cb.failure_count, 0);
1507 assert_eq!(cb.state, CircuitState::Closed);
1508 }
1509
1510 #[test]
1511 fn test_token_budget_tracking() {
1512 let mut budget = TokenBudget::new(1000, 0.002); assert_eq!(budget.remaining(), 1000);
1515 assert!(budget.can_spend(500));
1516
1517 budget.record_usage(300);
1518 assert_eq!(budget.remaining(), 700);
1519 assert!(budget.can_spend(500));
1520
1521 budget.record_usage(500);
1522 assert_eq!(budget.remaining(), 200);
1523 assert!(!budget.can_spend(500));
1524
1525 assert!((budget.estimated_cost() - 0.0016).abs() < 0.0001);
1527 }
1528
1529 #[test]
1530 fn test_provider_fallback_chain_creation() {
1531 let configs = vec![
1532 LLMConfig {
1533 provider: LLMProvider::LiteLLM,
1534 endpoint: "http://localhost:4000".to_string(),
1535 model: "gpt-4o".to_string(),
1536 api_key: Some("key1".to_string()),
1537 timeout_secs: 30,
1538 system_prompt: crate::config::default_system_prompt(),
1539 token_budget: None,
1540 retry_max: 3,
1541 retry_base_delay_ms: 100,
1542 retry_max_delay_ms: 10000,
1543 },
1544 LLMConfig {
1545 provider: LLMProvider::Ollama,
1546 endpoint: "http://localhost:11434".to_string(),
1547 model: "llama3.1".to_string(),
1548 api_key: None,
1549 timeout_secs: 30,
1550 system_prompt: crate::config::default_system_prompt(),
1551 token_budget: None,
1552 retry_max: 3,
1553 retry_base_delay_ms: 100,
1554 retry_max_delay_ms: 10000,
1555 },
1556 ];
1557
1558 let chain = ProviderFallbackChain::new(configs);
1559 assert_eq!(chain.provider_names(), vec!["LiteLLM", "Ollama"]);
1560 }
1561
1562 #[test]
1565 fn test_litellm_chat_auth_failure() {
1566 with_mockito(|mut server| async move {
1567 let mock = server
1568 .mock("POST", "/v1/chat/completions")
1569 .with_status(401)
1570 .with_header("content-type", "application/json")
1571 .with_body(r#"{"error": "Unauthorized"}"#)
1572 .create();
1573
1574 let config = LLMConfig {
1575 provider: LLMProvider::LiteLLM,
1576 endpoint: server.url(),
1577 model: "gpt-4o-mini".to_string(),
1578 api_key: Some("bad-key".to_string()),
1579 timeout_secs: 30,
1580 system_prompt: crate::config::default_system_prompt(),
1581 token_budget: None,
1582 retry_max: 3,
1583 retry_base_delay_ms: 100,
1584 retry_max_delay_ms: 10000,
1585 };
1586
1587 let client =
1588 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1589 let err = client.chat(make_chat_messages()).await.unwrap_err();
1590
1591 assert!(matches!(err, LLMError::AuthFailed));
1592 mock.assert();
1593 });
1594 }
1595
1596 #[test]
1597 fn test_litellm_chat_rate_limit() {
1598 with_mockito(|mut server| async move {
1599 let mock = server
1600 .mock("POST", "/v1/chat/completions")
1601 .with_status(429)
1602 .with_header("content-type", "application/json")
1603 .with_body(r#"{"error": "Rate limit exceeded"}"#)
1604 .create();
1605
1606 let config = LLMConfig {
1607 provider: LLMProvider::LiteLLM,
1608 endpoint: server.url(),
1609 model: "gpt-4o-mini".to_string(),
1610 api_key: Some("test-key".to_string()),
1611 timeout_secs: 30,
1612 system_prompt: crate::config::default_system_prompt(),
1613 token_budget: None,
1614 retry_max: 0,
1615 retry_base_delay_ms: 100,
1616 retry_max_delay_ms: 10000,
1617 };
1618
1619 let client =
1620 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1621 let err = client.chat(make_chat_messages()).await.unwrap_err();
1622
1623 assert!(matches!(err, LLMError::RateLimited));
1624 mock.assert();
1625 });
1626 }
1627
1628 #[test]
1629 fn test_litellm_chat_server_error() {
1630 with_mockito(|mut server| async move {
1631 let mock = server
1632 .mock("POST", "/v1/chat/completions")
1633 .with_status(500)
1634 .with_header("content-type", "application/json")
1635 .with_body(r#"{"error": "Internal server error"}"#)
1636 .create();
1637
1638 let config = LLMConfig {
1639 provider: LLMProvider::LiteLLM,
1640 endpoint: server.url(),
1641 model: "gpt-4o-mini".to_string(),
1642 api_key: Some("test-key".to_string()),
1643 timeout_secs: 30,
1644 system_prompt: crate::config::default_system_prompt(),
1645 token_budget: None,
1646 retry_max: 0,
1647 retry_base_delay_ms: 100,
1648 retry_max_delay_ms: 10000,
1649 };
1650
1651 let client =
1652 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1653 let err = client.chat(make_chat_messages()).await.unwrap_err();
1654
1655 assert!(matches!(err, LLMError::RequestFailed(_)));
1656 assert!(format!("{}", err).contains("500"));
1657 mock.assert();
1658 });
1659 }
1660
1661 #[test]
1662 fn test_litellm_chat_invalid_json() {
1663 with_mockito(|mut server| async move {
1664 let mock = server
1665 .mock("POST", "/v1/chat/completions")
1666 .with_status(200)
1667 .with_header("content-type", "application/json")
1668 .with_body("not-json")
1669 .create();
1670
1671 let config = LLMConfig {
1672 provider: LLMProvider::LiteLLM,
1673 endpoint: server.url(),
1674 model: "gpt-4o-mini".to_string(),
1675 api_key: Some("test-key".to_string()),
1676 timeout_secs: 30,
1677 system_prompt: crate::config::default_system_prompt(),
1678 token_budget: None,
1679 retry_max: 0,
1680 retry_base_delay_ms: 100,
1681 retry_max_delay_ms: 10000,
1682 };
1683
1684 let client =
1685 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1686 let err = client.chat(make_chat_messages()).await.unwrap_err();
1687
1688 assert!(matches!(err, LLMError::InvalidResponse(_)));
1689 mock.assert();
1690 });
1691 }
1692
1693 #[test]
1696 fn test_openrouter_chat_success() {
1697 with_mockito(|mut server| async move {
1698 let mock = server
1699 .mock("POST", "/v1/chat/completions")
1700 .with_status(200)
1701 .with_header("content-type", "application/json")
1702 .with_body(sample_chat_response_json(
1703 "anthropic/claude-sonnet-4-20250514",
1704 ))
1705 .create();
1706
1707 let config = LLMConfig {
1708 provider: LLMProvider::OpenRouter,
1709 endpoint: server.url(),
1710 model: "anthropic/claude-sonnet-4-20250514".to_string(),
1711 api_key: Some("or-key".to_string()),
1712 timeout_secs: 30,
1713 system_prompt: crate::config::default_system_prompt(),
1714 token_budget: None,
1715 retry_max: 3,
1716 retry_base_delay_ms: 100,
1717 retry_max_delay_ms: 10000,
1718 };
1719
1720 let client =
1721 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
1722 let response = client.chat(make_chat_messages()).await.unwrap();
1723
1724 assert_eq!(response.model, "anthropic/claude-sonnet-4-20250514");
1725 assert_eq!(response.choices[0].message.content, "Hi there!");
1726 mock.assert();
1727 });
1728 }
1729
1730 #[test]
1731 fn test_openrouter_chat_auth_failure() {
1732 with_mockito(|mut server| async move {
1733 let mock = server
1734 .mock("POST", "/v1/chat/completions")
1735 .with_status(401)
1736 .with_header("content-type", "application/json")
1737 .with_body(r#"{"error": "Unauthorized"}"#)
1738 .create();
1739
1740 let config = LLMConfig {
1741 provider: LLMProvider::OpenRouter,
1742 endpoint: server.url(),
1743 model: "anthropic/claude-sonnet-4-20250514".to_string(),
1744 api_key: Some("bad-key".to_string()),
1745 timeout_secs: 30,
1746 system_prompt: crate::config::default_system_prompt(),
1747 token_budget: None,
1748 retry_max: 3,
1749 retry_base_delay_ms: 100,
1750 retry_max_delay_ms: 10000,
1751 };
1752
1753 let client =
1754 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
1755 let err = client.chat(make_chat_messages()).await.unwrap_err();
1756
1757 assert!(matches!(err, LLMError::AuthFailed));
1758 mock.assert();
1759 });
1760 }
1761
1762 #[test]
1763 fn test_openrouter_chat_rate_limit() {
1764 with_mockito(|mut server| async move {
1765 let mock = server
1766 .mock("POST", "/v1/chat/completions")
1767 .with_status(429)
1768 .with_header("content-type", "application/json")
1769 .with_body(r#"{"error": "Rate limited"}"#)
1770 .create();
1771
1772 let config = LLMConfig {
1773 provider: LLMProvider::OpenRouter,
1774 endpoint: server.url(),
1775 model: "anthropic/claude-sonnet-4-20250514".to_string(),
1776 api_key: Some("or-key".to_string()),
1777 timeout_secs: 30,
1778 system_prompt: crate::config::default_system_prompt(),
1779 token_budget: None,
1780 retry_max: 0, retry_base_delay_ms: 100,
1782 retry_max_delay_ms: 10000,
1783 };
1784
1785 let client =
1786 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
1787 let err = client.chat(make_chat_messages()).await.unwrap_err();
1788
1789 assert!(matches!(err, LLMError::RateLimited));
1790 mock.assert();
1791 });
1792 }
1793
1794 #[test]
1795 fn test_openrouter_chat_server_error() {
1796 with_mockito(|mut server| async move {
1797 let mock = server
1798 .mock("POST", "/v1/chat/completions")
1799 .with_status(500)
1800 .with_header("content-type", "application/json")
1801 .with_body(r#"{"error": "Internal error"}"#)
1802 .create();
1803
1804 let config = LLMConfig {
1805 provider: LLMProvider::OpenRouter,
1806 endpoint: server.url(),
1807 model: "anthropic/claude-sonnet-4-20250514".to_string(),
1808 api_key: Some("or-key".to_string()),
1809 timeout_secs: 30,
1810 system_prompt: crate::config::default_system_prompt(),
1811 token_budget: None,
1812 retry_max: 0, retry_base_delay_ms: 100,
1814 retry_max_delay_ms: 10000,
1815 };
1816
1817 let client =
1818 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
1819 let err = client.chat(make_chat_messages()).await.unwrap_err();
1820
1821 assert!(matches!(err, LLMError::RequestFailed(_)));
1822 assert!(format!("{}", err).contains("500"));
1823 mock.assert();
1824 });
1825 }
1826
1827 #[test]
1828 fn test_openrouter_chat_invalid_json() {
1829 with_mockito(|mut server| async move {
1830 let mock = server
1831 .mock("POST", "/v1/chat/completions")
1832 .with_status(200)
1833 .with_header("content-type", "application/json")
1834 .with_body("not-json")
1835 .create();
1836
1837 let config = LLMConfig {
1838 provider: LLMProvider::OpenRouter,
1839 endpoint: server.url(),
1840 model: "anthropic/claude-sonnet-4-20250514".to_string(),
1841 api_key: Some("or-key".to_string()),
1842 timeout_secs: 30,
1843 system_prompt: crate::config::default_system_prompt(),
1844 token_budget: None,
1845 retry_max: 0, retry_base_delay_ms: 100,
1847 retry_max_delay_ms: 10000,
1848 };
1849
1850 let client =
1851 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
1852 let err = client.chat(make_chat_messages()).await.unwrap_err();
1853
1854 assert!(matches!(err, LLMError::InvalidResponse(_)));
1855 mock.assert();
1856 });
1857 }
1858
1859 #[test]
1862 fn test_openai_chat_success() {
1863 with_mockito(|mut server| async move {
1864 let mock = server
1865 .mock("POST", "/v1/chat/completions")
1866 .with_status(200)
1867 .with_header("content-type", "application/json")
1868 .with_body(sample_chat_response_json("gpt-4o"))
1869 .create();
1870
1871 let config = LLMConfig {
1872 provider: LLMProvider::OpenAI,
1873 endpoint: server.url(),
1874 model: "gpt-4o".to_string(),
1875 api_key: Some("sk-test".to_string()),
1876 timeout_secs: 60,
1877 system_prompt: crate::config::default_system_prompt(),
1878 token_budget: None,
1879 retry_max: 3,
1880 retry_base_delay_ms: 100,
1881 retry_max_delay_ms: 10000,
1882 };
1883
1884 let client =
1885 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
1886 let response = client.chat(make_chat_messages()).await.unwrap();
1887
1888 assert_eq!(response.model, "gpt-4o");
1889 assert_eq!(response.choices[0].message.content, "Hi there!");
1890 mock.assert();
1891 });
1892 }
1893
1894 #[test]
1895 fn test_openai_chat_auth_failure() {
1896 with_mockito(|mut server| async move {
1897 let mock = server
1898 .mock("POST", "/v1/chat/completions")
1899 .with_status(401)
1900 .with_header("content-type", "application/json")
1901 .with_body(r#"{"error": "Unauthorized"}"#)
1902 .create();
1903
1904 let config = LLMConfig {
1905 provider: LLMProvider::OpenAI,
1906 endpoint: server.url(),
1907 model: "gpt-4o".to_string(),
1908 api_key: Some("bad-key".to_string()),
1909 timeout_secs: 30,
1910 system_prompt: crate::config::default_system_prompt(),
1911 token_budget: None,
1912 retry_max: 3,
1913 retry_base_delay_ms: 100,
1914 retry_max_delay_ms: 10000,
1915 };
1916
1917 let client =
1918 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
1919 let err = client.chat(make_chat_messages()).await.unwrap_err();
1920
1921 assert!(matches!(err, LLMError::AuthFailed));
1922 mock.assert();
1923 });
1924 }
1925
1926 #[test]
1927 fn test_openai_chat_rate_limit() {
1928 with_mockito(|mut server| async move {
1929 let mock = server
1930 .mock("POST", "/v1/chat/completions")
1931 .with_status(429)
1932 .with_header("content-type", "application/json")
1933 .with_body(r#"{"error": "Rate limited"}"#)
1934 .create();
1935
1936 let config = LLMConfig {
1937 provider: LLMProvider::OpenAI,
1938 endpoint: server.url(),
1939 model: "gpt-4o".to_string(),
1940 api_key: Some("sk-test".to_string()),
1941 timeout_secs: 30,
1942 system_prompt: crate::config::default_system_prompt(),
1943 token_budget: None,
1944 retry_max: 0, retry_base_delay_ms: 100,
1946 retry_max_delay_ms: 10000,
1947 };
1948
1949 let client =
1950 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
1951 let err = client.chat(make_chat_messages()).await.unwrap_err();
1952
1953 assert!(matches!(err, LLMError::RateLimited));
1954 mock.assert();
1955 });
1956 }
1957
1958 #[test]
1959 fn test_openai_chat_server_error() {
1960 with_mockito(|mut server| async move {
1961 let mock = server
1962 .mock("POST", "/v1/chat/completions")
1963 .with_status(500)
1964 .with_header("content-type", "application/json")
1965 .with_body(r#"{"error": "Internal error"}"#)
1966 .create();
1967
1968 let config = LLMConfig {
1969 provider: LLMProvider::OpenAI,
1970 endpoint: server.url(),
1971 model: "gpt-4o".to_string(),
1972 api_key: Some("sk-test".to_string()),
1973 timeout_secs: 30,
1974 system_prompt: crate::config::default_system_prompt(),
1975 token_budget: None,
1976 retry_max: 0, retry_base_delay_ms: 100,
1978 retry_max_delay_ms: 10000,
1979 };
1980
1981 let client =
1982 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
1983 let err = client.chat(make_chat_messages()).await.unwrap_err();
1984
1985 assert!(matches!(err, LLMError::RequestFailed(_)));
1986 assert!(format!("{}", err).contains("500"));
1987 mock.assert();
1988 });
1989 }
1990
1991 #[test]
1992 fn test_openai_chat_invalid_json() {
1993 with_mockito(|mut server| async move {
1994 let mock = server
1995 .mock("POST", "/v1/chat/completions")
1996 .with_status(200)
1997 .with_header("content-type", "application/json")
1998 .with_body("not-json")
1999 .create();
2000
2001 let config = LLMConfig {
2002 provider: LLMProvider::OpenAI,
2003 endpoint: server.url(),
2004 model: "gpt-4o".to_string(),
2005 api_key: Some("sk-test".to_string()),
2006 timeout_secs: 30,
2007 system_prompt: crate::config::default_system_prompt(),
2008 token_budget: None,
2009 retry_max: 0, retry_base_delay_ms: 100,
2011 retry_max_delay_ms: 10000,
2012 };
2013
2014 let client =
2015 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
2016 let err = client.chat(make_chat_messages()).await.unwrap_err();
2017
2018 assert!(matches!(err, LLMError::InvalidResponse(_)));
2019 mock.assert();
2020 });
2021 }
2022
2023 #[test]
2026 fn test_ollama_chat_success() {
2027 with_mockito(|mut server| async move {
2028 let mock = server
2029 .mock("POST", "/api/chat")
2030 .with_status(200)
2031 .with_header("content-type", "application/json")
2032 .with_body(sample_ollama_response_json("llama3.1"))
2033 .create();
2034
2035 let config = LLMConfig {
2036 provider: LLMProvider::Ollama,
2037 endpoint: server.url(),
2038 model: "llama3.1".to_string(),
2039 api_key: None,
2040 timeout_secs: 30,
2041 system_prompt: crate::config::default_system_prompt(),
2042 token_budget: None,
2043 retry_max: 3,
2044 retry_base_delay_ms: 100,
2045 retry_max_delay_ms: 10000,
2046 };
2047
2048 let client = OllamaClient::new(&config).unwrap();
2049 let response = client.chat(make_chat_messages()).await.unwrap();
2050
2051 assert_eq!(response.model, "llama3.1");
2052 assert_eq!(response.choices[0].message.content, "Hi there!");
2053 assert_eq!(response.choices[0].finish_reason, Some("stop".to_string()));
2054 mock.assert();
2055 });
2056 }
2057
2058 #[test]
2059 fn test_ollama_chat_server_error() {
2060 with_mockito(|mut server| async move {
2061 let mock = server
2062 .mock("POST", "/api/chat")
2063 .with_status(500)
2064 .with_header("content-type", "application/json")
2065 .with_body(r#"{"error": "Model not loaded"}"#)
2066 .create();
2067
2068 let config = LLMConfig {
2069 provider: LLMProvider::Ollama,
2070 endpoint: server.url(),
2071 model: "llama3.1".to_string(),
2072 api_key: None,
2073 timeout_secs: 30,
2074 system_prompt: crate::config::default_system_prompt(),
2075 token_budget: None,
2076 retry_max: 3,
2077 retry_base_delay_ms: 100,
2078 retry_max_delay_ms: 10000,
2079 };
2080
2081 let client = OllamaClient::new(&config).unwrap();
2082 let err = client.chat(make_chat_messages()).await.unwrap_err();
2083
2084 assert!(matches!(err, LLMError::RequestFailed(_)));
2085 mock.assert();
2086 });
2087 }
2088
2089 #[test]
2090 fn test_ollama_chat_invalid_json() {
2091 with_mockito(|mut server| async move {
2092 let mock = server
2093 .mock("POST", "/api/chat")
2094 .with_status(200)
2095 .with_header("content-type", "application/json")
2096 .with_body("not-json")
2097 .create();
2098
2099 let config = LLMConfig {
2100 provider: LLMProvider::Ollama,
2101 endpoint: server.url(),
2102 model: "llama3.1".to_string(),
2103 api_key: None,
2104 timeout_secs: 30,
2105 system_prompt: crate::config::default_system_prompt(),
2106 token_budget: None,
2107 retry_max: 3,
2108 retry_base_delay_ms: 100,
2109 retry_max_delay_ms: 10000,
2110 };
2111
2112 let client = OllamaClient::new(&config).unwrap();
2113 let err = client.chat(make_chat_messages()).await.unwrap_err();
2114
2115 assert!(matches!(err, LLMError::InvalidResponse(_)));
2116 mock.assert();
2117 });
2118 }
2119
2120 #[test]
2121 fn test_ollama_chat_auth_failure() {
2122 with_mockito(|mut server| async move {
2123 let mock = server
2124 .mock("POST", "/api/chat")
2125 .with_status(401)
2126 .with_header("content-type", "application/json")
2127 .with_body(r#"{"error": "Unauthorized"}"#)
2128 .create();
2129
2130 let config = LLMConfig {
2131 provider: LLMProvider::Ollama,
2132 endpoint: server.url(),
2133 model: "llama3.1".to_string(),
2134 api_key: Some("bad-key".to_string()),
2135 timeout_secs: 30,
2136 system_prompt: crate::config::default_system_prompt(),
2137 token_budget: None,
2138 retry_max: 3,
2139 retry_base_delay_ms: 100,
2140 retry_max_delay_ms: 10000,
2141 };
2142
2143 let client = OllamaClient::new(&config).unwrap();
2144 let err = client.chat(make_chat_messages()).await.unwrap_err();
2145
2146 assert!(matches!(err, LLMError::AuthFailed));
2147 mock.assert();
2148 });
2149 }
2150
2151 #[test]
2154 fn test_create_client_factory_litellm() {
2155 let config = LLMConfig {
2156 provider: LLMProvider::LiteLLM,
2157 endpoint: "http://localhost:4000".to_string(),
2158 model: "gpt-4o-mini".to_string(),
2159 api_key: Some("test".to_string()),
2160 timeout_secs: 30,
2161 system_prompt: crate::config::default_system_prompt(),
2162 token_budget: None,
2163 retry_max: 3,
2164 retry_base_delay_ms: 100,
2165 retry_max_delay_ms: 10000,
2166 };
2167
2168 let client = create_client(&config).unwrap();
2169 assert_eq!(client.provider_name(), "litellm");
2170 assert_eq!(client.model(), "gpt-4o-mini");
2171 }
2172
2173 #[test]
2174 fn test_ollama_client_creation() {
2175 let config = LLMConfig {
2176 provider: LLMProvider::Ollama,
2177 endpoint: "http://localhost:11434".to_string(),
2178 model: "llama3.1".to_string(),
2179 api_key: None,
2180 timeout_secs: 30,
2181 system_prompt: crate::config::default_system_prompt(),
2182 token_budget: None,
2183 retry_max: 3,
2184 retry_base_delay_ms: 100,
2185 retry_max_delay_ms: 10000,
2186 };
2187
2188 let client = OllamaClient::new(&config).unwrap();
2189 assert_eq!(client.provider_name(), "ollama");
2190 assert_eq!(client.model(), "llama3.1");
2191 }
2192
2193 #[test]
2194 fn test_openai_client_creation() {
2195 let config = LLMConfig {
2196 provider: LLMProvider::OpenAI,
2197 endpoint: String::new(),
2198 model: "gpt-4o".to_string(),
2199 api_key: Some("sk-test".to_string()),
2200 timeout_secs: 60,
2201 system_prompt: crate::config::default_system_prompt(),
2202 token_budget: None,
2203 retry_max: 3,
2204 retry_base_delay_ms: 100,
2205 retry_max_delay_ms: 10000,
2206 };
2207
2208 let client =
2209 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
2210 assert_eq!(client.provider_name(), "openai");
2211 assert_eq!(client.model(), "gpt-4o");
2212 }
2213
2214 #[test]
2215 fn test_openrouter_client_creation() {
2216 let config = LLMConfig {
2217 provider: LLMProvider::OpenRouter,
2218 endpoint: String::new(),
2219 model: "anthropic/claude-sonnet-4-20250514".to_string(),
2220 api_key: Some("sk-test".to_string()),
2221 timeout_secs: 30,
2222 system_prompt: crate::config::default_system_prompt(),
2223 token_budget: None,
2224 retry_max: 3,
2225 retry_base_delay_ms: 100,
2226 retry_max_delay_ms: 10000,
2227 };
2228
2229 let client =
2230 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
2231 assert_eq!(client.provider_name(), "openrouter");
2232 assert_eq!(client.model(), "anthropic/claude-sonnet-4-20250514");
2233 }
2234
2235 #[test]
2236 fn test_multi_model_manager_empty() {
2237 let manager = MultiModelManager::new(vec![]).unwrap();
2238 assert_eq!(manager.client_count(), 0);
2239 assert!(manager.get_client(0).is_none());
2240 }
2241
2242 #[test]
2243 fn test_multi_model_manager_single() {
2244 let config = LLMConfig {
2245 provider: LLMProvider::LiteLLM,
2246 endpoint: "http://localhost:4000".to_string(),
2247 model: "gpt-4o-mini".to_string(),
2248 api_key: Some("test".to_string()),
2249 timeout_secs: 30,
2250 system_prompt: crate::config::default_system_prompt(),
2251 token_budget: None,
2252 retry_max: 3,
2253 retry_base_delay_ms: 100,
2254 retry_max_delay_ms: 10000,
2255 };
2256
2257 let manager = MultiModelManager::new(vec![config]).unwrap();
2258 assert_eq!(manager.client_count(), 1);
2259 assert!(manager.get_client(0).is_some());
2260 assert_eq!(manager.get_client(0).unwrap().provider_name(), "litellm");
2261 }
2262
2263 #[test]
2264 fn test_multi_model_manager_multiple() {
2265 let configs = vec![
2266 LLMConfig {
2267 provider: LLMProvider::LiteLLM,
2268 endpoint: "http://localhost:4000".to_string(),
2269 model: "gpt-4o-mini".to_string(),
2270 api_key: Some("test".to_string()),
2271 timeout_secs: 30,
2272 system_prompt: crate::config::default_system_prompt(),
2273 token_budget: None,
2274 retry_max: 3,
2275 retry_base_delay_ms: 100,
2276 retry_max_delay_ms: 10000,
2277 },
2278 LLMConfig {
2279 provider: LLMProvider::Ollama,
2280 endpoint: "http://localhost:11434".to_string(),
2281 model: "llama3.1".to_string(),
2282 api_key: None,
2283 timeout_secs: 60,
2284 system_prompt: crate::config::default_system_prompt(),
2285 token_budget: None,
2286 retry_max: 3,
2287 retry_base_delay_ms: 100,
2288 retry_max_delay_ms: 10000,
2289 },
2290 ];
2291
2292 let manager = MultiModelManager::new(configs).unwrap();
2293 assert_eq!(manager.client_count(), 2);
2294 assert_eq!(manager.get_client(0).unwrap().provider_name(), "litellm");
2295 assert_eq!(manager.get_client(1).unwrap().provider_name(), "ollama");
2296 }
2297
2298 #[test]
2299 fn test_multi_model_next_client_round_robin() {
2300 let configs = vec![
2301 LLMConfig {
2302 provider: LLMProvider::LiteLLM,
2303 endpoint: "http://localhost:4000".to_string(),
2304 model: "gpt-4o-mini".to_string(),
2305 api_key: Some("test".to_string()),
2306 timeout_secs: 30,
2307 system_prompt: crate::config::default_system_prompt(),
2308 token_budget: None,
2309 retry_max: 3,
2310 retry_base_delay_ms: 100,
2311 retry_max_delay_ms: 10000,
2312 },
2313 LLMConfig {
2314 provider: LLMProvider::Ollama,
2315 endpoint: "http://localhost:11434".to_string(),
2316 model: "llama3.1".to_string(),
2317 api_key: None,
2318 timeout_secs: 60,
2319 system_prompt: crate::config::default_system_prompt(),
2320 token_budget: None,
2321 retry_max: 3,
2322 retry_base_delay_ms: 100,
2323 retry_max_delay_ms: 10000,
2324 },
2325 ];
2326
2327 let manager = MultiModelManager::new(configs).unwrap();
2328 let next = manager.next_client(0).unwrap();
2330 assert_eq!(next.provider_name(), "ollama");
2331 let next = manager.next_client(1).unwrap();
2333 assert_eq!(next.provider_name(), "litellm");
2334 }
2335
2336 #[test]
2337 fn test_chat_request_serialization() {
2338 let request = ChatRequest {
2339 model: "gpt-4o-mini".to_string(),
2340 messages: vec![
2341 ChatMessage {
2342 role: "system".to_string(),
2343 content: "You are a helpful assistant.".to_string(),
2344 },
2345 ChatMessage {
2346 role: "user".to_string(),
2347 content: "Hello!".to_string(),
2348 },
2349 ],
2350 temperature: Some(0.7),
2351 max_tokens: Some(2048),
2352 stream: None,
2353 tools: None,
2354 tool_choice: None,
2355 };
2356
2357 let json = serde_json::to_string(&request).unwrap();
2358 assert!(json.contains("gpt-4o-mini"));
2359 assert!(json.contains("system"));
2360 assert!(json.contains("user"));
2361 assert!(json.contains("Hello!"));
2362 assert!(json.contains("0.7"));
2363 assert!(!json.contains("stream"));
2365 }
2366
2367 #[test]
2368 fn test_chat_response_deserialization() {
2369 let json = r#"{
2370 "id": "chat-123",
2371 "object": "chat.completion",
2372 "created": 1717000000,
2373 "model": "gpt-4o-mini",
2374 "choices": [
2375 {
2376 "index": 0,
2377 "message": {
2378 "role": "assistant",
2379 "content": "Hello! How can I help you?"
2380 },
2381 "finish_reason": "stop"
2382 }
2383 ],
2384 "usage": {
2385 "prompt_tokens": 10,
2386 "completion_tokens": 20,
2387 "total_tokens": 30
2388 }
2389 }"#;
2390
2391 let response: ChatResponse = serde_json::from_str(json).unwrap();
2392 assert_eq!(response.id, "chat-123");
2393 assert_eq!(response.model, "gpt-4o-mini");
2394 assert_eq!(response.choices.len(), 1);
2395 assert_eq!(response.choices[0].message.role, "assistant");
2396 assert_eq!(
2397 response.choices[0].message.content,
2398 "Hello! How can I help you?"
2399 );
2400 assert_eq!(response.usage.unwrap().total_tokens, 30);
2401 }
2402
2403 #[test]
2404 fn test_multi_model_manager_new_invalid_config() {
2405 let configs = vec![LLMConfig {
2407 provider: LLMProvider::LiteLLM,
2408 endpoint: String::new(), model: "gpt-4o-mini".to_string(),
2410 api_key: None,
2411 timeout_secs: 30,
2412 system_prompt: crate::config::default_system_prompt(),
2413 token_budget: None,
2414 retry_max: 3,
2415 retry_base_delay_ms: 100,
2416 retry_max_delay_ms: 10000,
2417 }];
2418
2419 let result = MultiModelManager::new(configs);
2420 assert!(result.is_ok());
2423 let manager = result.unwrap();
2424 assert_eq!(manager.client_count(), 1);
2425 }
2426
2427 #[test]
2428 fn test_create_client_all_providers() {
2429 let test_cases = vec![
2430 (LLMProvider::LiteLLM, "litellm"),
2431 (LLMProvider::OpenRouter, "openrouter"),
2432 (LLMProvider::Ollama, "ollama"),
2433 (LLMProvider::OpenAI, "openai"),
2434 ];
2435
2436 for (provider, expected_name) in test_cases {
2437 let config = LLMConfig {
2438 provider,
2439 endpoint: "http://localhost:4000".to_string(),
2440 model: "test-model".to_string(),
2441 api_key: Some("test-key".to_string()),
2442 timeout_secs: 30,
2443 system_prompt: crate::config::default_system_prompt(),
2444 token_budget: None,
2445 retry_max: 3,
2446 retry_base_delay_ms: 100,
2447 retry_max_delay_ms: 10000,
2448 };
2449
2450 let client = create_client(&config).unwrap();
2451 assert_eq!(client.provider_name(), expected_name);
2452 assert_eq!(client.model(), "test-model");
2453 }
2454 }
2455
2456 #[test]
2457 fn test_llm_error_display() {
2458 let err = LLMError::RequestFailed("timeout".to_string());
2459 assert_eq!(format!("{}", err), "Request failed: timeout");
2460
2461 let err = LLMError::AuthFailed;
2462 assert_eq!(format!("{}", err), "Authentication failed");
2463
2464 let err = LLMError::RateLimited;
2465 assert_eq!(format!("{}", err), "Rate limit exceeded");
2466
2467 let err = LLMError::InvalidResponse("bad json".to_string());
2468 assert_eq!(format!("{}", err), "Invalid response: bad json");
2469
2470 let err = LLMError::ProviderNotSupported("custom".to_string());
2471 assert_eq!(format!("{}", err), "Provider not supported: custom");
2472 }
2473
2474 #[test]
2475 fn test_llm_error_is_debug() {
2476 let err = LLMError::RequestFailed("test".to_string());
2477 let debug = format!("{:?}", err);
2478 assert!(debug.contains("RequestFailed"));
2479 }
2480
2481 #[test]
2482 fn test_llm_error_is_send() {
2483 fn check_send<T: Send>() {}
2484 check_send::<LLMError>();
2485 }
2486
2487 #[test]
2488 fn test_llm_error_is_sync() {
2489 fn check_sync<T: Sync>() {}
2490 check_sync::<LLMError>();
2491 }
2492
2493 #[test]
2494 fn test_litellm_client_new_with_invalid_timeout() {
2495 let config = LLMConfig {
2497 provider: LLMProvider::LiteLLM,
2498 endpoint: "http://localhost:4000".to_string(),
2499 model: "gpt-4o-mini".to_string(),
2500 api_key: Some("test".to_string()),
2501 timeout_secs: u64::MAX,
2502 system_prompt: crate::config::default_system_prompt(),
2503 token_budget: None,
2504 retry_max: 3,
2505 retry_base_delay_ms: 100,
2506 retry_max_delay_ms: 10000,
2507 };
2508
2509 let result = OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM);
2510 assert!(result.is_ok());
2511 }
2512
2513 #[test]
2514 fn test_openai_client_with_custom_endpoint() {
2515 let config = LLMConfig {
2516 provider: LLMProvider::OpenAI,
2517 endpoint: "https://custom.openai.example.com".to_string(),
2518 model: "gpt-4o".to_string(),
2519 api_key: Some("sk-test".to_string()),
2520 timeout_secs: 30,
2521 system_prompt: crate::config::default_system_prompt(),
2522 token_budget: None,
2523 retry_max: 3,
2524 retry_base_delay_ms: 100,
2525 retry_max_delay_ms: 10000,
2526 };
2527
2528 let client =
2529 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
2530 assert_eq!(client.provider_name(), "openai");
2531 assert_eq!(client.model(), "gpt-4o");
2532 }
2533
2534 #[test]
2535 fn test_openrouter_client_with_custom_endpoint() {
2536 let config = LLMConfig {
2537 provider: LLMProvider::OpenRouter,
2538 endpoint: "https://custom.openrouter.example.com".to_string(),
2539 model: "anthropic/claude-sonnet-4-20250514".to_string(),
2540 api_key: Some("or-key".to_string()),
2541 timeout_secs: 30,
2542 system_prompt: crate::config::default_system_prompt(),
2543 token_budget: None,
2544 retry_max: 3,
2545 retry_base_delay_ms: 100,
2546 retry_max_delay_ms: 10000,
2547 };
2548
2549 let client =
2550 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
2551 assert_eq!(client.provider_name(), "openrouter");
2552 assert_eq!(client.model(), "anthropic/claude-sonnet-4-20250514");
2553 }
2554
2555 #[test]
2556 fn test_ollama_client_with_auth() {
2557 let config = LLMConfig {
2559 provider: LLMProvider::Ollama,
2560 endpoint: "http://localhost:11434".to_string(),
2561 model: "llama3.1".to_string(),
2562 api_key: Some("some-key".to_string()),
2563 timeout_secs: 30,
2564 system_prompt: crate::config::default_system_prompt(),
2565 token_budget: None,
2566 retry_max: 3,
2567 retry_base_delay_ms: 100,
2568 retry_max_delay_ms: 10000,
2569 };
2570
2571 let client = OllamaClient::new(&config).unwrap();
2572 assert_eq!(client.provider_name(), "ollama");
2573 assert_eq!(client.model(), "llama3.1");
2574 }
2575
2576 #[test]
2577 fn test_chat_request_no_temperature() {
2578 let request = ChatRequest {
2579 model: "gpt-4o-mini".to_string(),
2580 messages: vec![ChatMessage {
2581 role: "user".to_string(),
2582 content: "Hello!".to_string(),
2583 }],
2584 temperature: None,
2585 max_tokens: None,
2586 stream: None,
2587 tools: None,
2588 tool_choice: None,
2589 };
2590
2591 let json = serde_json::to_string(&request).unwrap();
2592 assert!(json.contains("gpt-4o-mini"));
2593 assert!(json.contains("Hello!"));
2594 assert!(!json.contains("temperature"));
2596 assert!(!json.contains("max_tokens"));
2597 assert!(!json.contains("stream"));
2598 }
2599
2600 #[test]
2601 fn test_chat_response_deserialization_no_usage() {
2602 let json = r#"{
2603 "id": "chat-456",
2604 "object": "chat.completion",
2605 "created": 1717000001,
2606 "model": "gpt-4o",
2607 "choices": [
2608 {
2609 "index": 0,
2610 "message": {
2611 "role": "assistant",
2612 "content": "Sure!"
2613 },
2614 "finish_reason": "stop"
2615 }
2616 ]
2617 }"#;
2618
2619 let response: ChatResponse = serde_json::from_str(json).unwrap();
2620 assert_eq!(response.id, "chat-456");
2621 assert_eq!(response.model, "gpt-4o");
2622 assert_eq!(response.choices.len(), 1);
2623 assert!(response.usage.is_none());
2624 }
2625
2626 #[test]
2627 fn test_chat_response_deserialization_multiple_choices() {
2628 let json = r#"{
2629 "id": "chat-789",
2630 "object": "chat.completion",
2631 "created": 1717000002,
2632 "model": "gpt-4o",
2633 "choices": [
2634 {
2635 "index": 0,
2636 "message": {
2637 "role": "assistant",
2638 "content": "First choice"
2639 },
2640 "finish_reason": "stop"
2641 },
2642 {
2643 "index": 1,
2644 "message": {
2645 "role": "assistant",
2646 "content": "Second choice"
2647 },
2648 "finish_reason": "stop"
2649 }
2650 ]
2651 }"#;
2652
2653 let response: ChatResponse = serde_json::from_str(json).unwrap();
2654 assert_eq!(response.choices.len(), 2);
2655 assert_eq!(response.choices[0].message.content, "First choice");
2656 assert_eq!(response.choices[1].message.content, "Second choice");
2657 }
2658
2659 #[test]
2660 fn test_llm_error_into_boxed() {
2661 let err = LLMError::AuthFailed;
2662 let boxed: Box<dyn std::error::Error> = Box::new(err);
2663 assert!(format!("{}", boxed).contains("Authentication failed"));
2664 }
2665
2666 #[test]
2667 fn test_llm_error_into_string() {
2668 let err = LLMError::RateLimited;
2669 let msg: String = err.to_string();
2670 assert_eq!(msg, "Rate limit exceeded");
2671 }
2672
2673 #[test]
2674 fn test_create_client_with_empty_api_key() {
2675 let config = LLMConfig {
2677 provider: LLMProvider::Ollama,
2678 endpoint: "http://localhost:11434".to_string(),
2679 model: "llama3.1".to_string(),
2680 api_key: None,
2681 timeout_secs: 30,
2682 system_prompt: crate::config::default_system_prompt(),
2683 token_budget: None,
2684 retry_max: 3,
2685 retry_base_delay_ms: 100,
2686 retry_max_delay_ms: 10000,
2687 };
2688
2689 let client = create_client(&config).unwrap();
2690 assert_eq!(client.provider_name(), "ollama");
2691 }
2692
2693 #[test]
2694 fn test_multi_model_manager_get_client_out_of_bounds() {
2695 let manager = MultiModelManager::new(vec![]).unwrap();
2696 assert!(manager.get_client(0).is_none());
2697 assert!(manager.get_client(100).is_none());
2698 assert!(manager.get_client(usize::MAX).is_none());
2699 }
2700
2701 #[test]
2702 fn test_multi_model_next_client_empty() {
2703 let manager = MultiModelManager::new(vec![]).unwrap();
2704 assert!(manager.next_client(0).is_none());
2705 }
2706
2707 #[test]
2708 fn test_multi_model_next_client_single() {
2709 let config = LLMConfig {
2710 provider: LLMProvider::LiteLLM,
2711 endpoint: "http://localhost:4000".to_string(),
2712 model: "gpt-4o-mini".to_string(),
2713 api_key: Some("test".to_string()),
2714 timeout_secs: 30,
2715 system_prompt: crate::config::default_system_prompt(),
2716 token_budget: None,
2717 retry_max: 3,
2718 retry_base_delay_ms: 100,
2719 retry_max_delay_ms: 10000,
2720 };
2721
2722 let manager = MultiModelManager::new(vec![config]).unwrap();
2723 let next = manager.next_client(0).unwrap();
2725 assert_eq!(next.provider_name(), "litellm");
2726 }
2727}