1use std::error::Error as StdError;
25use std::fmt;
26use std::io;
27use std::ops::ControlFlow;
28use std::time::Duration;
29
30#[derive(Debug, Clone, Copy)]
32pub struct RetryPolicy {
33 pub max_attempts: u32,
47 pub base_delay: Duration,
49 pub max_delay: Duration,
51}
52
53impl RetryPolicy {
54 pub const UPLOAD: RetryPolicy = RetryPolicy {
57 max_attempts: 10,
58 base_delay: Duration::from_millis(50),
59 max_delay: Duration::from_secs(30),
60 };
61
62 pub fn delay_for(&self, next_attempt: u32) -> Duration {
63 let exp = next_attempt.saturating_sub(2);
67 let mult = 1u64.checked_shl(exp).unwrap_or(u64::MAX);
68 let ms = (self.base_delay.as_millis() as u64).saturating_mul(mult);
69 std::cmp::min(Duration::from_millis(ms), self.max_delay)
70 }
71}
72
73pub fn retry_sync<T, E, F>(policy: &RetryPolicy, mut op: F) -> Result<T, E>
82where
83 F: FnMut(u32) -> Result<T, ControlFlow<E, E>>,
84{
85 let max = policy.max_attempts.max(1);
86 let mut attempt: u32 = 1;
87 loop {
88 if attempt > 1 {
89 std::thread::sleep(policy.delay_for(attempt));
90 }
91 match op(attempt) {
92 Ok(v) => return Ok(v),
93 Err(ControlFlow::Break(e)) => return Err(e),
94 Err(ControlFlow::Continue(e)) => {
95 if attempt >= max {
96 return Err(e);
97 }
98 }
99 }
100 attempt += 1;
101 }
102}
103
104pub async fn retry_async<T, E, F, Fut>(policy: &RetryPolicy, mut op: F) -> Result<T, E>
108where
109 F: FnMut(u32) -> Fut,
110 Fut: std::future::Future<Output = Result<T, ControlFlow<E, E>>>,
111{
112 let max = policy.max_attempts.max(1);
113 let mut attempt: u32 = 1;
114 loop {
115 if attempt > 1 {
116 tokio::time::sleep(policy.delay_for(attempt)).await;
117 }
118 match op(attempt).await {
119 Ok(v) => return Ok(v),
120 Err(ControlFlow::Break(e)) => return Err(e),
121 Err(ControlFlow::Continue(e)) => {
122 if attempt >= max {
123 return Err(e);
124 }
125 }
126 }
127 attempt += 1;
128 }
129}
130
131#[derive(Debug, Clone, Copy, PartialEq, Eq)]
135pub enum SuccessClass {
136 Strict,
139 AllowRedirects,
143}
144
145pub fn retry_http_blocking<F, M>(
172 label: &str,
173 policy: &RetryPolicy,
174 success_class: SuccessClass,
175 mut send: F,
176 error_msg: M,
177) -> anyhow::Result<(reqwest::StatusCode, String)>
178where
179 F: FnMut(u32) -> Result<reqwest::blocking::Response, reqwest::Error>,
180 M: Fn(reqwest::StatusCode, &str) -> String,
181{
182 use anyhow::Context as _;
183 retry_sync(policy, |attempt| {
184 match send(attempt) {
185 Ok(resp) => {
186 let status = resp.status();
187 let succeeded = match success_class {
188 SuccessClass::Strict => status.is_success(),
189 SuccessClass::AllowRedirects => status.is_success() || status.is_redirection(),
190 };
191 let body = resp
192 .text()
193 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
194 if succeeded {
195 Ok((status, body))
196 } else {
197 let msg = error_msg(status, &body);
198 let inner = anyhow::anyhow!("{msg}");
199 let wrapped = anyhow::Error::new(HttpError::new(
200 std::io::Error::other(inner.to_string()),
201 status.as_u16(),
202 ))
203 .context(inner);
204 if is_retriable(wrapped.as_ref()) {
210 Err(ControlFlow::Continue(wrapped))
211 } else {
212 Err(ControlFlow::Break(wrapped))
213 }
214 }
215 }
216 Err(e) => {
217 let err = anyhow::Error::new(HttpError::from_response(e, None))
221 .context(format!("{label}: HTTP transport error"));
222 if is_retriable(err.as_ref()) {
223 Err(ControlFlow::Continue(err))
224 } else {
225 Err(ControlFlow::Break(err))
226 }
227 }
228 }
229 })
230 .with_context(|| format!("{label}: exhausted retry attempts"))
231}
232
233pub async fn retry_http_async<F, Fut, M>(
254 label: &str,
255 policy: &RetryPolicy,
256 success_class: SuccessClass,
257 mut send: F,
258 error_msg: M,
259) -> anyhow::Result<reqwest::Response>
260where
261 F: FnMut(u32) -> Fut,
262 Fut: std::future::Future<Output = Result<reqwest::Response, reqwest::Error>>,
263 M: Fn(reqwest::StatusCode, &str) -> String,
264{
265 use anyhow::Context as _;
266 retry_async(policy, |attempt| {
267 let fut = send(attempt);
268 let error_msg = &error_msg;
269 async move {
270 match fut.await {
271 Ok(resp) => {
272 let status = resp.status();
273 let succeeded = match success_class {
274 SuccessClass::Strict => status.is_success(),
275 SuccessClass::AllowRedirects => {
276 status.is_success() || status.is_redirection()
277 }
278 };
279 if succeeded {
280 Ok(resp)
281 } else {
282 let body = resp
283 .text()
284 .await
285 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
286 let msg = error_msg(status, &body);
287 let inner = anyhow::anyhow!("{msg}");
288 let wrapped = anyhow::Error::new(HttpError::new(
289 std::io::Error::other(inner.to_string()),
290 status.as_u16(),
291 ))
292 .context(inner);
293 if is_retriable(wrapped.as_ref()) {
299 Err(ControlFlow::Continue(wrapped))
300 } else {
301 Err(ControlFlow::Break(wrapped))
302 }
303 }
304 }
305 Err(e) => {
306 let err = anyhow::Error::new(HttpError::from_response(e, None))
310 .context(format!("{label}: HTTP transport error"));
311 if is_retriable(err.as_ref()) {
312 Err(ControlFlow::Continue(err))
313 } else {
314 Err(ControlFlow::Break(err))
315 }
316 }
317 }
318 }
319 })
320 .await
321 .with_context(|| format!("{label}: exhausted retry attempts"))
322}
323
324pub fn classify_http_sync(
332 result: reqwest::Result<reqwest::blocking::Response>,
333) -> Result<reqwest::blocking::Response, ControlFlow<anyhow::Error, anyhow::Error>> {
334 use anyhow::anyhow;
335 match result {
336 Ok(resp) => {
337 let status = resp.status();
338 if status.is_success() || status.is_redirection() {
339 Ok(resp)
340 } else if status.is_server_error() {
341 Err(ControlFlow::Continue(anyhow!(
342 "HTTP {} {}",
343 status.as_u16(),
344 status.canonical_reason().unwrap_or("server error")
345 )))
346 } else {
347 Err(ControlFlow::Break(anyhow!(
349 "HTTP {} {}",
350 status.as_u16(),
351 status.canonical_reason().unwrap_or("client error")
352 )))
353 }
354 }
355 Err(e) => Err(ControlFlow::Continue(anyhow!(e))),
357 }
358}
359
360#[derive(Debug)]
376pub struct HttpError {
377 source: Box<dyn StdError + Send + Sync + 'static>,
380 pub status: u16,
382}
383
384impl HttpError {
385 pub fn new<E>(source: E, status: u16) -> Self
388 where
389 E: StdError + Send + Sync + 'static,
390 {
391 Self {
392 source: Box::new(source),
393 status,
394 }
395 }
396
397 pub fn from_response<E>(err: E, resp: Option<&reqwest::Response>) -> Self
401 where
402 E: StdError + Send + Sync + 'static,
403 {
404 Self::new(err, resp.map(|r| r.status().as_u16()).unwrap_or(0))
405 }
406}
407
408impl fmt::Display for HttpError {
409 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
410 fmt::Display::fmt(&self.source, f)
413 }
414}
415
416impl StdError for HttpError {
417 fn source(&self) -> Option<&(dyn StdError + 'static)> {
418 Some(&*self.source)
419 }
420}
421
422#[derive(Debug)]
428pub struct Retriable(Box<dyn StdError + Send + Sync + 'static>);
429
430impl Retriable {
431 pub fn new<E>(source: E) -> Self
437 where
438 E: StdError + Send + Sync + 'static,
439 {
440 Self(Box::new(source))
441 }
442}
443
444impl fmt::Display for Retriable {
445 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
446 fmt::Display::fmt(&self.0, f)
447 }
448}
449
450impl StdError for Retriable {
451 fn source(&self) -> Option<&(dyn StdError + 'static)> {
452 Some(&*self.0)
453 }
454}
455
456pub fn is_network_error(err: &(dyn StdError + 'static)) -> bool {
485 let mut cur: Option<&(dyn StdError + 'static)> = Some(err);
486 while let Some(e) = cur {
487 if let Some(io_err) = e.downcast_ref::<io::Error>() {
490 match io_err.kind() {
491 io::ErrorKind::UnexpectedEof
492 | io::ErrorKind::TimedOut
493 | io::ErrorKind::ConnectionRefused
494 | io::ErrorKind::ConnectionReset
495 | io::ErrorKind::ConnectionAborted
496 | io::ErrorKind::BrokenPipe => return true,
497 _ => {}
498 }
499 let m = io_err.to_string().to_lowercase();
500 if m == "eof" || m == "unexpected eof" {
501 return true;
502 }
503 }
504
505 let s = e.to_string().to_lowercase();
509 if NETWORK_ERROR_NEEDLES.iter().any(|n| s.contains(n)) {
510 return true;
511 }
512
513 cur = e.source();
514 }
515 false
516}
517
518const NETWORK_ERROR_NEEDLES: &[&str] = &[
528 "connection reset",
529 "network is unreachable",
530 "connection closed",
531 "connection refused",
532 "tls handshake timeout",
533 "i/o timeout",
534 "broken pipe",
535 "timeout awaiting response headers",
536 "context deadline exceeded",
537 "operation timed out",
539 "the network connection was aborted",
541 "an existing connection was forcibly closed",
543 "dns error",
550 "failed to lookup address",
553 "no such host is known",
555];
556
557pub fn is_retriable(err: &(dyn StdError + 'static)) -> bool {
569 let mut cur: Option<&(dyn StdError + 'static)> = Some(err);
571 while let Some(e) = cur {
572 if e.is::<Retriable>() {
573 return true;
574 }
575 if let Some(http) = e.downcast_ref::<HttpError>()
576 && (http.status >= 500 || http.status == 429)
577 {
578 return true;
579 }
580 cur = e.source();
581 }
582
583 is_network_error(err)
585}
586
587pub fn is_retriable_opt(err: Option<&(dyn StdError + 'static)>) -> bool {
590 err.is_some_and(is_retriable)
591}
592
593#[cfg(test)]
594mod tests {
595 use super::*;
596 use std::sync::atomic::{AtomicU32, Ordering};
597
598 fn fast_policy() -> RetryPolicy {
599 RetryPolicy {
600 max_attempts: 4,
601 base_delay: Duration::from_millis(1),
602 max_delay: Duration::from_millis(5),
603 }
604 }
605
606 #[test]
607 fn delay_progression_caps_at_max() {
608 let p = RetryPolicy {
609 max_attempts: 10,
610 base_delay: Duration::from_millis(100),
611 max_delay: Duration::from_millis(500),
612 };
613 assert_eq!(p.delay_for(2), Duration::from_millis(100));
614 assert_eq!(p.delay_for(3), Duration::from_millis(200));
615 assert_eq!(p.delay_for(4), Duration::from_millis(400));
616 assert_eq!(p.delay_for(5), Duration::from_millis(500)); assert_eq!(p.delay_for(8), Duration::from_millis(500)); }
619
620 #[test]
621 fn sync_succeeds_on_first_attempt() {
622 let calls = AtomicU32::new(0);
623 let result: Result<&str, ()> = retry_sync(&fast_policy(), |_| {
624 calls.fetch_add(1, Ordering::SeqCst);
625 Ok("ok")
626 });
627 assert_eq!(result, Ok("ok"));
628 assert_eq!(calls.load(Ordering::SeqCst), 1);
629 }
630
631 #[test]
632 fn sync_retries_until_success() {
633 let calls = AtomicU32::new(0);
634 let result: Result<u32, &str> = retry_sync(&fast_policy(), |attempt| {
635 calls.fetch_add(1, Ordering::SeqCst);
636 if attempt < 3 {
637 Err(ControlFlow::Continue("transient"))
638 } else {
639 Ok(attempt)
640 }
641 });
642 assert_eq!(result, Ok(3));
643 assert_eq!(calls.load(Ordering::SeqCst), 3);
644 }
645
646 #[test]
647 fn sync_break_stops_immediately() {
648 let calls = AtomicU32::new(0);
649 let result: Result<(), &str> = retry_sync(&fast_policy(), |_| {
650 calls.fetch_add(1, Ordering::SeqCst);
651 Err(ControlFlow::Break("fatal"))
652 });
653 assert_eq!(result, Err("fatal"));
654 assert_eq!(calls.load(Ordering::SeqCst), 1);
655 }
656
657 #[test]
658 fn sync_returns_last_error_after_exhaustion() {
659 let calls = AtomicU32::new(0);
660 let result: Result<(), String> = retry_sync(&fast_policy(), |attempt| {
661 calls.fetch_add(1, Ordering::SeqCst);
662 Err(ControlFlow::Continue(format!("fail {attempt}")))
663 });
664 assert_eq!(result, Err("fail 4".to_string()));
665 assert_eq!(calls.load(Ordering::SeqCst), 4);
666 }
667
668 #[tokio::test]
669 async fn async_retries_until_success() {
670 let calls = std::sync::Arc::new(AtomicU32::new(0));
671 let calls_inner = calls.clone();
672 let result: Result<u32, &str> = retry_async(&fast_policy(), move |attempt| {
673 let c = calls_inner.clone();
674 async move {
675 c.fetch_add(1, Ordering::SeqCst);
676 if attempt < 2 {
677 Err(ControlFlow::Continue("transient"))
678 } else {
679 Ok(attempt)
680 }
681 }
682 })
683 .await;
684 assert_eq!(result, Ok(2));
685 assert_eq!(calls.load(Ordering::SeqCst), 2);
686 }
687
688 #[derive(Debug)]
696 struct StrErr(&'static str);
697 impl fmt::Display for StrErr {
698 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
699 f.write_str(self.0)
700 }
701 }
702 impl StdError for StrErr {}
703
704 #[derive(Debug)]
705 struct OwnedErr(String);
706 impl fmt::Display for OwnedErr {
707 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
708 f.write_str(&self.0)
709 }
710 }
711 impl StdError for OwnedErr {}
712
713 #[test]
714 fn network_error_substrings_match() {
715 for s in [
716 "connection reset by peer",
717 "network is unreachable",
718 "connection closed unexpectedly",
719 "connection refused",
720 "tls handshake timeout",
721 "i/o timeout",
722 "CONNECTION RESET",
723 "TLS Handshake Timeout",
724 "write: broken pipe",
725 "net/http: timeout awaiting response headers",
726 "context deadline exceeded",
727 "client error (Connect): dns error: failed to lookup address information: Name or service not known",
732 "dns error: nodename nor servname provided, or not known",
733 "dns error: No such host is known. (os error 11001)",
734 ] {
735 let e = OwnedErr(s.to_string());
736 assert!(is_network_error(&e), "expected network error: {s:?}");
737 }
738 }
739
740 #[test]
741 fn network_error_io_eof_kinds() {
742 let e = io::Error::from(io::ErrorKind::UnexpectedEof);
743 assert!(is_network_error(&e));
744
745 let e2 = io::Error::other("EOF");
747 assert!(is_network_error(&e2));
748 }
749
750 #[test]
757 fn is_network_error_classifies_io_timedout() {
758 let e = io::Error::from(io::ErrorKind::TimedOut);
759 assert!(is_network_error(&e));
760 assert!(is_retriable(&e));
761 }
762
763 #[test]
764 fn is_network_error_classifies_io_connection_refused() {
765 let e = io::Error::from(io::ErrorKind::ConnectionRefused);
766 assert!(is_network_error(&e));
767 assert!(is_retriable(&e));
768 }
769
770 #[test]
771 fn is_network_error_classifies_io_connection_reset() {
772 let e = io::Error::from(io::ErrorKind::ConnectionReset);
773 assert!(is_network_error(&e));
774 assert!(is_retriable(&e));
775 }
776
777 #[test]
778 fn is_network_error_classifies_io_connection_aborted() {
779 let e = io::Error::from(io::ErrorKind::ConnectionAborted);
780 assert!(is_network_error(&e));
781 assert!(is_retriable(&e));
782 }
783
784 #[test]
785 fn is_network_error_classifies_io_broken_pipe() {
786 let e = io::Error::from(io::ErrorKind::BrokenPipe);
787 assert!(is_network_error(&e));
788 assert!(is_retriable(&e));
789 }
790
791 #[test]
792 fn is_network_error_classifies_operation_timed_out_substring() {
793 let other_kind = io::Error::other("operation timed out");
798 assert!(is_network_error(&other_kind));
799 assert!(is_retriable(&other_kind));
800
801 let kind_only = io::Error::from(io::ErrorKind::TimedOut);
802 assert!(is_network_error(&kind_only));
803 assert!(is_retriable(&kind_only));
804 }
805
806 #[test]
807 fn network_error_wrapped_unexpected_eof() {
808 #[derive(Debug)]
810 struct Wrap(io::Error);
811 impl fmt::Display for Wrap {
812 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
813 write!(f, "read failed")
814 }
815 }
816 impl StdError for Wrap {
817 fn source(&self) -> Option<&(dyn StdError + 'static)> {
818 Some(&self.0)
819 }
820 }
821 let inner = io::Error::from(io::ErrorKind::UnexpectedEof);
822 let outer = Wrap(inner);
823 assert!(is_network_error(&outer));
824 }
825
826 #[test]
827 fn network_error_non_network_strings_reject() {
828 for s in [
829 "file not found",
830 "permission denied",
831 "dial tcp: lookup example.com: no such host",
832 "",
833 ] {
834 let e = OwnedErr(s.to_string());
835 assert!(!is_network_error(&e), "expected NOT network error: {s:?}");
836 }
837 }
838
839 #[test]
840 fn retriable_opt_nil_passthrough() {
841 assert!(!is_retriable_opt(None));
842 }
843
844 #[test]
845 fn http_error_500_retriable() {
846 let e = HttpError::new(StrErr("internal server error"), 500);
847 assert!(is_retriable(&e));
848 }
849
850 #[test]
851 fn http_error_502_503_retriable() {
852 for s in [502u16, 503] {
853 let e = HttpError::new(StrErr("bad gateway"), s);
854 assert!(is_retriable(&e), "status {s} should be retriable");
855 }
856 }
857
858 #[test]
859 fn http_error_429_retriable() {
860 let e = HttpError::new(StrErr("rate limited"), 429);
861 assert!(is_retriable(&e));
862 }
863
864 #[test]
865 fn http_error_4xx_not_retriable() {
866 for s in [400u16, 401, 403, 404, 422] {
867 let e = HttpError::new(StrErr("client err"), s);
868 assert!(!is_retriable(&e), "status {s} should NOT be retriable");
869 }
870 }
871
872 #[test]
873 fn http_error_zero_status_routes_via_message() {
874 let net = HttpError::new(StrErr("connection reset"), 0);
877 assert!(is_retriable(&net));
878
879 let non_net = HttpError::new(StrErr("dial failed"), 0);
880 assert!(!is_retriable(&non_net));
881 }
882
883 #[test]
884 fn http_error_unwrap_chain_visible() {
885 let inner = StrErr("inner");
886 let e = HttpError::new(inner, 503);
887 assert!(e.source().is_some());
888 }
889
890 #[test]
891 fn from_response_nil_resp_yields_status_zero() {
892 let inner = io::Error::other("connect: dial tcp");
896 let e = HttpError::from_response(inner, None);
897 assert_eq!(e.status, 0);
898 }
899
900 #[test]
901 fn from_response_unwrap_chain_visible() {
902 let inner = io::Error::other("connection reset by peer");
905 let e = HttpError::from_response(inner, None);
906 assert!(
907 e.source().is_some(),
908 "inner error must be reachable via source()"
909 );
910 assert!(is_retriable(&e));
912 }
913
914 #[test]
915 fn retriable_wrapper_is_retriable() {
916 let e = Retriable::new(StrErr("retry me"));
917 assert!(is_retriable(&e));
918 }
919
920 #[test]
921 fn retriable_wrapper_overrides_4xx() {
922 let inner = HttpError::new(StrErr("exists"), 422);
924 let outer = Retriable::new(inner);
925 assert!(is_retriable(&outer));
926 }
927
928 #[test]
929 fn retriable_wrapper_unwrap_chain_visible() {
930 let inner = StrErr("inner");
931 let e = Retriable::new(inner);
932 assert!(e.source().is_some());
933 }
934
935 #[test]
936 fn plain_error_not_retriable() {
937 let e = StrErr("something");
938 assert!(!is_retriable(&e));
939 }
940
941 #[test]
942 fn anyhow_error_threadable() {
943 let e: anyhow::Error = anyhow::anyhow!("connection refused");
946 assert!(is_retriable(e.as_ref()));
947
948 let e2: anyhow::Error = anyhow::anyhow!("permission denied");
949 assert!(!is_retriable(e2.as_ref()));
950 }
951
952 #[test]
953 fn is_retriable_chain_walks_to_http_error() {
954 let inner = HttpError::new(StrErr("bad gateway"), 503);
958 let wrapped: anyhow::Error = anyhow::Error::new(inner).context("publish failed");
959 assert!(is_retriable(wrapped.as_ref()));
960 }
961
962 #[test]
973 fn classifier_5xx_via_anyhow_chain_uses_as_ref() {
974 let wrapped: anyhow::Error =
975 anyhow::Error::new(HttpError::new(std::io::Error::other("503"), 503))
976 .context("publish");
977 assert!(
978 is_retriable(wrapped.as_ref()),
979 "5xx HttpError reached via as_ref() must classify retriable"
980 );
981 }
982
983 #[test]
984 fn classifier_root_cause_walks_past_http_error_drift_guard() {
985 let wrapped: anyhow::Error =
990 anyhow::Error::new(HttpError::new(std::io::Error::other("503"), 503))
991 .context("publish");
992 assert!(
993 !is_retriable(wrapped.root_cause()),
994 "root_cause() walks past HttpError; 5xx must NOT be detected via the leaf"
995 );
996 }
997
998 #[test]
999 fn classifier_429_via_anyhow_chain_uses_as_ref() {
1000 let wrapped: anyhow::Error =
1003 anyhow::Error::new(HttpError::new(std::io::Error::other("429"), 429))
1004 .context("publish");
1005 assert!(is_retriable(wrapped.as_ref()));
1006 assert!(!is_retriable(wrapped.root_cause()));
1007 }
1008
1009 fn spawn_oneshot_http_responder(
1018 responses: Vec<&'static str>,
1019 ) -> (std::net::SocketAddr, std::sync::Arc<AtomicU32>) {
1020 use std::io::{Read, Write};
1021 use std::net::TcpListener;
1022
1023 let listener = TcpListener::bind("127.0.0.1:0").expect("bind ephemeral port");
1024 let addr = listener.local_addr().expect("local_addr");
1025 let counter = std::sync::Arc::new(AtomicU32::new(0));
1026 let counter_inner = counter.clone();
1027 std::thread::spawn(move || {
1028 for (i, resp) in responses.iter().enumerate() {
1029 let (mut stream, _) = match listener.accept() {
1030 Ok(pair) => pair,
1031 Err(_) => return, };
1033 counter_inner.fetch_add(1, Ordering::SeqCst);
1034 let mut buf = [0u8; 8192];
1037 let _ = stream.set_read_timeout(Some(Duration::from_millis(500)));
1038 let _ = stream.read(&mut buf);
1039 let _ = stream.write_all(resp.as_bytes());
1040 let _ = stream.flush();
1041 let _ = stream.shutdown(std::net::Shutdown::Both);
1042 if i == responses.len() - 1 {
1043 break;
1044 }
1045 }
1046 });
1047 (addr, counter)
1048 }
1049
1050 #[test]
1051 fn retry_http_blocking_success_returns_first_attempt() {
1052 let (addr, calls) =
1053 spawn_oneshot_http_responder(vec!["HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"]);
1054 let client = reqwest::blocking::Client::builder()
1055 .timeout(Duration::from_secs(2))
1056 .build()
1057 .expect("client");
1058 let policy = RetryPolicy {
1059 max_attempts: 3,
1060 base_delay: Duration::from_millis(1),
1061 max_delay: Duration::from_millis(2),
1062 };
1063 let result = retry_http_blocking(
1064 "test",
1065 &policy,
1066 SuccessClass::Strict,
1067 |_| client.get(format!("http://{addr}/")).send(),
1068 |_, _| String::from("should not be called on success"),
1069 );
1070 let (status, body) = result.expect("success");
1071 assert_eq!(status.as_u16(), 200);
1072 assert_eq!(body, "ok");
1073 assert_eq!(calls.load(Ordering::SeqCst), 1, "single attempt");
1074 }
1075
1076 #[test]
1077 fn retry_http_blocking_retries_5xx_then_succeeds() {
1078 let (addr, calls) = spawn_oneshot_http_responder(vec![
1079 "HTTP/1.1 503 Service Unavailable\r\nContent-Length: 0\r\n\r\n",
1080 "HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
1081 ]);
1082 let client = reqwest::blocking::Client::builder()
1083 .timeout(Duration::from_secs(2))
1084 .build()
1085 .expect("client");
1086 let policy = RetryPolicy {
1087 max_attempts: 3,
1088 base_delay: Duration::from_millis(1),
1089 max_delay: Duration::from_millis(2),
1090 };
1091 let result = retry_http_blocking(
1092 "test",
1093 &policy,
1094 SuccessClass::Strict,
1095 |_| client.get(format!("http://{addr}/")).send(),
1096 |status, body| format!("{status}: {body}"),
1097 );
1098 let (status, _) = result.expect("eventually succeeds");
1099 assert_eq!(status.as_u16(), 200);
1100 assert_eq!(calls.load(Ordering::SeqCst), 2, "one retry then success");
1101 }
1102
1103 #[test]
1104 fn retry_http_blocking_4xx_fast_fails_no_retry() {
1105 let (addr, calls) = spawn_oneshot_http_responder(vec![
1106 "HTTP/1.1 404 Not Found\r\nContent-Length: 9\r\n\r\nnot found",
1107 ]);
1108 let client = reqwest::blocking::Client::builder()
1109 .timeout(Duration::from_secs(2))
1110 .build()
1111 .expect("client");
1112 let policy = RetryPolicy {
1113 max_attempts: 5,
1114 base_delay: Duration::from_millis(1),
1115 max_delay: Duration::from_millis(2),
1116 };
1117 let result = retry_http_blocking(
1118 "myscope",
1119 &policy,
1120 SuccessClass::Strict,
1121 |_| client.get(format!("http://{addr}/")).send(),
1122 |status, body| format!("custom error: {status} body={body}"),
1123 );
1124 let err = result.expect_err("4xx must fast-fail");
1125 let chain = format!("{err:#}");
1126 assert!(
1127 chain.contains("custom error"),
1128 "error formatter must be invoked on non-success; got: {chain}"
1129 );
1130 assert!(chain.contains("404"), "status must be in chain: {chain}");
1131 assert_eq!(
1132 calls.load(Ordering::SeqCst),
1133 1,
1134 "4xx must NOT retry (only one connection accepted)"
1135 );
1136 }
1137
1138 #[test]
1139 fn retry_http_blocking_redirect_class_alters_success_predicate() {
1140 let (addr, _calls) = spawn_oneshot_http_responder(vec![
1141 "HTTP/1.1 307 Temporary Redirect\r\nLocation: /next\r\nContent-Length: 0\r\n\r\n",
1142 ]);
1143 let client = reqwest::blocking::Client::builder()
1144 .timeout(Duration::from_secs(2))
1145 .redirect(reqwest::redirect::Policy::none())
1147 .build()
1148 .expect("client");
1149 let policy = RetryPolicy {
1150 max_attempts: 3,
1151 base_delay: Duration::from_millis(1),
1152 max_delay: Duration::from_millis(2),
1153 };
1154 let result = retry_http_blocking(
1155 "test",
1156 &policy,
1157 SuccessClass::AllowRedirects,
1158 |_| client.get(format!("http://{addr}/")).send(),
1159 |_, _| String::from("should not be called on 3xx with AllowRedirects"),
1160 );
1161 let (status, _) = result.expect("3xx is success under AllowRedirects");
1162 assert_eq!(status.as_u16(), 307);
1163 }
1164
1165 #[tokio::test]
1176 async fn retry_http_async_success_returns_first_attempt() {
1177 let (addr, calls) =
1178 spawn_oneshot_http_responder(vec!["HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"]);
1179 let client = reqwest::Client::builder()
1180 .timeout(Duration::from_secs(2))
1181 .build()
1182 .expect("client");
1183 let policy = RetryPolicy {
1184 max_attempts: 3,
1185 base_delay: Duration::from_millis(1),
1186 max_delay: Duration::from_millis(2),
1187 };
1188 let result = retry_http_async(
1189 "test",
1190 &policy,
1191 SuccessClass::Strict,
1192 |_| client.get(format!("http://{addr}/")).send(),
1193 |_, _| String::from("should not be called on success"),
1194 )
1195 .await;
1196 let resp = result.expect("success");
1197 assert_eq!(resp.status().as_u16(), 200);
1198 let body = resp.text().await.expect("body");
1199 assert_eq!(body, "ok");
1200 assert_eq!(calls.load(Ordering::SeqCst), 1, "single attempt");
1201 }
1202
1203 #[tokio::test]
1204 async fn retry_http_async_retries_5xx_then_succeeds() {
1205 let (addr, calls) = spawn_oneshot_http_responder(vec![
1206 "HTTP/1.1 503 Service Unavailable\r\nContent-Length: 0\r\n\r\n",
1207 "HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
1208 ]);
1209 let client = reqwest::Client::builder()
1210 .timeout(Duration::from_secs(2))
1211 .build()
1212 .expect("client");
1213 let policy = RetryPolicy {
1214 max_attempts: 3,
1215 base_delay: Duration::from_millis(1),
1216 max_delay: Duration::from_millis(2),
1217 };
1218 let result = retry_http_async(
1219 "test",
1220 &policy,
1221 SuccessClass::Strict,
1222 |_| client.get(format!("http://{addr}/")).send(),
1223 |status, body| format!("{status}: {body}"),
1224 )
1225 .await;
1226 let resp = result.expect("eventually succeeds");
1227 assert_eq!(resp.status().as_u16(), 200);
1228 assert_eq!(calls.load(Ordering::SeqCst), 2, "one retry then success");
1229 }
1230
1231 #[tokio::test]
1232 async fn retry_http_async_4xx_fast_fails_no_retry() {
1233 let (addr, calls) = spawn_oneshot_http_responder(vec![
1234 "HTTP/1.1 404 Not Found\r\nContent-Length: 9\r\n\r\nnot found",
1235 ]);
1236 let client = reqwest::Client::builder()
1237 .timeout(Duration::from_secs(2))
1238 .build()
1239 .expect("client");
1240 let policy = RetryPolicy {
1241 max_attempts: 5,
1242 base_delay: Duration::from_millis(1),
1243 max_delay: Duration::from_millis(2),
1244 };
1245 let result = retry_http_async(
1246 "myscope",
1247 &policy,
1248 SuccessClass::Strict,
1249 |_| client.get(format!("http://{addr}/")).send(),
1250 |status, body| format!("custom error: {status} body={body}"),
1251 )
1252 .await;
1253 let err = result.expect_err("4xx must fast-fail");
1254 let chain = format!("{err:#}");
1255 assert!(
1256 chain.contains("custom error"),
1257 "error formatter must be invoked on non-success; got: {chain}"
1258 );
1259 assert!(chain.contains("404"), "status must be in chain: {chain}");
1260 assert_eq!(
1261 calls.load(Ordering::SeqCst),
1262 1,
1263 "4xx must NOT retry (only one connection accepted)"
1264 );
1265 }
1266
1267 #[tokio::test]
1268 async fn retry_http_async_429_retries_then_succeeds() {
1269 let (addr, calls) = spawn_oneshot_http_responder(vec![
1274 "HTTP/1.1 429 Too Many Requests\r\nContent-Length: 0\r\n\r\n",
1275 "HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
1276 ]);
1277 let client = reqwest::Client::builder()
1278 .timeout(Duration::from_secs(2))
1279 .build()
1280 .expect("client");
1281 let policy = RetryPolicy {
1282 max_attempts: 3,
1283 base_delay: Duration::from_millis(1),
1284 max_delay: Duration::from_millis(2),
1285 };
1286 let result = retry_http_async(
1287 "test",
1288 &policy,
1289 SuccessClass::Strict,
1290 |_| client.get(format!("http://{addr}/")).send(),
1291 |status, body| format!("{status}: {body}"),
1292 )
1293 .await;
1294 let resp = result.expect("429 retried then success");
1295 assert_eq!(resp.status().as_u16(), 200);
1296 assert_eq!(calls.load(Ordering::SeqCst), 2);
1297 }
1298
1299 const TRANSPORT_FAIL_URL: &str = "http://nonexistent.invalid/";
1323
1324 #[test]
1325 fn retry_http_blocking_transport_error_retries_then_fails() {
1326 let attempts = std::sync::Arc::new(AtomicU32::new(0));
1327 let attempts_inner = attempts.clone();
1328 let client = reqwest::blocking::Client::builder()
1329 .timeout(Duration::from_millis(500))
1330 .build()
1331 .expect("client");
1332 let policy = RetryPolicy {
1333 max_attempts: 3,
1334 base_delay: Duration::from_millis(1),
1335 max_delay: Duration::from_millis(2),
1336 };
1337 let result = retry_http_blocking(
1338 "test-transport",
1339 &policy,
1340 SuccessClass::Strict,
1341 |_| {
1342 attempts_inner.fetch_add(1, Ordering::SeqCst);
1343 client.get(TRANSPORT_FAIL_URL).send()
1344 },
1345 |_, _| String::from("non-success branch should not be reached"),
1346 );
1347 let err = result.expect_err("transport error must surface as Err");
1348 let chain = format!("{err:#}");
1349 assert!(
1350 attempts.load(Ordering::SeqCst) > 1,
1351 "transport error must be retried; got {} attempts; chain={chain}",
1352 attempts.load(Ordering::SeqCst)
1353 );
1354 assert!(
1355 chain.contains("test-transport"),
1356 "label must surface in error chain; got: {chain}"
1357 );
1358 }
1359
1360 #[tokio::test]
1361 async fn retry_http_async_transport_error_retries_then_fails() {
1362 let attempts = std::sync::Arc::new(AtomicU32::new(0));
1363 let attempts_inner = attempts.clone();
1364 let client = reqwest::Client::builder()
1365 .timeout(Duration::from_millis(500))
1366 .build()
1367 .expect("client");
1368 let policy = RetryPolicy {
1369 max_attempts: 3,
1370 base_delay: Duration::from_millis(1),
1371 max_delay: Duration::from_millis(2),
1372 };
1373 let result = retry_http_async(
1374 "test-transport-async",
1375 &policy,
1376 SuccessClass::Strict,
1377 |_| {
1378 attempts_inner.fetch_add(1, Ordering::SeqCst);
1379 client.get(TRANSPORT_FAIL_URL).send()
1380 },
1381 |_, _| String::from("non-success branch should not be reached"),
1382 )
1383 .await;
1384 let err = result.expect_err("transport error must surface as Err");
1385 assert!(
1386 attempts.load(Ordering::SeqCst) > 1,
1387 "transport error must be retried; got {} attempts",
1388 attempts.load(Ordering::SeqCst)
1389 );
1390 let chain = format!("{err:#}");
1391 assert!(
1392 chain.contains("test-transport-async"),
1393 "label must surface in error chain; got: {chain}"
1394 );
1395 }
1396}