1use crate::ModelProvider;
5use crate::native::{
6 NativeMediaJob, NativeMediaRequest, NativeMediaResponse, ProviderNativeCapabilities,
7};
8use async_trait::async_trait;
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::time::Duration;
12
13fn is_non_retryable(err: &anyhow::Error) -> bool {
15 if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>()
16 && let Some(status) = reqwest_err.status()
17 {
18 let code = status.as_u16();
19 return status.is_client_error() && code != 429 && code != 408;
20 }
21 let msg = err.to_string();
22 for word in msg.split(|c: char| !c.is_ascii_digit()) {
23 if let Ok(code) = word.parse::<u16>()
24 && (400..500).contains(&code)
25 {
26 return code != 429 && code != 408;
27 }
28 }
29 false
30}
31
32fn is_rate_limited(err: &anyhow::Error) -> bool {
34 if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>()
35 && let Some(status) = reqwest_err.status()
36 {
37 return status.as_u16() == 429;
38 }
39 let msg = err.to_string();
40 msg.contains("429")
41 && (msg.contains("Too Many") || msg.contains("rate") || msg.contains("limit"))
42}
43
44fn parse_retry_after_ms(err: &anyhow::Error) -> Option<u64> {
47 let msg = err.to_string();
48 let lower = msg.to_lowercase();
49
50 for prefix in &[
52 "retry-after:",
53 "retry_after:",
54 "retry-after ",
55 "retry_after ",
56 ] {
57 if let Some(pos) = lower.find(prefix) {
58 let after = &msg[pos + prefix.len()..];
59 let num_str: String = after
60 .trim()
61 .chars()
62 .take_while(|c| c.is_ascii_digit() || *c == '.')
63 .collect();
64 if let Ok(secs) = num_str.parse::<f64>()
65 && secs.is_finite()
66 && secs >= 0.0
67 {
68 let millis = Duration::from_secs_f64(secs).as_millis();
69 if let Ok(value) = u64::try_from(millis) {
70 return Some(value);
71 }
72 }
73 }
74 }
75 None
76}
77
78pub struct ReliableProvider {
80 providers: Vec<(String, Box<dyn ModelProvider>)>,
81 max_retries: u32,
82 base_backoff_ms: u64,
83 api_keys: Vec<String>,
85 key_index: AtomicUsize,
86 model_fallbacks: HashMap<String, Vec<String>>,
88}
89
90impl ReliableProvider {
91 pub fn new(
92 providers: Vec<(String, Box<dyn ModelProvider>)>,
93 max_retries: u32,
94 base_backoff_ms: u64,
95 ) -> Self {
96 Self {
97 providers,
98 max_retries,
99 base_backoff_ms: base_backoff_ms.max(50),
100 api_keys: Vec::new(),
101 key_index: AtomicUsize::new(0),
102 model_fallbacks: HashMap::new(),
103 }
104 }
105
106 pub fn with_api_keys(mut self, keys: Vec<String>) -> Self {
108 self.api_keys = keys;
109 self
110 }
111
112 pub fn with_model_fallbacks(mut self, fallbacks: HashMap<String, Vec<String>>) -> Self {
114 self.model_fallbacks = fallbacks;
115 self
116 }
117
118 fn model_chain<'a>(&'a self, model: &'a str) -> Vec<&'a str> {
120 let mut chain = vec![model];
121 if let Some(fallbacks) = self.model_fallbacks.get(model) {
122 chain.extend(fallbacks.iter().map(|s| s.as_str()));
123 }
124 chain
125 }
126
127 fn rotate_key(&self) -> Option<&str> {
129 if self.api_keys.is_empty() {
130 return None;
131 }
132 let idx = self.key_index.fetch_add(1, Ordering::Relaxed) % self.api_keys.len();
133 Some(&self.api_keys[idx])
134 }
135
136 fn compute_backoff(&self, base: u64, err: &anyhow::Error) -> u64 {
138 if let Some(retry_after) = parse_retry_after_ms(err) {
139 retry_after.min(30_000).max(base)
141 } else {
142 base
143 }
144 }
145}
146
147#[async_trait]
148impl ModelProvider for ReliableProvider {
149 async fn warmup(&self) -> anyhow::Result<()> {
150 for (name, provider) in &self.providers {
151 tracing::info!(provider = name, "Warming up provider connection pool");
152 if let Err(e) = provider.warmup().await {
153 tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}");
154 }
155 }
156 Ok(())
157 }
158
159 async fn chat(
160 &self,
161 request: super::ChatRequest<'_>,
162 model: &str,
163 temperature: f64,
164 ) -> anyhow::Result<super::ChatResponse> {
165 let models = self.model_chain(model);
166 let mut failures = Vec::new();
167
168 for current_model in &models {
169 for (provider_name, provider) in &self.providers {
170 let mut backoff_ms = self.base_backoff_ms;
171
172 for attempt in 0..=self.max_retries {
173 match provider.chat(request, current_model, temperature).await {
174 Ok(resp) => {
175 if attempt > 0 || *current_model != model {
176 tracing::info!(
177 provider = provider_name,
178 model = *current_model,
179 attempt,
180 original_model = model,
181 "Provider recovered (failover/retry)"
182 );
183 }
184 return Ok(resp);
185 }
186 Err(e) => {
187 let non_retryable = is_non_retryable(&e);
188 let rate_limited = is_rate_limited(&e);
189
190 failures.push(format!(
191 "{provider_name}/{current_model} attempt {}/{}: {e}",
192 attempt + 1,
193 self.max_retries + 1
194 ));
195
196 if rate_limited && let Some(new_key) = self.rotate_key() {
197 tracing::info!(
198 provider = provider_name,
199 "Rate limited, rotated API key (key ending ...{})",
200 &new_key[new_key.len().saturating_sub(4)..]
201 );
202 }
203
204 if non_retryable {
205 tracing::warn!(
206 provider = provider_name,
207 model = *current_model,
208 "Non-retryable error, moving on"
209 );
210 break;
211 }
212
213 if attempt < self.max_retries {
214 let wait = self.compute_backoff(backoff_ms, &e);
215 tracing::warn!(
216 provider = provider_name,
217 model = *current_model,
218 attempt = attempt + 1,
219 backoff_ms = wait,
220 "Provider call failed, retrying"
221 );
222 tokio::time::sleep(Duration::from_millis(wait)).await;
223 backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
224 }
225 }
226 }
227 }
228
229 tracing::warn!(
230 provider = provider_name,
231 model = *current_model,
232 "Exhausted retries, trying next provider/model"
233 );
234 }
235 }
236
237 anyhow::bail!(
238 "All providers/models failed. Attempts:\n{}",
239 failures.join("\n")
240 )
241 }
242
243 async fn chat_stream(
244 &self,
245 request: super::ChatRequest<'_>,
246 model: &str,
247 temperature: f64,
248 events: tokio::sync::mpsc::UnboundedSender<super::ProviderStreamEvent>,
249 ) -> anyhow::Result<super::ChatResponse> {
250 let models = self.model_chain(model);
251 let mut failures = Vec::new();
252
253 for current_model in &models {
254 for (provider_name, provider) in &self.providers {
255 let mut backoff_ms = self.base_backoff_ms;
256
257 for attempt in 0..=self.max_retries {
258 match provider
259 .chat_stream(request, current_model, temperature, events.clone())
260 .await
261 {
262 Ok(resp) => {
263 if attempt > 0 || *current_model != model {
264 tracing::info!(
265 provider = provider_name,
266 model = *current_model,
267 attempt,
268 original_model = model,
269 "Provider streaming call recovered (failover/retry)"
270 );
271 }
272 return Ok(resp);
273 }
274 Err(e) => {
275 let non_retryable = is_non_retryable(&e);
276 let rate_limited = is_rate_limited(&e);
277
278 failures.push(format!(
279 "{provider_name}/{current_model} streaming attempt {}/{}: {e}",
280 attempt + 1,
281 self.max_retries + 1
282 ));
283
284 if rate_limited && let Some(new_key) = self.rotate_key() {
285 tracing::info!(
286 provider = provider_name,
287 "Rate limited, rotated API key (key ending ...{})",
288 &new_key[new_key.len().saturating_sub(4)..]
289 );
290 }
291
292 if non_retryable {
293 tracing::warn!(
294 provider = provider_name,
295 model = *current_model,
296 "Non-retryable streaming error, moving on"
297 );
298 break;
299 }
300
301 if attempt < self.max_retries {
302 let wait = self.compute_backoff(backoff_ms, &e);
303 tracing::warn!(
304 provider = provider_name,
305 model = *current_model,
306 attempt = attempt + 1,
307 backoff_ms = wait,
308 "Provider streaming call failed, retrying"
309 );
310 tokio::time::sleep(Duration::from_millis(wait)).await;
311 backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
312 }
313 }
314 }
315 }
316
317 tracing::warn!(
318 provider = provider_name,
319 model = *current_model,
320 "Exhausted streaming retries, trying next provider/model"
321 );
322 }
323 }
324
325 anyhow::bail!(
326 "All providers/models failed. Attempts:\n{}",
327 failures.join("\n")
328 )
329 }
330
331 fn context_window(&self, model: &str) -> Option<usize> {
332 self.providers
333 .first()
334 .and_then(|(_, p)| p.context_window(model))
335 }
336
337 fn supports_native_tools(&self) -> bool {
338 self.providers
339 .first()
340 .map(|(_, p)| p.supports_native_tools())
341 .unwrap_or(false)
342 }
343
344 fn supports_developer_role(&self, model: &str) -> bool {
345 self.providers
346 .first()
347 .map(|(_, p)| p.supports_developer_role(model))
348 .unwrap_or(false)
349 }
350
351 fn native_capabilities(&self) -> Option<ProviderNativeCapabilities> {
352 self.providers
353 .first()
354 .and_then(|(_, p)| p.native_capabilities())
355 }
356
357 async fn submit_media(
358 &self,
359 request: NativeMediaRequest,
360 ) -> anyhow::Result<NativeMediaResponse> {
361 let Some((provider_name, provider)) = self.providers.first() else {
362 anyhow::bail!("no provider configured for native media operation");
363 };
364
365 provider
366 .submit_media(request)
367 .await
368 .map_err(|err| anyhow::anyhow!("{provider_name} native media operation failed: {err}"))
369 }
370
371 async fn poll_media_job(&self, job: &NativeMediaJob) -> anyhow::Result<NativeMediaResponse> {
372 if let Some((_, provider)) = self
373 .providers
374 .iter()
375 .find(|(provider_name, _)| provider_name == &job.provider)
376 {
377 return provider.poll_media_job(job).await;
378 }
379
380 let Some((provider_name, provider)) = self.providers.first() else {
381 anyhow::bail!("no provider configured for native media job polling");
382 };
383
384 provider
385 .poll_media_job(job)
386 .await
387 .map_err(|err| anyhow::anyhow!("{provider_name} native media job poll failed: {err}"))
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use crate::traits::{ChatMessage, ChatRequest, ChatResponse, TokenUsage, one_shot};
395 use crate::{ProviderStreamEvent, ProviderToolTrace};
396 use std::sync::Arc;
397
398 struct MockProvider {
399 calls: Arc<AtomicUsize>,
400 fail_until_attempt: usize,
401 response: &'static str,
402 error: &'static str,
403 }
404
405 #[async_trait]
406 impl ModelProvider for MockProvider {
407 async fn chat(
408 &self,
409 _request: ChatRequest<'_>,
410 _model: &str,
411 _temperature: f64,
412 ) -> anyhow::Result<ChatResponse> {
413 let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
414 if attempt <= self.fail_until_attempt {
415 anyhow::bail!(self.error);
416 }
417 Ok(ChatResponse {
418 text: Some(self.response.to_string()),
419 tool_calls: vec![],
420 provider_tool_calls: vec![],
421 usage: TokenUsage::default(),
422 })
423 }
424 }
425
426 struct StreamingMockProvider {
427 chat_calls: Arc<AtomicUsize>,
428 stream_calls: Arc<AtomicUsize>,
429 }
430
431 #[async_trait]
432 impl ModelProvider for StreamingMockProvider {
433 async fn chat(
434 &self,
435 _request: ChatRequest<'_>,
436 _model: &str,
437 _temperature: f64,
438 ) -> anyhow::Result<ChatResponse> {
439 self.chat_calls.fetch_add(1, Ordering::SeqCst);
440 Ok(ChatResponse {
441 text: Some("non-streaming".to_string()),
442 tool_calls: vec![],
443 provider_tool_calls: vec![],
444 usage: TokenUsage::default(),
445 })
446 }
447
448 async fn chat_stream(
449 &self,
450 _request: ChatRequest<'_>,
451 _model: &str,
452 _temperature: f64,
453 events: tokio::sync::mpsc::UnboundedSender<ProviderStreamEvent>,
454 ) -> anyhow::Result<ChatResponse> {
455 self.stream_calls.fetch_add(1, Ordering::SeqCst);
456 events
457 .send(ProviderStreamEvent::ProviderToolStarted(
458 ProviderToolTrace {
459 id: "provider-tool-1".to_string(),
460 name: "web_search".to_string(),
461 provider: "xai".to_string(),
462 input: serde_json::json!({"query": "test"}),
463 output: None,
464 citations: vec![],
465 },
466 ))
467 .ok();
468 Ok(ChatResponse {
469 text: Some("streaming".to_string()),
470 tool_calls: vec![],
471 provider_tool_calls: vec![],
472 usage: TokenUsage::default(),
473 })
474 }
475 }
476
477 struct ModelAwareMock {
479 calls: Arc<AtomicUsize>,
480 models_seen: std::sync::Mutex<Vec<String>>,
481 fail_models: Vec<&'static str>,
482 response: &'static str,
483 }
484
485 #[async_trait]
486 impl ModelProvider for ModelAwareMock {
487 async fn chat(
488 &self,
489 _request: ChatRequest<'_>,
490 model: &str,
491 _temperature: f64,
492 ) -> anyhow::Result<ChatResponse> {
493 self.calls.fetch_add(1, Ordering::SeqCst);
494 self.models_seen.lock().unwrap().push(model.to_string());
495 if self.fail_models.contains(&model) {
496 anyhow::bail!("500 model {} unavailable", model);
497 }
498 Ok(ChatResponse {
499 text: Some(self.response.to_string()),
500 tool_calls: vec![],
501 provider_tool_calls: vec![],
502 usage: TokenUsage::default(),
503 })
504 }
505 }
506
507 #[tokio::test]
510 async fn succeeds_without_retry() {
511 let calls = Arc::new(AtomicUsize::new(0));
512 let provider = ReliableProvider::new(
513 vec![(
514 "primary".into(),
515 Box::new(MockProvider {
516 calls: Arc::clone(&calls),
517 fail_until_attempt: 0,
518 response: "ok",
519 error: "boom",
520 }),
521 )],
522 2,
523 1,
524 );
525
526 let result = one_shot(&provider, None, "hello", "test", 0.0)
527 .await
528 .unwrap();
529 assert_eq!(result, "ok");
530 assert_eq!(calls.load(Ordering::SeqCst), 1);
531 }
532
533 #[tokio::test]
534 async fn retries_then_recovers() {
535 let calls = Arc::new(AtomicUsize::new(0));
536 let provider = ReliableProvider::new(
537 vec![(
538 "primary".into(),
539 Box::new(MockProvider {
540 calls: Arc::clone(&calls),
541 fail_until_attempt: 1,
542 response: "recovered",
543 error: "temporary",
544 }),
545 )],
546 2,
547 1,
548 );
549
550 let result = one_shot(&provider, None, "hello", "test", 0.0)
551 .await
552 .unwrap();
553 assert_eq!(result, "recovered");
554 assert_eq!(calls.load(Ordering::SeqCst), 2);
555 }
556
557 #[tokio::test]
558 async fn falls_back_after_retries_exhausted() {
559 let primary_calls = Arc::new(AtomicUsize::new(0));
560 let fallback_calls = Arc::new(AtomicUsize::new(0));
561
562 let provider = ReliableProvider::new(
563 vec![
564 (
565 "primary".into(),
566 Box::new(MockProvider {
567 calls: Arc::clone(&primary_calls),
568 fail_until_attempt: usize::MAX,
569 response: "never",
570 error: "primary down",
571 }),
572 ),
573 (
574 "fallback".into(),
575 Box::new(MockProvider {
576 calls: Arc::clone(&fallback_calls),
577 fail_until_attempt: 0,
578 response: "from fallback",
579 error: "fallback down",
580 }),
581 ),
582 ],
583 1,
584 1,
585 );
586
587 let result = one_shot(&provider, None, "hello", "test", 0.0)
588 .await
589 .unwrap();
590 assert_eq!(result, "from fallback");
591 assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
592 assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
593 }
594
595 #[tokio::test]
596 async fn returns_aggregated_error_when_all_providers_fail() {
597 let provider = ReliableProvider::new(
598 vec![
599 (
600 "p1".into(),
601 Box::new(MockProvider {
602 calls: Arc::new(AtomicUsize::new(0)),
603 fail_until_attempt: usize::MAX,
604 response: "never",
605 error: "p1 error",
606 }),
607 ),
608 (
609 "p2".into(),
610 Box::new(MockProvider {
611 calls: Arc::new(AtomicUsize::new(0)),
612 fail_until_attempt: usize::MAX,
613 response: "never",
614 error: "p2 error",
615 }),
616 ),
617 ],
618 0,
619 1,
620 );
621
622 let err = one_shot(&provider, None, "hello", "test", 0.0)
623 .await
624 .expect_err("all providers should fail");
625 let msg = err.to_string();
626 assert!(msg.contains("All providers/models failed"));
627 assert!(msg.contains("p1"));
628 assert!(msg.contains("p2"));
629 }
630
631 #[test]
632 fn non_retryable_detects_common_patterns() {
633 assert!(is_non_retryable(&anyhow::anyhow!("400 Bad Request")));
634 assert!(is_non_retryable(&anyhow::anyhow!("401 Unauthorized")));
635 assert!(is_non_retryable(&anyhow::anyhow!("403 Forbidden")));
636 assert!(is_non_retryable(&anyhow::anyhow!("404 Not Found")));
637 assert!(!is_non_retryable(&anyhow::anyhow!("429 Too Many Requests")));
638 assert!(!is_non_retryable(&anyhow::anyhow!("408 Request Timeout")));
639 assert!(!is_non_retryable(&anyhow::anyhow!(
640 "500 Internal Server Error"
641 )));
642 assert!(!is_non_retryable(&anyhow::anyhow!("502 Bad Gateway")));
643 assert!(!is_non_retryable(&anyhow::anyhow!("timeout")));
644 assert!(!is_non_retryable(&anyhow::anyhow!("connection reset")));
645 }
646
647 #[tokio::test]
648 async fn skips_retries_on_non_retryable_error() {
649 let primary_calls = Arc::new(AtomicUsize::new(0));
650 let fallback_calls = Arc::new(AtomicUsize::new(0));
651
652 let provider = ReliableProvider::new(
653 vec![
654 (
655 "primary".into(),
656 Box::new(MockProvider {
657 calls: Arc::clone(&primary_calls),
658 fail_until_attempt: usize::MAX,
659 response: "never",
660 error: "401 Unauthorized",
661 }),
662 ),
663 (
664 "fallback".into(),
665 Box::new(MockProvider {
666 calls: Arc::clone(&fallback_calls),
667 fail_until_attempt: 0,
668 response: "from fallback",
669 error: "fallback err",
670 }),
671 ),
672 ],
673 3,
674 1,
675 );
676
677 let result = one_shot(&provider, None, "hello", "test", 0.0)
678 .await
679 .unwrap();
680 assert_eq!(result, "from fallback");
681 assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
683 assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
684 }
685
686 #[tokio::test]
687 async fn chat_retries_then_recovers() {
688 let calls = Arc::new(AtomicUsize::new(0));
689 let provider = ReliableProvider::new(
690 vec![(
691 "primary".into(),
692 Box::new(MockProvider {
693 calls: Arc::clone(&calls),
694 fail_until_attempt: 1,
695 response: "history ok",
696 error: "temporary",
697 }),
698 )],
699 2,
700 1,
701 );
702
703 let messages = vec![ChatMessage::system("system"), ChatMessage::user("hello")];
704 let request = ChatRequest {
705 messages: &messages,
706 tools: None,
707 native_tools: None,
708 };
709 let result = provider.chat(request, "test", 0.0).await.unwrap();
710 assert_eq!(result.text_or_empty(), "history ok");
711 assert_eq!(calls.load(Ordering::SeqCst), 2);
712 }
713
714 #[tokio::test]
715 async fn chat_stream_forwards_to_wrapped_provider_streaming_impl() {
716 let chat_calls = Arc::new(AtomicUsize::new(0));
717 let stream_calls = Arc::new(AtomicUsize::new(0));
718 let provider = ReliableProvider::new(
719 vec![(
720 "xai".into(),
721 Box::new(StreamingMockProvider {
722 chat_calls: Arc::clone(&chat_calls),
723 stream_calls: Arc::clone(&stream_calls),
724 }) as Box<dyn ModelProvider>,
725 )],
726 0,
727 1,
728 );
729
730 let messages = vec![ChatMessage::user("hello")];
731 let request = ChatRequest {
732 messages: &messages,
733 tools: None,
734 native_tools: None,
735 };
736 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
737
738 let result = provider
739 .chat_stream(request, "grok-4.3", 0.0, tx)
740 .await
741 .unwrap();
742
743 assert_eq!(result.text_or_empty(), "streaming");
744 assert_eq!(chat_calls.load(Ordering::SeqCst), 0);
745 assert_eq!(stream_calls.load(Ordering::SeqCst), 1);
746
747 let event = rx.recv().await.expect("provider stream event");
748 match event {
749 ProviderStreamEvent::ProviderToolStarted(trace) => {
750 assert_eq!(trace.name, "web_search");
751 assert_eq!(trace.provider, "xai");
752 }
753 other => panic!("unexpected provider stream event: {other:?}"),
754 }
755 }
756
757 #[tokio::test]
758 async fn chat_falls_back() {
759 let primary_calls = Arc::new(AtomicUsize::new(0));
760 let fallback_calls = Arc::new(AtomicUsize::new(0));
761
762 let provider = ReliableProvider::new(
763 vec![
764 (
765 "primary".into(),
766 Box::new(MockProvider {
767 calls: Arc::clone(&primary_calls),
768 fail_until_attempt: usize::MAX,
769 response: "never",
770 error: "primary down",
771 }),
772 ),
773 (
774 "fallback".into(),
775 Box::new(MockProvider {
776 calls: Arc::clone(&fallback_calls),
777 fail_until_attempt: 0,
778 response: "fallback ok",
779 error: "fallback err",
780 }),
781 ),
782 ],
783 1,
784 1,
785 );
786
787 let messages = vec![ChatMessage::user("hello")];
788 let request = ChatRequest {
789 messages: &messages,
790 tools: None,
791 native_tools: None,
792 };
793 let result = provider.chat(request, "test", 0.0).await.unwrap();
794 assert_eq!(result.text_or_empty(), "fallback ok");
795 assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
796 assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
797 }
798
799 #[tokio::test]
802 async fn model_failover_tries_fallback_model() {
803 let calls = Arc::new(AtomicUsize::new(0));
804 let mock = Arc::new(ModelAwareMock {
805 calls: Arc::clone(&calls),
806 models_seen: std::sync::Mutex::new(Vec::new()),
807 fail_models: vec!["claude-opus"],
808 response: "ok from sonnet",
809 });
810
811 let mut fallbacks = HashMap::new();
812 fallbacks.insert("claude-opus".to_string(), vec!["claude-sonnet".to_string()]);
813
814 let provider = ReliableProvider::new(
815 vec![(
816 "anthropic".into(),
817 Box::new(mock.clone()) as Box<dyn ModelProvider>,
818 )],
819 0, 1,
821 )
822 .with_model_fallbacks(fallbacks);
823
824 let result = one_shot(&provider, None, "hello", "claude-opus", 0.0)
825 .await
826 .unwrap();
827 assert_eq!(result, "ok from sonnet");
828
829 let seen = mock.models_seen.lock().unwrap();
830 assert_eq!(seen.len(), 2);
831 assert_eq!(seen[0], "claude-opus");
832 assert_eq!(seen[1], "claude-sonnet");
833 }
834
835 #[tokio::test]
836 async fn model_failover_all_models_fail() {
837 let calls = Arc::new(AtomicUsize::new(0));
838 let mock = Arc::new(ModelAwareMock {
839 calls: Arc::clone(&calls),
840 models_seen: std::sync::Mutex::new(Vec::new()),
841 fail_models: vec!["model-a", "model-b", "model-c"],
842 response: "never",
843 });
844
845 let mut fallbacks = HashMap::new();
846 fallbacks.insert(
847 "model-a".to_string(),
848 vec!["model-b".to_string(), "model-c".to_string()],
849 );
850
851 let provider = ReliableProvider::new(
852 vec![(
853 "p1".into(),
854 Box::new(mock.clone()) as Box<dyn ModelProvider>,
855 )],
856 0,
857 1,
858 )
859 .with_model_fallbacks(fallbacks);
860
861 let err = one_shot(&provider, None, "hello", "model-a", 0.0)
862 .await
863 .expect_err("all models should fail");
864 assert!(err.to_string().contains("All providers/models failed"));
865
866 let seen = mock.models_seen.lock().unwrap();
867 assert_eq!(seen.len(), 3);
868 }
869
870 #[tokio::test]
871 async fn no_model_fallbacks_behaves_like_before() {
872 let calls = Arc::new(AtomicUsize::new(0));
873 let provider = ReliableProvider::new(
874 vec![(
875 "primary".into(),
876 Box::new(MockProvider {
877 calls: Arc::clone(&calls),
878 fail_until_attempt: 0,
879 response: "ok",
880 error: "boom",
881 }),
882 )],
883 2,
884 1,
885 );
886 let result = one_shot(&provider, None, "hello", "test", 0.0)
888 .await
889 .unwrap();
890 assert_eq!(result, "ok");
891 assert_eq!(calls.load(Ordering::SeqCst), 1);
892 }
893
894 #[tokio::test]
897 async fn auth_rotation_cycles_keys() {
898 let provider = ReliableProvider::new(
899 vec![(
900 "p".into(),
901 Box::new(MockProvider {
902 calls: Arc::new(AtomicUsize::new(0)),
903 fail_until_attempt: 0,
904 response: "ok",
905 error: "",
906 }),
907 )],
908 0,
909 1,
910 )
911 .with_api_keys(vec!["key-a".into(), "key-b".into(), "key-c".into()]);
912
913 let keys: Vec<&str> = (0..5).map(|_| provider.rotate_key().unwrap()).collect();
915 assert_eq!(keys, vec!["key-a", "key-b", "key-c", "key-a", "key-b"]);
916 }
917
918 #[tokio::test]
919 async fn auth_rotation_returns_none_when_empty() {
920 let provider = ReliableProvider::new(vec![], 0, 1);
921 assert!(provider.rotate_key().is_none());
922 }
923
924 #[test]
927 fn parse_retry_after_integer() {
928 let err = anyhow::anyhow!("429 Too Many Requests, Retry-After: 5");
929 assert_eq!(parse_retry_after_ms(&err), Some(5000));
930 }
931
932 #[test]
933 fn parse_retry_after_float() {
934 let err = anyhow::anyhow!("Rate limited. retry_after: 2.5 seconds");
935 assert_eq!(parse_retry_after_ms(&err), Some(2500));
936 }
937
938 #[test]
939 fn parse_retry_after_missing() {
940 let err = anyhow::anyhow!("500 Internal Server Error");
941 assert_eq!(parse_retry_after_ms(&err), None);
942 }
943
944 #[test]
945 fn rate_limited_detection() {
946 assert!(is_rate_limited(&anyhow::anyhow!("429 Too Many Requests")));
947 assert!(is_rate_limited(&anyhow::anyhow!(
948 "HTTP 429 rate limit exceeded"
949 )));
950 assert!(!is_rate_limited(&anyhow::anyhow!("401 Unauthorized")));
951 assert!(!is_rate_limited(&anyhow::anyhow!(
952 "500 Internal Server Error"
953 )));
954 }
955
956 #[test]
957 fn compute_backoff_uses_retry_after() {
958 let provider = ReliableProvider::new(vec![], 0, 500);
959 let err = anyhow::anyhow!("429 Retry-After: 3");
960 assert_eq!(provider.compute_backoff(500, &err), 3000);
961 }
962
963 #[test]
964 fn compute_backoff_caps_at_30s() {
965 let provider = ReliableProvider::new(vec![], 0, 500);
966 let err = anyhow::anyhow!("429 Retry-After: 120");
967 assert_eq!(provider.compute_backoff(500, &err), 30_000);
968 }
969
970 #[test]
971 fn compute_backoff_falls_back_to_base() {
972 let provider = ReliableProvider::new(vec![], 0, 500);
973 let err = anyhow::anyhow!("500 Server Error");
974 assert_eq!(provider.compute_backoff(500, &err), 500);
975 }
976
977 #[async_trait]
980 impl ModelProvider for Arc<ModelAwareMock> {
981 async fn chat(
982 &self,
983 request: ChatRequest<'_>,
984 model: &str,
985 temperature: f64,
986 ) -> anyhow::Result<ChatResponse> {
987 self.as_ref().chat(request, model, temperature).await
988 }
989 }
990}