Skip to main content

nenjo_models/
reliable.rs

1//! Reliability wrapper providing exponential-backoff retries, provider fallback,
2//! API key rotation, and model failover.
3
4use crate::ModelProvider;
5use crate::native::{
6    NativeMediaJob, NativeMediaRequest, NativeMediaResponse, ProviderNativeCapabilities,
7};
8use async_trait::async_trait;
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::time::Duration;
12
13/// Check if an error is non-retryable (client errors that won't resolve with retries).
14fn is_non_retryable(err: &anyhow::Error) -> bool {
15    if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>()
16        && let Some(status) = reqwest_err.status()
17    {
18        let code = status.as_u16();
19        return status.is_client_error() && code != 429 && code != 408;
20    }
21    let msg = err.to_string();
22    for word in msg.split(|c: char| !c.is_ascii_digit()) {
23        if let Ok(code) = word.parse::<u16>()
24            && (400..500).contains(&code)
25        {
26            return code != 429 && code != 408;
27        }
28    }
29    false
30}
31
32/// Check if an error is a rate-limit (429) error.
33fn is_rate_limited(err: &anyhow::Error) -> bool {
34    if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>()
35        && let Some(status) = reqwest_err.status()
36    {
37        return status.as_u16() == 429;
38    }
39    let msg = err.to_string();
40    msg.contains("429")
41        && (msg.contains("Too Many") || msg.contains("rate") || msg.contains("limit"))
42}
43
44/// Try to extract a Retry-After value (in milliseconds) from an error message.
45/// Looks for patterns like `Retry-After: 5` or `retry_after: 2.5` in the error string.
46fn parse_retry_after_ms(err: &anyhow::Error) -> Option<u64> {
47    let msg = err.to_string();
48    let lower = msg.to_lowercase();
49
50    // Look for "retry-after: <number>" or "retry_after: <number>"
51    for prefix in &[
52        "retry-after:",
53        "retry_after:",
54        "retry-after ",
55        "retry_after ",
56    ] {
57        if let Some(pos) = lower.find(prefix) {
58            let after = &msg[pos + prefix.len()..];
59            let num_str: String = after
60                .trim()
61                .chars()
62                .take_while(|c| c.is_ascii_digit() || *c == '.')
63                .collect();
64            if let Ok(secs) = num_str.parse::<f64>()
65                && secs.is_finite()
66                && secs >= 0.0
67            {
68                let millis = Duration::from_secs_f64(secs).as_millis();
69                if let Ok(value) = u64::try_from(millis) {
70                    return Some(value);
71                }
72            }
73        }
74    }
75    None
76}
77
78/// Provider wrapper with retry, fallback, auth rotation, and model failover.
79pub struct ReliableProvider {
80    providers: Vec<(String, Box<dyn ModelProvider>)>,
81    max_retries: u32,
82    base_backoff_ms: u64,
83    /// Extra API keys for rotation (index tracks round-robin position).
84    api_keys: Vec<String>,
85    key_index: AtomicUsize,
86    /// Per-model fallback chains: model_name → [fallback_model_1, fallback_model_2, ...]
87    model_fallbacks: HashMap<String, Vec<String>>,
88}
89
90impl ReliableProvider {
91    pub fn new(
92        providers: Vec<(String, Box<dyn ModelProvider>)>,
93        max_retries: u32,
94        base_backoff_ms: u64,
95    ) -> Self {
96        Self {
97            providers,
98            max_retries,
99            base_backoff_ms: base_backoff_ms.max(50),
100            api_keys: Vec::new(),
101            key_index: AtomicUsize::new(0),
102            model_fallbacks: HashMap::new(),
103        }
104    }
105
106    /// Set additional API keys for round-robin rotation on rate-limit errors.
107    pub fn with_api_keys(mut self, keys: Vec<String>) -> Self {
108        self.api_keys = keys;
109        self
110    }
111
112    /// Set per-model fallback chains.
113    pub fn with_model_fallbacks(mut self, fallbacks: HashMap<String, Vec<String>>) -> Self {
114        self.model_fallbacks = fallbacks;
115        self
116    }
117
118    /// Build the list of models to try: [original, fallback1, fallback2, ...]
119    fn model_chain<'a>(&'a self, model: &'a str) -> Vec<&'a str> {
120        let mut chain = vec![model];
121        if let Some(fallbacks) = self.model_fallbacks.get(model) {
122            chain.extend(fallbacks.iter().map(|s| s.as_str()));
123        }
124        chain
125    }
126
127    /// Advance to the next API key and return it, or None if no extra keys configured.
128    fn rotate_key(&self) -> Option<&str> {
129        if self.api_keys.is_empty() {
130            return None;
131        }
132        let idx = self.key_index.fetch_add(1, Ordering::Relaxed) % self.api_keys.len();
133        Some(&self.api_keys[idx])
134    }
135
136    /// Compute backoff duration, respecting Retry-After if present.
137    fn compute_backoff(&self, base: u64, err: &anyhow::Error) -> u64 {
138        if let Some(retry_after) = parse_retry_after_ms(err) {
139            // Use Retry-After but cap at 30s to avoid indefinite waits
140            retry_after.min(30_000).max(base)
141        } else {
142            base
143        }
144    }
145}
146
147#[async_trait]
148impl ModelProvider for ReliableProvider {
149    async fn warmup(&self) -> anyhow::Result<()> {
150        for (name, provider) in &self.providers {
151            tracing::info!(provider = name, "Warming up provider connection pool");
152            if let Err(e) = provider.warmup().await {
153                tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}");
154            }
155        }
156        Ok(())
157    }
158
159    async fn chat(
160        &self,
161        request: super::ChatRequest<'_>,
162        model: &str,
163        temperature: f64,
164    ) -> anyhow::Result<super::ChatResponse> {
165        let models = self.model_chain(model);
166        let mut failures = Vec::new();
167
168        for current_model in &models {
169            for (provider_name, provider) in &self.providers {
170                let mut backoff_ms = self.base_backoff_ms;
171
172                for attempt in 0..=self.max_retries {
173                    match provider.chat(request, current_model, temperature).await {
174                        Ok(resp) => {
175                            if attempt > 0 || *current_model != model {
176                                tracing::info!(
177                                    provider = provider_name,
178                                    model = *current_model,
179                                    attempt,
180                                    original_model = model,
181                                    "Provider recovered (failover/retry)"
182                                );
183                            }
184                            return Ok(resp);
185                        }
186                        Err(e) => {
187                            let non_retryable = is_non_retryable(&e);
188                            let rate_limited = is_rate_limited(&e);
189
190                            failures.push(format!(
191                                "{provider_name}/{current_model} attempt {}/{}: {e}",
192                                attempt + 1,
193                                self.max_retries + 1
194                            ));
195
196                            if rate_limited && let Some(new_key) = self.rotate_key() {
197                                tracing::info!(
198                                    provider = provider_name,
199                                    "Rate limited, rotated API key (key ending ...{})",
200                                    &new_key[new_key.len().saturating_sub(4)..]
201                                );
202                            }
203
204                            if non_retryable {
205                                tracing::warn!(
206                                    provider = provider_name,
207                                    model = *current_model,
208                                    "Non-retryable error, moving on"
209                                );
210                                break;
211                            }
212
213                            if attempt < self.max_retries {
214                                let wait = self.compute_backoff(backoff_ms, &e);
215                                tracing::warn!(
216                                    provider = provider_name,
217                                    model = *current_model,
218                                    attempt = attempt + 1,
219                                    backoff_ms = wait,
220                                    "Provider call failed, retrying"
221                                );
222                                tokio::time::sleep(Duration::from_millis(wait)).await;
223                                backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
224                            }
225                        }
226                    }
227                }
228
229                tracing::warn!(
230                    provider = provider_name,
231                    model = *current_model,
232                    "Exhausted retries, trying next provider/model"
233                );
234            }
235        }
236
237        anyhow::bail!(
238            "All providers/models failed. Attempts:\n{}",
239            failures.join("\n")
240        )
241    }
242
243    async fn chat_stream(
244        &self,
245        request: super::ChatRequest<'_>,
246        model: &str,
247        temperature: f64,
248        events: tokio::sync::mpsc::UnboundedSender<super::ProviderStreamEvent>,
249    ) -> anyhow::Result<super::ChatResponse> {
250        let models = self.model_chain(model);
251        let mut failures = Vec::new();
252
253        for current_model in &models {
254            for (provider_name, provider) in &self.providers {
255                let mut backoff_ms = self.base_backoff_ms;
256
257                for attempt in 0..=self.max_retries {
258                    match provider
259                        .chat_stream(request, current_model, temperature, events.clone())
260                        .await
261                    {
262                        Ok(resp) => {
263                            if attempt > 0 || *current_model != model {
264                                tracing::info!(
265                                    provider = provider_name,
266                                    model = *current_model,
267                                    attempt,
268                                    original_model = model,
269                                    "Provider streaming call recovered (failover/retry)"
270                                );
271                            }
272                            return Ok(resp);
273                        }
274                        Err(e) => {
275                            let non_retryable = is_non_retryable(&e);
276                            let rate_limited = is_rate_limited(&e);
277
278                            failures.push(format!(
279                                "{provider_name}/{current_model} streaming attempt {}/{}: {e}",
280                                attempt + 1,
281                                self.max_retries + 1
282                            ));
283
284                            if rate_limited && let Some(new_key) = self.rotate_key() {
285                                tracing::info!(
286                                    provider = provider_name,
287                                    "Rate limited, rotated API key (key ending ...{})",
288                                    &new_key[new_key.len().saturating_sub(4)..]
289                                );
290                            }
291
292                            if non_retryable {
293                                tracing::warn!(
294                                    provider = provider_name,
295                                    model = *current_model,
296                                    "Non-retryable streaming error, moving on"
297                                );
298                                break;
299                            }
300
301                            if attempt < self.max_retries {
302                                let wait = self.compute_backoff(backoff_ms, &e);
303                                tracing::warn!(
304                                    provider = provider_name,
305                                    model = *current_model,
306                                    attempt = attempt + 1,
307                                    backoff_ms = wait,
308                                    "Provider streaming call failed, retrying"
309                                );
310                                tokio::time::sleep(Duration::from_millis(wait)).await;
311                                backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
312                            }
313                        }
314                    }
315                }
316
317                tracing::warn!(
318                    provider = provider_name,
319                    model = *current_model,
320                    "Exhausted streaming retries, trying next provider/model"
321                );
322            }
323        }
324
325        anyhow::bail!(
326            "All providers/models failed. Attempts:\n{}",
327            failures.join("\n")
328        )
329    }
330
331    fn context_window(&self, model: &str) -> Option<usize> {
332        self.providers
333            .first()
334            .and_then(|(_, p)| p.context_window(model))
335    }
336
337    fn supports_native_tools(&self) -> bool {
338        self.providers
339            .first()
340            .map(|(_, p)| p.supports_native_tools())
341            .unwrap_or(false)
342    }
343
344    fn supports_developer_role(&self, model: &str) -> bool {
345        self.providers
346            .first()
347            .map(|(_, p)| p.supports_developer_role(model))
348            .unwrap_or(false)
349    }
350
351    fn native_capabilities(&self) -> Option<ProviderNativeCapabilities> {
352        self.providers
353            .first()
354            .and_then(|(_, p)| p.native_capabilities())
355    }
356
357    async fn submit_media(
358        &self,
359        request: NativeMediaRequest,
360    ) -> anyhow::Result<NativeMediaResponse> {
361        let Some((provider_name, provider)) = self.providers.first() else {
362            anyhow::bail!("no provider configured for native media operation");
363        };
364
365        provider
366            .submit_media(request)
367            .await
368            .map_err(|err| anyhow::anyhow!("{provider_name} native media operation failed: {err}"))
369    }
370
371    async fn poll_media_job(&self, job: &NativeMediaJob) -> anyhow::Result<NativeMediaResponse> {
372        if let Some((_, provider)) = self
373            .providers
374            .iter()
375            .find(|(provider_name, _)| provider_name == &job.provider)
376        {
377            return provider.poll_media_job(job).await;
378        }
379
380        let Some((provider_name, provider)) = self.providers.first() else {
381            anyhow::bail!("no provider configured for native media job polling");
382        };
383
384        provider
385            .poll_media_job(job)
386            .await
387            .map_err(|err| anyhow::anyhow!("{provider_name} native media job poll failed: {err}"))
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use crate::traits::{ChatMessage, ChatRequest, ChatResponse, TokenUsage, one_shot};
395    use crate::{ProviderStreamEvent, ProviderToolTrace};
396    use std::sync::Arc;
397
398    struct MockProvider {
399        calls: Arc<AtomicUsize>,
400        fail_until_attempt: usize,
401        response: &'static str,
402        error: &'static str,
403    }
404
405    #[async_trait]
406    impl ModelProvider for MockProvider {
407        async fn chat(
408            &self,
409            _request: ChatRequest<'_>,
410            _model: &str,
411            _temperature: f64,
412        ) -> anyhow::Result<ChatResponse> {
413            let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
414            if attempt <= self.fail_until_attempt {
415                anyhow::bail!(self.error);
416            }
417            Ok(ChatResponse {
418                text: Some(self.response.to_string()),
419                tool_calls: vec![],
420                provider_tool_calls: vec![],
421                usage: TokenUsage::default(),
422            })
423        }
424    }
425
426    struct StreamingMockProvider {
427        chat_calls: Arc<AtomicUsize>,
428        stream_calls: Arc<AtomicUsize>,
429    }
430
431    #[async_trait]
432    impl ModelProvider for StreamingMockProvider {
433        async fn chat(
434            &self,
435            _request: ChatRequest<'_>,
436            _model: &str,
437            _temperature: f64,
438        ) -> anyhow::Result<ChatResponse> {
439            self.chat_calls.fetch_add(1, Ordering::SeqCst);
440            Ok(ChatResponse {
441                text: Some("non-streaming".to_string()),
442                tool_calls: vec![],
443                provider_tool_calls: vec![],
444                usage: TokenUsage::default(),
445            })
446        }
447
448        async fn chat_stream(
449            &self,
450            _request: ChatRequest<'_>,
451            _model: &str,
452            _temperature: f64,
453            events: tokio::sync::mpsc::UnboundedSender<ProviderStreamEvent>,
454        ) -> anyhow::Result<ChatResponse> {
455            self.stream_calls.fetch_add(1, Ordering::SeqCst);
456            events
457                .send(ProviderStreamEvent::ProviderToolStarted(
458                    ProviderToolTrace {
459                        id: "provider-tool-1".to_string(),
460                        name: "web_search".to_string(),
461                        provider: "xai".to_string(),
462                        input: serde_json::json!({"query": "test"}),
463                        output: None,
464                        citations: vec![],
465                    },
466                ))
467                .ok();
468            Ok(ChatResponse {
469                text: Some("streaming".to_string()),
470                tool_calls: vec![],
471                provider_tool_calls: vec![],
472                usage: TokenUsage::default(),
473            })
474        }
475    }
476
477    /// Mock that records which model was used for each call.
478    struct ModelAwareMock {
479        calls: Arc<AtomicUsize>,
480        models_seen: std::sync::Mutex<Vec<String>>,
481        fail_models: Vec<&'static str>,
482        response: &'static str,
483    }
484
485    #[async_trait]
486    impl ModelProvider for ModelAwareMock {
487        async fn chat(
488            &self,
489            _request: ChatRequest<'_>,
490            model: &str,
491            _temperature: f64,
492        ) -> anyhow::Result<ChatResponse> {
493            self.calls.fetch_add(1, Ordering::SeqCst);
494            self.models_seen.lock().unwrap().push(model.to_string());
495            if self.fail_models.contains(&model) {
496                anyhow::bail!("500 model {} unavailable", model);
497            }
498            Ok(ChatResponse {
499                text: Some(self.response.to_string()),
500                tool_calls: vec![],
501                provider_tool_calls: vec![],
502                usage: TokenUsage::default(),
503            })
504        }
505    }
506
507    // ── Existing tests (preserved) ──
508
509    #[tokio::test]
510    async fn succeeds_without_retry() {
511        let calls = Arc::new(AtomicUsize::new(0));
512        let provider = ReliableProvider::new(
513            vec![(
514                "primary".into(),
515                Box::new(MockProvider {
516                    calls: Arc::clone(&calls),
517                    fail_until_attempt: 0,
518                    response: "ok",
519                    error: "boom",
520                }),
521            )],
522            2,
523            1,
524        );
525
526        let result = one_shot(&provider, None, "hello", "test", 0.0)
527            .await
528            .unwrap();
529        assert_eq!(result, "ok");
530        assert_eq!(calls.load(Ordering::SeqCst), 1);
531    }
532
533    #[tokio::test]
534    async fn retries_then_recovers() {
535        let calls = Arc::new(AtomicUsize::new(0));
536        let provider = ReliableProvider::new(
537            vec![(
538                "primary".into(),
539                Box::new(MockProvider {
540                    calls: Arc::clone(&calls),
541                    fail_until_attempt: 1,
542                    response: "recovered",
543                    error: "temporary",
544                }),
545            )],
546            2,
547            1,
548        );
549
550        let result = one_shot(&provider, None, "hello", "test", 0.0)
551            .await
552            .unwrap();
553        assert_eq!(result, "recovered");
554        assert_eq!(calls.load(Ordering::SeqCst), 2);
555    }
556
557    #[tokio::test]
558    async fn falls_back_after_retries_exhausted() {
559        let primary_calls = Arc::new(AtomicUsize::new(0));
560        let fallback_calls = Arc::new(AtomicUsize::new(0));
561
562        let provider = ReliableProvider::new(
563            vec![
564                (
565                    "primary".into(),
566                    Box::new(MockProvider {
567                        calls: Arc::clone(&primary_calls),
568                        fail_until_attempt: usize::MAX,
569                        response: "never",
570                        error: "primary down",
571                    }),
572                ),
573                (
574                    "fallback".into(),
575                    Box::new(MockProvider {
576                        calls: Arc::clone(&fallback_calls),
577                        fail_until_attempt: 0,
578                        response: "from fallback",
579                        error: "fallback down",
580                    }),
581                ),
582            ],
583            1,
584            1,
585        );
586
587        let result = one_shot(&provider, None, "hello", "test", 0.0)
588            .await
589            .unwrap();
590        assert_eq!(result, "from fallback");
591        assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
592        assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
593    }
594
595    #[tokio::test]
596    async fn returns_aggregated_error_when_all_providers_fail() {
597        let provider = ReliableProvider::new(
598            vec![
599                (
600                    "p1".into(),
601                    Box::new(MockProvider {
602                        calls: Arc::new(AtomicUsize::new(0)),
603                        fail_until_attempt: usize::MAX,
604                        response: "never",
605                        error: "p1 error",
606                    }),
607                ),
608                (
609                    "p2".into(),
610                    Box::new(MockProvider {
611                        calls: Arc::new(AtomicUsize::new(0)),
612                        fail_until_attempt: usize::MAX,
613                        response: "never",
614                        error: "p2 error",
615                    }),
616                ),
617            ],
618            0,
619            1,
620        );
621
622        let err = one_shot(&provider, None, "hello", "test", 0.0)
623            .await
624            .expect_err("all providers should fail");
625        let msg = err.to_string();
626        assert!(msg.contains("All providers/models failed"));
627        assert!(msg.contains("p1"));
628        assert!(msg.contains("p2"));
629    }
630
631    #[test]
632    fn non_retryable_detects_common_patterns() {
633        assert!(is_non_retryable(&anyhow::anyhow!("400 Bad Request")));
634        assert!(is_non_retryable(&anyhow::anyhow!("401 Unauthorized")));
635        assert!(is_non_retryable(&anyhow::anyhow!("403 Forbidden")));
636        assert!(is_non_retryable(&anyhow::anyhow!("404 Not Found")));
637        assert!(!is_non_retryable(&anyhow::anyhow!("429 Too Many Requests")));
638        assert!(!is_non_retryable(&anyhow::anyhow!("408 Request Timeout")));
639        assert!(!is_non_retryable(&anyhow::anyhow!(
640            "500 Internal Server Error"
641        )));
642        assert!(!is_non_retryable(&anyhow::anyhow!("502 Bad Gateway")));
643        assert!(!is_non_retryable(&anyhow::anyhow!("timeout")));
644        assert!(!is_non_retryable(&anyhow::anyhow!("connection reset")));
645    }
646
647    #[tokio::test]
648    async fn skips_retries_on_non_retryable_error() {
649        let primary_calls = Arc::new(AtomicUsize::new(0));
650        let fallback_calls = Arc::new(AtomicUsize::new(0));
651
652        let provider = ReliableProvider::new(
653            vec![
654                (
655                    "primary".into(),
656                    Box::new(MockProvider {
657                        calls: Arc::clone(&primary_calls),
658                        fail_until_attempt: usize::MAX,
659                        response: "never",
660                        error: "401 Unauthorized",
661                    }),
662                ),
663                (
664                    "fallback".into(),
665                    Box::new(MockProvider {
666                        calls: Arc::clone(&fallback_calls),
667                        fail_until_attempt: 0,
668                        response: "from fallback",
669                        error: "fallback err",
670                    }),
671                ),
672            ],
673            3,
674            1,
675        );
676
677        let result = one_shot(&provider, None, "hello", "test", 0.0)
678            .await
679            .unwrap();
680        assert_eq!(result, "from fallback");
681        // Primary should have been called only once (no retries)
682        assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
683        assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
684    }
685
686    #[tokio::test]
687    async fn chat_retries_then_recovers() {
688        let calls = Arc::new(AtomicUsize::new(0));
689        let provider = ReliableProvider::new(
690            vec![(
691                "primary".into(),
692                Box::new(MockProvider {
693                    calls: Arc::clone(&calls),
694                    fail_until_attempt: 1,
695                    response: "history ok",
696                    error: "temporary",
697                }),
698            )],
699            2,
700            1,
701        );
702
703        let messages = vec![ChatMessage::system("system"), ChatMessage::user("hello")];
704        let request = ChatRequest {
705            messages: &messages,
706            tools: None,
707            native_tools: None,
708        };
709        let result = provider.chat(request, "test", 0.0).await.unwrap();
710        assert_eq!(result.text_or_empty(), "history ok");
711        assert_eq!(calls.load(Ordering::SeqCst), 2);
712    }
713
714    #[tokio::test]
715    async fn chat_stream_forwards_to_wrapped_provider_streaming_impl() {
716        let chat_calls = Arc::new(AtomicUsize::new(0));
717        let stream_calls = Arc::new(AtomicUsize::new(0));
718        let provider = ReliableProvider::new(
719            vec![(
720                "xai".into(),
721                Box::new(StreamingMockProvider {
722                    chat_calls: Arc::clone(&chat_calls),
723                    stream_calls: Arc::clone(&stream_calls),
724                }) as Box<dyn ModelProvider>,
725            )],
726            0,
727            1,
728        );
729
730        let messages = vec![ChatMessage::user("hello")];
731        let request = ChatRequest {
732            messages: &messages,
733            tools: None,
734            native_tools: None,
735        };
736        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
737
738        let result = provider
739            .chat_stream(request, "grok-4.3", 0.0, tx)
740            .await
741            .unwrap();
742
743        assert_eq!(result.text_or_empty(), "streaming");
744        assert_eq!(chat_calls.load(Ordering::SeqCst), 0);
745        assert_eq!(stream_calls.load(Ordering::SeqCst), 1);
746
747        let event = rx.recv().await.expect("provider stream event");
748        match event {
749            ProviderStreamEvent::ProviderToolStarted(trace) => {
750                assert_eq!(trace.name, "web_search");
751                assert_eq!(trace.provider, "xai");
752            }
753            other => panic!("unexpected provider stream event: {other:?}"),
754        }
755    }
756
757    #[tokio::test]
758    async fn chat_falls_back() {
759        let primary_calls = Arc::new(AtomicUsize::new(0));
760        let fallback_calls = Arc::new(AtomicUsize::new(0));
761
762        let provider = ReliableProvider::new(
763            vec![
764                (
765                    "primary".into(),
766                    Box::new(MockProvider {
767                        calls: Arc::clone(&primary_calls),
768                        fail_until_attempt: usize::MAX,
769                        response: "never",
770                        error: "primary down",
771                    }),
772                ),
773                (
774                    "fallback".into(),
775                    Box::new(MockProvider {
776                        calls: Arc::clone(&fallback_calls),
777                        fail_until_attempt: 0,
778                        response: "fallback ok",
779                        error: "fallback err",
780                    }),
781                ),
782            ],
783            1,
784            1,
785        );
786
787        let messages = vec![ChatMessage::user("hello")];
788        let request = ChatRequest {
789            messages: &messages,
790            tools: None,
791            native_tools: None,
792        };
793        let result = provider.chat(request, "test", 0.0).await.unwrap();
794        assert_eq!(result.text_or_empty(), "fallback ok");
795        assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
796        assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
797    }
798
799    // ── New tests: model failover ──
800
801    #[tokio::test]
802    async fn model_failover_tries_fallback_model() {
803        let calls = Arc::new(AtomicUsize::new(0));
804        let mock = Arc::new(ModelAwareMock {
805            calls: Arc::clone(&calls),
806            models_seen: std::sync::Mutex::new(Vec::new()),
807            fail_models: vec!["claude-opus"],
808            response: "ok from sonnet",
809        });
810
811        let mut fallbacks = HashMap::new();
812        fallbacks.insert("claude-opus".to_string(), vec!["claude-sonnet".to_string()]);
813
814        let provider = ReliableProvider::new(
815            vec![(
816                "anthropic".into(),
817                Box::new(mock.clone()) as Box<dyn ModelProvider>,
818            )],
819            0, // no retries — force immediate model failover
820            1,
821        )
822        .with_model_fallbacks(fallbacks);
823
824        let result = one_shot(&provider, None, "hello", "claude-opus", 0.0)
825            .await
826            .unwrap();
827        assert_eq!(result, "ok from sonnet");
828
829        let seen = mock.models_seen.lock().unwrap();
830        assert_eq!(seen.len(), 2);
831        assert_eq!(seen[0], "claude-opus");
832        assert_eq!(seen[1], "claude-sonnet");
833    }
834
835    #[tokio::test]
836    async fn model_failover_all_models_fail() {
837        let calls = Arc::new(AtomicUsize::new(0));
838        let mock = Arc::new(ModelAwareMock {
839            calls: Arc::clone(&calls),
840            models_seen: std::sync::Mutex::new(Vec::new()),
841            fail_models: vec!["model-a", "model-b", "model-c"],
842            response: "never",
843        });
844
845        let mut fallbacks = HashMap::new();
846        fallbacks.insert(
847            "model-a".to_string(),
848            vec!["model-b".to_string(), "model-c".to_string()],
849        );
850
851        let provider = ReliableProvider::new(
852            vec![(
853                "p1".into(),
854                Box::new(mock.clone()) as Box<dyn ModelProvider>,
855            )],
856            0,
857            1,
858        )
859        .with_model_fallbacks(fallbacks);
860
861        let err = one_shot(&provider, None, "hello", "model-a", 0.0)
862            .await
863            .expect_err("all models should fail");
864        assert!(err.to_string().contains("All providers/models failed"));
865
866        let seen = mock.models_seen.lock().unwrap();
867        assert_eq!(seen.len(), 3);
868    }
869
870    #[tokio::test]
871    async fn no_model_fallbacks_behaves_like_before() {
872        let calls = Arc::new(AtomicUsize::new(0));
873        let provider = ReliableProvider::new(
874            vec![(
875                "primary".into(),
876                Box::new(MockProvider {
877                    calls: Arc::clone(&calls),
878                    fail_until_attempt: 0,
879                    response: "ok",
880                    error: "boom",
881                }),
882            )],
883            2,
884            1,
885        );
886        // No model_fallbacks set — should work exactly as before
887        let result = one_shot(&provider, None, "hello", "test", 0.0)
888            .await
889            .unwrap();
890        assert_eq!(result, "ok");
891        assert_eq!(calls.load(Ordering::SeqCst), 1);
892    }
893
894    // ── New tests: auth rotation ──
895
896    #[tokio::test]
897    async fn auth_rotation_cycles_keys() {
898        let provider = ReliableProvider::new(
899            vec![(
900                "p".into(),
901                Box::new(MockProvider {
902                    calls: Arc::new(AtomicUsize::new(0)),
903                    fail_until_attempt: 0,
904                    response: "ok",
905                    error: "",
906                }),
907            )],
908            0,
909            1,
910        )
911        .with_api_keys(vec!["key-a".into(), "key-b".into(), "key-c".into()]);
912
913        // Rotate 5 times, verify round-robin
914        let keys: Vec<&str> = (0..5).map(|_| provider.rotate_key().unwrap()).collect();
915        assert_eq!(keys, vec!["key-a", "key-b", "key-c", "key-a", "key-b"]);
916    }
917
918    #[tokio::test]
919    async fn auth_rotation_returns_none_when_empty() {
920        let provider = ReliableProvider::new(vec![], 0, 1);
921        assert!(provider.rotate_key().is_none());
922    }
923
924    // ── New tests: Retry-After parsing ──
925
926    #[test]
927    fn parse_retry_after_integer() {
928        let err = anyhow::anyhow!("429 Too Many Requests, Retry-After: 5");
929        assert_eq!(parse_retry_after_ms(&err), Some(5000));
930    }
931
932    #[test]
933    fn parse_retry_after_float() {
934        let err = anyhow::anyhow!("Rate limited. retry_after: 2.5 seconds");
935        assert_eq!(parse_retry_after_ms(&err), Some(2500));
936    }
937
938    #[test]
939    fn parse_retry_after_missing() {
940        let err = anyhow::anyhow!("500 Internal Server Error");
941        assert_eq!(parse_retry_after_ms(&err), None);
942    }
943
944    #[test]
945    fn rate_limited_detection() {
946        assert!(is_rate_limited(&anyhow::anyhow!("429 Too Many Requests")));
947        assert!(is_rate_limited(&anyhow::anyhow!(
948            "HTTP 429 rate limit exceeded"
949        )));
950        assert!(!is_rate_limited(&anyhow::anyhow!("401 Unauthorized")));
951        assert!(!is_rate_limited(&anyhow::anyhow!(
952            "500 Internal Server Error"
953        )));
954    }
955
956    #[test]
957    fn compute_backoff_uses_retry_after() {
958        let provider = ReliableProvider::new(vec![], 0, 500);
959        let err = anyhow::anyhow!("429 Retry-After: 3");
960        assert_eq!(provider.compute_backoff(500, &err), 3000);
961    }
962
963    #[test]
964    fn compute_backoff_caps_at_30s() {
965        let provider = ReliableProvider::new(vec![], 0, 500);
966        let err = anyhow::anyhow!("429 Retry-After: 120");
967        assert_eq!(provider.compute_backoff(500, &err), 30_000);
968    }
969
970    #[test]
971    fn compute_backoff_falls_back_to_base() {
972        let provider = ReliableProvider::new(vec![], 0, 500);
973        let err = anyhow::anyhow!("500 Server Error");
974        assert_eq!(provider.compute_backoff(500, &err), 500);
975    }
976
977    // ── Arc<ModelAwareMock> Provider impl for test ──
978
979    #[async_trait]
980    impl ModelProvider for Arc<ModelAwareMock> {
981        async fn chat(
982            &self,
983            request: ChatRequest<'_>,
984            model: &str,
985            temperature: f64,
986        ) -> anyhow::Result<ChatResponse> {
987            self.as_ref().chat(request, model, temperature).await
988        }
989    }
990}