1use std::future::Future;
13use std::time::{SystemTime, UNIX_EPOCH};
14use tokio::time::sleep;
15
16use reqwest::RequestBuilder;
17
18use crate::error::AgentError;
19
20use super::retry_helpers::{
21 is_connection_error, is_max_tokens_overflow, is_rate_limit_error, is_server_error,
22 is_service_unavailable_error,
23};
24use std::fmt::Debug;
25
26pub const DEFAULT_MAX_RETRIES: u32 = 10;
32
33pub const BASE_DELAY_MS: u64 = 500;
35
36pub const MAX_DELAY_MS: u64 = 32000;
38
39pub const FLOOR_OUTPUT_TOKENS: u32 = 3000;
41
42pub const MAX_529_RETRIES: u32 = 3;
44
45pub const SHORT_RETRY_THRESHOLD_MS: u64 = 20_000;
47
48pub const DEFAULT_FAST_MODE_FALLBACK_HOLD_MS: u64 = 30 * 60 * 1000;
50
51pub const MIN_COOLDOWN_MS: u64 = 10 * 60 * 1000;
53
54#[derive(Debug, Clone)]
60pub struct RetryConfig {
61 pub max_retries: u32,
63 pub base_delay_ms: u64,
65 pub max_delay_ms: u64,
67 pub jitter: bool,
69 pub is_foreground: bool,
72 pub fallback_model: Option<String>,
74}
75
76impl Default for RetryConfig {
77 fn default() -> Self {
78 Self {
79 max_retries: DEFAULT_MAX_RETRIES,
80 base_delay_ms: BASE_DELAY_MS,
81 max_delay_ms: MAX_DELAY_MS,
82 jitter: true,
83 is_foreground: true,
84 fallback_model: None,
85 }
86 }
87}
88
89pub fn get_retry_delay(attempt: u32, retry_after_ms: Option<u64>, max_delay_ms: u64) -> u64 {
105 if let Some(retry_after) = retry_after_ms {
107 return retry_after;
108 }
109
110 let base_delay = if attempt == 0 {
112 BASE_DELAY_MS
113 } else {
114 BASE_DELAY_MS * 2u64.saturating_pow(attempt - 1)
115 };
116 let base_delay = base_delay.min(max_delay_ms);
117
118 if attempt > 0 {
120 base_delay + jitter(base_delay)
121 } else {
122 base_delay
123 }
124}
125
126fn jitter_fraction() -> f64 {
128 let nanos = SystemTime::now()
129 .duration_since(UNIX_EPOCH)
130 .unwrap_or_default()
131 .subsec_nanos();
132 (nanos as f64) / (u32::MAX as f64)
133}
134
135fn jitter(base_delay: u64) -> u64 {
137 (base_delay as f64 * 0.25 * jitter_fraction()) as u64
138}
139
140pub fn is_529_error(status: Option<u16>, message: &str) -> bool {
154 if status == Some(529) {
155 return true;
156 }
157 message.contains(r#""type":"overloaded_error""#)
158}
159
160pub fn should_retry(status: Option<u16>, message: &str) -> bool {
176 let s = status;
177
178 if is_connection_error(message) {
180 return true;
181 }
182
183 if s == Some(529) {
185 return true;
186 }
187
188 if message.contains(r#""type":"overloaded_error""#) {
190 return true;
191 }
192
193 if s == Some(400) && is_max_tokens_overflow(message) {
195 return true;
196 }
197
198 if is_mock_rate_limit_error(message) {
200 return false;
201 }
202
203 let status_code = s.unwrap_or(0);
205
206 if status_code == 408 {
208 return true;
209 }
210
211 if status_code == 409 {
213 return true;
214 }
215
216 if status_code == 401 {
218 return true;
219 }
220
221 if status_code == 403 && message.contains("OAuth token has been revoked") {
223 return true;
224 }
225
226 if status_code == 429 {
229 return true;
230 }
231
232 if status_code >= 500 {
234 return true;
235 }
236
237 if is_rate_limit_error(message) {
239 return true;
240 }
241 if is_service_unavailable_error(message) {
242 return true;
243 }
244 if is_server_error(message) {
245 return true;
246 }
247
248 false
249}
250
251fn is_mock_rate_limit_error(message: &str) -> bool {
253 message.contains("MOCK_RATE_LIMIT") || message.contains("mock rate limit")
254}
255
256#[derive(Debug, Clone)]
262pub struct MaxTokensOverflowData {
263 pub input_tokens: u32,
264 pub max_tokens: u32,
265 pub context_limit: u32,
266}
267
268pub fn parse_max_tokens_overflow(message: &str) -> Option<MaxTokensOverflowData> {
273 if !is_max_tokens_overflow(message) {
274 return None;
275 }
276
277 let numbers: Vec<u32> = message
279 .split(&['+', '>', ':', ' '][..])
280 .map(|s| s.trim().parse::<u32>().ok())
281 .filter_map(|n| n)
282 .collect();
283
284 if numbers.len() >= 3 {
285 Some(MaxTokensOverflowData {
286 input_tokens: numbers[0],
287 max_tokens: numbers[1],
288 context_limit: numbers[2],
289 })
290 } else {
291 None
292 }
293}
294
295pub fn extract_retry_after_ms(status: Option<u16>, message: &str) -> Option<u64> {
305 extract_retry_after_from_message(message)
308}
309
310fn extract_retry_after_from_message(message: &str) -> Option<u64> {
311 let lower = message.to_lowercase();
313 if let Some(pos) = lower.find("retry-after:") {
314 let after = &message[pos + "Retry-After:".len()..];
315 let trimmed = after.trim();
316 if let Some(brace_pos) = trimmed.find(|c| c == ' ' || c == '\n' || c == '\r') {
318 let secs_str = &trimmed[..brace_pos].trim();
319 if let Ok(secs) = secs_str.parse::<u64>() {
320 return Some(secs * 1000);
321 }
322 }
323 if let Ok(secs) = trimmed.parse::<u64>() {
325 return Some(secs * 1000);
326 }
327 }
328 None
329}
330
331pub fn extract_status_from_message(message: &str) -> Option<u16> {
336 for part in message.split_whitespace() {
338 if let Ok(code) = part.parse::<u16>() {
339 if code >= 400 && code <= 599 {
340 return Some(code);
341 }
342 }
343 }
344 None
345}
346
347pub async fn with_retry<F, Fut, T>(mut operation: F, config: RetryConfig) -> Result<T, AgentError>
368where
369 F: FnMut(u32) -> Fut,
370 Fut: Future<Output = Result<T, AgentError>>,
371{
372 let mut last_message: Option<String> = None;
373 let mut consecutive_529_errors: u32 = 0;
374
375 for attempt in 1..=config.max_retries + 1 {
376 match operation(attempt).await {
380 Ok(result) => {
381 if attempt > 1 {
382 log::debug!(
383 "[retry] Attempt {}/{} succeeded",
384 attempt,
385 config.max_retries + 1
386 );
387 }
388 return Ok(result);
389 }
390 Err(ref error) => {
391 let status = extract_status(error);
392 let message = error_to_message(error);
393
394 last_message = Some(message.clone());
395
396 log::debug!(
397 "[retry] Attempt {}/{}: status={:?} error={}",
398 attempt,
399 config.max_retries + 1,
400 status,
401 message.chars().take(200).collect::<String>()
402 );
403
404 if is_529_error(status, &message) {
406 consecutive_529_errors += 1;
407
408 if !config.is_foreground && consecutive_529_errors >= 1 {
410 log::debug!("[retry] 529 dropped for background request");
411 return Err(AgentError::Api(format!(
412 "Repeated 529 Overloaded errors: {}",
413 message
414 )));
415 }
416
417 if consecutive_529_errors >= MAX_529_RETRIES {
419 if let Some(ref fallback) = config.fallback_model {
420 return Err(AgentError::Api(format!(
421 "Model fallback triggered: exceeded {} consecutive 529s, switching to {}",
422 MAX_529_RETRIES, fallback
423 )));
424 }
425 return Err(AgentError::Api(format!(
426 "Repeated 529 Overloaded errors after {} retries: {}",
427 MAX_529_RETRIES, message
428 )));
429 }
430 } else {
431 consecutive_529_errors = 0;
433 }
434
435 if let Some(overflow) = parse_max_tokens_overflow(&message) {
437 log::debug!(
438 "[retry] Context overflow: input={} + max_tokens={} > limit={}",
439 overflow.input_tokens,
440 overflow.max_tokens,
441 overflow.context_limit
442 );
443 continue;
446 }
447
448 if attempt > config.max_retries {
450 if !should_retry(status, &message) {
452 log::debug!(
453 "[retry] Not retryable: status={:?} error={}",
454 status,
455 message.chars().take(100).collect::<String>()
456 );
457 return Err(AgentError::Api(
458 last_message
459 .take()
460 .unwrap_or_else(|| "Retry exhausted".to_string()),
461 ));
462 }
463 }
464
465 if attempt <= config.max_retries {
467 let retry_after_ms = extract_retry_after_ms(status, &message);
468 let delay = get_retry_delay(attempt, retry_after_ms, config.max_delay_ms);
469
470 log::debug!(
471 "[retry] Waiting {}ms before retry {}/{}",
472 delay,
473 attempt + 1,
474 config.max_retries + 1
475 );
476
477 sleep(std::time::Duration::from_millis(delay)).await;
478 }
479 }
480 }
481 }
482
483 Err(AgentError::Api(
484 last_message.unwrap_or_else(|| "Retry exhausted".to_string()),
485 ))
486}
487
488fn extract_status(error: &AgentError) -> Option<u16> {
490 match error {
491 AgentError::Http(e) => e.status().map(|s| s.as_u16()),
492 _ => extract_status_from_message(&error_to_message(error)),
493 }
494}
495
496fn error_to_message(error: &AgentError) -> String {
498 match error {
499 AgentError::Api(msg) => msg.clone(),
500 AgentError::Http(e) => format!("{}", e),
501 other => other.to_string(),
502 }
503}
504
505pub async fn retry_post(
517 builder: RequestBuilder,
518 config: RetryConfig,
519) -> Result<reqwest::Response, AgentError> {
520 let mut current_builder = builder;
521 let mut last_error_msg = String::new();
522 let mut consecutive_529_errors: u32 = 0;
523
524 for attempt in 1..=config.max_retries + 1 {
525 let send_builder = current_builder.try_clone().ok_or_else(|| {
527 AgentError::Api("Request builder cannot be cloned for retry".to_string())
528 })?;
529
530 match send_builder.send().await {
531 Ok(response) => {
532 if attempt > 1 {
533 log::debug!(
534 "[retry] POST attempt {}/{} succeeded",
535 attempt,
536 config.max_retries + 1
537 );
538 }
539 return Ok(response);
540 }
541 Err(error) => {
542 let status = error.status().map(|s| s.as_u16());
543 let message = format!("{}", error);
544
545 log::debug!(
546 "[retry] POST attempt {}/{}: status={:?} error={}",
547 attempt,
548 config.max_retries + 1,
549 status,
550 message.chars().take(200).collect::<String>()
551 );
552
553 last_error_msg = message.clone();
554
555 if is_529_error(status, &message) {
557 consecutive_529_errors += 1;
558
559 if !config.is_foreground && consecutive_529_errors >= 1 {
560 log::debug!("[retry] 529 dropped for background request");
561 return Err(AgentError::Api(format!(
562 "Repeated 529 Overloaded errors: {}",
563 message
564 )));
565 }
566
567 if consecutive_529_errors >= MAX_529_RETRIES {
568 if let Some(ref fallback) = config.fallback_model {
569 return Err(AgentError::Api(format!(
570 "Model fallback triggered: exceeded {} consecutive 529s, switching to {}",
571 MAX_529_RETRIES, fallback
572 )));
573 }
574 return Err(AgentError::Api(format!(
575 "Repeated 529 Overloaded errors after {} retries: {}",
576 MAX_529_RETRIES, message
577 )));
578 }
579 } else {
580 consecutive_529_errors = 0;
581 }
582
583 if parse_max_tokens_overflow(&message).is_some() {
585 log::debug!(
586 "[retry] Context overflow: input={} + max_tokens={} > limit={}",
587 parse_max_tokens_overflow(&message).unwrap().input_tokens,
588 parse_max_tokens_overflow(&message).unwrap().max_tokens,
589 parse_max_tokens_overflow(&message).unwrap().context_limit
590 );
591 continue;
592 }
593
594 if attempt > config.max_retries && !should_retry(status, &message) {
596 log::debug!("[retry] Not retryable: status={:?}", status);
597 return Err(AgentError::Api(message));
598 }
599
600 if attempt <= config.max_retries {
602 let retry_after_ms = extract_retry_after_ms(status, &message);
603 let delay = get_retry_delay(attempt, retry_after_ms, config.max_delay_ms);
604
605 log::debug!(
606 "[retry] Waiting {}ms before retry {}/{}",
607 delay,
608 attempt + 1,
609 config.max_retries + 1
610 );
611
612 sleep(std::time::Duration::from_millis(delay)).await;
613
614 current_builder = match current_builder.try_clone() {
616 Some(b) => b,
617 None => {
618 return Err(AgentError::Api(
619 "Request builder cannot be cloned for retry".to_string(),
620 ));
621 }
622 };
623 }
624 }
625 }
626 }
627
628 Err(AgentError::Api(last_error_msg))
629}
630
631#[cfg(test)]
636mod tests {
637 use super::*;
638
639 #[test]
640 fn test_should_retry_401() {
641 assert!(should_retry(Some(401), "authentication failed"));
642 }
643
644 #[test]
645 fn test_should_retry_408() {
646 assert!(should_retry(Some(408), "request timeout"));
647 }
648
649 #[test]
650 fn test_should_retry_409() {
651 assert!(should_retry(Some(409), "conflict"));
652 }
653
654 #[test]
655 fn test_should_retry_429() {
656 assert!(should_retry(Some(429), "rate limit exceeded"));
657 }
658
659 #[test]
660 fn test_should_retry_500() {
661 assert!(should_retry(Some(500), "internal server error"));
662 }
663
664 #[test]
665 fn test_should_retry_502() {
666 assert!(should_retry(Some(502), "bad gateway"));
667 }
668
669 #[test]
670 fn test_should_retry_503() {
671 assert!(should_retry(Some(503), "service unavailable"));
672 }
673
674 #[test]
675 fn test_should_retry_529() {
676 assert!(should_retry(Some(529), "overloaded"));
677 }
678
679 #[test]
680 fn test_should_retry_connection_error() {
681 assert!(should_retry(None, "connection refused"));
682 assert!(should_retry(None, "ECONNRESET"));
683 }
684
685 #[test]
686 fn test_should_retry_529_via_message_body() {
687 assert!(should_retry(
688 None,
689 r#"{"error":{"type":"overloaded_error","message":"server overloaded"}}"#
690 ));
691 }
692
693 #[test]
694 fn test_should_retry_rate_limit_via_string() {
695 assert!(should_retry(
696 None,
697 "API error: Streaming API error 429 Too Many Requests"
698 ));
699 }
700
701 #[test]
702 fn test_should_not_retry_404() {
703 assert!(!should_retry(Some(404), "not found"));
704 }
705
706 #[test]
707 fn test_should_not_retry_400_non_overflow() {
708 assert!(!should_retry(Some(400), "bad request"));
709 }
710
711 #[test]
712 fn test_should_not_retry_403_non_revoked() {
713 assert!(!should_retry(Some(403), "forbidden"));
714 }
715
716 #[test]
717 fn test_is_529_error_by_status() {
718 assert!(is_529_error(Some(529), "any message"));
719 assert!(!is_529_error(Some(500), "any message"));
720 assert!(!is_529_error(None, "any message"));
721 }
722
723 #[test]
724 fn test_is_529_error_by_message_body() {
725 assert!(is_529_error(
726 None,
727 r#"{"error":{"type":"overloaded_error"}}"#
728 ));
729 assert!(!is_529_error(None, "normal error"));
730 }
731
732 #[test]
733 fn test_get_retry_delay_exponential() {
734 let config_max = MAX_DELAY_MS;
735
736 let d1 = get_retry_delay(1, None, config_max);
738 assert!(
739 d1 >= BASE_DELAY_MS && d1 < BASE_DELAY_MS + (BASE_DELAY_MS as f64 * 0.25) as u64 + 1
740 );
741
742 let d2 = get_retry_delay(2, None, config_max);
743 assert!(d2 >= BASE_DELAY_MS * 2);
744
745 let d4 = get_retry_delay(4, None, config_max);
746 assert!(d4 >= BASE_DELAY_MS * 8);
747 }
748
749 #[test]
750 fn test_get_retry_delay_cap() {
751 let d = get_retry_delay(20, None, MAX_DELAY_MS);
753 assert!(d <= MAX_DELAY_MS + (MAX_DELAY_MS as f64 * 0.25) as u64);
754 }
755
756 #[test]
757 fn test_get_retry_delay_retry_after_override() {
758 assert_eq!(get_retry_delay(5, Some(30_000), MAX_DELAY_MS), 30_000);
760 assert_eq!(get_retry_delay(1, Some(1_000), MAX_DELAY_MS), 1_000);
761 }
762
763 #[test]
764 fn test_extract_retry_after_from_message() {
765 assert_eq!(
766 extract_retry_after_from_message("error Retry-After: 30"),
767 Some(30_000)
768 );
769 assert_eq!(
770 extract_retry_after_from_message("error Retry-After: 60"),
771 Some(60_000)
772 );
773 assert_eq!(extract_retry_after_from_message("no header here"), None);
774 }
775
776 #[test]
777 fn test_extract_status_from_message() {
778 assert_eq!(
779 extract_status_from_message("429 Too Many Requests"),
780 Some(429)
781 );
782 assert_eq!(
783 extract_status_from_message("500 Internal Server Error"),
784 Some(500)
785 );
786 assert_eq!(
787 extract_status_from_message("error: 503 service unavailable"),
788 Some(503)
789 );
790 assert_eq!(extract_status_from_message("no status here"), None);
791 }
792
793 #[test]
794 fn test_parse_max_tokens_overflow() {
795 let data = parse_max_tokens_overflow(
796 "input length and `max_tokens` exceed context limit: 188059 + 20000 > 200000",
797 );
798 assert!(data.is_some());
799 let data = data.unwrap();
800 assert_eq!(data.input_tokens, 188059);
801 assert_eq!(data.max_tokens, 20000);
802 assert_eq!(data.context_limit, 200000);
803 }
804
805 #[test]
806 fn test_parse_max_tokens_overflow_fails() {
807 assert!(parse_max_tokens_overflow("prompt too long").is_none());
808 }
809
810 #[test]
811 fn test_with_retry_success() {
812 let call_count = std::sync::atomic::AtomicU32::new(0);
813 let operation = |_| {
814 let call_count = &call_count;
815 async move {
816 call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
817 Ok::<_, AgentError>("success")
818 }
819 };
820
821 let rt = tokio::runtime::Runtime::new().unwrap();
822 let result = rt.block_on(with_retry(operation, RetryConfig::default()));
823 assert!(result.is_ok());
824 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
825 }
826
827 #[test]
828 fn test_with_retry_success_after_fails() {
829 let call_count = std::sync::atomic::AtomicU32::new(0);
830 let operation = |_| {
831 let call_count = &call_count;
832 async move {
833 let count = call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
834 if count < 2 {
835 Err(AgentError::Api("temporary error".to_string()))
836 } else {
837 Ok::<_, AgentError>("success")
838 }
839 }
840 };
841
842 let rt = tokio::runtime::Runtime::new().unwrap();
843 let result = rt.block_on(with_retry(operation, RetryConfig::default()));
844 assert!(result.is_ok());
845 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
846 }
847
848 #[test]
849 fn test_with_retry_exhausted() {
850 let call_count = std::sync::atomic::AtomicU32::new(0);
851 let operation = |_| {
852 let call_count = &call_count;
853 async move {
854 call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
855 Err::<String, AgentError>(AgentError::Api("persistent error".to_string()))
856 }
857 };
858
859 let config = RetryConfig {
860 max_retries: 2,
861 ..Default::default()
862 };
863 let rt = tokio::runtime::Runtime::new().unwrap();
864 let result = rt.block_on(with_retry(operation, config));
865 assert!(result.is_err());
866 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
867 }
868
869 #[test]
870 fn test_with_retry_rate_limit_retries() {
871 let call_count = std::sync::atomic::AtomicU32::new(0);
872 let operation = |_| {
873 let call_count = &call_count;
874 async move {
875 let count = call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
876 if count < 2 {
877 Err(AgentError::Api(
878 "API error: Streaming API error 429 Too Many Requests".to_string(),
879 ))
880 } else {
881 Ok::<_, AgentError>("success")
882 }
883 }
884 };
885
886 let config = RetryConfig {
887 max_retries: 3,
888 ..Default::default()
889 };
890 let rt = tokio::runtime::Runtime::new().unwrap();
891 let result = rt.block_on(with_retry(operation, config));
892 assert!(result.is_ok());
893 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
894 }
895}