Skip to main content

construct/providers/
reliable.rs

1use super::Provider;
2use super::traits::{
3    ChatMessage, ChatRequest, ChatResponse, StreamChunk, StreamEvent, StreamOptions, StreamResult,
4};
5use async_trait::async_trait;
6use futures_util::{StreamExt, stream};
7use std::cell::RefCell;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::time::Duration;
11
12// ── Provider Fallback Notification ──────────────────────────────────────
13// When ReliableProvider uses a fallback (different provider or model than
14// requested), it records the details here so channel code can notify the user.
15// Uses tokio::task_local to avoid cross-request leakage between concurrent
16// users (the old global static had a race window).
17
18/// Info about a provider fallback that occurred during a request.
19#[derive(Debug, Clone)]
20pub struct ProviderFallbackInfo {
21    /// Provider that was originally requested.
22    pub requested_provider: String,
23    /// Model that was originally requested.
24    pub requested_model: String,
25    /// Provider that actually served the request.
26    pub actual_provider: String,
27    /// Model that actually served the request.
28    pub actual_model: String,
29}
30
31tokio::task_local! {
32    static PROVIDER_FALLBACK: RefCell<Option<ProviderFallbackInfo>>;
33}
34
35/// Take (consume) the last provider fallback info, if any.
36/// Must be called within a `scope_provider_fallback` scope.
37pub fn take_last_provider_fallback() -> Option<ProviderFallbackInfo> {
38    PROVIDER_FALLBACK
39        .try_with(|cell| cell.borrow_mut().take())
40        .ok()
41        .flatten()
42}
43
44/// Run the given future within a provider-fallback scope.
45/// Both `record_provider_fallback` (inside ReliableProvider) and
46/// `take_last_provider_fallback` (post-loop channel code) must execute
47/// within this scope for the data to be visible.
48pub async fn scope_provider_fallback<F: std::future::Future>(future: F) -> F::Output {
49    PROVIDER_FALLBACK.scope(RefCell::new(None), future).await
50}
51
52/// Record a provider fallback event.
53fn record_provider_fallback(
54    requested_provider: &str,
55    requested_model: &str,
56    actual_provider: &str,
57    actual_model: &str,
58) {
59    let _ = PROVIDER_FALLBACK.try_with(|cell| {
60        *cell.borrow_mut() = Some(ProviderFallbackInfo {
61            requested_provider: requested_provider.to_string(),
62            requested_model: requested_model.to_string(),
63            actual_provider: actual_provider.to_string(),
64            actual_model: actual_model.to_string(),
65        });
66    });
67}
68
69// ── Error Classification ─────────────────────────────────────────────────
70// Errors are split into retryable (transient server/network failures) and
71// non-retryable (permanent client errors). This distinction drives whether
72// the retry loop continues, falls back to the next provider, or aborts
73// immediately — avoiding wasted latency on errors that cannot self-heal.
74
75/// Check if an error is non-retryable (client errors that won't resolve with retries).
76pub fn is_non_retryable(err: &anyhow::Error) -> bool {
77    // Context window errors are NOT non-retryable — they can be recovered
78    // by truncating conversation history, so let the retry loop handle them.
79    if is_context_window_exceeded(err) {
80        return false;
81    }
82
83    // Tool schema validation errors are NOT non-retryable — the provider's
84    // built-in fallback in compatible.rs can recover by switching to
85    // prompt-guided tool instructions.
86    if is_tool_schema_error(err) {
87        return false;
88    }
89
90    // 4xx errors are generally non-retryable (bad request, auth failure, etc.),
91    // except 429 (rate-limit — transient) and 408 (timeout — worth retrying).
92    if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>() {
93        if let Some(status) = reqwest_err.status() {
94            let code = status.as_u16();
95            return status.is_client_error() && code != 429 && code != 408;
96        }
97    }
98    // Fallback: parse status codes from stringified errors (some providers
99    // embed codes in error messages rather than returning typed HTTP errors).
100    let msg = err.to_string();
101    for word in msg.split(|c: char| !c.is_ascii_digit()) {
102        if let Ok(code) = word.parse::<u16>() {
103            if (400..500).contains(&code) {
104                return code != 429 && code != 408;
105            }
106        }
107    }
108
109    // Heuristic: detect auth/model failures by keyword when no HTTP status
110    // is available (e.g. gRPC or custom transport errors).
111    let msg_lower = msg.to_lowercase();
112    let auth_failure_hints = [
113        "invalid api key",
114        "incorrect api key",
115        "missing api key",
116        "api key not set",
117        "authentication failed",
118        "auth failed",
119        "unauthorized",
120        "forbidden",
121        "permission denied",
122        "access denied",
123        "invalid token",
124    ];
125
126    if auth_failure_hints
127        .iter()
128        .any(|hint| msg_lower.contains(hint))
129    {
130        return true;
131    }
132
133    msg_lower.contains("model")
134        && (msg_lower.contains("not found")
135            || msg_lower.contains("unknown")
136            || msg_lower.contains("unsupported")
137            || msg_lower.contains("does not exist")
138            || msg_lower.contains("invalid"))
139}
140
141/// Check if an error is a tool schema validation failure (e.g. Groq returning
142/// "tool call validation failed: attempted to call tool '...' which was not in request").
143/// These errors should NOT be classified as non-retryable because the provider's
144/// built-in fallback logic (`compatible.rs::is_native_tool_schema_unsupported`)
145/// can recover by switching to prompt-guided tool instructions.
146pub fn is_tool_schema_error(err: &anyhow::Error) -> bool {
147    let lower = err.to_string().to_lowercase();
148    let hints = [
149        "tool call validation failed",
150        "was not in request",
151        "not found in tool list",
152        "invalid_tool_call",
153    ];
154    hints.iter().any(|hint| lower.contains(hint))
155}
156
157pub(crate) fn is_context_window_exceeded(err: &anyhow::Error) -> bool {
158    let lower = err.to_string().to_lowercase();
159    let hints = [
160        "exceeds the context window",
161        "exceeds the available context size",
162        "context window of this model",
163        "maximum context length",
164        "context length exceeded",
165        "too many tokens",
166        "token limit exceeded",
167        "prompt is too long",
168        "input is too long",
169        "prompt exceeds max length",
170    ];
171
172    hints.iter().any(|hint| lower.contains(hint))
173}
174
175/// Check if an error is a rate-limit (429) error.
176fn is_rate_limited(err: &anyhow::Error) -> bool {
177    if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>() {
178        if let Some(status) = reqwest_err.status() {
179            return status.as_u16() == 429;
180        }
181    }
182    let msg = err.to_string();
183    msg.contains("429")
184        && (msg.contains("Too Many") || msg.contains("rate") || msg.contains("limit"))
185}
186
187/// Check if a 429 is a business/quota-plan error that retries cannot fix.
188///
189/// Examples:
190/// - plan does not include requested model
191/// - insufficient balance / package not active
192/// - known provider business codes (e.g. Z.AI: 1311, 1113)
193fn is_non_retryable_rate_limit(err: &anyhow::Error) -> bool {
194    if !is_rate_limited(err) {
195        return false;
196    }
197
198    let msg = err.to_string();
199    let lower = msg.to_lowercase();
200
201    let business_hints = [
202        "plan does not include",
203        "doesn't include",
204        "not include",
205        "insufficient balance",
206        "insufficient_balance",
207        "insufficient quota",
208        "insufficient_quota",
209        "quota exhausted",
210        "out of credits",
211        "no available package",
212        "package not active",
213        "purchase package",
214        "model not available for your plan",
215    ];
216
217    if business_hints.iter().any(|hint| lower.contains(hint)) {
218        return true;
219    }
220
221    // Known provider business codes observed for 429 where retry is futile.
222    for token in lower.split(|c: char| !c.is_ascii_digit()) {
223        if let Ok(code) = token.parse::<u16>() {
224            if matches!(code, 1113 | 1311) {
225                return true;
226            }
227        }
228    }
229
230    false
231}
232
233/// Try to extract a Retry-After value (in milliseconds) from an error message.
234/// Looks for patterns like `Retry-After: 5` or `retry_after: 2.5` in the error string.
235fn parse_retry_after_ms(err: &anyhow::Error) -> Option<u64> {
236    let msg = err.to_string();
237    let lower = msg.to_lowercase();
238
239    // Look for "retry-after: <number>" or "retry_after: <number>"
240    for prefix in &[
241        "retry-after:",
242        "retry_after:",
243        "retry-after ",
244        "retry_after ",
245    ] {
246        if let Some(pos) = lower.find(prefix) {
247            let after = &msg[pos + prefix.len()..];
248            let num_str: String = after
249                .trim()
250                .chars()
251                .take_while(|c| c.is_ascii_digit() || *c == '.')
252                .collect();
253            if let Ok(secs) = num_str.parse::<f64>() {
254                if secs.is_finite() && secs >= 0.0 {
255                    let millis = Duration::from_secs_f64(secs).as_millis();
256                    if let Ok(value) = u64::try_from(millis) {
257                        return Some(value);
258                    }
259                }
260            }
261        }
262    }
263    None
264}
265
266fn failure_reason(rate_limited: bool, non_retryable: bool) -> &'static str {
267    if rate_limited && non_retryable {
268        "rate_limited_non_retryable"
269    } else if rate_limited {
270        "rate_limited"
271    } else if non_retryable {
272        "non_retryable"
273    } else {
274        "retryable"
275    }
276}
277
278fn compact_error_detail(err: &anyhow::Error) -> String {
279    // Use {:#} to include the full error chain (root cause), not just the top-level message.
280    super::sanitize_api_error(&format!("{:#}", err))
281        .split_whitespace()
282        .collect::<Vec<_>>()
283        .join(" ")
284}
285
286/// Truncate conversation history by dropping the oldest non-system messages.
287/// Returns the number of messages dropped. Keeps at least the system message
288/// (if any) and the most recent user message.
289fn truncate_for_context(messages: &mut Vec<ChatMessage>) -> usize {
290    // Find all non-system message indices
291    let non_system: Vec<usize> = messages
292        .iter()
293        .enumerate()
294        .filter(|(_, m)| m.role != "system")
295        .map(|(i, _)| i)
296        .collect();
297
298    // Keep at least the last non-system message (most recent user turn)
299    if non_system.len() <= 1 {
300        return 0;
301    }
302
303    // Drop the oldest half of non-system messages
304    let drop_count = non_system.len() / 2;
305    let indices_to_remove: Vec<usize> = non_system[..drop_count].to_vec();
306
307    // Remove in reverse order to preserve indices
308    for &idx in indices_to_remove.iter().rev() {
309        messages.remove(idx);
310    }
311
312    drop_count
313}
314
315fn push_failure(
316    failures: &mut Vec<String>,
317    provider_name: &str,
318    model: &str,
319    attempt: u32,
320    max_attempts: u32,
321    reason: &str,
322    error_detail: &str,
323) {
324    failures.push(format!(
325        "provider={provider_name} model={model} attempt {attempt}/{max_attempts}: {reason}; error={error_detail}"
326    ));
327}
328
329// ── Resilient Provider Wrapper ────────────────────────────────────────────
330// Three-level failover strategy: model chain → provider chain → retry loop.
331//   Outer loop:  iterate model fallback chain (original model first, then
332//                configured alternatives).
333//   Middle loop: iterate registered providers in priority order.
334//   Inner loop:  retry the same (provider, model) pair with exponential
335//                backoff, rotating API keys on rate-limit errors.
336// Loop invariant: `failures` accumulates every failed attempt so the final
337// error message gives operators a complete diagnostic trail.
338
339/// Provider wrapper with retry, fallback, auth rotation, and model failover.
340pub struct ReliableProvider {
341    providers: Vec<(String, Box<dyn Provider>)>,
342    max_retries: u32,
343    base_backoff_ms: u64,
344    /// Extra API keys for rotation (index tracks round-robin position).
345    api_keys: Vec<String>,
346    key_index: AtomicUsize,
347    /// Per-model fallback chains: model_name → [fallback_model_1, fallback_model_2, ...]
348    model_fallbacks: HashMap<String, Vec<String>>,
349}
350
351impl ReliableProvider {
352    pub fn new(
353        providers: Vec<(String, Box<dyn Provider>)>,
354        max_retries: u32,
355        base_backoff_ms: u64,
356    ) -> Self {
357        Self {
358            providers,
359            max_retries,
360            base_backoff_ms: base_backoff_ms.max(50),
361            api_keys: Vec::new(),
362            key_index: AtomicUsize::new(0),
363            model_fallbacks: HashMap::new(),
364        }
365    }
366
367    /// Set additional API keys for round-robin rotation on rate-limit errors.
368    pub fn with_api_keys(mut self, keys: Vec<String>) -> Self {
369        self.api_keys = keys;
370        self
371    }
372
373    /// Set per-model fallback chains.
374    pub fn with_model_fallbacks(mut self, fallbacks: HashMap<String, Vec<String>>) -> Self {
375        self.model_fallbacks = fallbacks;
376        self
377    }
378
379    /// Build the list of models to try: [original, fallback1, fallback2, ...]
380    fn model_chain<'a>(&'a self, model: &'a str) -> Vec<&'a str> {
381        let mut chain = vec![model];
382        if let Some(fallbacks) = self.model_fallbacks.get(model) {
383            chain.extend(fallbacks.iter().map(|s| s.as_str()));
384        }
385        chain
386    }
387
388    /// Advance to the next API key and return it, or None if no extra keys configured.
389    fn rotate_key(&self) -> Option<&str> {
390        if self.api_keys.is_empty() {
391            return None;
392        }
393        let idx = self.key_index.fetch_add(1, Ordering::Relaxed) % self.api_keys.len();
394        Some(&self.api_keys[idx])
395    }
396
397    /// Compute backoff duration, respecting Retry-After if present.
398    fn compute_backoff(&self, base: u64, err: &anyhow::Error) -> u64 {
399        if let Some(retry_after) = parse_retry_after_ms(err) {
400            // Use Retry-After but cap at 30s to avoid indefinite waits
401            retry_after.min(30_000).max(base)
402        } else {
403            base
404        }
405    }
406}
407
408#[async_trait]
409impl Provider for ReliableProvider {
410    async fn warmup(&self) -> anyhow::Result<()> {
411        for (name, provider) in &self.providers {
412            tracing::info!(provider = name, "Warming up provider connection pool");
413            if provider.warmup().await.is_err() {
414                tracing::warn!(provider = name, "Warmup failed (non-fatal)");
415            }
416        }
417        Ok(())
418    }
419
420    async fn chat_with_system(
421        &self,
422        system_prompt: Option<&str>,
423        message: &str,
424        model: &str,
425        temperature: f64,
426    ) -> anyhow::Result<String> {
427        let models = self.model_chain(model);
428        let mut failures = Vec::new();
429
430        // Outer: model fallback chain. Middle: provider priority. Inner: retries.
431        // Each iteration: attempt one (provider, model) call. On success, return
432        // immediately. On non-retryable error, break to next provider. On
433        // retryable error, sleep with exponential backoff and retry.
434        for current_model in &models {
435            for (provider_name, provider) in &self.providers {
436                let mut backoff_ms = self.base_backoff_ms;
437
438                for attempt in 0..=self.max_retries {
439                    match provider
440                        .chat_with_system(system_prompt, message, current_model, temperature)
441                        .await
442                    {
443                        Ok(resp) => {
444                            if attempt > 0
445                                || *current_model != model
446                                || self.providers.first().map(|(n, _)| n.as_str())
447                                    != Some(provider_name)
448                            {
449                                tracing::info!(
450                                    provider = provider_name,
451                                    model = *current_model,
452                                    attempt,
453                                    original_model = model,
454                                    "Provider recovered (failover/retry)"
455                                );
456                                let primary = self
457                                    .providers
458                                    .first()
459                                    .map(|(n, _)| n.as_str())
460                                    .unwrap_or("");
461                                record_provider_fallback(
462                                    primary,
463                                    model,
464                                    provider_name,
465                                    current_model,
466                                );
467                            }
468                            return Ok(resp);
469                        }
470                        Err(e) => {
471                            // Context window exceeded: no history to truncate
472                            // in chat_with_system, bail immediately.
473                            if is_context_window_exceeded(&e) {
474                                let error_detail = compact_error_detail(&e);
475                                push_failure(
476                                    &mut failures,
477                                    provider_name,
478                                    current_model,
479                                    attempt + 1,
480                                    self.max_retries + 1,
481                                    "non_retryable",
482                                    &error_detail,
483                                );
484                                anyhow::bail!(
485                                    "Request exceeds model context window. Attempts:\n{}",
486                                    failures.join("\n")
487                                );
488                            }
489
490                            let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
491                            let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
492                            let rate_limited = is_rate_limited(&e);
493                            let failure_reason = failure_reason(rate_limited, non_retryable);
494                            let error_detail = compact_error_detail(&e);
495
496                            push_failure(
497                                &mut failures,
498                                provider_name,
499                                current_model,
500                                attempt + 1,
501                                self.max_retries + 1,
502                                failure_reason,
503                                &error_detail,
504                            );
505
506                            // Rate-limit with rotatable keys: cycle to the next API key
507                            // so the retry hits a different quota bucket.
508                            if rate_limited && !non_retryable_rate_limit {
509                                if let Some(new_key) = self.rotate_key() {
510                                    tracing::warn!(
511                                        provider = provider_name,
512                                        error = %error_detail,
513                                        "Rate limited; key rotation selected key ending ...{} \
514                                         but cannot apply (Provider trait has no set_api_key). \
515                                         Retrying with original key.",
516                                        &new_key[new_key.len().saturating_sub(4)..]
517                                    );
518                                }
519                            }
520
521                            if non_retryable {
522                                tracing::warn!(
523                                    provider = provider_name,
524                                    model = *current_model,
525                                    error = %error_detail,
526                                    "Non-retryable error, moving on"
527                                );
528                                break;
529                            }
530
531                            if attempt < self.max_retries {
532                                let wait = self.compute_backoff(backoff_ms, &e);
533                                tracing::warn!(
534                                    provider = provider_name,
535                                    model = *current_model,
536                                    attempt = attempt + 1,
537                                    backoff_ms = wait,
538                                    reason = failure_reason,
539                                    error = %error_detail,
540                                    "Provider call failed, retrying"
541                                );
542                                tokio::time::sleep(Duration::from_millis(wait)).await;
543                                backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
544                            }
545                        }
546                    }
547                }
548
549                tracing::warn!(
550                    provider = provider_name,
551                    model = *current_model,
552                    "Exhausted retries, trying next provider/model"
553                );
554            }
555
556            if *current_model != model {
557                tracing::warn!(
558                    original_model = model,
559                    fallback_model = *current_model,
560                    "Model fallback exhausted all providers, trying next fallback model"
561                );
562            }
563        }
564
565        anyhow::bail!(
566            "All providers/models failed. Attempts:\n{}",
567            failures.join("\n")
568        )
569    }
570
571    async fn chat_with_history(
572        &self,
573        messages: &[ChatMessage],
574        model: &str,
575        temperature: f64,
576    ) -> anyhow::Result<String> {
577        let models = self.model_chain(model);
578        let mut failures = Vec::new();
579        let mut effective_messages = messages.to_vec();
580        let mut context_truncated = false;
581
582        for current_model in &models {
583            for (provider_name, provider) in &self.providers {
584                let mut backoff_ms = self.base_backoff_ms;
585
586                for attempt in 0..=self.max_retries {
587                    match provider
588                        .chat_with_history(&effective_messages, current_model, temperature)
589                        .await
590                    {
591                        Ok(resp) => {
592                            if attempt > 0
593                                || *current_model != model
594                                || context_truncated
595                                || self.providers.first().map(|(n, _)| n.as_str())
596                                    != Some(provider_name)
597                            {
598                                tracing::info!(
599                                    provider = provider_name,
600                                    model = *current_model,
601                                    attempt,
602                                    original_model = model,
603                                    context_truncated,
604                                    "Provider recovered (failover/retry)"
605                                );
606                                let primary = self
607                                    .providers
608                                    .first()
609                                    .map(|(n, _)| n.as_str())
610                                    .unwrap_or("");
611                                record_provider_fallback(
612                                    primary,
613                                    model,
614                                    provider_name,
615                                    current_model,
616                                );
617                            }
618                            return Ok(resp);
619                        }
620                        Err(e) => {
621                            // Context window exceeded: truncate history and retry
622                            if is_context_window_exceeded(&e) && !context_truncated {
623                                let dropped = truncate_for_context(&mut effective_messages);
624                                if dropped > 0 {
625                                    context_truncated = true;
626                                    tracing::warn!(
627                                        provider = provider_name,
628                                        model = *current_model,
629                                        dropped,
630                                        remaining = effective_messages.len(),
631                                        "Context window exceeded; truncated history and retrying"
632                                    );
633                                    continue; // Retry with truncated messages (counts as an attempt)
634                                }
635                                // Nothing to truncate (system prompt alone exceeds
636                                // the model's context window) — bail immediately
637                                // instead of wasting retry attempts.
638                                let error_detail = compact_error_detail(&e);
639                                push_failure(
640                                    &mut failures,
641                                    provider_name,
642                                    current_model,
643                                    attempt + 1,
644                                    self.max_retries + 1,
645                                    "non_retryable",
646                                    &error_detail,
647                                );
648                                anyhow::bail!(
649                                    "Request exceeds model context window and cannot be reduced further. \
650                                     Try using a model with a larger context window, reducing the number \
651                                     of tools/skills, or enabling compact_context in config. Attempts:\n{}",
652                                    failures.join("\n")
653                                );
654                            }
655
656                            let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
657                            let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
658                            let rate_limited = is_rate_limited(&e);
659                            let failure_reason = failure_reason(rate_limited, non_retryable);
660                            let error_detail = compact_error_detail(&e);
661
662                            push_failure(
663                                &mut failures,
664                                provider_name,
665                                current_model,
666                                attempt + 1,
667                                self.max_retries + 1,
668                                failure_reason,
669                                &error_detail,
670                            );
671
672                            if rate_limited && !non_retryable_rate_limit {
673                                if let Some(new_key) = self.rotate_key() {
674                                    tracing::warn!(
675                                        provider = provider_name,
676                                        error = %error_detail,
677                                        "Rate limited; key rotation selected key ending ...{} \
678                                         but cannot apply (Provider trait has no set_api_key). \
679                                         Retrying with original key.",
680                                        &new_key[new_key.len().saturating_sub(4)..]
681                                    );
682                                }
683                            }
684
685                            if non_retryable {
686                                tracing::warn!(
687                                    provider = provider_name,
688                                    model = *current_model,
689                                    error = %error_detail,
690                                    "Non-retryable error, moving on"
691                                );
692                                break;
693                            }
694
695                            if attempt < self.max_retries {
696                                let wait = self.compute_backoff(backoff_ms, &e);
697                                tracing::warn!(
698                                    provider = provider_name,
699                                    model = *current_model,
700                                    attempt = attempt + 1,
701                                    backoff_ms = wait,
702                                    reason = failure_reason,
703                                    error = %error_detail,
704                                    "Provider call failed, retrying"
705                                );
706                                tokio::time::sleep(Duration::from_millis(wait)).await;
707                                backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
708                            }
709                        }
710                    }
711                }
712
713                tracing::warn!(
714                    provider = provider_name,
715                    model = *current_model,
716                    "Exhausted retries, trying next provider/model"
717                );
718            }
719        }
720
721        anyhow::bail!(
722            "All providers/models failed. Attempts:\n{}",
723            failures.join("\n")
724        )
725    }
726
727    fn supports_native_tools(&self) -> bool {
728        self.providers
729            .first()
730            .map(|(_, p)| p.supports_native_tools())
731            .unwrap_or(false)
732    }
733
734    fn supports_vision(&self) -> bool {
735        self.providers
736            .iter()
737            .any(|(_, provider)| provider.supports_vision())
738    }
739
740    async fn chat_with_tools(
741        &self,
742        messages: &[ChatMessage],
743        tools: &[serde_json::Value],
744        model: &str,
745        temperature: f64,
746    ) -> anyhow::Result<ChatResponse> {
747        let models = self.model_chain(model);
748        let mut failures = Vec::new();
749        let mut effective_messages = messages.to_vec();
750        let mut context_truncated = false;
751
752        for current_model in &models {
753            for (provider_name, provider) in &self.providers {
754                let mut backoff_ms = self.base_backoff_ms;
755
756                for attempt in 0..=self.max_retries {
757                    match provider
758                        .chat_with_tools(&effective_messages, tools, current_model, temperature)
759                        .await
760                    {
761                        Ok(resp) => {
762                            if attempt > 0
763                                || *current_model != model
764                                || context_truncated
765                                || self.providers.first().map(|(n, _)| n.as_str())
766                                    != Some(provider_name)
767                            {
768                                tracing::info!(
769                                    provider = provider_name,
770                                    model = *current_model,
771                                    attempt,
772                                    original_model = model,
773                                    context_truncated,
774                                    "Provider recovered (failover/retry)"
775                                );
776                                let primary = self
777                                    .providers
778                                    .first()
779                                    .map(|(n, _)| n.as_str())
780                                    .unwrap_or("");
781                                record_provider_fallback(
782                                    primary,
783                                    model,
784                                    provider_name,
785                                    current_model,
786                                );
787                            }
788                            return Ok(resp);
789                        }
790                        Err(e) => {
791                            // Context window exceeded: truncate history and retry
792                            if is_context_window_exceeded(&e) && !context_truncated {
793                                let dropped = truncate_for_context(&mut effective_messages);
794                                if dropped > 0 {
795                                    context_truncated = true;
796                                    tracing::warn!(
797                                        provider = provider_name,
798                                        model = *current_model,
799                                        dropped,
800                                        remaining = effective_messages.len(),
801                                        "Context window exceeded; truncated history and retrying"
802                                    );
803                                    continue; // Retry with truncated messages (counts as an attempt)
804                                }
805                                // Nothing to truncate (system prompt alone exceeds
806                                // the model's context window) — bail immediately
807                                // instead of wasting retry attempts.
808                                let error_detail = compact_error_detail(&e);
809                                push_failure(
810                                    &mut failures,
811                                    provider_name,
812                                    current_model,
813                                    attempt + 1,
814                                    self.max_retries + 1,
815                                    "non_retryable",
816                                    &error_detail,
817                                );
818                                anyhow::bail!(
819                                    "Request exceeds model context window and cannot be reduced further. \
820                                     Try using a model with a larger context window, reducing the number \
821                                     of tools/skills, or enabling compact_context in config. Attempts:\n{}",
822                                    failures.join("\n")
823                                );
824                            }
825
826                            let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
827                            let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
828                            let rate_limited = is_rate_limited(&e);
829                            let failure_reason = failure_reason(rate_limited, non_retryable);
830                            let error_detail = compact_error_detail(&e);
831
832                            push_failure(
833                                &mut failures,
834                                provider_name,
835                                current_model,
836                                attempt + 1,
837                                self.max_retries + 1,
838                                failure_reason,
839                                &error_detail,
840                            );
841
842                            if rate_limited && !non_retryable_rate_limit {
843                                if let Some(new_key) = self.rotate_key() {
844                                    tracing::warn!(
845                                        provider = provider_name,
846                                        error = %error_detail,
847                                        "Rate limited; key rotation selected key ending ...{} \
848                                         but cannot apply (Provider trait has no set_api_key). \
849                                         Retrying with original key.",
850                                        &new_key[new_key.len().saturating_sub(4)..]
851                                    );
852                                }
853                            }
854
855                            if non_retryable {
856                                tracing::warn!(
857                                    provider = provider_name,
858                                    model = *current_model,
859                                    error = %error_detail,
860                                    "Non-retryable error, moving on"
861                                );
862                                break;
863                            }
864
865                            if attempt < self.max_retries {
866                                let wait = self.compute_backoff(backoff_ms, &e);
867                                tracing::warn!(
868                                    provider = provider_name,
869                                    model = *current_model,
870                                    attempt = attempt + 1,
871                                    backoff_ms = wait,
872                                    reason = failure_reason,
873                                    error = %error_detail,
874                                    "Provider call failed, retrying"
875                                );
876                                tokio::time::sleep(Duration::from_millis(wait)).await;
877                                backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
878                            }
879                        }
880                    }
881                }
882
883                tracing::warn!(
884                    provider = provider_name,
885                    model = *current_model,
886                    "Exhausted retries, trying next provider/model"
887                );
888            }
889        }
890
891        anyhow::bail!(
892            "All providers/models failed. Attempts:\n{}",
893            failures.join("\n")
894        )
895    }
896
897    async fn chat(
898        &self,
899        request: ChatRequest<'_>,
900        model: &str,
901        temperature: f64,
902    ) -> anyhow::Result<ChatResponse> {
903        let models = self.model_chain(model);
904        let mut failures = Vec::new();
905        let mut effective_messages = request.messages.to_vec();
906        let mut context_truncated = false;
907
908        for current_model in &models {
909            for (provider_name, provider) in &self.providers {
910                let mut backoff_ms = self.base_backoff_ms;
911
912                for attempt in 0..=self.max_retries {
913                    let req = ChatRequest {
914                        messages: &effective_messages,
915                        tools: request.tools,
916                    };
917                    match provider.chat(req, current_model, temperature).await {
918                        Ok(resp) => {
919                            if attempt > 0
920                                || *current_model != model
921                                || context_truncated
922                                || self.providers.first().map(|(n, _)| n.as_str())
923                                    != Some(provider_name)
924                            {
925                                tracing::info!(
926                                    provider = provider_name,
927                                    model = *current_model,
928                                    attempt,
929                                    original_model = model,
930                                    context_truncated,
931                                    "Provider recovered (failover/retry)"
932                                );
933                                let primary = self
934                                    .providers
935                                    .first()
936                                    .map(|(n, _)| n.as_str())
937                                    .unwrap_or("");
938                                record_provider_fallback(
939                                    primary,
940                                    model,
941                                    provider_name,
942                                    current_model,
943                                );
944                            }
945                            return Ok(resp);
946                        }
947                        Err(e) => {
948                            // Context window exceeded: truncate history and retry
949                            if is_context_window_exceeded(&e) && !context_truncated {
950                                let dropped = truncate_for_context(&mut effective_messages);
951                                if dropped > 0 {
952                                    context_truncated = true;
953                                    tracing::warn!(
954                                        provider = provider_name,
955                                        model = *current_model,
956                                        dropped,
957                                        remaining = effective_messages.len(),
958                                        "Context window exceeded; truncated history and retrying"
959                                    );
960                                    continue; // Retry with truncated messages (counts as an attempt)
961                                }
962                                // Nothing to truncate (system prompt alone exceeds
963                                // the model's context window) — bail immediately
964                                // instead of wasting retry attempts.
965                                let error_detail = compact_error_detail(&e);
966                                push_failure(
967                                    &mut failures,
968                                    provider_name,
969                                    current_model,
970                                    attempt + 1,
971                                    self.max_retries + 1,
972                                    "non_retryable",
973                                    &error_detail,
974                                );
975                                anyhow::bail!(
976                                    "Request exceeds model context window and cannot be reduced further. \
977                                     Try using a model with a larger context window, reducing the number \
978                                     of tools/skills, or enabling compact_context in config. Attempts:\n{}",
979                                    failures.join("\n")
980                                );
981                            }
982
983                            let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
984                            let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
985                            let rate_limited = is_rate_limited(&e);
986                            let failure_reason = failure_reason(rate_limited, non_retryable);
987                            let error_detail = compact_error_detail(&e);
988
989                            push_failure(
990                                &mut failures,
991                                provider_name,
992                                current_model,
993                                attempt + 1,
994                                self.max_retries + 1,
995                                failure_reason,
996                                &error_detail,
997                            );
998
999                            if rate_limited && !non_retryable_rate_limit {
1000                                if let Some(new_key) = self.rotate_key() {
1001                                    tracing::warn!(
1002                                        provider = provider_name,
1003                                        error = %error_detail,
1004                                        "Rate limited; key rotation selected key ending ...{} \
1005                                         but cannot apply (Provider trait has no set_api_key). \
1006                                         Retrying with original key.",
1007                                        &new_key[new_key.len().saturating_sub(4)..]
1008                                    );
1009                                }
1010                            }
1011
1012                            if non_retryable {
1013                                tracing::warn!(
1014                                    provider = provider_name,
1015                                    model = *current_model,
1016                                    error = %error_detail,
1017                                    "Non-retryable error, moving on"
1018                                );
1019                                break;
1020                            }
1021
1022                            if attempt < self.max_retries {
1023                                let wait = self.compute_backoff(backoff_ms, &e);
1024                                tracing::warn!(
1025                                    provider = provider_name,
1026                                    model = *current_model,
1027                                    attempt = attempt + 1,
1028                                    backoff_ms = wait,
1029                                    reason = failure_reason,
1030                                    error = %error_detail,
1031                                    "Provider call failed, retrying"
1032                                );
1033                                tokio::time::sleep(Duration::from_millis(wait)).await;
1034                                backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
1035                            }
1036                        }
1037                    }
1038                }
1039
1040                tracing::warn!(
1041                    provider = provider_name,
1042                    model = *current_model,
1043                    "Exhausted retries, trying next provider/model"
1044                );
1045            }
1046
1047            if *current_model != model {
1048                tracing::warn!(
1049                    original_model = model,
1050                    fallback_model = *current_model,
1051                    "Model fallback exhausted all providers, trying next fallback model"
1052                );
1053            }
1054        }
1055
1056        anyhow::bail!(
1057            "All providers/models failed. Attempts:\n{}",
1058            failures.join("\n")
1059        )
1060    }
1061
1062    fn supports_streaming(&self) -> bool {
1063        self.providers.iter().any(|(_, p)| p.supports_streaming())
1064    }
1065
1066    fn supports_streaming_tool_events(&self) -> bool {
1067        self.providers
1068            .iter()
1069            .any(|(_, p)| p.supports_streaming_tool_events())
1070    }
1071
1072    fn stream_chat(
1073        &self,
1074        request: ChatRequest<'_>,
1075        model: &str,
1076        temperature: f64,
1077        options: StreamOptions,
1078    ) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
1079        let needs_tool_events = request.tools.is_some_and(|tools| !tools.is_empty());
1080
1081        for (provider_name, provider) in &self.providers {
1082            if !provider.supports_streaming() || !options.enabled {
1083                continue;
1084            }
1085
1086            if needs_tool_events && !provider.supports_streaming_tool_events() {
1087                continue;
1088            }
1089
1090            let provider_clone = provider_name.clone();
1091
1092            let current_model = self
1093                .model_chain(model)
1094                .first()
1095                .copied()
1096                .unwrap_or(model)
1097                .to_string();
1098
1099            let req = ChatRequest {
1100                messages: request.messages,
1101                tools: request.tools,
1102            };
1103            let stream = provider.stream_chat(req, &current_model, temperature, options);
1104            let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamEvent>>(100);
1105
1106            tokio::spawn(async move {
1107                let mut stream = stream;
1108                while let Some(event) = stream.next().await {
1109                    if let Err(ref e) = event {
1110                        tracing::warn!(
1111                            provider = provider_clone,
1112                            model = current_model,
1113                            "Streaming error: {e}"
1114                        );
1115                    }
1116                    if tx.send(event).await.is_err() {
1117                        break;
1118                    }
1119                }
1120            });
1121
1122            return stream::unfold(rx, |mut rx| async move {
1123                rx.recv().await.map(|event| (event, rx))
1124            })
1125            .boxed();
1126        }
1127
1128        let message = if needs_tool_events {
1129            "No provider supports streaming tool events".to_string()
1130        } else {
1131            "No provider supports streaming".to_string()
1132        };
1133        stream::once(async move { Err(super::traits::StreamError::Provider(message)) }).boxed()
1134    }
1135
1136    fn stream_chat_with_system(
1137        &self,
1138        system_prompt: Option<&str>,
1139        message: &str,
1140        model: &str,
1141        temperature: f64,
1142        options: StreamOptions,
1143    ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
1144        // Try each provider/model combination for streaming
1145        // For streaming, we use the first provider that supports it and has streaming enabled
1146        for (provider_name, provider) in &self.providers {
1147            if !provider.supports_streaming() || !options.enabled {
1148                continue;
1149            }
1150
1151            // Clone provider data for the stream
1152            let provider_clone = provider_name.clone();
1153
1154            // Try the first model in the chain for streaming
1155            let current_model = match self.model_chain(model).first() {
1156                Some(m) => (*m).to_string(),
1157                None => model.to_string(),
1158            };
1159
1160            // For streaming, we attempt once and propagate errors
1161            // The caller can retry the entire request if needed
1162            let stream = provider.stream_chat_with_system(
1163                system_prompt,
1164                message,
1165                &current_model,
1166                temperature,
1167                options,
1168            );
1169
1170            // Use a channel to bridge the stream with logging
1171            let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
1172
1173            tokio::spawn(async move {
1174                let mut stream = stream;
1175                while let Some(chunk) = stream.next().await {
1176                    if let Err(ref e) = chunk {
1177                        tracing::warn!(
1178                            provider = provider_clone,
1179                            model = current_model,
1180                            "Streaming error: {e}"
1181                        );
1182                    }
1183                    if tx.send(chunk).await.is_err() {
1184                        break; // Receiver dropped
1185                    }
1186                }
1187            });
1188
1189            // Convert channel receiver to stream
1190            return stream::unfold(rx, |mut rx| async move {
1191                rx.recv().await.map(|chunk| (chunk, rx))
1192            })
1193            .boxed();
1194        }
1195
1196        // No streaming support available
1197        stream::once(async move {
1198            Err(super::traits::StreamError::Provider(
1199                "No provider supports streaming".to_string(),
1200            ))
1201        })
1202        .boxed()
1203    }
1204
1205    fn stream_chat_with_history(
1206        &self,
1207        messages: &[ChatMessage],
1208        model: &str,
1209        temperature: f64,
1210        options: StreamOptions,
1211    ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
1212        // Try each provider/model combination for streaming with history.
1213        // Mirrors stream_chat_with_system but delegates to the underlying
1214        // provider's stream_chat_with_history, preserving the full conversation.
1215        for (provider_name, provider) in &self.providers {
1216            if !provider.supports_streaming() || !options.enabled {
1217                continue;
1218            }
1219
1220            let provider_clone = provider_name.clone();
1221
1222            let current_model = match self.model_chain(model).first() {
1223                Some(m) => (*m).to_string(),
1224                None => model.to_string(),
1225            };
1226
1227            let stream =
1228                provider.stream_chat_with_history(messages, &current_model, temperature, options);
1229
1230            let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
1231
1232            tokio::spawn(async move {
1233                let mut stream = stream;
1234                while let Some(chunk) = stream.next().await {
1235                    if let Err(ref e) = chunk {
1236                        tracing::warn!(
1237                            provider = provider_clone,
1238                            model = current_model,
1239                            "Streaming error: {e}"
1240                        );
1241                    }
1242                    if tx.send(chunk).await.is_err() {
1243                        break; // Receiver dropped
1244                    }
1245                }
1246            });
1247
1248            return stream::unfold(rx, |mut rx| async move {
1249                rx.recv().await.map(|chunk| (chunk, rx))
1250            })
1251            .boxed();
1252        }
1253
1254        // No streaming support available
1255        stream::once(async move {
1256            Err(super::traits::StreamError::Provider(
1257                "No provider supports streaming".to_string(),
1258            ))
1259        })
1260        .boxed()
1261    }
1262}
1263
1264#[cfg(test)]
1265mod tests {
1266    use super::*;
1267    use crate::tools::ToolSpec;
1268    use futures_util::StreamExt;
1269    use std::sync::Arc;
1270
1271    struct MockProvider {
1272        calls: Arc<AtomicUsize>,
1273        fail_until_attempt: usize,
1274        response: &'static str,
1275        error: &'static str,
1276    }
1277
1278    #[async_trait]
1279    impl Provider for MockProvider {
1280        async fn chat_with_system(
1281            &self,
1282            _system_prompt: Option<&str>,
1283            _message: &str,
1284            _model: &str,
1285            _temperature: f64,
1286        ) -> anyhow::Result<String> {
1287            let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
1288            if attempt <= self.fail_until_attempt {
1289                anyhow::bail!(self.error);
1290            }
1291            Ok(self.response.to_string())
1292        }
1293
1294        async fn chat_with_history(
1295            &self,
1296            _messages: &[ChatMessage],
1297            _model: &str,
1298            _temperature: f64,
1299        ) -> anyhow::Result<String> {
1300            let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
1301            if attempt <= self.fail_until_attempt {
1302                anyhow::bail!(self.error);
1303            }
1304            Ok(self.response.to_string())
1305        }
1306    }
1307
1308    /// Mock that records which model was used for each call.
1309    struct ModelAwareMock {
1310        calls: Arc<AtomicUsize>,
1311        models_seen: parking_lot::Mutex<Vec<String>>,
1312        fail_models: Vec<&'static str>,
1313        response: &'static str,
1314    }
1315
1316    #[async_trait]
1317    impl Provider for ModelAwareMock {
1318        async fn chat_with_system(
1319            &self,
1320            _system_prompt: Option<&str>,
1321            _message: &str,
1322            model: &str,
1323            _temperature: f64,
1324        ) -> anyhow::Result<String> {
1325            self.calls.fetch_add(1, Ordering::SeqCst);
1326            self.models_seen.lock().push(model.to_string());
1327            if self.fail_models.contains(&model) {
1328                anyhow::bail!("500 model {} unavailable", model);
1329            }
1330            Ok(self.response.to_string())
1331        }
1332    }
1333
1334    // ── Existing tests (preserved) ──
1335
1336    #[tokio::test]
1337    async fn succeeds_without_retry() {
1338        let calls = Arc::new(AtomicUsize::new(0));
1339        let provider = ReliableProvider::new(
1340            vec![(
1341                "primary".into(),
1342                Box::new(MockProvider {
1343                    calls: Arc::clone(&calls),
1344                    fail_until_attempt: 0,
1345                    response: "ok",
1346                    error: "boom",
1347                }),
1348            )],
1349            2,
1350            1,
1351        );
1352
1353        let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
1354        assert_eq!(result, "ok");
1355        assert_eq!(calls.load(Ordering::SeqCst), 1);
1356    }
1357
1358    #[tokio::test]
1359    async fn retries_then_recovers() {
1360        let calls = Arc::new(AtomicUsize::new(0));
1361        let provider = ReliableProvider::new(
1362            vec![(
1363                "primary".into(),
1364                Box::new(MockProvider {
1365                    calls: Arc::clone(&calls),
1366                    fail_until_attempt: 1,
1367                    response: "recovered",
1368                    error: "temporary",
1369                }),
1370            )],
1371            2,
1372            1,
1373        );
1374
1375        let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
1376        assert_eq!(result, "recovered");
1377        assert_eq!(calls.load(Ordering::SeqCst), 2);
1378    }
1379
1380    #[tokio::test]
1381    async fn falls_back_after_retries_exhausted() {
1382        let primary_calls = Arc::new(AtomicUsize::new(0));
1383        let fallback_calls = Arc::new(AtomicUsize::new(0));
1384
1385        let provider = ReliableProvider::new(
1386            vec![
1387                (
1388                    "primary".into(),
1389                    Box::new(MockProvider {
1390                        calls: Arc::clone(&primary_calls),
1391                        fail_until_attempt: usize::MAX,
1392                        response: "never",
1393                        error: "primary down",
1394                    }),
1395                ),
1396                (
1397                    "fallback".into(),
1398                    Box::new(MockProvider {
1399                        calls: Arc::clone(&fallback_calls),
1400                        fail_until_attempt: 0,
1401                        response: "from fallback",
1402                        error: "fallback down",
1403                    }),
1404                ),
1405            ],
1406            1,
1407            1,
1408        );
1409
1410        let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
1411        assert_eq!(result, "from fallback");
1412        assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
1413        assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
1414    }
1415
1416    #[tokio::test]
1417    async fn returns_aggregated_error_when_all_providers_fail() {
1418        let provider = ReliableProvider::new(
1419            vec![
1420                (
1421                    "p1".into(),
1422                    Box::new(MockProvider {
1423                        calls: Arc::new(AtomicUsize::new(0)),
1424                        fail_until_attempt: usize::MAX,
1425                        response: "never",
1426                        error: "p1 error",
1427                    }),
1428                ),
1429                (
1430                    "p2".into(),
1431                    Box::new(MockProvider {
1432                        calls: Arc::new(AtomicUsize::new(0)),
1433                        fail_until_attempt: usize::MAX,
1434                        response: "never",
1435                        error: "p2 error",
1436                    }),
1437                ),
1438            ],
1439            0,
1440            1,
1441        );
1442
1443        let err = provider
1444            .simple_chat("hello", "test", 0.0)
1445            .await
1446            .expect_err("all providers should fail");
1447        let msg = err.to_string();
1448        assert!(msg.contains("All providers/models failed"));
1449        assert!(msg.contains("provider=p1 model=test"));
1450        assert!(msg.contains("provider=p2 model=test"));
1451        assert!(msg.contains("error=p1 error"));
1452        assert!(msg.contains("error=p2 error"));
1453        assert!(msg.contains("retryable"));
1454    }
1455
1456    #[test]
1457    fn non_retryable_detects_common_patterns() {
1458        assert!(is_non_retryable(&anyhow::anyhow!("400 Bad Request")));
1459        assert!(is_non_retryable(&anyhow::anyhow!("401 Unauthorized")));
1460        assert!(is_non_retryable(&anyhow::anyhow!("403 Forbidden")));
1461        assert!(is_non_retryable(&anyhow::anyhow!("404 Not Found")));
1462        assert!(is_non_retryable(&anyhow::anyhow!(
1463            "invalid api key provided"
1464        )));
1465        assert!(is_non_retryable(&anyhow::anyhow!("authentication failed")));
1466        assert!(is_non_retryable(&anyhow::anyhow!(
1467            "model glm-4.7 not found"
1468        )));
1469        assert!(is_non_retryable(&anyhow::anyhow!(
1470            "unsupported model: glm-4.7"
1471        )));
1472        assert!(!is_non_retryable(&anyhow::anyhow!("429 Too Many Requests")));
1473        assert!(!is_non_retryable(&anyhow::anyhow!("408 Request Timeout")));
1474        assert!(!is_non_retryable(&anyhow::anyhow!(
1475            "500 Internal Server Error"
1476        )));
1477        assert!(!is_non_retryable(&anyhow::anyhow!("502 Bad Gateway")));
1478        assert!(!is_non_retryable(&anyhow::anyhow!("timeout")));
1479        assert!(!is_non_retryable(&anyhow::anyhow!("connection reset")));
1480        assert!(!is_non_retryable(&anyhow::anyhow!(
1481            "model overloaded, try again later"
1482        )));
1483        // Context window errors are now recoverable (not non-retryable)
1484        assert!(!is_non_retryable(&anyhow::anyhow!(
1485            "OpenAI Codex stream error: Your input exceeds the context window of this model."
1486        )));
1487    }
1488
1489    #[tokio::test]
1490    async fn context_window_error_aborts_retries_and_model_fallbacks() {
1491        let calls = Arc::new(AtomicUsize::new(0));
1492        let mut model_fallbacks = std::collections::HashMap::new();
1493        model_fallbacks.insert(
1494            "gpt-5.3-codex".to_string(),
1495            vec!["gpt-5.2-codex".to_string()],
1496        );
1497
1498        let provider = ReliableProvider::new(
1499            vec![(
1500                "openai-codex".into(),
1501                Box::new(MockProvider {
1502                    calls: Arc::clone(&calls),
1503                    fail_until_attempt: usize::MAX,
1504                    response: "never",
1505                    error: "OpenAI Codex stream error: Your input exceeds the context window of this model. Please adjust your input and try again.",
1506                }),
1507            )],
1508            4,
1509            1,
1510        )
1511        .with_model_fallbacks(model_fallbacks);
1512
1513        let err = provider
1514            .simple_chat("hello", "gpt-5.3-codex", 0.0)
1515            .await
1516            .expect_err("context window overflow should fail fast");
1517        let msg = err.to_string();
1518
1519        assert!(msg.contains("context window"));
1520        // chat_with_system has no history to truncate, so it bails immediately
1521        assert_eq!(calls.load(Ordering::SeqCst), 1);
1522    }
1523
1524    #[tokio::test]
1525    async fn aggregated_error_marks_non_retryable_model_mismatch_with_details() {
1526        let calls = Arc::new(AtomicUsize::new(0));
1527        let provider = ReliableProvider::new(
1528            vec![(
1529                "custom".into(),
1530                Box::new(MockProvider {
1531                    calls: Arc::clone(&calls),
1532                    fail_until_attempt: usize::MAX,
1533                    response: "never",
1534                    error: "unsupported model: glm-4.7",
1535                }),
1536            )],
1537            3,
1538            1,
1539        );
1540
1541        let err = provider
1542            .simple_chat("hello", "glm-4.7", 0.0)
1543            .await
1544            .expect_err("provider should fail");
1545        let msg = err.to_string();
1546
1547        assert!(msg.contains("non_retryable"));
1548        assert!(msg.contains("error=unsupported model: glm-4.7"));
1549        // Non-retryable errors should not consume retry budget.
1550        assert_eq!(calls.load(Ordering::SeqCst), 1);
1551    }
1552
1553    #[tokio::test]
1554    async fn skips_retries_on_non_retryable_error() {
1555        let primary_calls = Arc::new(AtomicUsize::new(0));
1556        let fallback_calls = Arc::new(AtomicUsize::new(0));
1557
1558        let provider = ReliableProvider::new(
1559            vec![
1560                (
1561                    "primary".into(),
1562                    Box::new(MockProvider {
1563                        calls: Arc::clone(&primary_calls),
1564                        fail_until_attempt: usize::MAX,
1565                        response: "never",
1566                        error: "401 Unauthorized",
1567                    }),
1568                ),
1569                (
1570                    "fallback".into(),
1571                    Box::new(MockProvider {
1572                        calls: Arc::clone(&fallback_calls),
1573                        fail_until_attempt: 0,
1574                        response: "from fallback",
1575                        error: "fallback err",
1576                    }),
1577                ),
1578            ],
1579            3,
1580            1,
1581        );
1582
1583        let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
1584        assert_eq!(result, "from fallback");
1585        // Primary should have been called only once (no retries)
1586        assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
1587        assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
1588    }
1589
1590    #[tokio::test]
1591    async fn chat_with_history_retries_then_recovers() {
1592        let calls = Arc::new(AtomicUsize::new(0));
1593        let provider = ReliableProvider::new(
1594            vec![(
1595                "primary".into(),
1596                Box::new(MockProvider {
1597                    calls: Arc::clone(&calls),
1598                    fail_until_attempt: 1,
1599                    response: "history ok",
1600                    error: "temporary",
1601                }),
1602            )],
1603            2,
1604            1,
1605        );
1606
1607        let messages = vec![ChatMessage::system("system"), ChatMessage::user("hello")];
1608        let result = provider
1609            .chat_with_history(&messages, "test", 0.0)
1610            .await
1611            .unwrap();
1612        assert_eq!(result, "history ok");
1613        assert_eq!(calls.load(Ordering::SeqCst), 2);
1614    }
1615
1616    #[tokio::test]
1617    async fn chat_with_history_falls_back() {
1618        let primary_calls = Arc::new(AtomicUsize::new(0));
1619        let fallback_calls = Arc::new(AtomicUsize::new(0));
1620
1621        let provider = ReliableProvider::new(
1622            vec![
1623                (
1624                    "primary".into(),
1625                    Box::new(MockProvider {
1626                        calls: Arc::clone(&primary_calls),
1627                        fail_until_attempt: usize::MAX,
1628                        response: "never",
1629                        error: "primary down",
1630                    }),
1631                ),
1632                (
1633                    "fallback".into(),
1634                    Box::new(MockProvider {
1635                        calls: Arc::clone(&fallback_calls),
1636                        fail_until_attempt: 0,
1637                        response: "fallback ok",
1638                        error: "fallback err",
1639                    }),
1640                ),
1641            ],
1642            1,
1643            1,
1644        );
1645
1646        let messages = vec![ChatMessage::user("hello")];
1647        let result = provider
1648            .chat_with_history(&messages, "test", 0.0)
1649            .await
1650            .unwrap();
1651        assert_eq!(result, "fallback ok");
1652        assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
1653        assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
1654    }
1655
1656    // ── New tests: model failover ──
1657
1658    #[tokio::test]
1659    async fn model_failover_tries_fallback_model() {
1660        let calls = Arc::new(AtomicUsize::new(0));
1661        let mock = Arc::new(ModelAwareMock {
1662            calls: Arc::clone(&calls),
1663            models_seen: parking_lot::Mutex::new(Vec::new()),
1664            fail_models: vec!["claude-opus"],
1665            response: "ok from sonnet",
1666        });
1667
1668        let mut fallbacks = HashMap::new();
1669        fallbacks.insert("claude-opus".to_string(), vec!["claude-sonnet".to_string()]);
1670
1671        let provider = ReliableProvider::new(
1672            vec![(
1673                "anthropic".into(),
1674                Box::new(mock.clone()) as Box<dyn Provider>,
1675            )],
1676            0, // no retries — force immediate model failover
1677            1,
1678        )
1679        .with_model_fallbacks(fallbacks);
1680
1681        let result = provider
1682            .simple_chat("hello", "claude-opus", 0.0)
1683            .await
1684            .unwrap();
1685        assert_eq!(result, "ok from sonnet");
1686
1687        let seen = mock.models_seen.lock();
1688        assert_eq!(seen.len(), 2);
1689        assert_eq!(seen[0], "claude-opus");
1690        assert_eq!(seen[1], "claude-sonnet");
1691    }
1692
1693    #[tokio::test]
1694    async fn model_failover_all_models_fail() {
1695        let calls = Arc::new(AtomicUsize::new(0));
1696        let mock = Arc::new(ModelAwareMock {
1697            calls: Arc::clone(&calls),
1698            models_seen: parking_lot::Mutex::new(Vec::new()),
1699            fail_models: vec!["model-a", "model-b", "model-c"],
1700            response: "never",
1701        });
1702
1703        let mut fallbacks = HashMap::new();
1704        fallbacks.insert(
1705            "model-a".to_string(),
1706            vec!["model-b".to_string(), "model-c".to_string()],
1707        );
1708
1709        let provider = ReliableProvider::new(
1710            vec![("p1".into(), Box::new(mock.clone()) as Box<dyn Provider>)],
1711            0,
1712            1,
1713        )
1714        .with_model_fallbacks(fallbacks);
1715
1716        let err = provider
1717            .simple_chat("hello", "model-a", 0.0)
1718            .await
1719            .expect_err("all models should fail");
1720        assert!(err.to_string().contains("All providers/models failed"));
1721
1722        let seen = mock.models_seen.lock();
1723        assert_eq!(seen.len(), 3);
1724    }
1725
1726    #[tokio::test]
1727    async fn no_model_fallbacks_behaves_like_before() {
1728        let calls = Arc::new(AtomicUsize::new(0));
1729        let provider = ReliableProvider::new(
1730            vec![(
1731                "primary".into(),
1732                Box::new(MockProvider {
1733                    calls: Arc::clone(&calls),
1734                    fail_until_attempt: 0,
1735                    response: "ok",
1736                    error: "boom",
1737                }),
1738            )],
1739            2,
1740            1,
1741        );
1742        // No model_fallbacks set — should work exactly as before
1743        let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
1744        assert_eq!(result, "ok");
1745        assert_eq!(calls.load(Ordering::SeqCst), 1);
1746    }
1747
1748    // ── New tests: auth rotation ──
1749
1750    #[tokio::test]
1751    async fn auth_rotation_cycles_keys() {
1752        let provider = ReliableProvider::new(
1753            vec![(
1754                "p".into(),
1755                Box::new(MockProvider {
1756                    calls: Arc::new(AtomicUsize::new(0)),
1757                    fail_until_attempt: 0,
1758                    response: "ok",
1759                    error: "",
1760                }),
1761            )],
1762            0,
1763            1,
1764        )
1765        .with_api_keys(vec!["key-a".into(), "key-b".into(), "key-c".into()]);
1766
1767        // Rotate 5 times, verify round-robin
1768        let keys: Vec<&str> = (0..5).map(|_| provider.rotate_key().unwrap()).collect();
1769        assert_eq!(keys, vec!["key-a", "key-b", "key-c", "key-a", "key-b"]);
1770    }
1771
1772    #[tokio::test]
1773    async fn auth_rotation_returns_none_when_empty() {
1774        let provider = ReliableProvider::new(vec![], 0, 1);
1775        assert!(provider.rotate_key().is_none());
1776    }
1777
1778    // ── New tests: Retry-After parsing ──
1779
1780    #[test]
1781    fn parse_retry_after_integer() {
1782        let err = anyhow::anyhow!("429 Too Many Requests, Retry-After: 5");
1783        assert_eq!(parse_retry_after_ms(&err), Some(5000));
1784    }
1785
1786    #[test]
1787    fn parse_retry_after_float() {
1788        let err = anyhow::anyhow!("Rate limited. retry_after: 2.5 seconds");
1789        assert_eq!(parse_retry_after_ms(&err), Some(2500));
1790    }
1791
1792    #[test]
1793    fn parse_retry_after_missing() {
1794        let err = anyhow::anyhow!("500 Internal Server Error");
1795        assert_eq!(parse_retry_after_ms(&err), None);
1796    }
1797
1798    #[test]
1799    fn rate_limited_detection() {
1800        assert!(is_rate_limited(&anyhow::anyhow!("429 Too Many Requests")));
1801        assert!(is_rate_limited(&anyhow::anyhow!(
1802            "HTTP 429 rate limit exceeded"
1803        )));
1804        assert!(!is_rate_limited(&anyhow::anyhow!("401 Unauthorized")));
1805        assert!(!is_rate_limited(&anyhow::anyhow!(
1806            "500 Internal Server Error"
1807        )));
1808    }
1809
1810    #[test]
1811    fn non_retryable_rate_limit_detects_plan_restricted_model() {
1812        let err = anyhow::anyhow!(
1813            "{}",
1814            "API error (429 Too Many Requests): {\"code\":1311,\"message\":\"the current account plan does not include glm-5\"}"
1815        );
1816        assert!(
1817            is_non_retryable_rate_limit(&err),
1818            "plan-restricted 429 should skip retries"
1819        );
1820    }
1821
1822    #[test]
1823    fn non_retryable_rate_limit_detects_insufficient_balance() {
1824        let err = anyhow::anyhow!(
1825            "{}",
1826            "API error (429 Too Many Requests): {\"code\":1113,\"message\":\"insufficient balance\"}"
1827        );
1828        assert!(
1829            is_non_retryable_rate_limit(&err),
1830            "insufficient-balance 429 should skip retries"
1831        );
1832    }
1833
1834    #[test]
1835    fn non_retryable_rate_limit_does_not_flag_generic_429() {
1836        let err = anyhow::anyhow!("429 Too Many Requests: rate limit exceeded");
1837        assert!(
1838            !is_non_retryable_rate_limit(&err),
1839            "generic rate-limit 429 should remain retryable"
1840        );
1841    }
1842
1843    #[test]
1844    fn compute_backoff_uses_retry_after() {
1845        let provider = ReliableProvider::new(vec![], 0, 500);
1846        let err = anyhow::anyhow!("429 Retry-After: 3");
1847        assert_eq!(provider.compute_backoff(500, &err), 3_000);
1848    }
1849
1850    #[test]
1851    fn compute_backoff_caps_at_30s() {
1852        let provider = ReliableProvider::new(vec![], 0, 500);
1853        let err = anyhow::anyhow!("429 Retry-After: 120");
1854        assert_eq!(provider.compute_backoff(500, &err), 30_000);
1855    }
1856
1857    #[test]
1858    fn compute_backoff_falls_back_to_base() {
1859        let provider = ReliableProvider::new(vec![], 0, 500);
1860        let err = anyhow::anyhow!("500 Server Error");
1861        assert_eq!(provider.compute_backoff(500, &err), 500);
1862    }
1863
1864    // ── §2.1 API auth error (401/403) tests ──────────────────
1865
1866    #[test]
1867    fn non_retryable_detects_401() {
1868        let err = anyhow::anyhow!("API error (401 Unauthorized): invalid api key");
1869        assert!(
1870            is_non_retryable(&err),
1871            "401 errors must be detected as non-retryable"
1872        );
1873    }
1874
1875    #[test]
1876    fn non_retryable_detects_403() {
1877        let err = anyhow::anyhow!("API error (403 Forbidden): access denied");
1878        assert!(
1879            is_non_retryable(&err),
1880            "403 errors must be detected as non-retryable"
1881        );
1882    }
1883
1884    #[test]
1885    fn non_retryable_detects_404() {
1886        let err = anyhow::anyhow!("API error (404 Not Found): model not found");
1887        assert!(
1888            is_non_retryable(&err),
1889            "404 errors must be detected as non-retryable"
1890        );
1891    }
1892
1893    #[test]
1894    fn non_retryable_does_not_flag_429() {
1895        let err = anyhow::anyhow!("429 Too Many Requests");
1896        assert!(
1897            !is_non_retryable(&err),
1898            "429 must NOT be treated as non-retryable (it is retryable with backoff)"
1899        );
1900    }
1901
1902    #[test]
1903    fn non_retryable_does_not_flag_408() {
1904        let err = anyhow::anyhow!("408 Request Timeout");
1905        assert!(
1906            !is_non_retryable(&err),
1907            "408 must NOT be treated as non-retryable (it is retryable)"
1908        );
1909    }
1910
1911    #[test]
1912    fn non_retryable_does_not_flag_500() {
1913        let err = anyhow::anyhow!("500 Internal Server Error");
1914        assert!(
1915            !is_non_retryable(&err),
1916            "500 must NOT be treated as non-retryable (server errors are retryable)"
1917        );
1918    }
1919
1920    #[test]
1921    fn non_retryable_does_not_flag_502() {
1922        let err = anyhow::anyhow!("502 Bad Gateway");
1923        assert!(
1924            !is_non_retryable(&err),
1925            "502 must NOT be treated as non-retryable"
1926        );
1927    }
1928
1929    // ── §2.2 Rate limit Retry-After edge cases ───────────────
1930
1931    #[test]
1932    fn parse_retry_after_zero() {
1933        let err = anyhow::anyhow!("429 Too Many Requests, Retry-After: 0");
1934        assert_eq!(
1935            parse_retry_after_ms(&err),
1936            Some(0),
1937            "Retry-After: 0 should parse as 0ms"
1938        );
1939    }
1940
1941    #[test]
1942    fn parse_retry_after_with_underscore_separator() {
1943        let err = anyhow::anyhow!("rate limited, retry_after: 10");
1944        assert_eq!(
1945            parse_retry_after_ms(&err),
1946            Some(10_000),
1947            "retry_after with underscore must be parsed"
1948        );
1949    }
1950
1951    #[test]
1952    fn parse_retry_after_space_separator() {
1953        let err = anyhow::anyhow!("Retry-After 7");
1954        assert_eq!(
1955            parse_retry_after_ms(&err),
1956            Some(7000),
1957            "Retry-After with space separator must be parsed"
1958        );
1959    }
1960
1961    #[test]
1962    fn rate_limited_false_for_generic_error() {
1963        let err = anyhow::anyhow!("Connection refused");
1964        assert!(
1965            !is_rate_limited(&err),
1966            "generic errors must not be flagged as rate-limited"
1967        );
1968    }
1969
1970    // ── §2.3 Malformed API response error classification ─────
1971
1972    #[tokio::test]
1973    async fn non_retryable_skips_retries_for_401() {
1974        let calls = Arc::new(AtomicUsize::new(0));
1975        let provider = ReliableProvider::new(
1976            vec![(
1977                "primary".into(),
1978                Box::new(MockProvider {
1979                    calls: Arc::clone(&calls),
1980                    fail_until_attempt: usize::MAX,
1981                    response: "never",
1982                    error: "API error (401 Unauthorized): invalid key",
1983                }),
1984            )],
1985            5,
1986            1,
1987        );
1988
1989        let result = provider.simple_chat("hello", "test", 0.0).await;
1990        assert!(result.is_err(), "401 should fail without retries");
1991        assert_eq!(
1992            calls.load(Ordering::SeqCst),
1993            1,
1994            "must not retry on 401 — should be exactly 1 call"
1995        );
1996    }
1997
1998    #[tokio::test]
1999    async fn non_retryable_rate_limit_skips_retries_for_plan_errors() {
2000        let calls = Arc::new(AtomicUsize::new(0));
2001        let provider = ReliableProvider::new(
2002            vec![(
2003                "primary".into(),
2004                Box::new(MockProvider {
2005                    calls: Arc::clone(&calls),
2006                    fail_until_attempt: usize::MAX,
2007                    response: "never",
2008                    error: "API error (429 Too Many Requests): {\"code\":1311,\"message\":\"plan does not include glm-5\"}",
2009                }),
2010            )],
2011            5,
2012            1,
2013        );
2014
2015        let result = provider.simple_chat("hello", "test", 0.0).await;
2016        assert!(
2017            result.is_err(),
2018            "plan-restricted 429 should fail quickly without retrying"
2019        );
2020        assert_eq!(
2021            calls.load(Ordering::SeqCst),
2022            1,
2023            "must not retry non-retryable 429 business errors"
2024        );
2025    }
2026
2027    // ── Arc<ModelAwareMock> Provider impl for test ──
2028
2029    #[async_trait]
2030    impl Provider for Arc<ModelAwareMock> {
2031        async fn chat_with_system(
2032            &self,
2033            system_prompt: Option<&str>,
2034            message: &str,
2035            model: &str,
2036            temperature: f64,
2037        ) -> anyhow::Result<String> {
2038            self.as_ref()
2039                .chat_with_system(system_prompt, message, model, temperature)
2040                .await
2041        }
2042    }
2043
2044    /// Mock provider that implements `chat()` with native tool support.
2045    struct NativeToolMock {
2046        calls: Arc<AtomicUsize>,
2047        fail_until_attempt: usize,
2048        response_text: &'static str,
2049        tool_calls: Vec<super::super::traits::ToolCall>,
2050        error: &'static str,
2051    }
2052
2053    #[async_trait]
2054    impl Provider for NativeToolMock {
2055        async fn chat_with_system(
2056            &self,
2057            _system_prompt: Option<&str>,
2058            _message: &str,
2059            _model: &str,
2060            _temperature: f64,
2061        ) -> anyhow::Result<String> {
2062            Ok(self.response_text.to_string())
2063        }
2064
2065        fn supports_native_tools(&self) -> bool {
2066            true
2067        }
2068
2069        async fn chat(
2070            &self,
2071            _request: ChatRequest<'_>,
2072            _model: &str,
2073            _temperature: f64,
2074        ) -> anyhow::Result<ChatResponse> {
2075            let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
2076            if attempt <= self.fail_until_attempt {
2077                anyhow::bail!(self.error);
2078            }
2079            Ok(ChatResponse {
2080                text: Some(self.response_text.to_string()),
2081                tool_calls: self.tool_calls.clone(),
2082                usage: None,
2083                reasoning_content: None,
2084            })
2085        }
2086    }
2087
2088    #[tokio::test]
2089    async fn chat_delegates_to_inner_provider() {
2090        let calls = Arc::new(AtomicUsize::new(0));
2091        let tool_call = super::super::traits::ToolCall {
2092            id: "call_1".to_string(),
2093            name: "shell".to_string(),
2094            arguments: r#"{"command":"date"}"#.to_string(),
2095        };
2096        let provider = ReliableProvider::new(
2097            vec![(
2098                "primary".into(),
2099                Box::new(NativeToolMock {
2100                    calls: Arc::clone(&calls),
2101                    fail_until_attempt: 0,
2102                    response_text: "ok",
2103                    tool_calls: vec![tool_call.clone()],
2104                    error: "boom",
2105                }) as Box<dyn Provider>,
2106            )],
2107            2,
2108            1,
2109        );
2110
2111        let messages = vec![ChatMessage::user("what time is it?")];
2112        let request = ChatRequest {
2113            messages: &messages,
2114            tools: None,
2115        };
2116        let result = provider.chat(request, "test-model", 0.0).await.unwrap();
2117
2118        assert_eq!(result.text.as_deref(), Some("ok"));
2119        assert_eq!(result.tool_calls.len(), 1);
2120        assert_eq!(result.tool_calls[0].name, "shell");
2121        assert_eq!(calls.load(Ordering::SeqCst), 1);
2122    }
2123
2124    #[tokio::test]
2125    async fn chat_retries_and_recovers() {
2126        let calls = Arc::new(AtomicUsize::new(0));
2127        let tool_call = super::super::traits::ToolCall {
2128            id: "call_1".to_string(),
2129            name: "shell".to_string(),
2130            arguments: r#"{"command":"date"}"#.to_string(),
2131        };
2132        let provider = ReliableProvider::new(
2133            vec![(
2134                "primary".into(),
2135                Box::new(NativeToolMock {
2136                    calls: Arc::clone(&calls),
2137                    fail_until_attempt: 2,
2138                    response_text: "recovered",
2139                    tool_calls: vec![tool_call],
2140                    error: "temporary failure",
2141                }) as Box<dyn Provider>,
2142            )],
2143            3,
2144            1,
2145        );
2146
2147        let messages = vec![ChatMessage::user("test")];
2148        let request = ChatRequest {
2149            messages: &messages,
2150            tools: None,
2151        };
2152        let result = provider.chat(request, "test-model", 0.0).await.unwrap();
2153
2154        assert_eq!(result.text.as_deref(), Some("recovered"));
2155        assert!(
2156            calls.load(Ordering::SeqCst) > 1,
2157            "should have retried at least once"
2158        );
2159    }
2160
2161    #[tokio::test]
2162    async fn chat_preserves_native_tools_support() {
2163        let calls = Arc::new(AtomicUsize::new(0));
2164        let provider = ReliableProvider::new(
2165            vec![(
2166                "primary".into(),
2167                Box::new(NativeToolMock {
2168                    calls: Arc::clone(&calls),
2169                    fail_until_attempt: 0,
2170                    response_text: "ok",
2171                    tool_calls: vec![],
2172                    error: "boom",
2173                }) as Box<dyn Provider>,
2174            )],
2175            2,
2176            1,
2177        );
2178
2179        assert!(
2180            provider.supports_native_tools(),
2181            "ReliableProvider must propagate supports_native_tools from inner provider"
2182        );
2183    }
2184
2185    // ── Gap 2-4: Parity tests for chat() ────────────────────────
2186
2187    /// Gap 2: `chat()` returns an aggregated error when all providers fail,
2188    /// matching behavior of `returns_aggregated_error_when_all_providers_fail`.
2189    #[tokio::test]
2190    async fn chat_returns_aggregated_error_when_all_providers_fail() {
2191        let provider = ReliableProvider::new(
2192            vec![
2193                (
2194                    "p1".into(),
2195                    Box::new(NativeToolMock {
2196                        calls: Arc::new(AtomicUsize::new(0)),
2197                        fail_until_attempt: usize::MAX,
2198                        response_text: "never",
2199                        tool_calls: vec![],
2200                        error: "p1 chat error",
2201                    }) as Box<dyn Provider>,
2202                ),
2203                (
2204                    "p2".into(),
2205                    Box::new(NativeToolMock {
2206                        calls: Arc::new(AtomicUsize::new(0)),
2207                        fail_until_attempt: usize::MAX,
2208                        response_text: "never",
2209                        tool_calls: vec![],
2210                        error: "p2 chat error",
2211                    }) as Box<dyn Provider>,
2212                ),
2213            ],
2214            0,
2215            1,
2216        );
2217
2218        let messages = vec![ChatMessage::user("hello")];
2219        let request = ChatRequest {
2220            messages: &messages,
2221            tools: None,
2222        };
2223        let err = provider
2224            .chat(request, "test", 0.0)
2225            .await
2226            .expect_err("all providers should fail");
2227        let msg = err.to_string();
2228        assert!(msg.contains("All providers/models failed"));
2229        assert!(msg.contains("provider=p1 model=test"));
2230        assert!(msg.contains("provider=p2 model=test"));
2231        assert!(msg.contains("error=p1 chat error"));
2232        assert!(msg.contains("error=p2 chat error"));
2233        assert!(msg.contains("retryable"));
2234    }
2235
2236    /// Mock that records model names and can fail specific models,
2237    /// implementing `chat()` for native tool calling parity tests.
2238    struct NativeModelAwareMock {
2239        calls: Arc<AtomicUsize>,
2240        models_seen: parking_lot::Mutex<Vec<String>>,
2241        fail_models: Vec<&'static str>,
2242        response_text: &'static str,
2243    }
2244
2245    #[async_trait]
2246    impl Provider for NativeModelAwareMock {
2247        async fn chat_with_system(
2248            &self,
2249            _system_prompt: Option<&str>,
2250            _message: &str,
2251            _model: &str,
2252            _temperature: f64,
2253        ) -> anyhow::Result<String> {
2254            Ok(self.response_text.to_string())
2255        }
2256
2257        fn supports_native_tools(&self) -> bool {
2258            true
2259        }
2260
2261        async fn chat(
2262            &self,
2263            _request: ChatRequest<'_>,
2264            model: &str,
2265            _temperature: f64,
2266        ) -> anyhow::Result<ChatResponse> {
2267            self.calls.fetch_add(1, Ordering::SeqCst);
2268            self.models_seen.lock().push(model.to_string());
2269            if self.fail_models.contains(&model) {
2270                anyhow::bail!("500 model {} unavailable", model);
2271            }
2272            Ok(ChatResponse {
2273                text: Some(self.response_text.to_string()),
2274                tool_calls: vec![],
2275                usage: None,
2276                reasoning_content: None,
2277            })
2278        }
2279    }
2280
2281    #[async_trait]
2282    impl Provider for Arc<NativeModelAwareMock> {
2283        async fn chat_with_system(
2284            &self,
2285            system_prompt: Option<&str>,
2286            message: &str,
2287            model: &str,
2288            temperature: f64,
2289        ) -> anyhow::Result<String> {
2290            self.as_ref()
2291                .chat_with_system(system_prompt, message, model, temperature)
2292                .await
2293        }
2294
2295        fn supports_native_tools(&self) -> bool {
2296            true
2297        }
2298
2299        async fn chat(
2300            &self,
2301            request: ChatRequest<'_>,
2302            model: &str,
2303            temperature: f64,
2304        ) -> anyhow::Result<ChatResponse> {
2305            self.as_ref().chat(request, model, temperature).await
2306        }
2307    }
2308
2309    /// Gap 3: `chat()` tries fallback models on failure,
2310    /// matching behavior of `model_failover_tries_fallback_model`.
2311    #[tokio::test]
2312    async fn chat_tries_model_failover_on_failure() {
2313        let calls = Arc::new(AtomicUsize::new(0));
2314        let mock = Arc::new(NativeModelAwareMock {
2315            calls: Arc::clone(&calls),
2316            models_seen: parking_lot::Mutex::new(Vec::new()),
2317            fail_models: vec!["claude-opus"],
2318            response_text: "ok from sonnet",
2319        });
2320
2321        let mut fallbacks = HashMap::new();
2322        fallbacks.insert("claude-opus".to_string(), vec!["claude-sonnet".to_string()]);
2323
2324        let provider = ReliableProvider::new(
2325            vec![(
2326                "anthropic".into(),
2327                Box::new(mock.clone()) as Box<dyn Provider>,
2328            )],
2329            0, // no retries — force immediate model failover
2330            1,
2331        )
2332        .with_model_fallbacks(fallbacks);
2333
2334        let messages = vec![ChatMessage::user("hello")];
2335        let request = ChatRequest {
2336            messages: &messages,
2337            tools: None,
2338        };
2339        let result = provider.chat(request, "claude-opus", 0.0).await.unwrap();
2340        assert_eq!(result.text.as_deref(), Some("ok from sonnet"));
2341
2342        let seen = mock.models_seen.lock();
2343        assert_eq!(seen.len(), 2);
2344        assert_eq!(seen[0], "claude-opus");
2345        assert_eq!(seen[1], "claude-sonnet");
2346    }
2347
2348    /// Gap 4: `chat()` skips retries on non-retryable errors (401, 403, etc.),
2349    /// matching behavior of `skips_retries_on_non_retryable_error`.
2350    #[tokio::test]
2351    async fn chat_skips_non_retryable_errors() {
2352        let primary_calls = Arc::new(AtomicUsize::new(0));
2353        let fallback_calls = Arc::new(AtomicUsize::new(0));
2354
2355        let provider = ReliableProvider::new(
2356            vec![
2357                (
2358                    "primary".into(),
2359                    Box::new(NativeToolMock {
2360                        calls: Arc::clone(&primary_calls),
2361                        fail_until_attempt: usize::MAX,
2362                        response_text: "never",
2363                        tool_calls: vec![],
2364                        error: "401 Unauthorized",
2365                    }) as Box<dyn Provider>,
2366                ),
2367                (
2368                    "fallback".into(),
2369                    Box::new(NativeToolMock {
2370                        calls: Arc::clone(&fallback_calls),
2371                        fail_until_attempt: 0,
2372                        response_text: "from fallback",
2373                        tool_calls: vec![],
2374                        error: "fallback err",
2375                    }) as Box<dyn Provider>,
2376                ),
2377            ],
2378            3,
2379            1,
2380        );
2381
2382        let messages = vec![ChatMessage::user("hello")];
2383        let request = ChatRequest {
2384            messages: &messages,
2385            tools: None,
2386        };
2387        let result = provider.chat(request, "test", 0.0).await.unwrap();
2388        assert_eq!(result.text.as_deref(), Some("from fallback"));
2389        // Primary should have been called only once (no retries)
2390        assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
2391        assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
2392    }
2393
2394    // ── Context window truncation tests ─────────────────────────
2395
2396    #[test]
2397    fn context_window_error_is_not_non_retryable() {
2398        // Context window errors should be recoverable via truncation
2399        assert!(!is_non_retryable(&anyhow::anyhow!(
2400            "exceeds the context window"
2401        )));
2402        assert!(!is_non_retryable(&anyhow::anyhow!(
2403            "maximum context length exceeded"
2404        )));
2405        assert!(!is_non_retryable(&anyhow::anyhow!(
2406            "too many tokens in the request"
2407        )));
2408        assert!(!is_non_retryable(&anyhow::anyhow!("token limit exceeded")));
2409    }
2410
2411    #[test]
2412    fn is_context_window_exceeded_detects_llamacpp() {
2413        assert!(is_context_window_exceeded(&anyhow::anyhow!(
2414            "request (8968 tokens) exceeds the available context size (8448 tokens), try increasing it"
2415        )));
2416    }
2417
2418    #[test]
2419    fn truncate_for_context_drops_oldest_non_system() {
2420        let mut messages = vec![
2421            ChatMessage::system("sys"),
2422            ChatMessage::user("msg1"),
2423            ChatMessage::assistant("resp1"),
2424            ChatMessage::user("msg2"),
2425            ChatMessage::assistant("resp2"),
2426            ChatMessage::user("msg3"),
2427        ];
2428
2429        let dropped = truncate_for_context(&mut messages);
2430
2431        // 5 non-system messages, drop oldest half = 2
2432        assert_eq!(dropped, 2);
2433        // System message preserved
2434        assert_eq!(messages[0].role, "system");
2435        // Remaining messages should be the newer ones
2436        assert_eq!(messages.len(), 4); // system + 3 remaining non-system
2437        // The last message should still be the most recent user message
2438        assert_eq!(messages.last().unwrap().content, "msg3");
2439    }
2440
2441    #[test]
2442    fn truncate_for_context_preserves_system_and_last_message() {
2443        // Only one non-system message: nothing to drop
2444        let mut messages = vec![ChatMessage::system("sys"), ChatMessage::user("only")];
2445        let dropped = truncate_for_context(&mut messages);
2446        assert_eq!(dropped, 0);
2447        assert_eq!(messages.len(), 2);
2448
2449        // No system message, only one user message
2450        let mut messages = vec![ChatMessage::user("only")];
2451        let dropped = truncate_for_context(&mut messages);
2452        assert_eq!(dropped, 0);
2453        assert_eq!(messages.len(), 1);
2454    }
2455
2456    /// Mock that fails with context error on first N calls, then succeeds.
2457    /// Tracks the number of messages received on each call.
2458    struct ContextOverflowMock {
2459        calls: Arc<AtomicUsize>,
2460        fail_until_attempt: usize,
2461        message_counts: parking_lot::Mutex<Vec<usize>>,
2462    }
2463
2464    #[async_trait]
2465    impl Provider for ContextOverflowMock {
2466        async fn chat_with_system(
2467            &self,
2468            _system_prompt: Option<&str>,
2469            _message: &str,
2470            _model: &str,
2471            _temperature: f64,
2472        ) -> anyhow::Result<String> {
2473            Ok("ok".to_string())
2474        }
2475
2476        async fn chat_with_history(
2477            &self,
2478            messages: &[ChatMessage],
2479            _model: &str,
2480            _temperature: f64,
2481        ) -> anyhow::Result<String> {
2482            let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
2483            self.message_counts.lock().push(messages.len());
2484            if attempt <= self.fail_until_attempt {
2485                anyhow::bail!(
2486                    "request (8968 tokens) exceeds the available context size (8448 tokens), try increasing it"
2487                );
2488            }
2489            Ok("recovered after truncation".to_string())
2490        }
2491    }
2492
2493    #[tokio::test]
2494    async fn chat_with_history_truncates_on_context_overflow() {
2495        let calls = Arc::new(AtomicUsize::new(0));
2496        let mock = ContextOverflowMock {
2497            calls: Arc::clone(&calls),
2498            fail_until_attempt: 1, // fail first call, succeed after truncation
2499            message_counts: parking_lot::Mutex::new(Vec::new()),
2500        };
2501
2502        let provider = ReliableProvider::new(
2503            vec![("local".into(), Box::new(mock) as Box<dyn Provider>)],
2504            3,
2505            1,
2506        );
2507
2508        let messages = vec![
2509            ChatMessage::system("system prompt"),
2510            ChatMessage::user("old message 1"),
2511            ChatMessage::assistant("old response 1"),
2512            ChatMessage::user("old message 2"),
2513            ChatMessage::assistant("old response 2"),
2514            ChatMessage::user("current question"),
2515        ];
2516
2517        let result = provider
2518            .chat_with_history(&messages, "local-model", 0.0)
2519            .await
2520            .unwrap();
2521        assert_eq!(result, "recovered after truncation");
2522        // Should have been called twice: once with full messages, once with truncated
2523        assert_eq!(calls.load(Ordering::SeqCst), 2);
2524    }
2525
2526    #[tokio::test]
2527    async fn context_overflow_with_no_history_to_truncate_bails_immediately() {
2528        let calls = Arc::new(AtomicUsize::new(0));
2529        let mock = ContextOverflowMock {
2530            calls: Arc::clone(&calls),
2531            fail_until_attempt: 999, // always fail
2532            message_counts: parking_lot::Mutex::new(Vec::new()),
2533        };
2534
2535        let provider = ReliableProvider::new(
2536            vec![("local".into(), Box::new(mock) as Box<dyn Provider>)],
2537            3,
2538            1,
2539        );
2540
2541        // Only system + one user message — nothing to truncate
2542        let messages = vec![
2543            ChatMessage::system("huge system prompt that exceeds context window"),
2544            ChatMessage::user("hello"),
2545        ];
2546
2547        let result = provider
2548            .chat_with_history(&messages, "local-model", 0.0)
2549            .await;
2550        assert!(result.is_err());
2551        let err_msg = result.unwrap_err().to_string();
2552        assert!(
2553            err_msg.contains("cannot be reduced further"),
2554            "Should bail with actionable message, got: {err_msg}"
2555        );
2556        // Should only be called once — no useless retries
2557        assert_eq!(
2558            calls.load(Ordering::SeqCst),
2559            1,
2560            "Should not retry when truncation is impossible"
2561        );
2562    }
2563
2564    // ── Tool schema error detection tests ───────────────────────────────
2565
2566    #[test]
2567    fn tool_schema_error_detects_groq_validation_failure() {
2568        let msg = r#"Groq API error (400 Bad Request): {"error":{"message":"tool call validation failed: attempted to call tool 'memory_recall' which was not in request"}}"#;
2569        let err = anyhow::anyhow!("{}", msg);
2570        assert!(is_tool_schema_error(&err));
2571    }
2572
2573    #[test]
2574    fn tool_schema_error_detects_not_in_request() {
2575        let err = anyhow::anyhow!("tool 'search' was not in request");
2576        assert!(is_tool_schema_error(&err));
2577    }
2578
2579    #[test]
2580    fn tool_schema_error_detects_not_found_in_tool_list() {
2581        let err = anyhow::anyhow!("function 'foo' not found in tool list");
2582        assert!(is_tool_schema_error(&err));
2583    }
2584
2585    #[test]
2586    fn tool_schema_error_detects_invalid_tool_call() {
2587        let err = anyhow::anyhow!("invalid_tool_call: no matching function");
2588        assert!(is_tool_schema_error(&err));
2589    }
2590
2591    #[test]
2592    fn tool_schema_error_ignores_unrelated_errors() {
2593        let err = anyhow::anyhow!("invalid api key");
2594        assert!(!is_tool_schema_error(&err));
2595
2596        let err = anyhow::anyhow!("model not found");
2597        assert!(!is_tool_schema_error(&err));
2598    }
2599
2600    #[test]
2601    fn non_retryable_returns_false_for_tool_schema_400() {
2602        // A 400 error with tool schema validation text should NOT be non-retryable.
2603        let msg = "400 Bad Request: tool call validation failed: attempted to call tool 'x' which was not in request";
2604        let err = anyhow::anyhow!("{}", msg);
2605        assert!(!is_non_retryable(&err));
2606    }
2607
2608    #[test]
2609    fn non_retryable_returns_true_for_other_400_errors() {
2610        // A regular 400 error (e.g. invalid API key) should still be non-retryable.
2611        let err = anyhow::anyhow!("400 Bad Request: invalid api key provided");
2612        assert!(is_non_retryable(&err));
2613    }
2614
2615    struct StreamingToolEventMock {
2616        stream_calls: Arc<AtomicUsize>,
2617        supports_tool_events: bool,
2618    }
2619
2620    impl StreamingToolEventMock {
2621        fn new(supports_tool_events: bool) -> Self {
2622            Self {
2623                stream_calls: Arc::new(AtomicUsize::new(0)),
2624                supports_tool_events,
2625            }
2626        }
2627    }
2628
2629    #[async_trait]
2630    impl Provider for StreamingToolEventMock {
2631        async fn chat_with_system(
2632            &self,
2633            _system_prompt: Option<&str>,
2634            _message: &str,
2635            _model: &str,
2636            _temperature: f64,
2637        ) -> anyhow::Result<String> {
2638            Ok("ok".to_string())
2639        }
2640
2641        fn supports_streaming(&self) -> bool {
2642            true
2643        }
2644
2645        fn supports_streaming_tool_events(&self) -> bool {
2646            self.supports_tool_events
2647        }
2648
2649        fn stream_chat(
2650            &self,
2651            _request: ChatRequest<'_>,
2652            _model: &str,
2653            _temperature: f64,
2654            _options: StreamOptions,
2655        ) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
2656            self.stream_calls.fetch_add(1, Ordering::SeqCst);
2657            stream::iter(vec![
2658                Ok(StreamEvent::ToolCall(super::super::traits::ToolCall {
2659                    id: "call_1".to_string(),
2660                    name: "shell".to_string(),
2661                    arguments: r#"{"command":"date"}"#.to_string(),
2662                })),
2663                Ok(StreamEvent::Final),
2664            ])
2665            .boxed()
2666        }
2667    }
2668
2669    #[async_trait]
2670    impl Provider for Arc<StreamingToolEventMock> {
2671        async fn chat_with_system(
2672            &self,
2673            system_prompt: Option<&str>,
2674            message: &str,
2675            model: &str,
2676            temperature: f64,
2677        ) -> anyhow::Result<String> {
2678            self.as_ref()
2679                .chat_with_system(system_prompt, message, model, temperature)
2680                .await
2681        }
2682
2683        fn supports_streaming(&self) -> bool {
2684            self.as_ref().supports_streaming()
2685        }
2686
2687        fn supports_streaming_tool_events(&self) -> bool {
2688            self.as_ref().supports_streaming_tool_events()
2689        }
2690
2691        fn stream_chat(
2692            &self,
2693            request: ChatRequest<'_>,
2694            model: &str,
2695            temperature: f64,
2696            options: StreamOptions,
2697        ) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
2698            self.as_ref()
2699                .stream_chat(request, model, temperature, options)
2700        }
2701    }
2702
2703    #[tokio::test]
2704    async fn stream_chat_prefers_provider_with_tool_event_support() {
2705        let primary = Arc::new(StreamingToolEventMock::new(false));
2706        let fallback = Arc::new(StreamingToolEventMock::new(true));
2707        let provider = ReliableProvider::new(
2708            vec![
2709                (
2710                    "primary".into(),
2711                    Box::new(Arc::clone(&primary)) as Box<dyn Provider>,
2712                ),
2713                (
2714                    "fallback".into(),
2715                    Box::new(Arc::clone(&fallback)) as Box<dyn Provider>,
2716                ),
2717            ],
2718            0,
2719            1,
2720        );
2721
2722        let messages = vec![ChatMessage::user("hello")];
2723        let tools = vec![ToolSpec {
2724            name: "shell".to_string(),
2725            description: "run shell".to_string(),
2726            parameters: serde_json::json!({
2727                "type": "object",
2728                "properties": {
2729                    "command": { "type": "string" }
2730                }
2731            }),
2732        }];
2733        let mut stream = provider.stream_chat(
2734            ChatRequest {
2735                messages: &messages,
2736                tools: Some(&tools),
2737            },
2738            "model",
2739            0.0,
2740            StreamOptions::new(true),
2741        );
2742
2743        let first = stream.next().await.unwrap().unwrap();
2744        let second = stream.next().await.unwrap().unwrap();
2745        assert!(stream.next().await.is_none());
2746
2747        match first {
2748            StreamEvent::ToolCall(call) => assert_eq!(call.name, "shell"),
2749            other => panic!("expected tool-call event, got {other:?}"),
2750        }
2751        assert!(matches!(second, StreamEvent::Final));
2752        assert_eq!(primary.stream_calls.load(Ordering::SeqCst), 0);
2753        assert_eq!(fallback.stream_calls.load(Ordering::SeqCst), 1);
2754    }
2755
2756    #[tokio::test]
2757    async fn stream_chat_errors_when_no_provider_supports_tool_events() {
2758        let primary = Arc::new(StreamingToolEventMock::new(false));
2759        let provider = ReliableProvider::new(
2760            vec![(
2761                "primary".into(),
2762                Box::new(Arc::clone(&primary)) as Box<dyn Provider>,
2763            )],
2764            0,
2765            1,
2766        );
2767
2768        let messages = vec![ChatMessage::user("hello")];
2769        let tools = vec![ToolSpec {
2770            name: "shell".to_string(),
2771            description: "run shell".to_string(),
2772            parameters: serde_json::json!({"type": "object"}),
2773        }];
2774        let mut stream = provider.stream_chat(
2775            ChatRequest {
2776                messages: &messages,
2777                tools: Some(&tools),
2778            },
2779            "model",
2780            0.0,
2781            StreamOptions::new(true),
2782        );
2783
2784        let first = stream.next().await.unwrap();
2785        let err = first.expect_err("stream should fail without tool-event support");
2786        assert!(
2787            err.to_string()
2788                .contains("No provider supports streaming tool events"),
2789            "unexpected stream error: {err}"
2790        );
2791        assert!(stream.next().await.is_none());
2792        assert_eq!(primary.stream_calls.load(Ordering::SeqCst), 0);
2793    }
2794
2795    // ── stream_chat_with_history failover tests ──────────────────────
2796
2797    /// Mock provider that supports streaming via stream_chat_with_history.
2798    struct StreamingHistoryMock {
2799        stream_calls: Arc<AtomicUsize>,
2800        supports: bool,
2801    }
2802
2803    #[async_trait]
2804    impl Provider for StreamingHistoryMock {
2805        async fn chat_with_system(
2806            &self,
2807            _system_prompt: Option<&str>,
2808            _message: &str,
2809            _model: &str,
2810            _temperature: f64,
2811        ) -> anyhow::Result<String> {
2812            Ok("ok".to_string())
2813        }
2814
2815        fn supports_streaming(&self) -> bool {
2816            self.supports
2817        }
2818
2819        fn stream_chat_with_history(
2820            &self,
2821            messages: &[ChatMessage],
2822            _model: &str,
2823            _temperature: f64,
2824            _options: StreamOptions,
2825        ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
2826            self.stream_calls.fetch_add(1, Ordering::SeqCst);
2827            // Echo the number of messages as the delta to verify history was passed through
2828            let msg_count = messages.len().to_string();
2829            stream::iter(vec![
2830                Ok(StreamChunk::delta(msg_count)),
2831                Ok(StreamChunk::final_chunk()),
2832            ])
2833            .boxed()
2834        }
2835    }
2836
2837    #[tokio::test]
2838    async fn stream_chat_with_history_delegates_to_streaming_provider() {
2839        let calls = Arc::new(AtomicUsize::new(0));
2840        let provider = ReliableProvider::new(
2841            vec![(
2842                "primary".into(),
2843                Box::new(StreamingHistoryMock {
2844                    stream_calls: Arc::clone(&calls),
2845                    supports: true,
2846                }) as Box<dyn Provider>,
2847            )],
2848            0,
2849            1,
2850        );
2851
2852        let messages = vec![
2853            ChatMessage::system("system"),
2854            ChatMessage::user("msg1"),
2855            ChatMessage::assistant("resp1"),
2856            ChatMessage::user("msg2"),
2857        ];
2858        let mut stream =
2859            provider.stream_chat_with_history(&messages, "model", 0.0, StreamOptions::new(true));
2860
2861        let first = stream.next().await.unwrap().unwrap();
2862        assert_eq!(first.delta, "4", "should pass all 4 messages to provider");
2863        let second = stream.next().await.unwrap().unwrap();
2864        assert!(second.is_final);
2865        assert!(stream.next().await.is_none());
2866        assert_eq!(calls.load(Ordering::SeqCst), 1);
2867    }
2868
2869    #[tokio::test]
2870    async fn stream_chat_with_history_skips_non_streaming_providers() {
2871        let non_streaming_calls = Arc::new(AtomicUsize::new(0));
2872        let streaming_calls = Arc::new(AtomicUsize::new(0));
2873
2874        let provider = ReliableProvider::new(
2875            vec![
2876                (
2877                    "non-streaming".into(),
2878                    Box::new(StreamingHistoryMock {
2879                        stream_calls: Arc::clone(&non_streaming_calls),
2880                        supports: false,
2881                    }) as Box<dyn Provider>,
2882                ),
2883                (
2884                    "streaming".into(),
2885                    Box::new(StreamingHistoryMock {
2886                        stream_calls: Arc::clone(&streaming_calls),
2887                        supports: true,
2888                    }) as Box<dyn Provider>,
2889                ),
2890            ],
2891            0,
2892            1,
2893        );
2894
2895        let messages = vec![ChatMessage::user("hello")];
2896        let mut stream =
2897            provider.stream_chat_with_history(&messages, "model", 0.0, StreamOptions::new(true));
2898
2899        let first = stream.next().await.unwrap().unwrap();
2900        assert_eq!(first.delta, "1");
2901        assert_eq!(
2902            non_streaming_calls.load(Ordering::SeqCst),
2903            0,
2904            "non-streaming provider should be skipped"
2905        );
2906        assert_eq!(
2907            streaming_calls.load(Ordering::SeqCst),
2908            1,
2909            "streaming provider should be used"
2910        );
2911    }
2912
2913    #[tokio::test]
2914    async fn stream_chat_with_history_errors_when_no_provider_supports_streaming() {
2915        let provider = ReliableProvider::new(
2916            vec![(
2917                "non-streaming".into(),
2918                Box::new(StreamingHistoryMock {
2919                    stream_calls: Arc::new(AtomicUsize::new(0)),
2920                    supports: false,
2921                }) as Box<dyn Provider>,
2922            )],
2923            0,
2924            1,
2925        );
2926
2927        let messages = vec![ChatMessage::user("hello")];
2928        let mut stream =
2929            provider.stream_chat_with_history(&messages, "model", 0.0, StreamOptions::new(true));
2930
2931        let first = stream.next().await.unwrap();
2932        let err = first.expect_err("should fail when no provider supports streaming");
2933        assert!(
2934            err.to_string().contains("No provider supports streaming"),
2935            "unexpected error: {err}"
2936        );
2937        assert!(stream.next().await.is_none());
2938    }
2939
2940    #[tokio::test]
2941    async fn fallback_records_provider_fallback_info() {
2942        scope_provider_fallback(async {
2943            let provider = ReliableProvider::new(
2944                vec![
2945                    (
2946                        "broken".into(),
2947                        Box::new(MockProvider {
2948                            calls: Arc::new(AtomicUsize::new(0)),
2949                            fail_until_attempt: 99, // always fail
2950                            response: "unused",
2951                            error: "401 Unauthorized",
2952                        }),
2953                    ),
2954                    (
2955                        "working".into(),
2956                        Box::new(MockProvider {
2957                            calls: Arc::new(AtomicUsize::new(0)),
2958                            fail_until_attempt: 0,
2959                            response: "hello from working",
2960                            error: "unused",
2961                        }),
2962                    ),
2963                ],
2964                2,
2965                1,
2966            );
2967
2968            let resp = provider.simple_chat("hi", "test-model", 0.0).await.unwrap();
2969            assert_eq!(resp, "hello from working");
2970
2971            let fb = take_last_provider_fallback();
2972            assert!(fb.is_some(), "fallback info should be recorded");
2973            let fb = fb.unwrap();
2974            assert_eq!(fb.requested_provider, "broken");
2975            assert_eq!(fb.actual_provider, "working");
2976            assert_eq!(fb.actual_model, "test-model");
2977
2978            // Second take should be None.
2979            assert!(take_last_provider_fallback().is_none());
2980        })
2981        .await;
2982    }
2983}