1use futures::Stream;
7use reqwest::{Client, Response};
8use serde::{Deserialize, Serialize};
9use std::pin::Pin;
10use std::sync::Arc;
11use std::time::Duration;
12use thiserror::Error;
13use tokio::time::sleep;
14use tracing::instrument;
15
16#[derive(Debug, Clone)]
18#[allow(dead_code)]
19pub struct StreamChunk {
20 pub content: String,
21 #[allow(dead_code)]
22 pub finish_reason: Option<String>,
23}
24
25#[allow(dead_code)]
27pub type StreamResult = Pin<Box<dyn Stream<Item = Result<StreamChunk, LLMError>> + Send>>;
28
29use crate::config::{LLMConfig, LLMProvider};
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36#[non_exhaustive]
37pub enum OpenAICompatibleProvider {
38 LiteLLM,
39 OpenAI,
40 OpenRouter,
41 Generic,
43}
44
45impl OpenAICompatibleProvider {
46 pub fn default_endpoint(&self) -> &'static str {
48 match self {
49 OpenAICompatibleProvider::LiteLLM => "http://localhost:4000",
50 OpenAICompatibleProvider::OpenAI => "https://api.openai.com",
51 OpenAICompatibleProvider::OpenRouter => "https://openrouter.ai",
52 OpenAICompatibleProvider::Generic => "http://localhost:8000",
53 }
54 }
55
56 pub fn name(&self) -> &'static str {
58 match self {
59 OpenAICompatibleProvider::LiteLLM => "litellm",
60 OpenAICompatibleProvider::OpenAI => "openai",
61 OpenAICompatibleProvider::OpenRouter => "openrouter",
62 OpenAICompatibleProvider::Generic => "openai-compatible",
63 }
64 }
65
66 pub fn requires_custom_headers(&self) -> bool {
68 matches!(self, OpenAICompatibleProvider::OpenRouter)
69 }
70}
71
72#[derive(Error, Debug)]
78#[non_exhaustive]
79pub enum LLMError {
80 #[error("Request failed: {0}")]
81 RequestFailed(String),
82
83 #[error("Invalid response: {0}")]
84 InvalidResponse(String),
85
86 #[error("Rate limit exceeded")]
87 RateLimited,
88
89 #[error("Authentication failed")]
90 AuthFailed,
91
92 #[error("Provider not supported: {0}")]
93 #[allow(dead_code)]
94 ProviderNotSupported(String),
95
96 #[error("Token budget exceeded")]
97 TokenBudgetExceeded,
98
99 #[error("All providers failed after retries")]
100 AllProvidersFailed,
101
102 #[error("Circuit breaker open for provider: {0}")]
103 CircuitBreakerOpen(String),
104}
105
106#[derive(Debug, Clone)]
108pub struct RetryConfig {
109 pub max_retries: u32,
111 pub base_delay_ms: u64,
113 pub max_delay_ms: u64,
115 pub jitter: f64,
117}
118
119impl Default for RetryConfig {
120 fn default() -> Self {
121 Self {
122 max_retries: 3,
123 base_delay_ms: 100,
124 max_delay_ms: 10000,
125 jitter: 0.5,
126 }
127 }
128}
129
130impl RetryConfig {
131 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
133 use rand::Rng;
134 let exp = 2u64.pow(attempt);
135 let base = self.base_delay_ms * exp;
136 let capped = base.min(self.max_delay_ms);
137 let jitter_range = (capped as f64) * self.jitter;
138 let jitter = rand::thread_rng().gen_range(-jitter_range..=jitter_range) as u64;
139 let delay = capped.saturating_add(jitter).max(self.base_delay_ms);
140 Duration::from_millis(delay)
141 }
142}
143
144#[derive(Debug, Clone, Copy, PartialEq, Eq)]
149#[non_exhaustive]
150pub enum CircuitState {
151 Closed,
152 Open,
153 HalfOpen,
154}
155
156#[derive(Debug)]
158pub struct CircuitBreaker {
159 pub state: CircuitState,
160 pub failure_count: u32,
161 pub last_failure_time: Option<std::time::Instant>,
162 pub open_duration: Duration,
163}
164
165impl CircuitBreaker {
166 pub fn new(open_duration_secs: u64) -> Self {
167 Self {
168 state: CircuitState::Closed,
169 failure_count: 0,
170 last_failure_time: None,
171 open_duration: Duration::from_secs(open_duration_secs),
172 }
173 }
174
175 pub fn record_success(&mut self) {
176 self.failure_count = 0;
177 self.state = CircuitState::Closed;
178 }
179
180 pub fn record_failure(&mut self) {
181 self.failure_count += 1;
182 self.last_failure_time = Some(std::time::Instant::now());
183 if self.failure_count >= 5 {
184 self.state = CircuitState::Open;
185 }
186 }
187
188 pub fn can_execute(&mut self) -> bool {
189 match self.state {
190 CircuitState::Closed => true,
191 CircuitState::Open => {
192 if let Some(last) = self.last_failure_time {
193 if last.elapsed() >= self.open_duration {
194 self.state = CircuitState::HalfOpen;
195 return true;
196 }
197 }
198 false
199 }
200 CircuitState::HalfOpen => true,
201 }
202 }
203}
204
205#[derive(Debug, Clone)]
207pub struct TokenBudget {
208 pub max_tokens: u32,
210 pub used_tokens: u32,
212 pub cost_per_1k: f64,
214}
215
216#[allow(dead_code)]
217impl TokenBudget {
218 pub fn new(max_tokens: u32, cost_per_1k: f64) -> Self {
219 Self {
220 max_tokens,
221 used_tokens: 0,
222 cost_per_1k,
223 }
224 }
225
226 pub fn remaining(&self) -> u32 {
227 self.max_tokens.saturating_sub(self.used_tokens)
228 }
229
230 pub fn can_spend(&self, tokens: u32) -> bool {
231 self.remaining() >= tokens
232 }
233
234 pub fn record_usage(&mut self, tokens: u32) {
235 self.used_tokens = self.used_tokens.saturating_add(tokens);
236 }
237
238 pub fn estimated_cost(&self) -> f64 {
239 (self.used_tokens as f64 / 1000.0) * self.cost_per_1k
240 }
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct ChatMessage {
245 pub role: String,
246 pub content: String,
247}
248
249#[derive(Debug, Clone, Serialize)]
250pub struct ChatRequest {
251 pub model: String,
252 pub messages: Vec<ChatMessage>,
253 #[serde(skip_serializing_if = "Option::is_none")]
254 pub temperature: Option<f32>,
255 #[serde(skip_serializing_if = "Option::is_none")]
256 pub max_tokens: Option<u32>,
257 #[serde(skip_serializing_if = "Option::is_none")]
258 pub stream: Option<bool>,
259 #[serde(skip_serializing_if = "Option::is_none")]
260 pub tools: Option<Vec<serde_json::Value>>,
261 #[serde(skip_serializing_if = "Option::is_none")]
262 pub tool_choice: Option<String>,
263}
264
265#[derive(Debug, Clone, Deserialize)]
266pub struct ChatResponse {
267 #[allow(dead_code)]
268 pub id: String,
269 #[allow(dead_code)]
270 pub object: String,
271 #[allow(dead_code)]
272 pub created: u64,
273 #[allow(dead_code)]
274 pub model: String,
275 pub choices: Vec<Choice>,
276 #[allow(dead_code)]
277 pub usage: Option<Usage>,
278}
279
280#[derive(Debug, Clone, Deserialize)]
281pub struct ToolCallResponse {
282 #[allow(dead_code)]
283 pub id: String,
284 #[allow(dead_code)]
285 #[serde(rename = "type")]
286 pub call_type: String,
287 pub function: FunctionCall,
288}
289
290#[derive(Debug, Clone, Deserialize)]
291pub struct FunctionCall {
292 pub name: String,
293 pub arguments: String,
294}
295
296#[derive(Debug, Clone, Deserialize)]
297pub struct Choice {
298 #[allow(dead_code)]
299 pub index: u32,
300 pub message: ChatMessage,
301 #[allow(dead_code)]
302 pub finish_reason: Option<String>,
303 #[serde(default, skip_serializing_if = "Option::is_none")]
304 pub tool_calls: Option<Vec<ToolCallResponse>>,
305}
306
307#[derive(Debug, Clone, Deserialize)]
308pub struct Usage {
309 #[allow(dead_code)]
310 pub prompt_tokens: u32,
311 #[allow(dead_code)]
312 pub completion_tokens: u32,
313 #[allow(dead_code)]
314 pub total_tokens: u32,
315}
316
317#[async_trait::async_trait]
319pub trait LLMProviderTrait: Send + Sync {
320 async fn chat(&self, messages: Vec<ChatMessage>) -> Result<ChatResponse, LLMError>;
321 #[allow(dead_code)]
322 async fn chat_stream(&self, messages: Vec<ChatMessage>) -> Result<StreamResult, LLMError> {
323 let response = self.chat(messages).await?;
325 let content = response
326 .choices
327 .first()
328 .map(|c| c.message.content.clone())
329 .unwrap_or_default();
330 let finish_reason = response
331 .choices
332 .first()
333 .and_then(|c| c.finish_reason.clone());
334
335 let stream = futures::stream::once(async move {
336 Ok(StreamChunk {
337 content,
338 finish_reason,
339 })
340 });
341 Ok(Box::pin(stream))
342 }
343 fn provider_name(&self) -> &str;
344 fn model(&self) -> &str;
345}
346
347async fn handle_openai_response(response: Response) -> Result<ChatResponse, LLMError> {
349 let status = response.status();
350
351 if status.is_success() {
352 response
353 .json::<ChatResponse>()
354 .await
355 .map_err(|e| LLMError::InvalidResponse(e.to_string()))
356 } else if status == reqwest::StatusCode::UNAUTHORIZED {
357 Err(LLMError::AuthFailed)
358 } else if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
359 Err(LLMError::RateLimited)
360 } else {
361 let body = response
362 .text()
363 .await
364 .unwrap_or_else(|_| "Unknown error".to_string());
365 Err(LLMError::RequestFailed(format!("{}: {}", status, body)))
366 }
367}
368
369pub struct OpenAICompatibleClient {
372 client: Client,
373 config: LLMConfig,
374 provider: OpenAICompatibleProvider,
375 retry_config: RetryConfig,
376 circuit_breaker: std::sync::Mutex<CircuitBreaker>,
377}
378
379impl OpenAICompatibleClient {
380 pub fn new(config: &LLMConfig, provider: OpenAICompatibleProvider) -> Result<Self, LLMError> {
381 let client = Client::builder()
382 .timeout(std::time::Duration::from_secs(config.timeout_secs))
383 .build()
384 .map_err(|e| LLMError::RequestFailed(format!("Failed to create HTTP client: {}", e)))?;
385
386 let retry_config = RetryConfig {
387 max_retries: config.retry_max,
388 base_delay_ms: config.retry_base_delay_ms,
389 max_delay_ms: config.retry_max_delay_ms,
390 jitter: 0.5,
391 };
392
393 Ok(Self {
394 client,
395 config: config.clone(),
396 provider,
397 retry_config,
398 circuit_breaker: std::sync::Mutex::new(CircuitBreaker::new(30)),
399 })
400 }
401
402 async fn send_request_with_retry(
404 &self,
405 request: ChatRequest,
406 ) -> Result<ChatResponse, LLMError> {
407 let mut last_error = None;
408
409 for attempt in 0..=self.retry_config.max_retries {
410 {
412 let mut cb = self.circuit_breaker.lock().map_err(|_| {
413 LLMError::RequestFailed("Circuit breaker lock poisoned".to_string())
414 })?;
415 if !cb.can_execute() {
416 return Err(LLMError::CircuitBreakerOpen(
417 self.provider.name().to_string(),
418 ));
419 }
420 }
421
422 let result = self.send_request_inner(request.clone()).await;
423
424 match result {
425 Ok(response) => {
426 {
428 let mut cb = self.circuit_breaker.lock().map_err(|_| {
429 LLMError::RequestFailed("Circuit breaker lock poisoned".to_string())
430 })?;
431 cb.record_success();
432 }
433 return Ok(response);
434 }
435 Err(e) => {
436 {
438 let mut cb = self.circuit_breaker.lock().map_err(|_| {
439 LLMError::RequestFailed("Circuit breaker lock poisoned".to_string())
440 })?;
441 cb.record_failure();
442 }
443
444 last_error = Some(e);
445
446 if matches!(last_error, Some(LLMError::AuthFailed)) {
448 return Err(last_error.unwrap());
449 }
450
451 if attempt < self.retry_config.max_retries {
453 let delay = self.retry_config.delay_for_attempt(attempt);
454 sleep(delay).await;
455 }
456 }
457 }
458 }
459
460 Err(last_error.unwrap_or(LLMError::AllProvidersFailed))
461 }
462
463 async fn send_request_inner(&self, request: ChatRequest) -> Result<ChatResponse, LLMError> {
465 let req = self.apply_headers(self.client.post(self.endpoint()).json(&request));
466
467 let response = req
468 .send()
469 .await
470 .map_err(|e| LLMError::RequestFailed(e.to_string()))?;
471
472 handle_openai_response(response).await
473 }
474
475 fn build_request(&self, messages: Vec<ChatMessage>) -> ChatRequest {
476 ChatRequest {
477 model: self.config.model.clone(),
478 messages,
479 temperature: Some(0.7),
480 max_tokens: Some(2048),
481 stream: None,
482 tools: None,
483 tool_choice: None,
484 }
485 }
486
487 fn endpoint(&self) -> String {
488 let base = if self.config.endpoint.is_empty() {
489 self.provider.default_endpoint()
490 } else {
491 &self.config.endpoint
492 };
493 format!("{}/v1/chat/completions", base.trim_end_matches('/'))
494 }
495
496 fn apply_headers(&self, mut req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
497 if let Some(ref key) = self.config.api_key {
498 req = req.header("Authorization", format!("Bearer {}", key));
499 }
500
501 if self.provider.requires_custom_headers() {
503 req = req
504 .header("HTTP-Referer", "https://github.com/egkristi/RavenClaws")
505 .header("X-Title", "RavenClaws");
506 }
507
508 req
509 }
510
511 #[allow(dead_code)]
512 async fn send_request(&self, request: ChatRequest) -> Result<ChatResponse, LLMError> {
513 let req = self.apply_headers(self.client.post(self.endpoint()).json(&request));
514
515 let response = req
516 .send()
517 .await
518 .map_err(|e| LLMError::RequestFailed(e.to_string()))?;
519
520 handle_openai_response(response).await
521 }
522}
523
524#[async_trait::async_trait]
525impl LLMProviderTrait for OpenAICompatibleClient {
526 #[instrument(skip(self, messages), fields(provider = self.provider_name(), model = self.model()))]
527 async fn chat(&self, messages: Vec<ChatMessage>) -> Result<ChatResponse, LLMError> {
528 let request = self.build_request(messages);
529 self.send_request_with_retry(request).await
530 }
531
532 async fn chat_stream(&self, messages: Vec<ChatMessage>) -> Result<StreamResult, LLMError> {
533 let request = ChatRequest {
534 model: self.config.model.clone(),
535 messages,
536 temperature: Some(0.7),
537 max_tokens: Some(2048),
538 stream: Some(true),
539 tools: None,
540 tool_choice: None,
541 };
542
543 let req = self.apply_headers(self.client.post(self.endpoint()).json(&request));
544
545 let response = req
546 .send()
547 .await
548 .map_err(|e| LLMError::RequestFailed(e.to_string()))?;
549
550 let status = response.status();
551 if !status.is_success() {
552 if status == reqwest::StatusCode::UNAUTHORIZED {
553 return Err(LLMError::AuthFailed);
554 } else if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
555 return Err(LLMError::RateLimited);
556 } else {
557 let body = response
558 .text()
559 .await
560 .unwrap_or_else(|_| "Unknown error".to_string());
561 return Err(LLMError::RequestFailed(format!("{}: {}", status, body)));
562 }
563 }
564
565 use futures::StreamExt;
567 let stream = response
568 .bytes_stream()
569 .filter_map(|chunk_result| async move {
570 match chunk_result {
571 Err(e) => Some(Err(LLMError::RequestFailed(e.to_string()))),
572 Ok(bytes) => {
573 let text = String::from_utf8_lossy(&bytes);
574 let mut content = String::new();
575 let mut finish_reason = None;
576
577 for line in text.lines() {
578 if let Some(data) = line.strip_prefix("data: ") {
579 if data == "[DONE]" {
580 finish_reason = Some("stop".to_string());
581 continue;
582 }
583 if let Ok(sse_chunk) =
584 serde_json::from_str::<serde_json::Value>(data)
585 {
586 if let Some(choice) =
587 sse_chunk["choices"].as_array().and_then(|c| c.first())
588 {
589 if let Some(delta) = choice["delta"].as_object() {
590 if let Some(c) = delta["content"].as_str() {
591 content.push_str(c);
592 }
593 }
594 if let Some(reason) = choice["finish_reason"].as_str() {
595 if reason != "null" {
596 finish_reason = Some(reason.to_string());
597 }
598 }
599 }
600 }
601 }
602 }
603
604 if content.is_empty() && finish_reason.is_none() {
605 None
606 } else {
607 Some(Ok(StreamChunk {
608 content,
609 finish_reason,
610 }))
611 }
612 }
613 }
614 });
615
616 Ok(Box::pin(stream))
617 }
618
619 fn provider_name(&self) -> &str {
620 self.provider.name()
621 }
622
623 fn model(&self) -> &str {
624 &self.config.model
625 }
626}
627
628pub struct OllamaClient {
630 client: Client,
631 config: LLMConfig,
632}
633
634impl OllamaClient {
635 pub fn new(config: &LLMConfig) -> Result<Self, LLMError> {
636 let client = Client::builder()
637 .timeout(std::time::Duration::from_secs(config.timeout_secs))
638 .build()
639 .map_err(|e| LLMError::RequestFailed(format!("Failed to create HTTP client: {}", e)))?;
640
641 Ok(Self {
642 client,
643 config: config.clone(),
644 })
645 }
646}
647
648#[async_trait::async_trait]
649impl LLMProviderTrait for OllamaClient {
650 #[instrument(skip(self, messages), fields(provider = self.provider_name(), model = self.model()))]
651 async fn chat(&self, messages: Vec<ChatMessage>) -> Result<ChatResponse, LLMError> {
652 #[derive(Serialize)]
654 struct OllamaRequest {
655 model: String,
656 messages: Vec<ChatMessage>,
657 stream: bool,
658 }
659
660 let request = OllamaRequest {
661 model: self.config.model.clone(),
662 messages,
663 stream: false,
664 };
665
666 let response = self
667 .client
668 .post(format!(
669 "{}/api/chat",
670 self.config.endpoint.trim_end_matches('/')
671 ))
672 .json(&request)
673 .send()
674 .await
675 .map_err(|e| LLMError::RequestFailed(e.to_string()))?;
676
677 let status = response.status();
678
679 if status.is_success() {
680 #[derive(Deserialize)]
682 struct OllamaResponse {
683 model: String,
684 message: ChatMessage,
685 done: bool,
686 }
687
688 let ollama_resp = response
689 .json::<OllamaResponse>()
690 .await
691 .map_err(|e| LLMError::InvalidResponse(e.to_string()))?;
692
693 Ok(ChatResponse {
694 id: format!("ollama-{}", uuid::Uuid::new_v4()),
695 object: "chat.completion".to_string(),
696 created: std::time::SystemTime::now()
697 .duration_since(std::time::UNIX_EPOCH)
698 .unwrap()
699 .as_secs(),
700 model: ollama_resp.model,
701 choices: vec![Choice {
702 index: 0,
703 message: ollama_resp.message,
704 finish_reason: if ollama_resp.done {
705 Some("stop".to_string())
706 } else {
707 None
708 },
709 tool_calls: None,
710 }],
711 usage: None, })
713 } else if status == reqwest::StatusCode::UNAUTHORIZED {
714 Err(LLMError::AuthFailed)
715 } else {
716 let body = response
717 .text()
718 .await
719 .unwrap_or_else(|_| "Unknown error".to_string());
720 Err(LLMError::RequestFailed(format!("{}: {}", status, body)))
721 }
722 }
723
724 fn provider_name(&self) -> &str {
725 "ollama"
726 }
727
728 fn model(&self) -> &str {
729 &self.config.model
730 }
731}
732
733pub struct AnthropicClient {
735 client: Client,
736 config: LLMConfig,
737}
738
739impl AnthropicClient {
740 pub fn new(config: &LLMConfig) -> Result<Self, LLMError> {
741 let client = Client::builder()
742 .timeout(std::time::Duration::from_secs(config.timeout_secs))
743 .build()
744 .map_err(|e| LLMError::RequestFailed(format!("Failed to create HTTP client: {}", e)))?;
745
746 Ok(Self {
747 client,
748 config: config.clone(),
749 })
750 }
751}
752
753#[async_trait::async_trait]
754impl LLMProviderTrait for AnthropicClient {
755 #[instrument(skip(self, messages), fields(provider = self.provider_name(), model = self.model()))]
756 async fn chat(&self, messages: Vec<ChatMessage>) -> Result<ChatResponse, LLMError> {
757 #[derive(Serialize)]
759 struct AnthropicRequest {
760 model: String,
761 max_tokens: u32,
762 messages: Vec<AnthropicMessage>,
763 #[serde(skip_serializing_if = "Option::is_none")]
764 system: Option<String>,
765 #[serde(skip_serializing_if = "Option::is_none")]
766 temperature: Option<f32>,
767 }
768
769 #[derive(Serialize)]
770 struct AnthropicMessage {
771 role: String,
772 content: String,
773 }
774
775 let system = messages
777 .iter()
778 .find(|m| m.role == "system")
779 .map(|m| m.content.clone());
780
781 let anthropic_messages: Vec<AnthropicMessage> = messages
782 .into_iter()
783 .filter(|m| m.role != "system")
784 .map(|m| AnthropicMessage {
785 role: if m.role == "user" {
786 "user".to_string()
787 } else {
788 "assistant".to_string()
789 },
790 content: m.content,
791 })
792 .collect();
793
794 let request = AnthropicRequest {
795 model: self.config.model.clone(),
796 max_tokens: 2048,
797 messages: anthropic_messages,
798 system,
799 temperature: Some(0.7),
800 };
801
802 let api_key = self
803 .config
804 .api_key
805 .clone()
806 .ok_or_else(|| LLMError::AuthFailed)?;
807
808 let response = self
809 .client
810 .post("https://api.anthropic.com/v1/messages")
811 .header("x-api-key", api_key)
812 .header("anthropic-version", "2023-06-01")
813 .header("content-type", "application/json")
814 .json(&request)
815 .send()
816 .await
817 .map_err(|e| LLMError::RequestFailed(e.to_string()))?;
818
819 let status = response.status();
820
821 if status.is_success() {
822 #[derive(Deserialize)]
824 #[allow(dead_code)]
825 struct AnthropicResponse {
826 id: String,
827 #[serde(rename = "type")]
828 response_type: String,
829 role: String,
830 content: Vec<AnthropicContentBlock>,
831 model: String,
832 stop_reason: Option<String>,
833 #[serde(default)]
834 usage: Option<AnthropicUsage>,
835 }
836
837 #[derive(Deserialize)]
838 #[serde(tag = "type", rename_all = "lowercase")]
839 enum AnthropicContentBlock {
840 Text {
841 text: String,
842 },
843 ToolUse {
844 id: String,
845 name: String,
846 input: serde_json::Value,
847 },
848 }
849
850 #[derive(Deserialize)]
851 struct AnthropicUsage {
852 input_tokens: u32,
853 output_tokens: u32,
854 }
855
856 let anthropic_resp = response
857 .json::<AnthropicResponse>()
858 .await
859 .map_err(|e| LLMError::InvalidResponse(e.to_string()))?;
860
861 let mut content = String::new();
863 let mut tool_calls = None;
864
865 for block in anthropic_resp.content {
866 match block {
867 AnthropicContentBlock::Text { text } => {
868 content.push_str(&text);
869 }
870 AnthropicContentBlock::ToolUse { id, name, input } => {
871 if tool_calls.is_none() {
872 tool_calls = Some(Vec::new());
873 }
874 if let Some(ref mut calls) = tool_calls {
875 calls.push(ToolCallResponse {
876 id,
877 call_type: "function".to_string(),
878 function: FunctionCall {
879 name,
880 arguments: input.to_string(),
881 },
882 });
883 }
884 }
885 }
886 }
887
888 Ok(ChatResponse {
889 id: anthropic_resp.id,
890 object: "chat.completion".to_string(),
891 created: std::time::SystemTime::now()
892 .duration_since(std::time::UNIX_EPOCH)
893 .unwrap()
894 .as_secs(),
895 model: anthropic_resp.model,
896 choices: vec![Choice {
897 index: 0,
898 message: ChatMessage {
899 role: "assistant".to_string(),
900 content,
901 },
902 finish_reason: anthropic_resp.stop_reason,
903 tool_calls,
904 }],
905 usage: anthropic_resp.usage.map(|u| Usage {
906 prompt_tokens: u.input_tokens,
907 completion_tokens: u.output_tokens,
908 total_tokens: u.input_tokens + u.output_tokens,
909 }),
910 })
911 } else if status == reqwest::StatusCode::UNAUTHORIZED {
912 Err(LLMError::AuthFailed)
913 } else if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
914 Err(LLMError::RateLimited)
915 } else {
916 let body = response
917 .text()
918 .await
919 .unwrap_or_else(|_| "Unknown error".to_string());
920 Err(LLMError::RequestFailed(format!("{}: {}", status, body)))
921 }
922 }
923
924 fn provider_name(&self) -> &str {
925 "anthropic"
926 }
927
928 fn model(&self) -> &str {
929 &self.config.model
930 }
931}
932
933pub fn create_client(config: &LLMConfig) -> Result<Arc<dyn LLMProviderTrait>, LLMError> {
935 match config.provider {
936 LLMProvider::LiteLLM => {
937 let unified = OpenAICompatibleClient::new(config, OpenAICompatibleProvider::LiteLLM)?;
938 Ok(Arc::new(unified))
939 }
940 LLMProvider::OpenRouter => {
941 let unified =
942 OpenAICompatibleClient::new(config, OpenAICompatibleProvider::OpenRouter)?;
943 Ok(Arc::new(unified))
944 }
945 LLMProvider::Ollama => Ok(Arc::new(OllamaClient::new(config)?)),
946 LLMProvider::OpenAI => {
947 let unified = OpenAICompatibleClient::new(config, OpenAICompatibleProvider::OpenAI)?;
948 Ok(Arc::new(unified))
949 }
950 LLMProvider::Anthropic => Ok(Arc::new(AnthropicClient::new(config)?)),
951 LLMProvider::OpenAICompatible => {
952 let unified = OpenAICompatibleClient::new(config, OpenAICompatibleProvider::Generic)?;
953 Ok(Arc::new(unified))
954 }
955 }
956}
957
958#[derive(Clone)]
960pub struct MultiModelManager {
961 clients: Vec<Arc<dyn LLMProviderTrait>>,
962}
963
964impl MultiModelManager {
965 pub fn new(configs: Vec<LLMConfig>) -> Result<Self, LLMError> {
966 let clients: Result<Vec<_>, _> = configs.iter().map(create_client).collect();
967 Ok(Self { clients: clients? })
968 }
969
970 pub fn get_client(&self, index: usize) -> Option<&Arc<dyn LLMProviderTrait>> {
971 self.clients.get(index)
972 }
973
974 pub fn client_count(&self) -> usize {
975 self.clients.len()
976 }
977
978 pub fn next_client(&self, last_index: usize) -> Option<&Arc<dyn LLMProviderTrait>> {
980 if self.clients.is_empty() {
981 return None;
982 }
983 let next = (last_index + 1) % self.clients.len();
984 Some(&self.clients[next])
985 }
986}
987
988#[derive(Debug)]
990pub struct ProviderFallbackChain {
991 pub configs: Vec<LLMConfig>,
993 token_budget: Option<TokenBudget>,
994}
995
996impl ProviderFallbackChain {
997 pub fn new(configs: Vec<LLMConfig>) -> Self {
998 Self {
999 configs,
1000 token_budget: None,
1001 }
1002 }
1003
1004 pub fn with_token_budget(mut self, budget: TokenBudget) -> Self {
1005 self.token_budget = Some(budget);
1006 self
1007 }
1008
1009 #[instrument(skip(self, messages))]
1011 pub async fn chat_with_fallback(
1012 &mut self,
1013 messages: Vec<ChatMessage>,
1014 ) -> Result<ChatResponse, LLMError> {
1015 let mut last_error = None;
1016
1017 for (i, config) in self.configs.iter().enumerate() {
1018 let client = match create_client(config) {
1019 Ok(c) => c,
1020 Err(e) => {
1021 tracing::warn!(
1022 "Failed to create client for provider {:?}: {}",
1023 config.provider,
1024 e
1025 );
1026 last_error = Some(e);
1027 continue;
1028 }
1029 };
1030
1031 if let Some(ref budget) = self.token_budget {
1033 if !budget.can_spend(500) {
1035 return Err(LLMError::TokenBudgetExceeded);
1036 }
1037 }
1038
1039 match client.chat(messages.clone()).await {
1040 Ok(response) => {
1041 if let Some(ref mut budget) = self.token_budget {
1043 if let Some(usage) = &response.usage {
1044 budget.record_usage(usage.total_tokens);
1045 }
1046 }
1047 return Ok(response);
1048 }
1049 Err(e) => {
1050 tracing::warn!("Provider {} failed: {}", i, e);
1051 last_error = Some(e);
1052 }
1054 }
1055 }
1056
1057 Err(last_error.unwrap_or(LLMError::AllProvidersFailed))
1058 }
1059
1060 #[allow(dead_code)]
1062 pub fn provider_names(&self) -> Vec<String> {
1063 self.configs
1064 .iter()
1065 .map(|c| format!("{:?}", c.provider))
1066 .collect()
1067 }
1068}
1069
1070#[cfg(test)]
1071mod tests {
1072 use super::*;
1073 use mockito::Server;
1074
1075 fn make_chat_messages() -> Vec<ChatMessage> {
1078 vec![
1079 ChatMessage {
1080 role: "system".to_string(),
1081 content: "You are helpful.".to_string(),
1082 },
1083 ChatMessage {
1084 role: "user".to_string(),
1085 content: "Hello!".to_string(),
1086 },
1087 ]
1088 }
1089
1090 fn sample_chat_response_json(model: &str) -> String {
1091 format!(
1092 r#"{{
1093 "id": "chat-123",
1094 "object": "chat.completion",
1095 "created": 1717000000,
1096 "model": "{}",
1097 "choices": [
1098 {{
1099 "index": 0,
1100 "message": {{
1101 "role": "assistant",
1102 "content": "Hi there!"
1103 }},
1104 "finish_reason": "stop"
1105 }}
1106 ],
1107 "usage": {{
1108 "prompt_tokens": 10,
1109 "completion_tokens": 5,
1110 "total_tokens": 15
1111 }}
1112 }}"#,
1113 model
1114 )
1115 }
1116
1117 fn sample_ollama_response_json(model: &str) -> String {
1118 format!(
1119 r#"{{
1120 "model": "{}",
1121 "message": {{
1122 "role": "assistant",
1123 "content": "Hi there!"
1124 }},
1125 "done": true
1126 }}"#,
1127 model
1128 )
1129 }
1130
1131 fn with_mockito<F, Fut>(f: F)
1135 where
1136 F: FnOnce(mockito::ServerGuard) -> Fut,
1137 Fut: std::future::Future<Output = ()>,
1138 {
1139 let server = Server::new();
1140 let rt = tokio::runtime::Runtime::new().unwrap();
1141 rt.block_on(f(server));
1142 }
1143
1144 #[test]
1147 fn test_openai_compatible_provider_defaults() {
1148 assert_eq!(
1149 OpenAICompatibleProvider::LiteLLM.default_endpoint(),
1150 "http://localhost:4000"
1151 );
1152 assert_eq!(
1153 OpenAICompatibleProvider::OpenAI.default_endpoint(),
1154 "https://api.openai.com"
1155 );
1156 assert_eq!(
1157 OpenAICompatibleProvider::OpenRouter.default_endpoint(),
1158 "https://openrouter.ai"
1159 );
1160 }
1161
1162 #[test]
1163 fn test_openai_compatible_provider_names() {
1164 assert_eq!(OpenAICompatibleProvider::LiteLLM.name(), "litellm");
1165 assert_eq!(OpenAICompatibleProvider::OpenAI.name(), "openai");
1166 assert_eq!(OpenAICompatibleProvider::OpenRouter.name(), "openrouter");
1167 }
1168
1169 #[test]
1170 fn test_openai_compatible_requires_custom_headers() {
1171 assert!(!OpenAICompatibleProvider::LiteLLM.requires_custom_headers());
1172 assert!(OpenAICompatibleProvider::OpenRouter.requires_custom_headers());
1173 assert!(!OpenAICompatibleProvider::OpenAI.requires_custom_headers());
1174 }
1175
1176 #[test]
1177 fn test_openai_compatible_client_new() {
1178 let config = LLMConfig {
1179 provider: LLMProvider::LiteLLM,
1180 endpoint: "http://localhost:4000".to_string(),
1181 model: "gpt-4o-mini".to_string(),
1182 api_key: Some("test-key".to_string()),
1183 timeout_secs: 30,
1184 system_prompt: crate::config::default_system_prompt(),
1185 token_budget: None,
1186 retry_max: 3,
1187 retry_base_delay_ms: 100,
1188 retry_max_delay_ms: 10000,
1189 };
1190
1191 let client = OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM);
1192 assert!(client.is_ok());
1193 assert_eq!(client.unwrap().provider_name(), "litellm");
1194 }
1195
1196 #[test]
1197 fn test_openai_compatible_client_endpoint() {
1198 let config = LLMConfig {
1200 provider: LLMProvider::OpenAI,
1201 endpoint: "https://custom.api.example.com".to_string(),
1202 model: "gpt-4o".to_string(),
1203 api_key: Some("test-key".to_string()),
1204 timeout_secs: 30,
1205 system_prompt: crate::config::default_system_prompt(),
1206 token_budget: None,
1207 retry_max: 3,
1208 retry_base_delay_ms: 100,
1209 retry_max_delay_ms: 10000,
1210 };
1211
1212 let client =
1213 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
1214 assert_eq!(client.provider_name(), "openai");
1216 }
1217
1218 #[test]
1219 fn test_openai_compatible_client_chat_success() {
1220 with_mockito(|mut server| async move {
1221 let mock = server
1222 .mock("POST", "/v1/chat/completions")
1223 .with_status(200)
1224 .with_header("content-type", "application/json")
1225 .with_body(sample_chat_response_json("gpt-4o-mini"))
1226 .create();
1227
1228 let config = LLMConfig {
1229 provider: LLMProvider::LiteLLM,
1230 endpoint: server.url(),
1231 model: "gpt-4o-mini".to_string(),
1232 api_key: Some("test-key".to_string()),
1233 timeout_secs: 30,
1234 system_prompt: crate::config::default_system_prompt(),
1235 token_budget: None,
1236 retry_max: 3,
1237 retry_base_delay_ms: 100,
1238 retry_max_delay_ms: 10000,
1239 };
1240
1241 let client =
1242 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1243 let response = client.chat(make_chat_messages()).await.unwrap();
1244
1245 assert_eq!(response.model, "gpt-4o-mini");
1246 assert_eq!(response.choices[0].message.content, "Hi there!");
1247 mock.assert();
1248 });
1249 }
1250
1251 #[test]
1252 fn test_openai_compatible_client_auth_failure() {
1253 with_mockito(|mut server| async move {
1254 let mock = server
1255 .mock("POST", "/v1/chat/completions")
1256 .with_status(401)
1257 .with_body(r#"{"error": "Unauthorized"}"#)
1258 .create();
1259
1260 let config = LLMConfig {
1261 provider: LLMProvider::LiteLLM,
1262 endpoint: server.url(),
1263 model: "gpt-4o-mini".to_string(),
1264 api_key: Some("bad-key".to_string()),
1265 timeout_secs: 30,
1266 system_prompt: crate::config::default_system_prompt(),
1267 token_budget: None,
1268 retry_max: 3,
1269 retry_base_delay_ms: 100,
1270 retry_max_delay_ms: 10000,
1271 };
1272
1273 let client =
1274 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1275 let err = client.chat(make_chat_messages()).await.unwrap_err();
1276
1277 assert!(matches!(err, LLMError::AuthFailed));
1278 mock.assert();
1279 });
1280 }
1281
1282 #[test]
1283 fn test_openai_compatible_client_rate_limit() {
1284 with_mockito(|mut server| async move {
1285 let mock = server
1286 .mock("POST", "/v1/chat/completions")
1287 .with_status(429)
1288 .with_body(r#"{"error": "Rate limited"}"#)
1289 .create();
1290
1291 let config = LLMConfig {
1292 provider: LLMProvider::LiteLLM,
1293 endpoint: server.url(),
1294 model: "gpt-4o-mini".to_string(),
1295 api_key: Some("test-key".to_string()),
1296 timeout_secs: 30,
1297 system_prompt: crate::config::default_system_prompt(),
1298 token_budget: None,
1299 retry_max: 0, retry_base_delay_ms: 100,
1301 retry_max_delay_ms: 10000,
1302 };
1303
1304 let client =
1305 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1306 let err = client.chat(make_chat_messages()).await.unwrap_err();
1307
1308 assert!(matches!(err, LLMError::RateLimited));
1309 mock.assert();
1310 });
1311 }
1312
1313 #[test]
1314 fn test_openrouter_client_uses_custom_headers() {
1315 with_mockito(|mut server| async move {
1316 let mock = server
1317 .mock("POST", "/v1/chat/completions")
1318 .match_header("HTTP-Referer", "https://github.com/egkristi/RavenClaws")
1319 .match_header("X-Title", "RavenClaws")
1320 .with_status(200)
1321 .with_body(sample_chat_response_json("claude-sonnet-4"))
1322 .create();
1323
1324 let config = LLMConfig {
1325 provider: LLMProvider::OpenRouter,
1326 endpoint: server.url(),
1327 model: "claude-sonnet-4".to_string(),
1328 api_key: Some("or-key".to_string()),
1329 timeout_secs: 30,
1330 system_prompt: crate::config::default_system_prompt(),
1331 token_budget: None,
1332 retry_max: 3,
1333 retry_base_delay_ms: 100,
1334 retry_max_delay_ms: 10000,
1335 };
1336
1337 let client =
1338 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
1339 let _ = client.chat(make_chat_messages()).await.unwrap();
1340 mock.assert();
1341 });
1342 }
1343
1344 #[test]
1347 fn test_anthropic_client_new() {
1348 let config = LLMConfig {
1349 provider: LLMProvider::Anthropic,
1350 endpoint: String::new(),
1351 model: "claude-sonnet-4-20250514".to_string(),
1352 api_key: Some("sk-ant-test".to_string()),
1353 timeout_secs: 30,
1354 system_prompt: crate::config::default_system_prompt(),
1355 token_budget: None,
1356 retry_max: 3,
1357 retry_base_delay_ms: 100,
1358 retry_max_delay_ms: 10000,
1359 };
1360
1361 let client = AnthropicClient::new(&config);
1362 assert!(client.is_ok());
1363 }
1364
1365 #[test]
1366 fn test_anthropic_client_provider_name() {
1367 let config = LLMConfig {
1368 provider: LLMProvider::Anthropic,
1369 endpoint: String::new(),
1370 model: "claude-sonnet-4-20250514".to_string(),
1371 api_key: Some("sk-ant-test".to_string()),
1372 timeout_secs: 30,
1373 system_prompt: crate::config::default_system_prompt(),
1374 token_budget: None,
1375 retry_max: 3,
1376 retry_base_delay_ms: 100,
1377 retry_max_delay_ms: 10000,
1378 };
1379
1380 let client = AnthropicClient::new(&config).unwrap();
1381 assert_eq!(client.provider_name(), "anthropic");
1382 }
1383
1384 #[test]
1385 fn test_anthropic_client_model() {
1386 let config = LLMConfig {
1387 provider: LLMProvider::Anthropic,
1388 endpoint: String::new(),
1389 model: "claude-opus-4-20250514".to_string(),
1390 api_key: Some("sk-ant-test".to_string()),
1391 timeout_secs: 30,
1392 system_prompt: crate::config::default_system_prompt(),
1393 token_budget: None,
1394 retry_max: 3,
1395 retry_base_delay_ms: 100,
1396 retry_max_delay_ms: 10000,
1397 };
1398
1399 let client = AnthropicClient::new(&config).unwrap();
1400 assert_eq!(client.model(), "claude-opus-4-20250514");
1401 }
1402
1403 #[test]
1404 fn test_create_client_anthropic() {
1405 let config = LLMConfig {
1406 provider: LLMProvider::Anthropic,
1407 endpoint: String::new(),
1408 model: "claude-sonnet-4-20250514".to_string(),
1409 api_key: Some("sk-ant-test".to_string()),
1410 timeout_secs: 30,
1411 system_prompt: crate::config::default_system_prompt(),
1412 token_budget: None,
1413 retry_max: 3,
1414 retry_base_delay_ms: 100,
1415 retry_max_delay_ms: 10000,
1416 };
1417
1418 let client = create_client(&config);
1419 assert!(client.is_ok());
1420 assert_eq!(client.unwrap().provider_name(), "anthropic");
1421 }
1422
1423 #[test]
1426 fn test_retry_config_delay_calculation() {
1427 let config = RetryConfig {
1428 max_retries: 3,
1429 base_delay_ms: 100,
1430 max_delay_ms: 10000,
1431 jitter: 0.0, };
1433
1434 assert_eq!(config.delay_for_attempt(0).as_millis(), 100);
1436 assert_eq!(config.delay_for_attempt(1).as_millis(), 200);
1437 assert_eq!(config.delay_for_attempt(2).as_millis(), 400);
1438 }
1439
1440 #[test]
1441 fn test_retry_config_max_delay_cap() {
1442 let config = RetryConfig {
1443 max_retries: 10,
1444 base_delay_ms: 100,
1445 max_delay_ms: 1000,
1446 jitter: 0.0,
1447 };
1448
1449 assert!(config.delay_for_attempt(10).as_millis() <= 1000);
1451 }
1452
1453 #[test]
1454 fn test_circuit_breaker_state_transitions() {
1455 let mut cb = CircuitBreaker::new(30);
1456
1457 assert_eq!(cb.state, CircuitState::Closed);
1459 assert!(cb.can_execute());
1460
1461 for _ in 0..5 {
1463 cb.record_failure();
1464 }
1465 assert_eq!(cb.state, CircuitState::Open);
1466 assert!(!cb.can_execute());
1467 }
1468
1469 #[test]
1470 fn test_circuit_breaker_success_resets() {
1471 let mut cb = CircuitBreaker::new(30);
1472
1473 for _ in 0..3 {
1475 cb.record_failure();
1476 }
1477 assert_eq!(cb.failure_count, 3);
1478
1479 cb.record_success();
1481 assert_eq!(cb.failure_count, 0);
1482 assert_eq!(cb.state, CircuitState::Closed);
1483 }
1484
1485 #[test]
1486 fn test_token_budget_tracking() {
1487 let mut budget = TokenBudget::new(1000, 0.002); assert_eq!(budget.remaining(), 1000);
1490 assert!(budget.can_spend(500));
1491
1492 budget.record_usage(300);
1493 assert_eq!(budget.remaining(), 700);
1494 assert!(budget.can_spend(500));
1495
1496 budget.record_usage(500);
1497 assert_eq!(budget.remaining(), 200);
1498 assert!(!budget.can_spend(500));
1499
1500 assert!((budget.estimated_cost() - 0.0016).abs() < 0.0001);
1502 }
1503
1504 #[test]
1505 fn test_provider_fallback_chain_creation() {
1506 let configs = vec![
1507 LLMConfig {
1508 provider: LLMProvider::LiteLLM,
1509 endpoint: "http://localhost:4000".to_string(),
1510 model: "gpt-4o".to_string(),
1511 api_key: Some("key1".to_string()),
1512 timeout_secs: 30,
1513 system_prompt: crate::config::default_system_prompt(),
1514 token_budget: None,
1515 retry_max: 3,
1516 retry_base_delay_ms: 100,
1517 retry_max_delay_ms: 10000,
1518 },
1519 LLMConfig {
1520 provider: LLMProvider::Ollama,
1521 endpoint: "http://localhost:11434".to_string(),
1522 model: "llama3.1".to_string(),
1523 api_key: None,
1524 timeout_secs: 30,
1525 system_prompt: crate::config::default_system_prompt(),
1526 token_budget: None,
1527 retry_max: 3,
1528 retry_base_delay_ms: 100,
1529 retry_max_delay_ms: 10000,
1530 },
1531 ];
1532
1533 let chain = ProviderFallbackChain::new(configs);
1534 assert_eq!(chain.provider_names(), vec!["LiteLLM", "Ollama"]);
1535 }
1536
1537 #[test]
1540 fn test_litellm_chat_auth_failure() {
1541 with_mockito(|mut server| async move {
1542 let mock = server
1543 .mock("POST", "/v1/chat/completions")
1544 .with_status(401)
1545 .with_header("content-type", "application/json")
1546 .with_body(r#"{"error": "Unauthorized"}"#)
1547 .create();
1548
1549 let config = LLMConfig {
1550 provider: LLMProvider::LiteLLM,
1551 endpoint: server.url(),
1552 model: "gpt-4o-mini".to_string(),
1553 api_key: Some("bad-key".to_string()),
1554 timeout_secs: 30,
1555 system_prompt: crate::config::default_system_prompt(),
1556 token_budget: None,
1557 retry_max: 3,
1558 retry_base_delay_ms: 100,
1559 retry_max_delay_ms: 10000,
1560 };
1561
1562 let client =
1563 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1564 let err = client.chat(make_chat_messages()).await.unwrap_err();
1565
1566 assert!(matches!(err, LLMError::AuthFailed));
1567 mock.assert();
1568 });
1569 }
1570
1571 #[test]
1572 fn test_litellm_chat_rate_limit() {
1573 with_mockito(|mut server| async move {
1574 let mock = server
1575 .mock("POST", "/v1/chat/completions")
1576 .with_status(429)
1577 .with_header("content-type", "application/json")
1578 .with_body(r#"{"error": "Rate limit exceeded"}"#)
1579 .create();
1580
1581 let config = LLMConfig {
1582 provider: LLMProvider::LiteLLM,
1583 endpoint: server.url(),
1584 model: "gpt-4o-mini".to_string(),
1585 api_key: Some("test-key".to_string()),
1586 timeout_secs: 30,
1587 system_prompt: crate::config::default_system_prompt(),
1588 token_budget: None,
1589 retry_max: 0,
1590 retry_base_delay_ms: 100,
1591 retry_max_delay_ms: 10000,
1592 };
1593
1594 let client =
1595 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1596 let err = client.chat(make_chat_messages()).await.unwrap_err();
1597
1598 assert!(matches!(err, LLMError::RateLimited));
1599 mock.assert();
1600 });
1601 }
1602
1603 #[test]
1604 fn test_litellm_chat_server_error() {
1605 with_mockito(|mut server| async move {
1606 let mock = server
1607 .mock("POST", "/v1/chat/completions")
1608 .with_status(500)
1609 .with_header("content-type", "application/json")
1610 .with_body(r#"{"error": "Internal server error"}"#)
1611 .create();
1612
1613 let config = LLMConfig {
1614 provider: LLMProvider::LiteLLM,
1615 endpoint: server.url(),
1616 model: "gpt-4o-mini".to_string(),
1617 api_key: Some("test-key".to_string()),
1618 timeout_secs: 30,
1619 system_prompt: crate::config::default_system_prompt(),
1620 token_budget: None,
1621 retry_max: 0,
1622 retry_base_delay_ms: 100,
1623 retry_max_delay_ms: 10000,
1624 };
1625
1626 let client =
1627 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1628 let err = client.chat(make_chat_messages()).await.unwrap_err();
1629
1630 assert!(matches!(err, LLMError::RequestFailed(_)));
1631 assert!(format!("{}", err).contains("500"));
1632 mock.assert();
1633 });
1634 }
1635
1636 #[test]
1637 fn test_litellm_chat_invalid_json() {
1638 with_mockito(|mut server| async move {
1639 let mock = server
1640 .mock("POST", "/v1/chat/completions")
1641 .with_status(200)
1642 .with_header("content-type", "application/json")
1643 .with_body("not-json")
1644 .create();
1645
1646 let config = LLMConfig {
1647 provider: LLMProvider::LiteLLM,
1648 endpoint: server.url(),
1649 model: "gpt-4o-mini".to_string(),
1650 api_key: Some("test-key".to_string()),
1651 timeout_secs: 30,
1652 system_prompt: crate::config::default_system_prompt(),
1653 token_budget: None,
1654 retry_max: 0,
1655 retry_base_delay_ms: 100,
1656 retry_max_delay_ms: 10000,
1657 };
1658
1659 let client =
1660 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM).unwrap();
1661 let err = client.chat(make_chat_messages()).await.unwrap_err();
1662
1663 assert!(matches!(err, LLMError::InvalidResponse(_)));
1664 mock.assert();
1665 });
1666 }
1667
1668 #[test]
1671 fn test_openrouter_chat_success() {
1672 with_mockito(|mut server| async move {
1673 let mock = server
1674 .mock("POST", "/v1/chat/completions")
1675 .with_status(200)
1676 .with_header("content-type", "application/json")
1677 .with_body(sample_chat_response_json(
1678 "anthropic/claude-sonnet-4-20250514",
1679 ))
1680 .create();
1681
1682 let config = LLMConfig {
1683 provider: LLMProvider::OpenRouter,
1684 endpoint: server.url(),
1685 model: "anthropic/claude-sonnet-4-20250514".to_string(),
1686 api_key: Some("or-key".to_string()),
1687 timeout_secs: 30,
1688 system_prompt: crate::config::default_system_prompt(),
1689 token_budget: None,
1690 retry_max: 3,
1691 retry_base_delay_ms: 100,
1692 retry_max_delay_ms: 10000,
1693 };
1694
1695 let client =
1696 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
1697 let response = client.chat(make_chat_messages()).await.unwrap();
1698
1699 assert_eq!(response.model, "anthropic/claude-sonnet-4-20250514");
1700 assert_eq!(response.choices[0].message.content, "Hi there!");
1701 mock.assert();
1702 });
1703 }
1704
1705 #[test]
1706 fn test_openrouter_chat_auth_failure() {
1707 with_mockito(|mut server| async move {
1708 let mock = server
1709 .mock("POST", "/v1/chat/completions")
1710 .with_status(401)
1711 .with_header("content-type", "application/json")
1712 .with_body(r#"{"error": "Unauthorized"}"#)
1713 .create();
1714
1715 let config = LLMConfig {
1716 provider: LLMProvider::OpenRouter,
1717 endpoint: server.url(),
1718 model: "anthropic/claude-sonnet-4-20250514".to_string(),
1719 api_key: Some("bad-key".to_string()),
1720 timeout_secs: 30,
1721 system_prompt: crate::config::default_system_prompt(),
1722 token_budget: None,
1723 retry_max: 3,
1724 retry_base_delay_ms: 100,
1725 retry_max_delay_ms: 10000,
1726 };
1727
1728 let client =
1729 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
1730 let err = client.chat(make_chat_messages()).await.unwrap_err();
1731
1732 assert!(matches!(err, LLMError::AuthFailed));
1733 mock.assert();
1734 });
1735 }
1736
1737 #[test]
1738 fn test_openrouter_chat_rate_limit() {
1739 with_mockito(|mut server| async move {
1740 let mock = server
1741 .mock("POST", "/v1/chat/completions")
1742 .with_status(429)
1743 .with_header("content-type", "application/json")
1744 .with_body(r#"{"error": "Rate limited"}"#)
1745 .create();
1746
1747 let config = LLMConfig {
1748 provider: LLMProvider::OpenRouter,
1749 endpoint: server.url(),
1750 model: "anthropic/claude-sonnet-4-20250514".to_string(),
1751 api_key: Some("or-key".to_string()),
1752 timeout_secs: 30,
1753 system_prompt: crate::config::default_system_prompt(),
1754 token_budget: None,
1755 retry_max: 0, retry_base_delay_ms: 100,
1757 retry_max_delay_ms: 10000,
1758 };
1759
1760 let client =
1761 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
1762 let err = client.chat(make_chat_messages()).await.unwrap_err();
1763
1764 assert!(matches!(err, LLMError::RateLimited));
1765 mock.assert();
1766 });
1767 }
1768
1769 #[test]
1770 fn test_openrouter_chat_server_error() {
1771 with_mockito(|mut server| async move {
1772 let mock = server
1773 .mock("POST", "/v1/chat/completions")
1774 .with_status(500)
1775 .with_header("content-type", "application/json")
1776 .with_body(r#"{"error": "Internal error"}"#)
1777 .create();
1778
1779 let config = LLMConfig {
1780 provider: LLMProvider::OpenRouter,
1781 endpoint: server.url(),
1782 model: "anthropic/claude-sonnet-4-20250514".to_string(),
1783 api_key: Some("or-key".to_string()),
1784 timeout_secs: 30,
1785 system_prompt: crate::config::default_system_prompt(),
1786 token_budget: None,
1787 retry_max: 0, retry_base_delay_ms: 100,
1789 retry_max_delay_ms: 10000,
1790 };
1791
1792 let client =
1793 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
1794 let err = client.chat(make_chat_messages()).await.unwrap_err();
1795
1796 assert!(matches!(err, LLMError::RequestFailed(_)));
1797 assert!(format!("{}", err).contains("500"));
1798 mock.assert();
1799 });
1800 }
1801
1802 #[test]
1803 fn test_openrouter_chat_invalid_json() {
1804 with_mockito(|mut server| async move {
1805 let mock = server
1806 .mock("POST", "/v1/chat/completions")
1807 .with_status(200)
1808 .with_header("content-type", "application/json")
1809 .with_body("not-json")
1810 .create();
1811
1812 let config = LLMConfig {
1813 provider: LLMProvider::OpenRouter,
1814 endpoint: server.url(),
1815 model: "anthropic/claude-sonnet-4-20250514".to_string(),
1816 api_key: Some("or-key".to_string()),
1817 timeout_secs: 30,
1818 system_prompt: crate::config::default_system_prompt(),
1819 token_budget: None,
1820 retry_max: 0, retry_base_delay_ms: 100,
1822 retry_max_delay_ms: 10000,
1823 };
1824
1825 let client =
1826 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
1827 let err = client.chat(make_chat_messages()).await.unwrap_err();
1828
1829 assert!(matches!(err, LLMError::InvalidResponse(_)));
1830 mock.assert();
1831 });
1832 }
1833
1834 #[test]
1837 fn test_openai_chat_success() {
1838 with_mockito(|mut server| async move {
1839 let mock = server
1840 .mock("POST", "/v1/chat/completions")
1841 .with_status(200)
1842 .with_header("content-type", "application/json")
1843 .with_body(sample_chat_response_json("gpt-4o"))
1844 .create();
1845
1846 let config = LLMConfig {
1847 provider: LLMProvider::OpenAI,
1848 endpoint: server.url(),
1849 model: "gpt-4o".to_string(),
1850 api_key: Some("sk-test".to_string()),
1851 timeout_secs: 60,
1852 system_prompt: crate::config::default_system_prompt(),
1853 token_budget: None,
1854 retry_max: 3,
1855 retry_base_delay_ms: 100,
1856 retry_max_delay_ms: 10000,
1857 };
1858
1859 let client =
1860 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
1861 let response = client.chat(make_chat_messages()).await.unwrap();
1862
1863 assert_eq!(response.model, "gpt-4o");
1864 assert_eq!(response.choices[0].message.content, "Hi there!");
1865 mock.assert();
1866 });
1867 }
1868
1869 #[test]
1870 fn test_openai_chat_auth_failure() {
1871 with_mockito(|mut server| async move {
1872 let mock = server
1873 .mock("POST", "/v1/chat/completions")
1874 .with_status(401)
1875 .with_header("content-type", "application/json")
1876 .with_body(r#"{"error": "Unauthorized"}"#)
1877 .create();
1878
1879 let config = LLMConfig {
1880 provider: LLMProvider::OpenAI,
1881 endpoint: server.url(),
1882 model: "gpt-4o".to_string(),
1883 api_key: Some("bad-key".to_string()),
1884 timeout_secs: 30,
1885 system_prompt: crate::config::default_system_prompt(),
1886 token_budget: None,
1887 retry_max: 3,
1888 retry_base_delay_ms: 100,
1889 retry_max_delay_ms: 10000,
1890 };
1891
1892 let client =
1893 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
1894 let err = client.chat(make_chat_messages()).await.unwrap_err();
1895
1896 assert!(matches!(err, LLMError::AuthFailed));
1897 mock.assert();
1898 });
1899 }
1900
1901 #[test]
1902 fn test_openai_chat_rate_limit() {
1903 with_mockito(|mut server| async move {
1904 let mock = server
1905 .mock("POST", "/v1/chat/completions")
1906 .with_status(429)
1907 .with_header("content-type", "application/json")
1908 .with_body(r#"{"error": "Rate limited"}"#)
1909 .create();
1910
1911 let config = LLMConfig {
1912 provider: LLMProvider::OpenAI,
1913 endpoint: server.url(),
1914 model: "gpt-4o".to_string(),
1915 api_key: Some("sk-test".to_string()),
1916 timeout_secs: 30,
1917 system_prompt: crate::config::default_system_prompt(),
1918 token_budget: None,
1919 retry_max: 0, retry_base_delay_ms: 100,
1921 retry_max_delay_ms: 10000,
1922 };
1923
1924 let client =
1925 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
1926 let err = client.chat(make_chat_messages()).await.unwrap_err();
1927
1928 assert!(matches!(err, LLMError::RateLimited));
1929 mock.assert();
1930 });
1931 }
1932
1933 #[test]
1934 fn test_openai_chat_server_error() {
1935 with_mockito(|mut server| async move {
1936 let mock = server
1937 .mock("POST", "/v1/chat/completions")
1938 .with_status(500)
1939 .with_header("content-type", "application/json")
1940 .with_body(r#"{"error": "Internal error"}"#)
1941 .create();
1942
1943 let config = LLMConfig {
1944 provider: LLMProvider::OpenAI,
1945 endpoint: server.url(),
1946 model: "gpt-4o".to_string(),
1947 api_key: Some("sk-test".to_string()),
1948 timeout_secs: 30,
1949 system_prompt: crate::config::default_system_prompt(),
1950 token_budget: None,
1951 retry_max: 0, retry_base_delay_ms: 100,
1953 retry_max_delay_ms: 10000,
1954 };
1955
1956 let client =
1957 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
1958 let err = client.chat(make_chat_messages()).await.unwrap_err();
1959
1960 assert!(matches!(err, LLMError::RequestFailed(_)));
1961 assert!(format!("{}", err).contains("500"));
1962 mock.assert();
1963 });
1964 }
1965
1966 #[test]
1967 fn test_openai_chat_invalid_json() {
1968 with_mockito(|mut server| async move {
1969 let mock = server
1970 .mock("POST", "/v1/chat/completions")
1971 .with_status(200)
1972 .with_header("content-type", "application/json")
1973 .with_body("not-json")
1974 .create();
1975
1976 let config = LLMConfig {
1977 provider: LLMProvider::OpenAI,
1978 endpoint: server.url(),
1979 model: "gpt-4o".to_string(),
1980 api_key: Some("sk-test".to_string()),
1981 timeout_secs: 30,
1982 system_prompt: crate::config::default_system_prompt(),
1983 token_budget: None,
1984 retry_max: 0, retry_base_delay_ms: 100,
1986 retry_max_delay_ms: 10000,
1987 };
1988
1989 let client =
1990 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
1991 let err = client.chat(make_chat_messages()).await.unwrap_err();
1992
1993 assert!(matches!(err, LLMError::InvalidResponse(_)));
1994 mock.assert();
1995 });
1996 }
1997
1998 #[test]
2001 fn test_ollama_chat_success() {
2002 with_mockito(|mut server| async move {
2003 let mock = server
2004 .mock("POST", "/api/chat")
2005 .with_status(200)
2006 .with_header("content-type", "application/json")
2007 .with_body(sample_ollama_response_json("llama3.1"))
2008 .create();
2009
2010 let config = LLMConfig {
2011 provider: LLMProvider::Ollama,
2012 endpoint: server.url(),
2013 model: "llama3.1".to_string(),
2014 api_key: None,
2015 timeout_secs: 30,
2016 system_prompt: crate::config::default_system_prompt(),
2017 token_budget: None,
2018 retry_max: 3,
2019 retry_base_delay_ms: 100,
2020 retry_max_delay_ms: 10000,
2021 };
2022
2023 let client = OllamaClient::new(&config).unwrap();
2024 let response = client.chat(make_chat_messages()).await.unwrap();
2025
2026 assert_eq!(response.model, "llama3.1");
2027 assert_eq!(response.choices[0].message.content, "Hi there!");
2028 assert_eq!(response.choices[0].finish_reason, Some("stop".to_string()));
2029 mock.assert();
2030 });
2031 }
2032
2033 #[test]
2034 fn test_ollama_chat_server_error() {
2035 with_mockito(|mut server| async move {
2036 let mock = server
2037 .mock("POST", "/api/chat")
2038 .with_status(500)
2039 .with_header("content-type", "application/json")
2040 .with_body(r#"{"error": "Model not loaded"}"#)
2041 .create();
2042
2043 let config = LLMConfig {
2044 provider: LLMProvider::Ollama,
2045 endpoint: server.url(),
2046 model: "llama3.1".to_string(),
2047 api_key: None,
2048 timeout_secs: 30,
2049 system_prompt: crate::config::default_system_prompt(),
2050 token_budget: None,
2051 retry_max: 3,
2052 retry_base_delay_ms: 100,
2053 retry_max_delay_ms: 10000,
2054 };
2055
2056 let client = OllamaClient::new(&config).unwrap();
2057 let err = client.chat(make_chat_messages()).await.unwrap_err();
2058
2059 assert!(matches!(err, LLMError::RequestFailed(_)));
2060 mock.assert();
2061 });
2062 }
2063
2064 #[test]
2065 fn test_ollama_chat_invalid_json() {
2066 with_mockito(|mut server| async move {
2067 let mock = server
2068 .mock("POST", "/api/chat")
2069 .with_status(200)
2070 .with_header("content-type", "application/json")
2071 .with_body("not-json")
2072 .create();
2073
2074 let config = LLMConfig {
2075 provider: LLMProvider::Ollama,
2076 endpoint: server.url(),
2077 model: "llama3.1".to_string(),
2078 api_key: None,
2079 timeout_secs: 30,
2080 system_prompt: crate::config::default_system_prompt(),
2081 token_budget: None,
2082 retry_max: 3,
2083 retry_base_delay_ms: 100,
2084 retry_max_delay_ms: 10000,
2085 };
2086
2087 let client = OllamaClient::new(&config).unwrap();
2088 let err = client.chat(make_chat_messages()).await.unwrap_err();
2089
2090 assert!(matches!(err, LLMError::InvalidResponse(_)));
2091 mock.assert();
2092 });
2093 }
2094
2095 #[test]
2096 fn test_ollama_chat_auth_failure() {
2097 with_mockito(|mut server| async move {
2098 let mock = server
2099 .mock("POST", "/api/chat")
2100 .with_status(401)
2101 .with_header("content-type", "application/json")
2102 .with_body(r#"{"error": "Unauthorized"}"#)
2103 .create();
2104
2105 let config = LLMConfig {
2106 provider: LLMProvider::Ollama,
2107 endpoint: server.url(),
2108 model: "llama3.1".to_string(),
2109 api_key: Some("bad-key".to_string()),
2110 timeout_secs: 30,
2111 system_prompt: crate::config::default_system_prompt(),
2112 token_budget: None,
2113 retry_max: 3,
2114 retry_base_delay_ms: 100,
2115 retry_max_delay_ms: 10000,
2116 };
2117
2118 let client = OllamaClient::new(&config).unwrap();
2119 let err = client.chat(make_chat_messages()).await.unwrap_err();
2120
2121 assert!(matches!(err, LLMError::AuthFailed));
2122 mock.assert();
2123 });
2124 }
2125
2126 #[test]
2129 fn test_create_client_factory_litellm() {
2130 let config = LLMConfig {
2131 provider: LLMProvider::LiteLLM,
2132 endpoint: "http://localhost:4000".to_string(),
2133 model: "gpt-4o-mini".to_string(),
2134 api_key: Some("test".to_string()),
2135 timeout_secs: 30,
2136 system_prompt: crate::config::default_system_prompt(),
2137 token_budget: None,
2138 retry_max: 3,
2139 retry_base_delay_ms: 100,
2140 retry_max_delay_ms: 10000,
2141 };
2142
2143 let client = create_client(&config).unwrap();
2144 assert_eq!(client.provider_name(), "litellm");
2145 assert_eq!(client.model(), "gpt-4o-mini");
2146 }
2147
2148 #[test]
2149 fn test_ollama_client_creation() {
2150 let config = LLMConfig {
2151 provider: LLMProvider::Ollama,
2152 endpoint: "http://localhost:11434".to_string(),
2153 model: "llama3.1".to_string(),
2154 api_key: None,
2155 timeout_secs: 30,
2156 system_prompt: crate::config::default_system_prompt(),
2157 token_budget: None,
2158 retry_max: 3,
2159 retry_base_delay_ms: 100,
2160 retry_max_delay_ms: 10000,
2161 };
2162
2163 let client = OllamaClient::new(&config).unwrap();
2164 assert_eq!(client.provider_name(), "ollama");
2165 assert_eq!(client.model(), "llama3.1");
2166 }
2167
2168 #[test]
2169 fn test_openai_client_creation() {
2170 let config = LLMConfig {
2171 provider: LLMProvider::OpenAI,
2172 endpoint: String::new(),
2173 model: "gpt-4o".to_string(),
2174 api_key: Some("sk-test".to_string()),
2175 timeout_secs: 60,
2176 system_prompt: crate::config::default_system_prompt(),
2177 token_budget: None,
2178 retry_max: 3,
2179 retry_base_delay_ms: 100,
2180 retry_max_delay_ms: 10000,
2181 };
2182
2183 let client =
2184 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
2185 assert_eq!(client.provider_name(), "openai");
2186 assert_eq!(client.model(), "gpt-4o");
2187 }
2188
2189 #[test]
2190 fn test_openrouter_client_creation() {
2191 let config = LLMConfig {
2192 provider: LLMProvider::OpenRouter,
2193 endpoint: String::new(),
2194 model: "anthropic/claude-sonnet-4-20250514".to_string(),
2195 api_key: Some("sk-test".to_string()),
2196 timeout_secs: 30,
2197 system_prompt: crate::config::default_system_prompt(),
2198 token_budget: None,
2199 retry_max: 3,
2200 retry_base_delay_ms: 100,
2201 retry_max_delay_ms: 10000,
2202 };
2203
2204 let client =
2205 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
2206 assert_eq!(client.provider_name(), "openrouter");
2207 assert_eq!(client.model(), "anthropic/claude-sonnet-4-20250514");
2208 }
2209
2210 #[test]
2211 fn test_multi_model_manager_empty() {
2212 let manager = MultiModelManager::new(vec![]).unwrap();
2213 assert_eq!(manager.client_count(), 0);
2214 assert!(manager.get_client(0).is_none());
2215 }
2216
2217 #[test]
2218 fn test_multi_model_manager_single() {
2219 let config = LLMConfig {
2220 provider: LLMProvider::LiteLLM,
2221 endpoint: "http://localhost:4000".to_string(),
2222 model: "gpt-4o-mini".to_string(),
2223 api_key: Some("test".to_string()),
2224 timeout_secs: 30,
2225 system_prompt: crate::config::default_system_prompt(),
2226 token_budget: None,
2227 retry_max: 3,
2228 retry_base_delay_ms: 100,
2229 retry_max_delay_ms: 10000,
2230 };
2231
2232 let manager = MultiModelManager::new(vec![config]).unwrap();
2233 assert_eq!(manager.client_count(), 1);
2234 assert!(manager.get_client(0).is_some());
2235 assert_eq!(manager.get_client(0).unwrap().provider_name(), "litellm");
2236 }
2237
2238 #[test]
2239 fn test_multi_model_manager_multiple() {
2240 let configs = vec![
2241 LLMConfig {
2242 provider: LLMProvider::LiteLLM,
2243 endpoint: "http://localhost:4000".to_string(),
2244 model: "gpt-4o-mini".to_string(),
2245 api_key: Some("test".to_string()),
2246 timeout_secs: 30,
2247 system_prompt: crate::config::default_system_prompt(),
2248 token_budget: None,
2249 retry_max: 3,
2250 retry_base_delay_ms: 100,
2251 retry_max_delay_ms: 10000,
2252 },
2253 LLMConfig {
2254 provider: LLMProvider::Ollama,
2255 endpoint: "http://localhost:11434".to_string(),
2256 model: "llama3.1".to_string(),
2257 api_key: None,
2258 timeout_secs: 60,
2259 system_prompt: crate::config::default_system_prompt(),
2260 token_budget: None,
2261 retry_max: 3,
2262 retry_base_delay_ms: 100,
2263 retry_max_delay_ms: 10000,
2264 },
2265 ];
2266
2267 let manager = MultiModelManager::new(configs).unwrap();
2268 assert_eq!(manager.client_count(), 2);
2269 assert_eq!(manager.get_client(0).unwrap().provider_name(), "litellm");
2270 assert_eq!(manager.get_client(1).unwrap().provider_name(), "ollama");
2271 }
2272
2273 #[test]
2274 fn test_multi_model_next_client_round_robin() {
2275 let configs = vec![
2276 LLMConfig {
2277 provider: LLMProvider::LiteLLM,
2278 endpoint: "http://localhost:4000".to_string(),
2279 model: "gpt-4o-mini".to_string(),
2280 api_key: Some("test".to_string()),
2281 timeout_secs: 30,
2282 system_prompt: crate::config::default_system_prompt(),
2283 token_budget: None,
2284 retry_max: 3,
2285 retry_base_delay_ms: 100,
2286 retry_max_delay_ms: 10000,
2287 },
2288 LLMConfig {
2289 provider: LLMProvider::Ollama,
2290 endpoint: "http://localhost:11434".to_string(),
2291 model: "llama3.1".to_string(),
2292 api_key: None,
2293 timeout_secs: 60,
2294 system_prompt: crate::config::default_system_prompt(),
2295 token_budget: None,
2296 retry_max: 3,
2297 retry_base_delay_ms: 100,
2298 retry_max_delay_ms: 10000,
2299 },
2300 ];
2301
2302 let manager = MultiModelManager::new(configs).unwrap();
2303 let next = manager.next_client(0).unwrap();
2305 assert_eq!(next.provider_name(), "ollama");
2306 let next = manager.next_client(1).unwrap();
2308 assert_eq!(next.provider_name(), "litellm");
2309 }
2310
2311 #[test]
2312 fn test_chat_request_serialization() {
2313 let request = ChatRequest {
2314 model: "gpt-4o-mini".to_string(),
2315 messages: vec![
2316 ChatMessage {
2317 role: "system".to_string(),
2318 content: "You are a helpful assistant.".to_string(),
2319 },
2320 ChatMessage {
2321 role: "user".to_string(),
2322 content: "Hello!".to_string(),
2323 },
2324 ],
2325 temperature: Some(0.7),
2326 max_tokens: Some(2048),
2327 stream: None,
2328 tools: None,
2329 tool_choice: None,
2330 };
2331
2332 let json = serde_json::to_string(&request).unwrap();
2333 assert!(json.contains("gpt-4o-mini"));
2334 assert!(json.contains("system"));
2335 assert!(json.contains("user"));
2336 assert!(json.contains("Hello!"));
2337 assert!(json.contains("0.7"));
2338 assert!(!json.contains("stream"));
2340 }
2341
2342 #[test]
2343 fn test_chat_response_deserialization() {
2344 let json = r#"{
2345 "id": "chat-123",
2346 "object": "chat.completion",
2347 "created": 1717000000,
2348 "model": "gpt-4o-mini",
2349 "choices": [
2350 {
2351 "index": 0,
2352 "message": {
2353 "role": "assistant",
2354 "content": "Hello! How can I help you?"
2355 },
2356 "finish_reason": "stop"
2357 }
2358 ],
2359 "usage": {
2360 "prompt_tokens": 10,
2361 "completion_tokens": 20,
2362 "total_tokens": 30
2363 }
2364 }"#;
2365
2366 let response: ChatResponse = serde_json::from_str(json).unwrap();
2367 assert_eq!(response.id, "chat-123");
2368 assert_eq!(response.model, "gpt-4o-mini");
2369 assert_eq!(response.choices.len(), 1);
2370 assert_eq!(response.choices[0].message.role, "assistant");
2371 assert_eq!(
2372 response.choices[0].message.content,
2373 "Hello! How can I help you?"
2374 );
2375 assert_eq!(response.usage.unwrap().total_tokens, 30);
2376 }
2377
2378 #[test]
2379 fn test_multi_model_manager_new_invalid_config() {
2380 let configs = vec![LLMConfig {
2382 provider: LLMProvider::LiteLLM,
2383 endpoint: String::new(), model: "gpt-4o-mini".to_string(),
2385 api_key: None,
2386 timeout_secs: 30,
2387 system_prompt: crate::config::default_system_prompt(),
2388 token_budget: None,
2389 retry_max: 3,
2390 retry_base_delay_ms: 100,
2391 retry_max_delay_ms: 10000,
2392 }];
2393
2394 let result = MultiModelManager::new(configs);
2395 assert!(result.is_ok());
2398 let manager = result.unwrap();
2399 assert_eq!(manager.client_count(), 1);
2400 }
2401
2402 #[test]
2403 fn test_create_client_all_providers() {
2404 let test_cases = vec![
2405 (LLMProvider::LiteLLM, "litellm"),
2406 (LLMProvider::OpenRouter, "openrouter"),
2407 (LLMProvider::Ollama, "ollama"),
2408 (LLMProvider::OpenAI, "openai"),
2409 ];
2410
2411 for (provider, expected_name) in test_cases {
2412 let config = LLMConfig {
2413 provider,
2414 endpoint: "http://localhost:4000".to_string(),
2415 model: "test-model".to_string(),
2416 api_key: Some("test-key".to_string()),
2417 timeout_secs: 30,
2418 system_prompt: crate::config::default_system_prompt(),
2419 token_budget: None,
2420 retry_max: 3,
2421 retry_base_delay_ms: 100,
2422 retry_max_delay_ms: 10000,
2423 };
2424
2425 let client = create_client(&config).unwrap();
2426 assert_eq!(client.provider_name(), expected_name);
2427 assert_eq!(client.model(), "test-model");
2428 }
2429 }
2430
2431 #[test]
2432 fn test_llm_error_display() {
2433 let err = LLMError::RequestFailed("timeout".to_string());
2434 assert_eq!(format!("{}", err), "Request failed: timeout");
2435
2436 let err = LLMError::AuthFailed;
2437 assert_eq!(format!("{}", err), "Authentication failed");
2438
2439 let err = LLMError::RateLimited;
2440 assert_eq!(format!("{}", err), "Rate limit exceeded");
2441
2442 let err = LLMError::InvalidResponse("bad json".to_string());
2443 assert_eq!(format!("{}", err), "Invalid response: bad json");
2444
2445 let err = LLMError::ProviderNotSupported("custom".to_string());
2446 assert_eq!(format!("{}", err), "Provider not supported: custom");
2447 }
2448
2449 #[test]
2450 fn test_llm_error_is_debug() {
2451 let err = LLMError::RequestFailed("test".to_string());
2452 let debug = format!("{:?}", err);
2453 assert!(debug.contains("RequestFailed"));
2454 }
2455
2456 #[test]
2457 fn test_llm_error_is_send() {
2458 fn check_send<T: Send>() {}
2459 check_send::<LLMError>();
2460 }
2461
2462 #[test]
2463 fn test_llm_error_is_sync() {
2464 fn check_sync<T: Sync>() {}
2465 check_sync::<LLMError>();
2466 }
2467
2468 #[test]
2469 fn test_litellm_client_new_with_invalid_timeout() {
2470 let config = LLMConfig {
2472 provider: LLMProvider::LiteLLM,
2473 endpoint: "http://localhost:4000".to_string(),
2474 model: "gpt-4o-mini".to_string(),
2475 api_key: Some("test".to_string()),
2476 timeout_secs: u64::MAX,
2477 system_prompt: crate::config::default_system_prompt(),
2478 token_budget: None,
2479 retry_max: 3,
2480 retry_base_delay_ms: 100,
2481 retry_max_delay_ms: 10000,
2482 };
2483
2484 let result = OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::LiteLLM);
2485 assert!(result.is_ok());
2486 }
2487
2488 #[test]
2489 fn test_openai_client_with_custom_endpoint() {
2490 let config = LLMConfig {
2491 provider: LLMProvider::OpenAI,
2492 endpoint: "https://custom.openai.example.com".to_string(),
2493 model: "gpt-4o".to_string(),
2494 api_key: Some("sk-test".to_string()),
2495 timeout_secs: 30,
2496 system_prompt: crate::config::default_system_prompt(),
2497 token_budget: None,
2498 retry_max: 3,
2499 retry_base_delay_ms: 100,
2500 retry_max_delay_ms: 10000,
2501 };
2502
2503 let client =
2504 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenAI).unwrap();
2505 assert_eq!(client.provider_name(), "openai");
2506 assert_eq!(client.model(), "gpt-4o");
2507 }
2508
2509 #[test]
2510 fn test_openrouter_client_with_custom_endpoint() {
2511 let config = LLMConfig {
2512 provider: LLMProvider::OpenRouter,
2513 endpoint: "https://custom.openrouter.example.com".to_string(),
2514 model: "anthropic/claude-sonnet-4-20250514".to_string(),
2515 api_key: Some("or-key".to_string()),
2516 timeout_secs: 30,
2517 system_prompt: crate::config::default_system_prompt(),
2518 token_budget: None,
2519 retry_max: 3,
2520 retry_base_delay_ms: 100,
2521 retry_max_delay_ms: 10000,
2522 };
2523
2524 let client =
2525 OpenAICompatibleClient::new(&config, OpenAICompatibleProvider::OpenRouter).unwrap();
2526 assert_eq!(client.provider_name(), "openrouter");
2527 assert_eq!(client.model(), "anthropic/claude-sonnet-4-20250514");
2528 }
2529
2530 #[test]
2531 fn test_ollama_client_with_auth() {
2532 let config = LLMConfig {
2534 provider: LLMProvider::Ollama,
2535 endpoint: "http://localhost:11434".to_string(),
2536 model: "llama3.1".to_string(),
2537 api_key: Some("some-key".to_string()),
2538 timeout_secs: 30,
2539 system_prompt: crate::config::default_system_prompt(),
2540 token_budget: None,
2541 retry_max: 3,
2542 retry_base_delay_ms: 100,
2543 retry_max_delay_ms: 10000,
2544 };
2545
2546 let client = OllamaClient::new(&config).unwrap();
2547 assert_eq!(client.provider_name(), "ollama");
2548 assert_eq!(client.model(), "llama3.1");
2549 }
2550
2551 #[test]
2552 fn test_chat_request_no_temperature() {
2553 let request = ChatRequest {
2554 model: "gpt-4o-mini".to_string(),
2555 messages: vec![ChatMessage {
2556 role: "user".to_string(),
2557 content: "Hello!".to_string(),
2558 }],
2559 temperature: None,
2560 max_tokens: None,
2561 stream: None,
2562 tools: None,
2563 tool_choice: None,
2564 };
2565
2566 let json = serde_json::to_string(&request).unwrap();
2567 assert!(json.contains("gpt-4o-mini"));
2568 assert!(json.contains("Hello!"));
2569 assert!(!json.contains("temperature"));
2571 assert!(!json.contains("max_tokens"));
2572 assert!(!json.contains("stream"));
2573 }
2574
2575 #[test]
2576 fn test_chat_response_deserialization_no_usage() {
2577 let json = r#"{
2578 "id": "chat-456",
2579 "object": "chat.completion",
2580 "created": 1717000001,
2581 "model": "gpt-4o",
2582 "choices": [
2583 {
2584 "index": 0,
2585 "message": {
2586 "role": "assistant",
2587 "content": "Sure!"
2588 },
2589 "finish_reason": "stop"
2590 }
2591 ]
2592 }"#;
2593
2594 let response: ChatResponse = serde_json::from_str(json).unwrap();
2595 assert_eq!(response.id, "chat-456");
2596 assert_eq!(response.model, "gpt-4o");
2597 assert_eq!(response.choices.len(), 1);
2598 assert!(response.usage.is_none());
2599 }
2600
2601 #[test]
2602 fn test_chat_response_deserialization_multiple_choices() {
2603 let json = r#"{
2604 "id": "chat-789",
2605 "object": "chat.completion",
2606 "created": 1717000002,
2607 "model": "gpt-4o",
2608 "choices": [
2609 {
2610 "index": 0,
2611 "message": {
2612 "role": "assistant",
2613 "content": "First choice"
2614 },
2615 "finish_reason": "stop"
2616 },
2617 {
2618 "index": 1,
2619 "message": {
2620 "role": "assistant",
2621 "content": "Second choice"
2622 },
2623 "finish_reason": "stop"
2624 }
2625 ]
2626 }"#;
2627
2628 let response: ChatResponse = serde_json::from_str(json).unwrap();
2629 assert_eq!(response.choices.len(), 2);
2630 assert_eq!(response.choices[0].message.content, "First choice");
2631 assert_eq!(response.choices[1].message.content, "Second choice");
2632 }
2633
2634 #[test]
2635 fn test_llm_error_into_boxed() {
2636 let err = LLMError::AuthFailed;
2637 let boxed: Box<dyn std::error::Error> = Box::new(err);
2638 assert!(format!("{}", boxed).contains("Authentication failed"));
2639 }
2640
2641 #[test]
2642 fn test_llm_error_into_string() {
2643 let err = LLMError::RateLimited;
2644 let msg: String = err.to_string();
2645 assert_eq!(msg, "Rate limit exceeded");
2646 }
2647
2648 #[test]
2649 fn test_create_client_with_empty_api_key() {
2650 let config = LLMConfig {
2652 provider: LLMProvider::Ollama,
2653 endpoint: "http://localhost:11434".to_string(),
2654 model: "llama3.1".to_string(),
2655 api_key: None,
2656 timeout_secs: 30,
2657 system_prompt: crate::config::default_system_prompt(),
2658 token_budget: None,
2659 retry_max: 3,
2660 retry_base_delay_ms: 100,
2661 retry_max_delay_ms: 10000,
2662 };
2663
2664 let client = create_client(&config).unwrap();
2665 assert_eq!(client.provider_name(), "ollama");
2666 }
2667
2668 #[test]
2669 fn test_multi_model_manager_get_client_out_of_bounds() {
2670 let manager = MultiModelManager::new(vec![]).unwrap();
2671 assert!(manager.get_client(0).is_none());
2672 assert!(manager.get_client(100).is_none());
2673 assert!(manager.get_client(usize::MAX).is_none());
2674 }
2675
2676 #[test]
2677 fn test_multi_model_next_client_empty() {
2678 let manager = MultiModelManager::new(vec![]).unwrap();
2679 assert!(manager.next_client(0).is_none());
2680 }
2681
2682 #[test]
2683 fn test_multi_model_next_client_single() {
2684 let config = LLMConfig {
2685 provider: LLMProvider::LiteLLM,
2686 endpoint: "http://localhost:4000".to_string(),
2687 model: "gpt-4o-mini".to_string(),
2688 api_key: Some("test".to_string()),
2689 timeout_secs: 30,
2690 system_prompt: crate::config::default_system_prompt(),
2691 token_budget: None,
2692 retry_max: 3,
2693 retry_base_delay_ms: 100,
2694 retry_max_delay_ms: 10000,
2695 };
2696
2697 let manager = MultiModelManager::new(vec![config]).unwrap();
2698 let next = manager.next_client(0).unwrap();
2700 assert_eq!(next.provider_name(), "litellm");
2701 }
2702}