1use super::Provider;
2use super::traits::{
3 ChatMessage, ChatRequest, ChatResponse, StreamChunk, StreamEvent, StreamOptions, StreamResult,
4};
5use async_trait::async_trait;
6use futures_util::{StreamExt, stream};
7use std::cell::RefCell;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::time::Duration;
11
12#[derive(Debug, Clone)]
20pub struct ProviderFallbackInfo {
21 pub requested_provider: String,
23 pub requested_model: String,
25 pub actual_provider: String,
27 pub actual_model: String,
29}
30
31tokio::task_local! {
32 static PROVIDER_FALLBACK: RefCell<Option<ProviderFallbackInfo>>;
33}
34
35pub fn take_last_provider_fallback() -> Option<ProviderFallbackInfo> {
38 PROVIDER_FALLBACK
39 .try_with(|cell| cell.borrow_mut().take())
40 .ok()
41 .flatten()
42}
43
44pub async fn scope_provider_fallback<F: std::future::Future>(future: F) -> F::Output {
49 PROVIDER_FALLBACK.scope(RefCell::new(None), future).await
50}
51
52fn record_provider_fallback(
54 requested_provider: &str,
55 requested_model: &str,
56 actual_provider: &str,
57 actual_model: &str,
58) {
59 let _ = PROVIDER_FALLBACK.try_with(|cell| {
60 *cell.borrow_mut() = Some(ProviderFallbackInfo {
61 requested_provider: requested_provider.to_string(),
62 requested_model: requested_model.to_string(),
63 actual_provider: actual_provider.to_string(),
64 actual_model: actual_model.to_string(),
65 });
66 });
67}
68
69pub fn is_non_retryable(err: &anyhow::Error) -> bool {
77 if is_context_window_exceeded(err) {
80 return false;
81 }
82
83 if is_tool_schema_error(err) {
87 return false;
88 }
89
90 if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>() {
93 if let Some(status) = reqwest_err.status() {
94 let code = status.as_u16();
95 return status.is_client_error() && code != 429 && code != 408;
96 }
97 }
98 let msg = err.to_string();
101 for word in msg.split(|c: char| !c.is_ascii_digit()) {
102 if let Ok(code) = word.parse::<u16>() {
103 if (400..500).contains(&code) {
104 return code != 429 && code != 408;
105 }
106 }
107 }
108
109 let msg_lower = msg.to_lowercase();
112 let auth_failure_hints = [
113 "invalid api key",
114 "incorrect api key",
115 "missing api key",
116 "api key not set",
117 "authentication failed",
118 "auth failed",
119 "unauthorized",
120 "forbidden",
121 "permission denied",
122 "access denied",
123 "invalid token",
124 ];
125
126 if auth_failure_hints
127 .iter()
128 .any(|hint| msg_lower.contains(hint))
129 {
130 return true;
131 }
132
133 msg_lower.contains("model")
134 && (msg_lower.contains("not found")
135 || msg_lower.contains("unknown")
136 || msg_lower.contains("unsupported")
137 || msg_lower.contains("does not exist")
138 || msg_lower.contains("invalid"))
139}
140
141pub fn is_tool_schema_error(err: &anyhow::Error) -> bool {
147 let lower = err.to_string().to_lowercase();
148 let hints = [
149 "tool call validation failed",
150 "was not in request",
151 "not found in tool list",
152 "invalid_tool_call",
153 ];
154 hints.iter().any(|hint| lower.contains(hint))
155}
156
157pub(crate) fn is_context_window_exceeded(err: &anyhow::Error) -> bool {
158 let lower = err.to_string().to_lowercase();
159 let hints = [
160 "exceeds the context window",
161 "exceeds the available context size",
162 "context window of this model",
163 "maximum context length",
164 "context length exceeded",
165 "too many tokens",
166 "token limit exceeded",
167 "prompt is too long",
168 "input is too long",
169 "prompt exceeds max length",
170 ];
171
172 hints.iter().any(|hint| lower.contains(hint))
173}
174
175fn is_rate_limited(err: &anyhow::Error) -> bool {
177 if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>() {
178 if let Some(status) = reqwest_err.status() {
179 return status.as_u16() == 429;
180 }
181 }
182 let msg = err.to_string();
183 msg.contains("429")
184 && (msg.contains("Too Many") || msg.contains("rate") || msg.contains("limit"))
185}
186
187fn is_non_retryable_rate_limit(err: &anyhow::Error) -> bool {
194 if !is_rate_limited(err) {
195 return false;
196 }
197
198 let msg = err.to_string();
199 let lower = msg.to_lowercase();
200
201 let business_hints = [
202 "plan does not include",
203 "doesn't include",
204 "not include",
205 "insufficient balance",
206 "insufficient_balance",
207 "insufficient quota",
208 "insufficient_quota",
209 "quota exhausted",
210 "out of credits",
211 "no available package",
212 "package not active",
213 "purchase package",
214 "model not available for your plan",
215 ];
216
217 if business_hints.iter().any(|hint| lower.contains(hint)) {
218 return true;
219 }
220
221 for token in lower.split(|c: char| !c.is_ascii_digit()) {
223 if let Ok(code) = token.parse::<u16>() {
224 if matches!(code, 1113 | 1311) {
225 return true;
226 }
227 }
228 }
229
230 false
231}
232
233fn parse_retry_after_ms(err: &anyhow::Error) -> Option<u64> {
236 let msg = err.to_string();
237 let lower = msg.to_lowercase();
238
239 for prefix in &[
241 "retry-after:",
242 "retry_after:",
243 "retry-after ",
244 "retry_after ",
245 ] {
246 if let Some(pos) = lower.find(prefix) {
247 let after = &msg[pos + prefix.len()..];
248 let num_str: String = after
249 .trim()
250 .chars()
251 .take_while(|c| c.is_ascii_digit() || *c == '.')
252 .collect();
253 if let Ok(secs) = num_str.parse::<f64>() {
254 if secs.is_finite() && secs >= 0.0 {
255 let millis = Duration::from_secs_f64(secs).as_millis();
256 if let Ok(value) = u64::try_from(millis) {
257 return Some(value);
258 }
259 }
260 }
261 }
262 }
263 None
264}
265
266fn failure_reason(rate_limited: bool, non_retryable: bool) -> &'static str {
267 if rate_limited && non_retryable {
268 "rate_limited_non_retryable"
269 } else if rate_limited {
270 "rate_limited"
271 } else if non_retryable {
272 "non_retryable"
273 } else {
274 "retryable"
275 }
276}
277
278fn compact_error_detail(err: &anyhow::Error) -> String {
279 super::sanitize_api_error(&format!("{:#}", err))
281 .split_whitespace()
282 .collect::<Vec<_>>()
283 .join(" ")
284}
285
286fn truncate_for_context(messages: &mut Vec<ChatMessage>) -> usize {
290 let non_system: Vec<usize> = messages
292 .iter()
293 .enumerate()
294 .filter(|(_, m)| m.role != "system")
295 .map(|(i, _)| i)
296 .collect();
297
298 if non_system.len() <= 1 {
300 return 0;
301 }
302
303 let drop_count = non_system.len() / 2;
305 let indices_to_remove: Vec<usize> = non_system[..drop_count].to_vec();
306
307 for &idx in indices_to_remove.iter().rev() {
309 messages.remove(idx);
310 }
311
312 drop_count
313}
314
315fn push_failure(
316 failures: &mut Vec<String>,
317 provider_name: &str,
318 model: &str,
319 attempt: u32,
320 max_attempts: u32,
321 reason: &str,
322 error_detail: &str,
323) {
324 failures.push(format!(
325 "provider={provider_name} model={model} attempt {attempt}/{max_attempts}: {reason}; error={error_detail}"
326 ));
327}
328
329pub struct ReliableProvider {
341 providers: Vec<(String, Box<dyn Provider>)>,
342 max_retries: u32,
343 base_backoff_ms: u64,
344 api_keys: Vec<String>,
346 key_index: AtomicUsize,
347 model_fallbacks: HashMap<String, Vec<String>>,
349}
350
351impl ReliableProvider {
352 pub fn new(
353 providers: Vec<(String, Box<dyn Provider>)>,
354 max_retries: u32,
355 base_backoff_ms: u64,
356 ) -> Self {
357 Self {
358 providers,
359 max_retries,
360 base_backoff_ms: base_backoff_ms.max(50),
361 api_keys: Vec::new(),
362 key_index: AtomicUsize::new(0),
363 model_fallbacks: HashMap::new(),
364 }
365 }
366
367 pub fn with_api_keys(mut self, keys: Vec<String>) -> Self {
369 self.api_keys = keys;
370 self
371 }
372
373 pub fn with_model_fallbacks(mut self, fallbacks: HashMap<String, Vec<String>>) -> Self {
375 self.model_fallbacks = fallbacks;
376 self
377 }
378
379 fn model_chain<'a>(&'a self, model: &'a str) -> Vec<&'a str> {
381 let mut chain = vec![model];
382 if let Some(fallbacks) = self.model_fallbacks.get(model) {
383 chain.extend(fallbacks.iter().map(|s| s.as_str()));
384 }
385 chain
386 }
387
388 fn rotate_key(&self) -> Option<&str> {
390 if self.api_keys.is_empty() {
391 return None;
392 }
393 let idx = self.key_index.fetch_add(1, Ordering::Relaxed) % self.api_keys.len();
394 Some(&self.api_keys[idx])
395 }
396
397 fn compute_backoff(&self, base: u64, err: &anyhow::Error) -> u64 {
399 if let Some(retry_after) = parse_retry_after_ms(err) {
400 retry_after.min(30_000).max(base)
402 } else {
403 base
404 }
405 }
406}
407
408#[async_trait]
409impl Provider for ReliableProvider {
410 async fn warmup(&self) -> anyhow::Result<()> {
411 for (name, provider) in &self.providers {
412 tracing::info!(provider = name, "Warming up provider connection pool");
413 if provider.warmup().await.is_err() {
414 tracing::warn!(provider = name, "Warmup failed (non-fatal)");
415 }
416 }
417 Ok(())
418 }
419
420 async fn chat_with_system(
421 &self,
422 system_prompt: Option<&str>,
423 message: &str,
424 model: &str,
425 temperature: f64,
426 ) -> anyhow::Result<String> {
427 let models = self.model_chain(model);
428 let mut failures = Vec::new();
429
430 for current_model in &models {
435 for (provider_name, provider) in &self.providers {
436 let mut backoff_ms = self.base_backoff_ms;
437
438 for attempt in 0..=self.max_retries {
439 match provider
440 .chat_with_system(system_prompt, message, current_model, temperature)
441 .await
442 {
443 Ok(resp) => {
444 if attempt > 0
445 || *current_model != model
446 || self.providers.first().map(|(n, _)| n.as_str())
447 != Some(provider_name)
448 {
449 tracing::info!(
450 provider = provider_name,
451 model = *current_model,
452 attempt,
453 original_model = model,
454 "Provider recovered (failover/retry)"
455 );
456 let primary = self
457 .providers
458 .first()
459 .map(|(n, _)| n.as_str())
460 .unwrap_or("");
461 record_provider_fallback(
462 primary,
463 model,
464 provider_name,
465 current_model,
466 );
467 }
468 return Ok(resp);
469 }
470 Err(e) => {
471 if is_context_window_exceeded(&e) {
474 let error_detail = compact_error_detail(&e);
475 push_failure(
476 &mut failures,
477 provider_name,
478 current_model,
479 attempt + 1,
480 self.max_retries + 1,
481 "non_retryable",
482 &error_detail,
483 );
484 anyhow::bail!(
485 "Request exceeds model context window. Attempts:\n{}",
486 failures.join("\n")
487 );
488 }
489
490 let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
491 let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
492 let rate_limited = is_rate_limited(&e);
493 let failure_reason = failure_reason(rate_limited, non_retryable);
494 let error_detail = compact_error_detail(&e);
495
496 push_failure(
497 &mut failures,
498 provider_name,
499 current_model,
500 attempt + 1,
501 self.max_retries + 1,
502 failure_reason,
503 &error_detail,
504 );
505
506 if rate_limited && !non_retryable_rate_limit {
509 if let Some(new_key) = self.rotate_key() {
510 tracing::warn!(
511 provider = provider_name,
512 error = %error_detail,
513 "Rate limited; key rotation selected key ending ...{} \
514 but cannot apply (Provider trait has no set_api_key). \
515 Retrying with original key.",
516 &new_key[new_key.len().saturating_sub(4)..]
517 );
518 }
519 }
520
521 if non_retryable {
522 tracing::warn!(
523 provider = provider_name,
524 model = *current_model,
525 error = %error_detail,
526 "Non-retryable error, moving on"
527 );
528 break;
529 }
530
531 if attempt < self.max_retries {
532 let wait = self.compute_backoff(backoff_ms, &e);
533 tracing::warn!(
534 provider = provider_name,
535 model = *current_model,
536 attempt = attempt + 1,
537 backoff_ms = wait,
538 reason = failure_reason,
539 error = %error_detail,
540 "Provider call failed, retrying"
541 );
542 tokio::time::sleep(Duration::from_millis(wait)).await;
543 backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
544 }
545 }
546 }
547 }
548
549 tracing::warn!(
550 provider = provider_name,
551 model = *current_model,
552 "Exhausted retries, trying next provider/model"
553 );
554 }
555
556 if *current_model != model {
557 tracing::warn!(
558 original_model = model,
559 fallback_model = *current_model,
560 "Model fallback exhausted all providers, trying next fallback model"
561 );
562 }
563 }
564
565 anyhow::bail!(
566 "All providers/models failed. Attempts:\n{}",
567 failures.join("\n")
568 )
569 }
570
571 async fn chat_with_history(
572 &self,
573 messages: &[ChatMessage],
574 model: &str,
575 temperature: f64,
576 ) -> anyhow::Result<String> {
577 let models = self.model_chain(model);
578 let mut failures = Vec::new();
579 let mut effective_messages = messages.to_vec();
580 let mut context_truncated = false;
581
582 for current_model in &models {
583 for (provider_name, provider) in &self.providers {
584 let mut backoff_ms = self.base_backoff_ms;
585
586 for attempt in 0..=self.max_retries {
587 match provider
588 .chat_with_history(&effective_messages, current_model, temperature)
589 .await
590 {
591 Ok(resp) => {
592 if attempt > 0
593 || *current_model != model
594 || context_truncated
595 || self.providers.first().map(|(n, _)| n.as_str())
596 != Some(provider_name)
597 {
598 tracing::info!(
599 provider = provider_name,
600 model = *current_model,
601 attempt,
602 original_model = model,
603 context_truncated,
604 "Provider recovered (failover/retry)"
605 );
606 let primary = self
607 .providers
608 .first()
609 .map(|(n, _)| n.as_str())
610 .unwrap_or("");
611 record_provider_fallback(
612 primary,
613 model,
614 provider_name,
615 current_model,
616 );
617 }
618 return Ok(resp);
619 }
620 Err(e) => {
621 if is_context_window_exceeded(&e) && !context_truncated {
623 let dropped = truncate_for_context(&mut effective_messages);
624 if dropped > 0 {
625 context_truncated = true;
626 tracing::warn!(
627 provider = provider_name,
628 model = *current_model,
629 dropped,
630 remaining = effective_messages.len(),
631 "Context window exceeded; truncated history and retrying"
632 );
633 continue; }
635 let error_detail = compact_error_detail(&e);
639 push_failure(
640 &mut failures,
641 provider_name,
642 current_model,
643 attempt + 1,
644 self.max_retries + 1,
645 "non_retryable",
646 &error_detail,
647 );
648 anyhow::bail!(
649 "Request exceeds model context window and cannot be reduced further. \
650 Try using a model with a larger context window, reducing the number \
651 of tools/skills, or enabling compact_context in config. Attempts:\n{}",
652 failures.join("\n")
653 );
654 }
655
656 let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
657 let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
658 let rate_limited = is_rate_limited(&e);
659 let failure_reason = failure_reason(rate_limited, non_retryable);
660 let error_detail = compact_error_detail(&e);
661
662 push_failure(
663 &mut failures,
664 provider_name,
665 current_model,
666 attempt + 1,
667 self.max_retries + 1,
668 failure_reason,
669 &error_detail,
670 );
671
672 if rate_limited && !non_retryable_rate_limit {
673 if let Some(new_key) = self.rotate_key() {
674 tracing::warn!(
675 provider = provider_name,
676 error = %error_detail,
677 "Rate limited; key rotation selected key ending ...{} \
678 but cannot apply (Provider trait has no set_api_key). \
679 Retrying with original key.",
680 &new_key[new_key.len().saturating_sub(4)..]
681 );
682 }
683 }
684
685 if non_retryable {
686 tracing::warn!(
687 provider = provider_name,
688 model = *current_model,
689 error = %error_detail,
690 "Non-retryable error, moving on"
691 );
692 break;
693 }
694
695 if attempt < self.max_retries {
696 let wait = self.compute_backoff(backoff_ms, &e);
697 tracing::warn!(
698 provider = provider_name,
699 model = *current_model,
700 attempt = attempt + 1,
701 backoff_ms = wait,
702 reason = failure_reason,
703 error = %error_detail,
704 "Provider call failed, retrying"
705 );
706 tokio::time::sleep(Duration::from_millis(wait)).await;
707 backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
708 }
709 }
710 }
711 }
712
713 tracing::warn!(
714 provider = provider_name,
715 model = *current_model,
716 "Exhausted retries, trying next provider/model"
717 );
718 }
719 }
720
721 anyhow::bail!(
722 "All providers/models failed. Attempts:\n{}",
723 failures.join("\n")
724 )
725 }
726
727 fn supports_native_tools(&self) -> bool {
728 self.providers
729 .first()
730 .map(|(_, p)| p.supports_native_tools())
731 .unwrap_or(false)
732 }
733
734 fn supports_vision(&self) -> bool {
735 self.providers
736 .iter()
737 .any(|(_, provider)| provider.supports_vision())
738 }
739
740 async fn chat_with_tools(
741 &self,
742 messages: &[ChatMessage],
743 tools: &[serde_json::Value],
744 model: &str,
745 temperature: f64,
746 ) -> anyhow::Result<ChatResponse> {
747 let models = self.model_chain(model);
748 let mut failures = Vec::new();
749 let mut effective_messages = messages.to_vec();
750 let mut context_truncated = false;
751
752 for current_model in &models {
753 for (provider_name, provider) in &self.providers {
754 let mut backoff_ms = self.base_backoff_ms;
755
756 for attempt in 0..=self.max_retries {
757 match provider
758 .chat_with_tools(&effective_messages, tools, current_model, temperature)
759 .await
760 {
761 Ok(resp) => {
762 if attempt > 0
763 || *current_model != model
764 || context_truncated
765 || self.providers.first().map(|(n, _)| n.as_str())
766 != Some(provider_name)
767 {
768 tracing::info!(
769 provider = provider_name,
770 model = *current_model,
771 attempt,
772 original_model = model,
773 context_truncated,
774 "Provider recovered (failover/retry)"
775 );
776 let primary = self
777 .providers
778 .first()
779 .map(|(n, _)| n.as_str())
780 .unwrap_or("");
781 record_provider_fallback(
782 primary,
783 model,
784 provider_name,
785 current_model,
786 );
787 }
788 return Ok(resp);
789 }
790 Err(e) => {
791 if is_context_window_exceeded(&e) && !context_truncated {
793 let dropped = truncate_for_context(&mut effective_messages);
794 if dropped > 0 {
795 context_truncated = true;
796 tracing::warn!(
797 provider = provider_name,
798 model = *current_model,
799 dropped,
800 remaining = effective_messages.len(),
801 "Context window exceeded; truncated history and retrying"
802 );
803 continue; }
805 let error_detail = compact_error_detail(&e);
809 push_failure(
810 &mut failures,
811 provider_name,
812 current_model,
813 attempt + 1,
814 self.max_retries + 1,
815 "non_retryable",
816 &error_detail,
817 );
818 anyhow::bail!(
819 "Request exceeds model context window and cannot be reduced further. \
820 Try using a model with a larger context window, reducing the number \
821 of tools/skills, or enabling compact_context in config. Attempts:\n{}",
822 failures.join("\n")
823 );
824 }
825
826 let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
827 let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
828 let rate_limited = is_rate_limited(&e);
829 let failure_reason = failure_reason(rate_limited, non_retryable);
830 let error_detail = compact_error_detail(&e);
831
832 push_failure(
833 &mut failures,
834 provider_name,
835 current_model,
836 attempt + 1,
837 self.max_retries + 1,
838 failure_reason,
839 &error_detail,
840 );
841
842 if rate_limited && !non_retryable_rate_limit {
843 if let Some(new_key) = self.rotate_key() {
844 tracing::warn!(
845 provider = provider_name,
846 error = %error_detail,
847 "Rate limited; key rotation selected key ending ...{} \
848 but cannot apply (Provider trait has no set_api_key). \
849 Retrying with original key.",
850 &new_key[new_key.len().saturating_sub(4)..]
851 );
852 }
853 }
854
855 if non_retryable {
856 tracing::warn!(
857 provider = provider_name,
858 model = *current_model,
859 error = %error_detail,
860 "Non-retryable error, moving on"
861 );
862 break;
863 }
864
865 if attempt < self.max_retries {
866 let wait = self.compute_backoff(backoff_ms, &e);
867 tracing::warn!(
868 provider = provider_name,
869 model = *current_model,
870 attempt = attempt + 1,
871 backoff_ms = wait,
872 reason = failure_reason,
873 error = %error_detail,
874 "Provider call failed, retrying"
875 );
876 tokio::time::sleep(Duration::from_millis(wait)).await;
877 backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
878 }
879 }
880 }
881 }
882
883 tracing::warn!(
884 provider = provider_name,
885 model = *current_model,
886 "Exhausted retries, trying next provider/model"
887 );
888 }
889 }
890
891 anyhow::bail!(
892 "All providers/models failed. Attempts:\n{}",
893 failures.join("\n")
894 )
895 }
896
897 async fn chat(
898 &self,
899 request: ChatRequest<'_>,
900 model: &str,
901 temperature: f64,
902 ) -> anyhow::Result<ChatResponse> {
903 let models = self.model_chain(model);
904 let mut failures = Vec::new();
905 let mut effective_messages = request.messages.to_vec();
906 let mut context_truncated = false;
907
908 for current_model in &models {
909 for (provider_name, provider) in &self.providers {
910 let mut backoff_ms = self.base_backoff_ms;
911
912 for attempt in 0..=self.max_retries {
913 let req = ChatRequest {
914 messages: &effective_messages,
915 tools: request.tools,
916 };
917 match provider.chat(req, current_model, temperature).await {
918 Ok(resp) => {
919 if attempt > 0
920 || *current_model != model
921 || context_truncated
922 || self.providers.first().map(|(n, _)| n.as_str())
923 != Some(provider_name)
924 {
925 tracing::info!(
926 provider = provider_name,
927 model = *current_model,
928 attempt,
929 original_model = model,
930 context_truncated,
931 "Provider recovered (failover/retry)"
932 );
933 let primary = self
934 .providers
935 .first()
936 .map(|(n, _)| n.as_str())
937 .unwrap_or("");
938 record_provider_fallback(
939 primary,
940 model,
941 provider_name,
942 current_model,
943 );
944 }
945 return Ok(resp);
946 }
947 Err(e) => {
948 if is_context_window_exceeded(&e) && !context_truncated {
950 let dropped = truncate_for_context(&mut effective_messages);
951 if dropped > 0 {
952 context_truncated = true;
953 tracing::warn!(
954 provider = provider_name,
955 model = *current_model,
956 dropped,
957 remaining = effective_messages.len(),
958 "Context window exceeded; truncated history and retrying"
959 );
960 continue; }
962 let error_detail = compact_error_detail(&e);
966 push_failure(
967 &mut failures,
968 provider_name,
969 current_model,
970 attempt + 1,
971 self.max_retries + 1,
972 "non_retryable",
973 &error_detail,
974 );
975 anyhow::bail!(
976 "Request exceeds model context window and cannot be reduced further. \
977 Try using a model with a larger context window, reducing the number \
978 of tools/skills, or enabling compact_context in config. Attempts:\n{}",
979 failures.join("\n")
980 );
981 }
982
983 let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
984 let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
985 let rate_limited = is_rate_limited(&e);
986 let failure_reason = failure_reason(rate_limited, non_retryable);
987 let error_detail = compact_error_detail(&e);
988
989 push_failure(
990 &mut failures,
991 provider_name,
992 current_model,
993 attempt + 1,
994 self.max_retries + 1,
995 failure_reason,
996 &error_detail,
997 );
998
999 if rate_limited && !non_retryable_rate_limit {
1000 if let Some(new_key) = self.rotate_key() {
1001 tracing::warn!(
1002 provider = provider_name,
1003 error = %error_detail,
1004 "Rate limited; key rotation selected key ending ...{} \
1005 but cannot apply (Provider trait has no set_api_key). \
1006 Retrying with original key.",
1007 &new_key[new_key.len().saturating_sub(4)..]
1008 );
1009 }
1010 }
1011
1012 if non_retryable {
1013 tracing::warn!(
1014 provider = provider_name,
1015 model = *current_model,
1016 error = %error_detail,
1017 "Non-retryable error, moving on"
1018 );
1019 break;
1020 }
1021
1022 if attempt < self.max_retries {
1023 let wait = self.compute_backoff(backoff_ms, &e);
1024 tracing::warn!(
1025 provider = provider_name,
1026 model = *current_model,
1027 attempt = attempt + 1,
1028 backoff_ms = wait,
1029 reason = failure_reason,
1030 error = %error_detail,
1031 "Provider call failed, retrying"
1032 );
1033 tokio::time::sleep(Duration::from_millis(wait)).await;
1034 backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
1035 }
1036 }
1037 }
1038 }
1039
1040 tracing::warn!(
1041 provider = provider_name,
1042 model = *current_model,
1043 "Exhausted retries, trying next provider/model"
1044 );
1045 }
1046
1047 if *current_model != model {
1048 tracing::warn!(
1049 original_model = model,
1050 fallback_model = *current_model,
1051 "Model fallback exhausted all providers, trying next fallback model"
1052 );
1053 }
1054 }
1055
1056 anyhow::bail!(
1057 "All providers/models failed. Attempts:\n{}",
1058 failures.join("\n")
1059 )
1060 }
1061
1062 fn supports_streaming(&self) -> bool {
1063 self.providers.iter().any(|(_, p)| p.supports_streaming())
1064 }
1065
1066 fn supports_streaming_tool_events(&self) -> bool {
1067 self.providers
1068 .iter()
1069 .any(|(_, p)| p.supports_streaming_tool_events())
1070 }
1071
1072 fn stream_chat(
1073 &self,
1074 request: ChatRequest<'_>,
1075 model: &str,
1076 temperature: f64,
1077 options: StreamOptions,
1078 ) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
1079 let needs_tool_events = request.tools.is_some_and(|tools| !tools.is_empty());
1080
1081 for (provider_name, provider) in &self.providers {
1082 if !provider.supports_streaming() || !options.enabled {
1083 continue;
1084 }
1085
1086 if needs_tool_events && !provider.supports_streaming_tool_events() {
1087 continue;
1088 }
1089
1090 let provider_clone = provider_name.clone();
1091
1092 let current_model = self
1093 .model_chain(model)
1094 .first()
1095 .copied()
1096 .unwrap_or(model)
1097 .to_string();
1098
1099 let req = ChatRequest {
1100 messages: request.messages,
1101 tools: request.tools,
1102 };
1103 let stream = provider.stream_chat(req, ¤t_model, temperature, options);
1104 let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamEvent>>(100);
1105
1106 tokio::spawn(async move {
1107 let mut stream = stream;
1108 while let Some(event) = stream.next().await {
1109 if let Err(ref e) = event {
1110 tracing::warn!(
1111 provider = provider_clone,
1112 model = current_model,
1113 "Streaming error: {e}"
1114 );
1115 }
1116 if tx.send(event).await.is_err() {
1117 break;
1118 }
1119 }
1120 });
1121
1122 return stream::unfold(rx, |mut rx| async move {
1123 rx.recv().await.map(|event| (event, rx))
1124 })
1125 .boxed();
1126 }
1127
1128 let message = if needs_tool_events {
1129 "No provider supports streaming tool events".to_string()
1130 } else {
1131 "No provider supports streaming".to_string()
1132 };
1133 stream::once(async move { Err(super::traits::StreamError::Provider(message)) }).boxed()
1134 }
1135
1136 fn stream_chat_with_system(
1137 &self,
1138 system_prompt: Option<&str>,
1139 message: &str,
1140 model: &str,
1141 temperature: f64,
1142 options: StreamOptions,
1143 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
1144 for (provider_name, provider) in &self.providers {
1147 if !provider.supports_streaming() || !options.enabled {
1148 continue;
1149 }
1150
1151 let provider_clone = provider_name.clone();
1153
1154 let current_model = match self.model_chain(model).first() {
1156 Some(m) => (*m).to_string(),
1157 None => model.to_string(),
1158 };
1159
1160 let stream = provider.stream_chat_with_system(
1163 system_prompt,
1164 message,
1165 ¤t_model,
1166 temperature,
1167 options,
1168 );
1169
1170 let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
1172
1173 tokio::spawn(async move {
1174 let mut stream = stream;
1175 while let Some(chunk) = stream.next().await {
1176 if let Err(ref e) = chunk {
1177 tracing::warn!(
1178 provider = provider_clone,
1179 model = current_model,
1180 "Streaming error: {e}"
1181 );
1182 }
1183 if tx.send(chunk).await.is_err() {
1184 break; }
1186 }
1187 });
1188
1189 return stream::unfold(rx, |mut rx| async move {
1191 rx.recv().await.map(|chunk| (chunk, rx))
1192 })
1193 .boxed();
1194 }
1195
1196 stream::once(async move {
1198 Err(super::traits::StreamError::Provider(
1199 "No provider supports streaming".to_string(),
1200 ))
1201 })
1202 .boxed()
1203 }
1204
1205 fn stream_chat_with_history(
1206 &self,
1207 messages: &[ChatMessage],
1208 model: &str,
1209 temperature: f64,
1210 options: StreamOptions,
1211 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
1212 for (provider_name, provider) in &self.providers {
1216 if !provider.supports_streaming() || !options.enabled {
1217 continue;
1218 }
1219
1220 let provider_clone = provider_name.clone();
1221
1222 let current_model = match self.model_chain(model).first() {
1223 Some(m) => (*m).to_string(),
1224 None => model.to_string(),
1225 };
1226
1227 let stream =
1228 provider.stream_chat_with_history(messages, ¤t_model, temperature, options);
1229
1230 let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
1231
1232 tokio::spawn(async move {
1233 let mut stream = stream;
1234 while let Some(chunk) = stream.next().await {
1235 if let Err(ref e) = chunk {
1236 tracing::warn!(
1237 provider = provider_clone,
1238 model = current_model,
1239 "Streaming error: {e}"
1240 );
1241 }
1242 if tx.send(chunk).await.is_err() {
1243 break; }
1245 }
1246 });
1247
1248 return stream::unfold(rx, |mut rx| async move {
1249 rx.recv().await.map(|chunk| (chunk, rx))
1250 })
1251 .boxed();
1252 }
1253
1254 stream::once(async move {
1256 Err(super::traits::StreamError::Provider(
1257 "No provider supports streaming".to_string(),
1258 ))
1259 })
1260 .boxed()
1261 }
1262}
1263
1264#[cfg(test)]
1265mod tests {
1266 use super::*;
1267 use crate::tools::ToolSpec;
1268 use futures_util::StreamExt;
1269 use std::sync::Arc;
1270
1271 struct MockProvider {
1272 calls: Arc<AtomicUsize>,
1273 fail_until_attempt: usize,
1274 response: &'static str,
1275 error: &'static str,
1276 }
1277
1278 #[async_trait]
1279 impl Provider for MockProvider {
1280 async fn chat_with_system(
1281 &self,
1282 _system_prompt: Option<&str>,
1283 _message: &str,
1284 _model: &str,
1285 _temperature: f64,
1286 ) -> anyhow::Result<String> {
1287 let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
1288 if attempt <= self.fail_until_attempt {
1289 anyhow::bail!(self.error);
1290 }
1291 Ok(self.response.to_string())
1292 }
1293
1294 async fn chat_with_history(
1295 &self,
1296 _messages: &[ChatMessage],
1297 _model: &str,
1298 _temperature: f64,
1299 ) -> anyhow::Result<String> {
1300 let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
1301 if attempt <= self.fail_until_attempt {
1302 anyhow::bail!(self.error);
1303 }
1304 Ok(self.response.to_string())
1305 }
1306 }
1307
1308 struct ModelAwareMock {
1310 calls: Arc<AtomicUsize>,
1311 models_seen: parking_lot::Mutex<Vec<String>>,
1312 fail_models: Vec<&'static str>,
1313 response: &'static str,
1314 }
1315
1316 #[async_trait]
1317 impl Provider for ModelAwareMock {
1318 async fn chat_with_system(
1319 &self,
1320 _system_prompt: Option<&str>,
1321 _message: &str,
1322 model: &str,
1323 _temperature: f64,
1324 ) -> anyhow::Result<String> {
1325 self.calls.fetch_add(1, Ordering::SeqCst);
1326 self.models_seen.lock().push(model.to_string());
1327 if self.fail_models.contains(&model) {
1328 anyhow::bail!("500 model {} unavailable", model);
1329 }
1330 Ok(self.response.to_string())
1331 }
1332 }
1333
1334 #[tokio::test]
1337 async fn succeeds_without_retry() {
1338 let calls = Arc::new(AtomicUsize::new(0));
1339 let provider = ReliableProvider::new(
1340 vec![(
1341 "primary".into(),
1342 Box::new(MockProvider {
1343 calls: Arc::clone(&calls),
1344 fail_until_attempt: 0,
1345 response: "ok",
1346 error: "boom",
1347 }),
1348 )],
1349 2,
1350 1,
1351 );
1352
1353 let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
1354 assert_eq!(result, "ok");
1355 assert_eq!(calls.load(Ordering::SeqCst), 1);
1356 }
1357
1358 #[tokio::test]
1359 async fn retries_then_recovers() {
1360 let calls = Arc::new(AtomicUsize::new(0));
1361 let provider = ReliableProvider::new(
1362 vec![(
1363 "primary".into(),
1364 Box::new(MockProvider {
1365 calls: Arc::clone(&calls),
1366 fail_until_attempt: 1,
1367 response: "recovered",
1368 error: "temporary",
1369 }),
1370 )],
1371 2,
1372 1,
1373 );
1374
1375 let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
1376 assert_eq!(result, "recovered");
1377 assert_eq!(calls.load(Ordering::SeqCst), 2);
1378 }
1379
1380 #[tokio::test]
1381 async fn falls_back_after_retries_exhausted() {
1382 let primary_calls = Arc::new(AtomicUsize::new(0));
1383 let fallback_calls = Arc::new(AtomicUsize::new(0));
1384
1385 let provider = ReliableProvider::new(
1386 vec![
1387 (
1388 "primary".into(),
1389 Box::new(MockProvider {
1390 calls: Arc::clone(&primary_calls),
1391 fail_until_attempt: usize::MAX,
1392 response: "never",
1393 error: "primary down",
1394 }),
1395 ),
1396 (
1397 "fallback".into(),
1398 Box::new(MockProvider {
1399 calls: Arc::clone(&fallback_calls),
1400 fail_until_attempt: 0,
1401 response: "from fallback",
1402 error: "fallback down",
1403 }),
1404 ),
1405 ],
1406 1,
1407 1,
1408 );
1409
1410 let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
1411 assert_eq!(result, "from fallback");
1412 assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
1413 assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
1414 }
1415
1416 #[tokio::test]
1417 async fn returns_aggregated_error_when_all_providers_fail() {
1418 let provider = ReliableProvider::new(
1419 vec![
1420 (
1421 "p1".into(),
1422 Box::new(MockProvider {
1423 calls: Arc::new(AtomicUsize::new(0)),
1424 fail_until_attempt: usize::MAX,
1425 response: "never",
1426 error: "p1 error",
1427 }),
1428 ),
1429 (
1430 "p2".into(),
1431 Box::new(MockProvider {
1432 calls: Arc::new(AtomicUsize::new(0)),
1433 fail_until_attempt: usize::MAX,
1434 response: "never",
1435 error: "p2 error",
1436 }),
1437 ),
1438 ],
1439 0,
1440 1,
1441 );
1442
1443 let err = provider
1444 .simple_chat("hello", "test", 0.0)
1445 .await
1446 .expect_err("all providers should fail");
1447 let msg = err.to_string();
1448 assert!(msg.contains("All providers/models failed"));
1449 assert!(msg.contains("provider=p1 model=test"));
1450 assert!(msg.contains("provider=p2 model=test"));
1451 assert!(msg.contains("error=p1 error"));
1452 assert!(msg.contains("error=p2 error"));
1453 assert!(msg.contains("retryable"));
1454 }
1455
1456 #[test]
1457 fn non_retryable_detects_common_patterns() {
1458 assert!(is_non_retryable(&anyhow::anyhow!("400 Bad Request")));
1459 assert!(is_non_retryable(&anyhow::anyhow!("401 Unauthorized")));
1460 assert!(is_non_retryable(&anyhow::anyhow!("403 Forbidden")));
1461 assert!(is_non_retryable(&anyhow::anyhow!("404 Not Found")));
1462 assert!(is_non_retryable(&anyhow::anyhow!(
1463 "invalid api key provided"
1464 )));
1465 assert!(is_non_retryable(&anyhow::anyhow!("authentication failed")));
1466 assert!(is_non_retryable(&anyhow::anyhow!(
1467 "model glm-4.7 not found"
1468 )));
1469 assert!(is_non_retryable(&anyhow::anyhow!(
1470 "unsupported model: glm-4.7"
1471 )));
1472 assert!(!is_non_retryable(&anyhow::anyhow!("429 Too Many Requests")));
1473 assert!(!is_non_retryable(&anyhow::anyhow!("408 Request Timeout")));
1474 assert!(!is_non_retryable(&anyhow::anyhow!(
1475 "500 Internal Server Error"
1476 )));
1477 assert!(!is_non_retryable(&anyhow::anyhow!("502 Bad Gateway")));
1478 assert!(!is_non_retryable(&anyhow::anyhow!("timeout")));
1479 assert!(!is_non_retryable(&anyhow::anyhow!("connection reset")));
1480 assert!(!is_non_retryable(&anyhow::anyhow!(
1481 "model overloaded, try again later"
1482 )));
1483 assert!(!is_non_retryable(&anyhow::anyhow!(
1485 "OpenAI Codex stream error: Your input exceeds the context window of this model."
1486 )));
1487 }
1488
1489 #[tokio::test]
1490 async fn context_window_error_aborts_retries_and_model_fallbacks() {
1491 let calls = Arc::new(AtomicUsize::new(0));
1492 let mut model_fallbacks = std::collections::HashMap::new();
1493 model_fallbacks.insert(
1494 "gpt-5.3-codex".to_string(),
1495 vec!["gpt-5.2-codex".to_string()],
1496 );
1497
1498 let provider = ReliableProvider::new(
1499 vec![(
1500 "openai-codex".into(),
1501 Box::new(MockProvider {
1502 calls: Arc::clone(&calls),
1503 fail_until_attempt: usize::MAX,
1504 response: "never",
1505 error: "OpenAI Codex stream error: Your input exceeds the context window of this model. Please adjust your input and try again.",
1506 }),
1507 )],
1508 4,
1509 1,
1510 )
1511 .with_model_fallbacks(model_fallbacks);
1512
1513 let err = provider
1514 .simple_chat("hello", "gpt-5.3-codex", 0.0)
1515 .await
1516 .expect_err("context window overflow should fail fast");
1517 let msg = err.to_string();
1518
1519 assert!(msg.contains("context window"));
1520 assert_eq!(calls.load(Ordering::SeqCst), 1);
1522 }
1523
1524 #[tokio::test]
1525 async fn aggregated_error_marks_non_retryable_model_mismatch_with_details() {
1526 let calls = Arc::new(AtomicUsize::new(0));
1527 let provider = ReliableProvider::new(
1528 vec![(
1529 "custom".into(),
1530 Box::new(MockProvider {
1531 calls: Arc::clone(&calls),
1532 fail_until_attempt: usize::MAX,
1533 response: "never",
1534 error: "unsupported model: glm-4.7",
1535 }),
1536 )],
1537 3,
1538 1,
1539 );
1540
1541 let err = provider
1542 .simple_chat("hello", "glm-4.7", 0.0)
1543 .await
1544 .expect_err("provider should fail");
1545 let msg = err.to_string();
1546
1547 assert!(msg.contains("non_retryable"));
1548 assert!(msg.contains("error=unsupported model: glm-4.7"));
1549 assert_eq!(calls.load(Ordering::SeqCst), 1);
1551 }
1552
1553 #[tokio::test]
1554 async fn skips_retries_on_non_retryable_error() {
1555 let primary_calls = Arc::new(AtomicUsize::new(0));
1556 let fallback_calls = Arc::new(AtomicUsize::new(0));
1557
1558 let provider = ReliableProvider::new(
1559 vec![
1560 (
1561 "primary".into(),
1562 Box::new(MockProvider {
1563 calls: Arc::clone(&primary_calls),
1564 fail_until_attempt: usize::MAX,
1565 response: "never",
1566 error: "401 Unauthorized",
1567 }),
1568 ),
1569 (
1570 "fallback".into(),
1571 Box::new(MockProvider {
1572 calls: Arc::clone(&fallback_calls),
1573 fail_until_attempt: 0,
1574 response: "from fallback",
1575 error: "fallback err",
1576 }),
1577 ),
1578 ],
1579 3,
1580 1,
1581 );
1582
1583 let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
1584 assert_eq!(result, "from fallback");
1585 assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
1587 assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
1588 }
1589
1590 #[tokio::test]
1591 async fn chat_with_history_retries_then_recovers() {
1592 let calls = Arc::new(AtomicUsize::new(0));
1593 let provider = ReliableProvider::new(
1594 vec![(
1595 "primary".into(),
1596 Box::new(MockProvider {
1597 calls: Arc::clone(&calls),
1598 fail_until_attempt: 1,
1599 response: "history ok",
1600 error: "temporary",
1601 }),
1602 )],
1603 2,
1604 1,
1605 );
1606
1607 let messages = vec![ChatMessage::system("system"), ChatMessage::user("hello")];
1608 let result = provider
1609 .chat_with_history(&messages, "test", 0.0)
1610 .await
1611 .unwrap();
1612 assert_eq!(result, "history ok");
1613 assert_eq!(calls.load(Ordering::SeqCst), 2);
1614 }
1615
1616 #[tokio::test]
1617 async fn chat_with_history_falls_back() {
1618 let primary_calls = Arc::new(AtomicUsize::new(0));
1619 let fallback_calls = Arc::new(AtomicUsize::new(0));
1620
1621 let provider = ReliableProvider::new(
1622 vec![
1623 (
1624 "primary".into(),
1625 Box::new(MockProvider {
1626 calls: Arc::clone(&primary_calls),
1627 fail_until_attempt: usize::MAX,
1628 response: "never",
1629 error: "primary down",
1630 }),
1631 ),
1632 (
1633 "fallback".into(),
1634 Box::new(MockProvider {
1635 calls: Arc::clone(&fallback_calls),
1636 fail_until_attempt: 0,
1637 response: "fallback ok",
1638 error: "fallback err",
1639 }),
1640 ),
1641 ],
1642 1,
1643 1,
1644 );
1645
1646 let messages = vec![ChatMessage::user("hello")];
1647 let result = provider
1648 .chat_with_history(&messages, "test", 0.0)
1649 .await
1650 .unwrap();
1651 assert_eq!(result, "fallback ok");
1652 assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
1653 assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
1654 }
1655
1656 #[tokio::test]
1659 async fn model_failover_tries_fallback_model() {
1660 let calls = Arc::new(AtomicUsize::new(0));
1661 let mock = Arc::new(ModelAwareMock {
1662 calls: Arc::clone(&calls),
1663 models_seen: parking_lot::Mutex::new(Vec::new()),
1664 fail_models: vec!["claude-opus"],
1665 response: "ok from sonnet",
1666 });
1667
1668 let mut fallbacks = HashMap::new();
1669 fallbacks.insert("claude-opus".to_string(), vec!["claude-sonnet".to_string()]);
1670
1671 let provider = ReliableProvider::new(
1672 vec![(
1673 "anthropic".into(),
1674 Box::new(mock.clone()) as Box<dyn Provider>,
1675 )],
1676 0, 1,
1678 )
1679 .with_model_fallbacks(fallbacks);
1680
1681 let result = provider
1682 .simple_chat("hello", "claude-opus", 0.0)
1683 .await
1684 .unwrap();
1685 assert_eq!(result, "ok from sonnet");
1686
1687 let seen = mock.models_seen.lock();
1688 assert_eq!(seen.len(), 2);
1689 assert_eq!(seen[0], "claude-opus");
1690 assert_eq!(seen[1], "claude-sonnet");
1691 }
1692
1693 #[tokio::test]
1694 async fn model_failover_all_models_fail() {
1695 let calls = Arc::new(AtomicUsize::new(0));
1696 let mock = Arc::new(ModelAwareMock {
1697 calls: Arc::clone(&calls),
1698 models_seen: parking_lot::Mutex::new(Vec::new()),
1699 fail_models: vec!["model-a", "model-b", "model-c"],
1700 response: "never",
1701 });
1702
1703 let mut fallbacks = HashMap::new();
1704 fallbacks.insert(
1705 "model-a".to_string(),
1706 vec!["model-b".to_string(), "model-c".to_string()],
1707 );
1708
1709 let provider = ReliableProvider::new(
1710 vec![("p1".into(), Box::new(mock.clone()) as Box<dyn Provider>)],
1711 0,
1712 1,
1713 )
1714 .with_model_fallbacks(fallbacks);
1715
1716 let err = provider
1717 .simple_chat("hello", "model-a", 0.0)
1718 .await
1719 .expect_err("all models should fail");
1720 assert!(err.to_string().contains("All providers/models failed"));
1721
1722 let seen = mock.models_seen.lock();
1723 assert_eq!(seen.len(), 3);
1724 }
1725
1726 #[tokio::test]
1727 async fn no_model_fallbacks_behaves_like_before() {
1728 let calls = Arc::new(AtomicUsize::new(0));
1729 let provider = ReliableProvider::new(
1730 vec![(
1731 "primary".into(),
1732 Box::new(MockProvider {
1733 calls: Arc::clone(&calls),
1734 fail_until_attempt: 0,
1735 response: "ok",
1736 error: "boom",
1737 }),
1738 )],
1739 2,
1740 1,
1741 );
1742 let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
1744 assert_eq!(result, "ok");
1745 assert_eq!(calls.load(Ordering::SeqCst), 1);
1746 }
1747
1748 #[tokio::test]
1751 async fn auth_rotation_cycles_keys() {
1752 let provider = ReliableProvider::new(
1753 vec![(
1754 "p".into(),
1755 Box::new(MockProvider {
1756 calls: Arc::new(AtomicUsize::new(0)),
1757 fail_until_attempt: 0,
1758 response: "ok",
1759 error: "",
1760 }),
1761 )],
1762 0,
1763 1,
1764 )
1765 .with_api_keys(vec!["key-a".into(), "key-b".into(), "key-c".into()]);
1766
1767 let keys: Vec<&str> = (0..5).map(|_| provider.rotate_key().unwrap()).collect();
1769 assert_eq!(keys, vec!["key-a", "key-b", "key-c", "key-a", "key-b"]);
1770 }
1771
1772 #[tokio::test]
1773 async fn auth_rotation_returns_none_when_empty() {
1774 let provider = ReliableProvider::new(vec![], 0, 1);
1775 assert!(provider.rotate_key().is_none());
1776 }
1777
1778 #[test]
1781 fn parse_retry_after_integer() {
1782 let err = anyhow::anyhow!("429 Too Many Requests, Retry-After: 5");
1783 assert_eq!(parse_retry_after_ms(&err), Some(5000));
1784 }
1785
1786 #[test]
1787 fn parse_retry_after_float() {
1788 let err = anyhow::anyhow!("Rate limited. retry_after: 2.5 seconds");
1789 assert_eq!(parse_retry_after_ms(&err), Some(2500));
1790 }
1791
1792 #[test]
1793 fn parse_retry_after_missing() {
1794 let err = anyhow::anyhow!("500 Internal Server Error");
1795 assert_eq!(parse_retry_after_ms(&err), None);
1796 }
1797
1798 #[test]
1799 fn rate_limited_detection() {
1800 assert!(is_rate_limited(&anyhow::anyhow!("429 Too Many Requests")));
1801 assert!(is_rate_limited(&anyhow::anyhow!(
1802 "HTTP 429 rate limit exceeded"
1803 )));
1804 assert!(!is_rate_limited(&anyhow::anyhow!("401 Unauthorized")));
1805 assert!(!is_rate_limited(&anyhow::anyhow!(
1806 "500 Internal Server Error"
1807 )));
1808 }
1809
1810 #[test]
1811 fn non_retryable_rate_limit_detects_plan_restricted_model() {
1812 let err = anyhow::anyhow!(
1813 "{}",
1814 "API error (429 Too Many Requests): {\"code\":1311,\"message\":\"the current account plan does not include glm-5\"}"
1815 );
1816 assert!(
1817 is_non_retryable_rate_limit(&err),
1818 "plan-restricted 429 should skip retries"
1819 );
1820 }
1821
1822 #[test]
1823 fn non_retryable_rate_limit_detects_insufficient_balance() {
1824 let err = anyhow::anyhow!(
1825 "{}",
1826 "API error (429 Too Many Requests): {\"code\":1113,\"message\":\"insufficient balance\"}"
1827 );
1828 assert!(
1829 is_non_retryable_rate_limit(&err),
1830 "insufficient-balance 429 should skip retries"
1831 );
1832 }
1833
1834 #[test]
1835 fn non_retryable_rate_limit_does_not_flag_generic_429() {
1836 let err = anyhow::anyhow!("429 Too Many Requests: rate limit exceeded");
1837 assert!(
1838 !is_non_retryable_rate_limit(&err),
1839 "generic rate-limit 429 should remain retryable"
1840 );
1841 }
1842
1843 #[test]
1844 fn compute_backoff_uses_retry_after() {
1845 let provider = ReliableProvider::new(vec![], 0, 500);
1846 let err = anyhow::anyhow!("429 Retry-After: 3");
1847 assert_eq!(provider.compute_backoff(500, &err), 3_000);
1848 }
1849
1850 #[test]
1851 fn compute_backoff_caps_at_30s() {
1852 let provider = ReliableProvider::new(vec![], 0, 500);
1853 let err = anyhow::anyhow!("429 Retry-After: 120");
1854 assert_eq!(provider.compute_backoff(500, &err), 30_000);
1855 }
1856
1857 #[test]
1858 fn compute_backoff_falls_back_to_base() {
1859 let provider = ReliableProvider::new(vec![], 0, 500);
1860 let err = anyhow::anyhow!("500 Server Error");
1861 assert_eq!(provider.compute_backoff(500, &err), 500);
1862 }
1863
1864 #[test]
1867 fn non_retryable_detects_401() {
1868 let err = anyhow::anyhow!("API error (401 Unauthorized): invalid api key");
1869 assert!(
1870 is_non_retryable(&err),
1871 "401 errors must be detected as non-retryable"
1872 );
1873 }
1874
1875 #[test]
1876 fn non_retryable_detects_403() {
1877 let err = anyhow::anyhow!("API error (403 Forbidden): access denied");
1878 assert!(
1879 is_non_retryable(&err),
1880 "403 errors must be detected as non-retryable"
1881 );
1882 }
1883
1884 #[test]
1885 fn non_retryable_detects_404() {
1886 let err = anyhow::anyhow!("API error (404 Not Found): model not found");
1887 assert!(
1888 is_non_retryable(&err),
1889 "404 errors must be detected as non-retryable"
1890 );
1891 }
1892
1893 #[test]
1894 fn non_retryable_does_not_flag_429() {
1895 let err = anyhow::anyhow!("429 Too Many Requests");
1896 assert!(
1897 !is_non_retryable(&err),
1898 "429 must NOT be treated as non-retryable (it is retryable with backoff)"
1899 );
1900 }
1901
1902 #[test]
1903 fn non_retryable_does_not_flag_408() {
1904 let err = anyhow::anyhow!("408 Request Timeout");
1905 assert!(
1906 !is_non_retryable(&err),
1907 "408 must NOT be treated as non-retryable (it is retryable)"
1908 );
1909 }
1910
1911 #[test]
1912 fn non_retryable_does_not_flag_500() {
1913 let err = anyhow::anyhow!("500 Internal Server Error");
1914 assert!(
1915 !is_non_retryable(&err),
1916 "500 must NOT be treated as non-retryable (server errors are retryable)"
1917 );
1918 }
1919
1920 #[test]
1921 fn non_retryable_does_not_flag_502() {
1922 let err = anyhow::anyhow!("502 Bad Gateway");
1923 assert!(
1924 !is_non_retryable(&err),
1925 "502 must NOT be treated as non-retryable"
1926 );
1927 }
1928
1929 #[test]
1932 fn parse_retry_after_zero() {
1933 let err = anyhow::anyhow!("429 Too Many Requests, Retry-After: 0");
1934 assert_eq!(
1935 parse_retry_after_ms(&err),
1936 Some(0),
1937 "Retry-After: 0 should parse as 0ms"
1938 );
1939 }
1940
1941 #[test]
1942 fn parse_retry_after_with_underscore_separator() {
1943 let err = anyhow::anyhow!("rate limited, retry_after: 10");
1944 assert_eq!(
1945 parse_retry_after_ms(&err),
1946 Some(10_000),
1947 "retry_after with underscore must be parsed"
1948 );
1949 }
1950
1951 #[test]
1952 fn parse_retry_after_space_separator() {
1953 let err = anyhow::anyhow!("Retry-After 7");
1954 assert_eq!(
1955 parse_retry_after_ms(&err),
1956 Some(7000),
1957 "Retry-After with space separator must be parsed"
1958 );
1959 }
1960
1961 #[test]
1962 fn rate_limited_false_for_generic_error() {
1963 let err = anyhow::anyhow!("Connection refused");
1964 assert!(
1965 !is_rate_limited(&err),
1966 "generic errors must not be flagged as rate-limited"
1967 );
1968 }
1969
1970 #[tokio::test]
1973 async fn non_retryable_skips_retries_for_401() {
1974 let calls = Arc::new(AtomicUsize::new(0));
1975 let provider = ReliableProvider::new(
1976 vec![(
1977 "primary".into(),
1978 Box::new(MockProvider {
1979 calls: Arc::clone(&calls),
1980 fail_until_attempt: usize::MAX,
1981 response: "never",
1982 error: "API error (401 Unauthorized): invalid key",
1983 }),
1984 )],
1985 5,
1986 1,
1987 );
1988
1989 let result = provider.simple_chat("hello", "test", 0.0).await;
1990 assert!(result.is_err(), "401 should fail without retries");
1991 assert_eq!(
1992 calls.load(Ordering::SeqCst),
1993 1,
1994 "must not retry on 401 — should be exactly 1 call"
1995 );
1996 }
1997
1998 #[tokio::test]
1999 async fn non_retryable_rate_limit_skips_retries_for_plan_errors() {
2000 let calls = Arc::new(AtomicUsize::new(0));
2001 let provider = ReliableProvider::new(
2002 vec![(
2003 "primary".into(),
2004 Box::new(MockProvider {
2005 calls: Arc::clone(&calls),
2006 fail_until_attempt: usize::MAX,
2007 response: "never",
2008 error: "API error (429 Too Many Requests): {\"code\":1311,\"message\":\"plan does not include glm-5\"}",
2009 }),
2010 )],
2011 5,
2012 1,
2013 );
2014
2015 let result = provider.simple_chat("hello", "test", 0.0).await;
2016 assert!(
2017 result.is_err(),
2018 "plan-restricted 429 should fail quickly without retrying"
2019 );
2020 assert_eq!(
2021 calls.load(Ordering::SeqCst),
2022 1,
2023 "must not retry non-retryable 429 business errors"
2024 );
2025 }
2026
2027 #[async_trait]
2030 impl Provider for Arc<ModelAwareMock> {
2031 async fn chat_with_system(
2032 &self,
2033 system_prompt: Option<&str>,
2034 message: &str,
2035 model: &str,
2036 temperature: f64,
2037 ) -> anyhow::Result<String> {
2038 self.as_ref()
2039 .chat_with_system(system_prompt, message, model, temperature)
2040 .await
2041 }
2042 }
2043
2044 struct NativeToolMock {
2046 calls: Arc<AtomicUsize>,
2047 fail_until_attempt: usize,
2048 response_text: &'static str,
2049 tool_calls: Vec<super::super::traits::ToolCall>,
2050 error: &'static str,
2051 }
2052
2053 #[async_trait]
2054 impl Provider for NativeToolMock {
2055 async fn chat_with_system(
2056 &self,
2057 _system_prompt: Option<&str>,
2058 _message: &str,
2059 _model: &str,
2060 _temperature: f64,
2061 ) -> anyhow::Result<String> {
2062 Ok(self.response_text.to_string())
2063 }
2064
2065 fn supports_native_tools(&self) -> bool {
2066 true
2067 }
2068
2069 async fn chat(
2070 &self,
2071 _request: ChatRequest<'_>,
2072 _model: &str,
2073 _temperature: f64,
2074 ) -> anyhow::Result<ChatResponse> {
2075 let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
2076 if attempt <= self.fail_until_attempt {
2077 anyhow::bail!(self.error);
2078 }
2079 Ok(ChatResponse {
2080 text: Some(self.response_text.to_string()),
2081 tool_calls: self.tool_calls.clone(),
2082 usage: None,
2083 reasoning_content: None,
2084 })
2085 }
2086 }
2087
2088 #[tokio::test]
2089 async fn chat_delegates_to_inner_provider() {
2090 let calls = Arc::new(AtomicUsize::new(0));
2091 let tool_call = super::super::traits::ToolCall {
2092 id: "call_1".to_string(),
2093 name: "shell".to_string(),
2094 arguments: r#"{"command":"date"}"#.to_string(),
2095 };
2096 let provider = ReliableProvider::new(
2097 vec![(
2098 "primary".into(),
2099 Box::new(NativeToolMock {
2100 calls: Arc::clone(&calls),
2101 fail_until_attempt: 0,
2102 response_text: "ok",
2103 tool_calls: vec![tool_call.clone()],
2104 error: "boom",
2105 }) as Box<dyn Provider>,
2106 )],
2107 2,
2108 1,
2109 );
2110
2111 let messages = vec![ChatMessage::user("what time is it?")];
2112 let request = ChatRequest {
2113 messages: &messages,
2114 tools: None,
2115 };
2116 let result = provider.chat(request, "test-model", 0.0).await.unwrap();
2117
2118 assert_eq!(result.text.as_deref(), Some("ok"));
2119 assert_eq!(result.tool_calls.len(), 1);
2120 assert_eq!(result.tool_calls[0].name, "shell");
2121 assert_eq!(calls.load(Ordering::SeqCst), 1);
2122 }
2123
2124 #[tokio::test]
2125 async fn chat_retries_and_recovers() {
2126 let calls = Arc::new(AtomicUsize::new(0));
2127 let tool_call = super::super::traits::ToolCall {
2128 id: "call_1".to_string(),
2129 name: "shell".to_string(),
2130 arguments: r#"{"command":"date"}"#.to_string(),
2131 };
2132 let provider = ReliableProvider::new(
2133 vec![(
2134 "primary".into(),
2135 Box::new(NativeToolMock {
2136 calls: Arc::clone(&calls),
2137 fail_until_attempt: 2,
2138 response_text: "recovered",
2139 tool_calls: vec![tool_call],
2140 error: "temporary failure",
2141 }) as Box<dyn Provider>,
2142 )],
2143 3,
2144 1,
2145 );
2146
2147 let messages = vec![ChatMessage::user("test")];
2148 let request = ChatRequest {
2149 messages: &messages,
2150 tools: None,
2151 };
2152 let result = provider.chat(request, "test-model", 0.0).await.unwrap();
2153
2154 assert_eq!(result.text.as_deref(), Some("recovered"));
2155 assert!(
2156 calls.load(Ordering::SeqCst) > 1,
2157 "should have retried at least once"
2158 );
2159 }
2160
2161 #[tokio::test]
2162 async fn chat_preserves_native_tools_support() {
2163 let calls = Arc::new(AtomicUsize::new(0));
2164 let provider = ReliableProvider::new(
2165 vec![(
2166 "primary".into(),
2167 Box::new(NativeToolMock {
2168 calls: Arc::clone(&calls),
2169 fail_until_attempt: 0,
2170 response_text: "ok",
2171 tool_calls: vec![],
2172 error: "boom",
2173 }) as Box<dyn Provider>,
2174 )],
2175 2,
2176 1,
2177 );
2178
2179 assert!(
2180 provider.supports_native_tools(),
2181 "ReliableProvider must propagate supports_native_tools from inner provider"
2182 );
2183 }
2184
2185 #[tokio::test]
2190 async fn chat_returns_aggregated_error_when_all_providers_fail() {
2191 let provider = ReliableProvider::new(
2192 vec![
2193 (
2194 "p1".into(),
2195 Box::new(NativeToolMock {
2196 calls: Arc::new(AtomicUsize::new(0)),
2197 fail_until_attempt: usize::MAX,
2198 response_text: "never",
2199 tool_calls: vec![],
2200 error: "p1 chat error",
2201 }) as Box<dyn Provider>,
2202 ),
2203 (
2204 "p2".into(),
2205 Box::new(NativeToolMock {
2206 calls: Arc::new(AtomicUsize::new(0)),
2207 fail_until_attempt: usize::MAX,
2208 response_text: "never",
2209 tool_calls: vec![],
2210 error: "p2 chat error",
2211 }) as Box<dyn Provider>,
2212 ),
2213 ],
2214 0,
2215 1,
2216 );
2217
2218 let messages = vec![ChatMessage::user("hello")];
2219 let request = ChatRequest {
2220 messages: &messages,
2221 tools: None,
2222 };
2223 let err = provider
2224 .chat(request, "test", 0.0)
2225 .await
2226 .expect_err("all providers should fail");
2227 let msg = err.to_string();
2228 assert!(msg.contains("All providers/models failed"));
2229 assert!(msg.contains("provider=p1 model=test"));
2230 assert!(msg.contains("provider=p2 model=test"));
2231 assert!(msg.contains("error=p1 chat error"));
2232 assert!(msg.contains("error=p2 chat error"));
2233 assert!(msg.contains("retryable"));
2234 }
2235
2236 struct NativeModelAwareMock {
2239 calls: Arc<AtomicUsize>,
2240 models_seen: parking_lot::Mutex<Vec<String>>,
2241 fail_models: Vec<&'static str>,
2242 response_text: &'static str,
2243 }
2244
2245 #[async_trait]
2246 impl Provider for NativeModelAwareMock {
2247 async fn chat_with_system(
2248 &self,
2249 _system_prompt: Option<&str>,
2250 _message: &str,
2251 _model: &str,
2252 _temperature: f64,
2253 ) -> anyhow::Result<String> {
2254 Ok(self.response_text.to_string())
2255 }
2256
2257 fn supports_native_tools(&self) -> bool {
2258 true
2259 }
2260
2261 async fn chat(
2262 &self,
2263 _request: ChatRequest<'_>,
2264 model: &str,
2265 _temperature: f64,
2266 ) -> anyhow::Result<ChatResponse> {
2267 self.calls.fetch_add(1, Ordering::SeqCst);
2268 self.models_seen.lock().push(model.to_string());
2269 if self.fail_models.contains(&model) {
2270 anyhow::bail!("500 model {} unavailable", model);
2271 }
2272 Ok(ChatResponse {
2273 text: Some(self.response_text.to_string()),
2274 tool_calls: vec![],
2275 usage: None,
2276 reasoning_content: None,
2277 })
2278 }
2279 }
2280
2281 #[async_trait]
2282 impl Provider for Arc<NativeModelAwareMock> {
2283 async fn chat_with_system(
2284 &self,
2285 system_prompt: Option<&str>,
2286 message: &str,
2287 model: &str,
2288 temperature: f64,
2289 ) -> anyhow::Result<String> {
2290 self.as_ref()
2291 .chat_with_system(system_prompt, message, model, temperature)
2292 .await
2293 }
2294
2295 fn supports_native_tools(&self) -> bool {
2296 true
2297 }
2298
2299 async fn chat(
2300 &self,
2301 request: ChatRequest<'_>,
2302 model: &str,
2303 temperature: f64,
2304 ) -> anyhow::Result<ChatResponse> {
2305 self.as_ref().chat(request, model, temperature).await
2306 }
2307 }
2308
2309 #[tokio::test]
2312 async fn chat_tries_model_failover_on_failure() {
2313 let calls = Arc::new(AtomicUsize::new(0));
2314 let mock = Arc::new(NativeModelAwareMock {
2315 calls: Arc::clone(&calls),
2316 models_seen: parking_lot::Mutex::new(Vec::new()),
2317 fail_models: vec!["claude-opus"],
2318 response_text: "ok from sonnet",
2319 });
2320
2321 let mut fallbacks = HashMap::new();
2322 fallbacks.insert("claude-opus".to_string(), vec!["claude-sonnet".to_string()]);
2323
2324 let provider = ReliableProvider::new(
2325 vec![(
2326 "anthropic".into(),
2327 Box::new(mock.clone()) as Box<dyn Provider>,
2328 )],
2329 0, 1,
2331 )
2332 .with_model_fallbacks(fallbacks);
2333
2334 let messages = vec![ChatMessage::user("hello")];
2335 let request = ChatRequest {
2336 messages: &messages,
2337 tools: None,
2338 };
2339 let result = provider.chat(request, "claude-opus", 0.0).await.unwrap();
2340 assert_eq!(result.text.as_deref(), Some("ok from sonnet"));
2341
2342 let seen = mock.models_seen.lock();
2343 assert_eq!(seen.len(), 2);
2344 assert_eq!(seen[0], "claude-opus");
2345 assert_eq!(seen[1], "claude-sonnet");
2346 }
2347
2348 #[tokio::test]
2351 async fn chat_skips_non_retryable_errors() {
2352 let primary_calls = Arc::new(AtomicUsize::new(0));
2353 let fallback_calls = Arc::new(AtomicUsize::new(0));
2354
2355 let provider = ReliableProvider::new(
2356 vec![
2357 (
2358 "primary".into(),
2359 Box::new(NativeToolMock {
2360 calls: Arc::clone(&primary_calls),
2361 fail_until_attempt: usize::MAX,
2362 response_text: "never",
2363 tool_calls: vec![],
2364 error: "401 Unauthorized",
2365 }) as Box<dyn Provider>,
2366 ),
2367 (
2368 "fallback".into(),
2369 Box::new(NativeToolMock {
2370 calls: Arc::clone(&fallback_calls),
2371 fail_until_attempt: 0,
2372 response_text: "from fallback",
2373 tool_calls: vec![],
2374 error: "fallback err",
2375 }) as Box<dyn Provider>,
2376 ),
2377 ],
2378 3,
2379 1,
2380 );
2381
2382 let messages = vec![ChatMessage::user("hello")];
2383 let request = ChatRequest {
2384 messages: &messages,
2385 tools: None,
2386 };
2387 let result = provider.chat(request, "test", 0.0).await.unwrap();
2388 assert_eq!(result.text.as_deref(), Some("from fallback"));
2389 assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
2391 assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
2392 }
2393
2394 #[test]
2397 fn context_window_error_is_not_non_retryable() {
2398 assert!(!is_non_retryable(&anyhow::anyhow!(
2400 "exceeds the context window"
2401 )));
2402 assert!(!is_non_retryable(&anyhow::anyhow!(
2403 "maximum context length exceeded"
2404 )));
2405 assert!(!is_non_retryable(&anyhow::anyhow!(
2406 "too many tokens in the request"
2407 )));
2408 assert!(!is_non_retryable(&anyhow::anyhow!("token limit exceeded")));
2409 }
2410
2411 #[test]
2412 fn is_context_window_exceeded_detects_llamacpp() {
2413 assert!(is_context_window_exceeded(&anyhow::anyhow!(
2414 "request (8968 tokens) exceeds the available context size (8448 tokens), try increasing it"
2415 )));
2416 }
2417
2418 #[test]
2419 fn truncate_for_context_drops_oldest_non_system() {
2420 let mut messages = vec![
2421 ChatMessage::system("sys"),
2422 ChatMessage::user("msg1"),
2423 ChatMessage::assistant("resp1"),
2424 ChatMessage::user("msg2"),
2425 ChatMessage::assistant("resp2"),
2426 ChatMessage::user("msg3"),
2427 ];
2428
2429 let dropped = truncate_for_context(&mut messages);
2430
2431 assert_eq!(dropped, 2);
2433 assert_eq!(messages[0].role, "system");
2435 assert_eq!(messages.len(), 4); assert_eq!(messages.last().unwrap().content, "msg3");
2439 }
2440
2441 #[test]
2442 fn truncate_for_context_preserves_system_and_last_message() {
2443 let mut messages = vec![ChatMessage::system("sys"), ChatMessage::user("only")];
2445 let dropped = truncate_for_context(&mut messages);
2446 assert_eq!(dropped, 0);
2447 assert_eq!(messages.len(), 2);
2448
2449 let mut messages = vec![ChatMessage::user("only")];
2451 let dropped = truncate_for_context(&mut messages);
2452 assert_eq!(dropped, 0);
2453 assert_eq!(messages.len(), 1);
2454 }
2455
2456 struct ContextOverflowMock {
2459 calls: Arc<AtomicUsize>,
2460 fail_until_attempt: usize,
2461 message_counts: parking_lot::Mutex<Vec<usize>>,
2462 }
2463
2464 #[async_trait]
2465 impl Provider for ContextOverflowMock {
2466 async fn chat_with_system(
2467 &self,
2468 _system_prompt: Option<&str>,
2469 _message: &str,
2470 _model: &str,
2471 _temperature: f64,
2472 ) -> anyhow::Result<String> {
2473 Ok("ok".to_string())
2474 }
2475
2476 async fn chat_with_history(
2477 &self,
2478 messages: &[ChatMessage],
2479 _model: &str,
2480 _temperature: f64,
2481 ) -> anyhow::Result<String> {
2482 let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
2483 self.message_counts.lock().push(messages.len());
2484 if attempt <= self.fail_until_attempt {
2485 anyhow::bail!(
2486 "request (8968 tokens) exceeds the available context size (8448 tokens), try increasing it"
2487 );
2488 }
2489 Ok("recovered after truncation".to_string())
2490 }
2491 }
2492
2493 #[tokio::test]
2494 async fn chat_with_history_truncates_on_context_overflow() {
2495 let calls = Arc::new(AtomicUsize::new(0));
2496 let mock = ContextOverflowMock {
2497 calls: Arc::clone(&calls),
2498 fail_until_attempt: 1, message_counts: parking_lot::Mutex::new(Vec::new()),
2500 };
2501
2502 let provider = ReliableProvider::new(
2503 vec![("local".into(), Box::new(mock) as Box<dyn Provider>)],
2504 3,
2505 1,
2506 );
2507
2508 let messages = vec![
2509 ChatMessage::system("system prompt"),
2510 ChatMessage::user("old message 1"),
2511 ChatMessage::assistant("old response 1"),
2512 ChatMessage::user("old message 2"),
2513 ChatMessage::assistant("old response 2"),
2514 ChatMessage::user("current question"),
2515 ];
2516
2517 let result = provider
2518 .chat_with_history(&messages, "local-model", 0.0)
2519 .await
2520 .unwrap();
2521 assert_eq!(result, "recovered after truncation");
2522 assert_eq!(calls.load(Ordering::SeqCst), 2);
2524 }
2525
2526 #[tokio::test]
2527 async fn context_overflow_with_no_history_to_truncate_bails_immediately() {
2528 let calls = Arc::new(AtomicUsize::new(0));
2529 let mock = ContextOverflowMock {
2530 calls: Arc::clone(&calls),
2531 fail_until_attempt: 999, message_counts: parking_lot::Mutex::new(Vec::new()),
2533 };
2534
2535 let provider = ReliableProvider::new(
2536 vec![("local".into(), Box::new(mock) as Box<dyn Provider>)],
2537 3,
2538 1,
2539 );
2540
2541 let messages = vec![
2543 ChatMessage::system("huge system prompt that exceeds context window"),
2544 ChatMessage::user("hello"),
2545 ];
2546
2547 let result = provider
2548 .chat_with_history(&messages, "local-model", 0.0)
2549 .await;
2550 assert!(result.is_err());
2551 let err_msg = result.unwrap_err().to_string();
2552 assert!(
2553 err_msg.contains("cannot be reduced further"),
2554 "Should bail with actionable message, got: {err_msg}"
2555 );
2556 assert_eq!(
2558 calls.load(Ordering::SeqCst),
2559 1,
2560 "Should not retry when truncation is impossible"
2561 );
2562 }
2563
2564 #[test]
2567 fn tool_schema_error_detects_groq_validation_failure() {
2568 let msg = r#"Groq API error (400 Bad Request): {"error":{"message":"tool call validation failed: attempted to call tool 'memory_recall' which was not in request"}}"#;
2569 let err = anyhow::anyhow!("{}", msg);
2570 assert!(is_tool_schema_error(&err));
2571 }
2572
2573 #[test]
2574 fn tool_schema_error_detects_not_in_request() {
2575 let err = anyhow::anyhow!("tool 'search' was not in request");
2576 assert!(is_tool_schema_error(&err));
2577 }
2578
2579 #[test]
2580 fn tool_schema_error_detects_not_found_in_tool_list() {
2581 let err = anyhow::anyhow!("function 'foo' not found in tool list");
2582 assert!(is_tool_schema_error(&err));
2583 }
2584
2585 #[test]
2586 fn tool_schema_error_detects_invalid_tool_call() {
2587 let err = anyhow::anyhow!("invalid_tool_call: no matching function");
2588 assert!(is_tool_schema_error(&err));
2589 }
2590
2591 #[test]
2592 fn tool_schema_error_ignores_unrelated_errors() {
2593 let err = anyhow::anyhow!("invalid api key");
2594 assert!(!is_tool_schema_error(&err));
2595
2596 let err = anyhow::anyhow!("model not found");
2597 assert!(!is_tool_schema_error(&err));
2598 }
2599
2600 #[test]
2601 fn non_retryable_returns_false_for_tool_schema_400() {
2602 let msg = "400 Bad Request: tool call validation failed: attempted to call tool 'x' which was not in request";
2604 let err = anyhow::anyhow!("{}", msg);
2605 assert!(!is_non_retryable(&err));
2606 }
2607
2608 #[test]
2609 fn non_retryable_returns_true_for_other_400_errors() {
2610 let err = anyhow::anyhow!("400 Bad Request: invalid api key provided");
2612 assert!(is_non_retryable(&err));
2613 }
2614
2615 struct StreamingToolEventMock {
2616 stream_calls: Arc<AtomicUsize>,
2617 supports_tool_events: bool,
2618 }
2619
2620 impl StreamingToolEventMock {
2621 fn new(supports_tool_events: bool) -> Self {
2622 Self {
2623 stream_calls: Arc::new(AtomicUsize::new(0)),
2624 supports_tool_events,
2625 }
2626 }
2627 }
2628
2629 #[async_trait]
2630 impl Provider for StreamingToolEventMock {
2631 async fn chat_with_system(
2632 &self,
2633 _system_prompt: Option<&str>,
2634 _message: &str,
2635 _model: &str,
2636 _temperature: f64,
2637 ) -> anyhow::Result<String> {
2638 Ok("ok".to_string())
2639 }
2640
2641 fn supports_streaming(&self) -> bool {
2642 true
2643 }
2644
2645 fn supports_streaming_tool_events(&self) -> bool {
2646 self.supports_tool_events
2647 }
2648
2649 fn stream_chat(
2650 &self,
2651 _request: ChatRequest<'_>,
2652 _model: &str,
2653 _temperature: f64,
2654 _options: StreamOptions,
2655 ) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
2656 self.stream_calls.fetch_add(1, Ordering::SeqCst);
2657 stream::iter(vec![
2658 Ok(StreamEvent::ToolCall(super::super::traits::ToolCall {
2659 id: "call_1".to_string(),
2660 name: "shell".to_string(),
2661 arguments: r#"{"command":"date"}"#.to_string(),
2662 })),
2663 Ok(StreamEvent::Final),
2664 ])
2665 .boxed()
2666 }
2667 }
2668
2669 #[async_trait]
2670 impl Provider for Arc<StreamingToolEventMock> {
2671 async fn chat_with_system(
2672 &self,
2673 system_prompt: Option<&str>,
2674 message: &str,
2675 model: &str,
2676 temperature: f64,
2677 ) -> anyhow::Result<String> {
2678 self.as_ref()
2679 .chat_with_system(system_prompt, message, model, temperature)
2680 .await
2681 }
2682
2683 fn supports_streaming(&self) -> bool {
2684 self.as_ref().supports_streaming()
2685 }
2686
2687 fn supports_streaming_tool_events(&self) -> bool {
2688 self.as_ref().supports_streaming_tool_events()
2689 }
2690
2691 fn stream_chat(
2692 &self,
2693 request: ChatRequest<'_>,
2694 model: &str,
2695 temperature: f64,
2696 options: StreamOptions,
2697 ) -> stream::BoxStream<'static, StreamResult<StreamEvent>> {
2698 self.as_ref()
2699 .stream_chat(request, model, temperature, options)
2700 }
2701 }
2702
2703 #[tokio::test]
2704 async fn stream_chat_prefers_provider_with_tool_event_support() {
2705 let primary = Arc::new(StreamingToolEventMock::new(false));
2706 let fallback = Arc::new(StreamingToolEventMock::new(true));
2707 let provider = ReliableProvider::new(
2708 vec![
2709 (
2710 "primary".into(),
2711 Box::new(Arc::clone(&primary)) as Box<dyn Provider>,
2712 ),
2713 (
2714 "fallback".into(),
2715 Box::new(Arc::clone(&fallback)) as Box<dyn Provider>,
2716 ),
2717 ],
2718 0,
2719 1,
2720 );
2721
2722 let messages = vec![ChatMessage::user("hello")];
2723 let tools = vec![ToolSpec {
2724 name: "shell".to_string(),
2725 description: "run shell".to_string(),
2726 parameters: serde_json::json!({
2727 "type": "object",
2728 "properties": {
2729 "command": { "type": "string" }
2730 }
2731 }),
2732 }];
2733 let mut stream = provider.stream_chat(
2734 ChatRequest {
2735 messages: &messages,
2736 tools: Some(&tools),
2737 },
2738 "model",
2739 0.0,
2740 StreamOptions::new(true),
2741 );
2742
2743 let first = stream.next().await.unwrap().unwrap();
2744 let second = stream.next().await.unwrap().unwrap();
2745 assert!(stream.next().await.is_none());
2746
2747 match first {
2748 StreamEvent::ToolCall(call) => assert_eq!(call.name, "shell"),
2749 other => panic!("expected tool-call event, got {other:?}"),
2750 }
2751 assert!(matches!(second, StreamEvent::Final));
2752 assert_eq!(primary.stream_calls.load(Ordering::SeqCst), 0);
2753 assert_eq!(fallback.stream_calls.load(Ordering::SeqCst), 1);
2754 }
2755
2756 #[tokio::test]
2757 async fn stream_chat_errors_when_no_provider_supports_tool_events() {
2758 let primary = Arc::new(StreamingToolEventMock::new(false));
2759 let provider = ReliableProvider::new(
2760 vec![(
2761 "primary".into(),
2762 Box::new(Arc::clone(&primary)) as Box<dyn Provider>,
2763 )],
2764 0,
2765 1,
2766 );
2767
2768 let messages = vec![ChatMessage::user("hello")];
2769 let tools = vec![ToolSpec {
2770 name: "shell".to_string(),
2771 description: "run shell".to_string(),
2772 parameters: serde_json::json!({"type": "object"}),
2773 }];
2774 let mut stream = provider.stream_chat(
2775 ChatRequest {
2776 messages: &messages,
2777 tools: Some(&tools),
2778 },
2779 "model",
2780 0.0,
2781 StreamOptions::new(true),
2782 );
2783
2784 let first = stream.next().await.unwrap();
2785 let err = first.expect_err("stream should fail without tool-event support");
2786 assert!(
2787 err.to_string()
2788 .contains("No provider supports streaming tool events"),
2789 "unexpected stream error: {err}"
2790 );
2791 assert!(stream.next().await.is_none());
2792 assert_eq!(primary.stream_calls.load(Ordering::SeqCst), 0);
2793 }
2794
2795 struct StreamingHistoryMock {
2799 stream_calls: Arc<AtomicUsize>,
2800 supports: bool,
2801 }
2802
2803 #[async_trait]
2804 impl Provider for StreamingHistoryMock {
2805 async fn chat_with_system(
2806 &self,
2807 _system_prompt: Option<&str>,
2808 _message: &str,
2809 _model: &str,
2810 _temperature: f64,
2811 ) -> anyhow::Result<String> {
2812 Ok("ok".to_string())
2813 }
2814
2815 fn supports_streaming(&self) -> bool {
2816 self.supports
2817 }
2818
2819 fn stream_chat_with_history(
2820 &self,
2821 messages: &[ChatMessage],
2822 _model: &str,
2823 _temperature: f64,
2824 _options: StreamOptions,
2825 ) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
2826 self.stream_calls.fetch_add(1, Ordering::SeqCst);
2827 let msg_count = messages.len().to_string();
2829 stream::iter(vec![
2830 Ok(StreamChunk::delta(msg_count)),
2831 Ok(StreamChunk::final_chunk()),
2832 ])
2833 .boxed()
2834 }
2835 }
2836
2837 #[tokio::test]
2838 async fn stream_chat_with_history_delegates_to_streaming_provider() {
2839 let calls = Arc::new(AtomicUsize::new(0));
2840 let provider = ReliableProvider::new(
2841 vec![(
2842 "primary".into(),
2843 Box::new(StreamingHistoryMock {
2844 stream_calls: Arc::clone(&calls),
2845 supports: true,
2846 }) as Box<dyn Provider>,
2847 )],
2848 0,
2849 1,
2850 );
2851
2852 let messages = vec![
2853 ChatMessage::system("system"),
2854 ChatMessage::user("msg1"),
2855 ChatMessage::assistant("resp1"),
2856 ChatMessage::user("msg2"),
2857 ];
2858 let mut stream =
2859 provider.stream_chat_with_history(&messages, "model", 0.0, StreamOptions::new(true));
2860
2861 let first = stream.next().await.unwrap().unwrap();
2862 assert_eq!(first.delta, "4", "should pass all 4 messages to provider");
2863 let second = stream.next().await.unwrap().unwrap();
2864 assert!(second.is_final);
2865 assert!(stream.next().await.is_none());
2866 assert_eq!(calls.load(Ordering::SeqCst), 1);
2867 }
2868
2869 #[tokio::test]
2870 async fn stream_chat_with_history_skips_non_streaming_providers() {
2871 let non_streaming_calls = Arc::new(AtomicUsize::new(0));
2872 let streaming_calls = Arc::new(AtomicUsize::new(0));
2873
2874 let provider = ReliableProvider::new(
2875 vec![
2876 (
2877 "non-streaming".into(),
2878 Box::new(StreamingHistoryMock {
2879 stream_calls: Arc::clone(&non_streaming_calls),
2880 supports: false,
2881 }) as Box<dyn Provider>,
2882 ),
2883 (
2884 "streaming".into(),
2885 Box::new(StreamingHistoryMock {
2886 stream_calls: Arc::clone(&streaming_calls),
2887 supports: true,
2888 }) as Box<dyn Provider>,
2889 ),
2890 ],
2891 0,
2892 1,
2893 );
2894
2895 let messages = vec![ChatMessage::user("hello")];
2896 let mut stream =
2897 provider.stream_chat_with_history(&messages, "model", 0.0, StreamOptions::new(true));
2898
2899 let first = stream.next().await.unwrap().unwrap();
2900 assert_eq!(first.delta, "1");
2901 assert_eq!(
2902 non_streaming_calls.load(Ordering::SeqCst),
2903 0,
2904 "non-streaming provider should be skipped"
2905 );
2906 assert_eq!(
2907 streaming_calls.load(Ordering::SeqCst),
2908 1,
2909 "streaming provider should be used"
2910 );
2911 }
2912
2913 #[tokio::test]
2914 async fn stream_chat_with_history_errors_when_no_provider_supports_streaming() {
2915 let provider = ReliableProvider::new(
2916 vec![(
2917 "non-streaming".into(),
2918 Box::new(StreamingHistoryMock {
2919 stream_calls: Arc::new(AtomicUsize::new(0)),
2920 supports: false,
2921 }) as Box<dyn Provider>,
2922 )],
2923 0,
2924 1,
2925 );
2926
2927 let messages = vec![ChatMessage::user("hello")];
2928 let mut stream =
2929 provider.stream_chat_with_history(&messages, "model", 0.0, StreamOptions::new(true));
2930
2931 let first = stream.next().await.unwrap();
2932 let err = first.expect_err("should fail when no provider supports streaming");
2933 assert!(
2934 err.to_string().contains("No provider supports streaming"),
2935 "unexpected error: {err}"
2936 );
2937 assert!(stream.next().await.is_none());
2938 }
2939
2940 #[tokio::test]
2941 async fn fallback_records_provider_fallback_info() {
2942 scope_provider_fallback(async {
2943 let provider = ReliableProvider::new(
2944 vec![
2945 (
2946 "broken".into(),
2947 Box::new(MockProvider {
2948 calls: Arc::new(AtomicUsize::new(0)),
2949 fail_until_attempt: 99, response: "unused",
2951 error: "401 Unauthorized",
2952 }),
2953 ),
2954 (
2955 "working".into(),
2956 Box::new(MockProvider {
2957 calls: Arc::new(AtomicUsize::new(0)),
2958 fail_until_attempt: 0,
2959 response: "hello from working",
2960 error: "unused",
2961 }),
2962 ),
2963 ],
2964 2,
2965 1,
2966 );
2967
2968 let resp = provider.simple_chat("hi", "test-model", 0.0).await.unwrap();
2969 assert_eq!(resp, "hello from working");
2970
2971 let fb = take_last_provider_fallback();
2972 assert!(fb.is_some(), "fallback info should be recorded");
2973 let fb = fb.unwrap();
2974 assert_eq!(fb.requested_provider, "broken");
2975 assert_eq!(fb.actual_provider, "working");
2976 assert_eq!(fb.actual_model, "test-model");
2977
2978 assert!(take_last_provider_fallback().is_none());
2980 })
2981 .await;
2982 }
2983}