1use crate::config::{ExponentialBackoff, RetryConfig, RetryTrigger};
2use crate::error::HttpError;
3use crate::response::{ResponseBody, parse_retry_after};
4use bytes::Bytes;
5use http::{HeaderValue, Request, Response};
6use http_body_util::{BodyExt, Full};
7use rand::Rng;
8use std::future::Future;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::time::Duration;
12use tower::{Layer, Service, ServiceExt};
13
14pub const RETRY_ATTEMPT_HEADER: &str = "X-Retry-Attempt";
17
18#[derive(Clone)]
23pub struct RetryLayer {
24 config: RetryConfig,
25 total_timeout: Option<Duration>,
26}
27
28impl RetryLayer {
29 #[must_use]
31 pub fn new(config: RetryConfig) -> Self {
32 Self {
33 config,
34 total_timeout: None,
35 }
36 }
37
38 #[must_use]
40 pub fn with_total_timeout(config: RetryConfig, total_timeout: Option<Duration>) -> Self {
41 Self {
42 config,
43 total_timeout,
44 }
45 }
46}
47
48impl<S> Layer<S> for RetryLayer {
49 type Service = RetryService<S>;
50
51 fn layer(&self, inner: S) -> Self::Service {
52 RetryService {
53 inner,
54 config: self.config.clone(),
55 total_timeout: self.total_timeout,
56 }
57 }
58}
59
60#[derive(Clone)]
75pub struct RetryService<S> {
76 inner: S,
77 config: RetryConfig,
78 total_timeout: Option<Duration>,
79}
80
81impl<S> Service<Request<Full<Bytes>>> for RetryService<S>
82where
83 S: Service<Request<Full<Bytes>>, Response = Response<ResponseBody>, Error = HttpError>
84 + Clone
85 + Send
86 + 'static,
87 S::Future: Send,
88{
89 type Response = S::Response;
90 type Error = HttpError;
91 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
92
93 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
94 self.inner.poll_ready(cx)
95 }
96
97 fn call(&mut self, req: Request<Full<Bytes>>) -> Self::Future {
98 let clone = self.inner.clone();
101 let inner = std::mem::replace(&mut self.inner, clone);
102 let config = self.config.clone();
103 let total_timeout = self.total_timeout;
104
105 let (parts, body_bytes) = req.into_parts();
106
107 let http_version = parts.version;
109
110 let extensions = parts.extensions.clone();
114
115 let has_idempotency_key = config
118 .idempotency_key_header
119 .as_ref()
120 .is_some_and(|name| parts.headers.contains_key(name));
121
122 let parts = std::sync::Arc::new(parts);
123
124 Box::pin(async move {
125 let method = parts.method.clone();
126
127 let url_host = parts
130 .uri
131 .authority()
132 .map(ToString::to_string)
133 .or_else(|| parts.uri.host().map(ToOwned::to_owned))
134 .unwrap_or_else(|| "unknown".to_owned());
135 let request_id = parts
136 .headers
137 .get("x-request-id")
138 .or_else(|| parts.headers.get("x-correlation-id"))
139 .and_then(|v| v.to_str().ok())
140 .map(String::from);
141
142 let deadline_info = total_timeout.map(|t| (tokio::time::Instant::now() + t, t));
145
146 let mut attempt = 0usize;
147 loop {
148 if let Some((deadline, timeout_duration)) = deadline_info
150 && tokio::time::Instant::now() >= deadline
151 {
152 return Err(HttpError::DeadlineExceeded(timeout_duration));
153 }
154
155 let mut req = Request::from_parts((*parts).clone(), body_bytes.clone());
157
158 *req.version_mut() = http_version;
160
161 *req.extensions_mut() = extensions.clone();
164
165 if attempt > 0 {
167 if let Ok(value) = HeaderValue::try_from(attempt.to_string()) {
169 req.headers_mut().insert(RETRY_ATTEMPT_HEADER, value);
170 }
171 }
172
173 let mut svc = inner.clone();
174 svc.ready().await?;
175
176 match svc.call(req).await {
177 Ok(resp) => {
178 let status_code = resp.status().as_u16();
180 let trigger = RetryTrigger::Status(status_code);
181
182 if config.max_retries > 0
183 && attempt < config.max_retries
184 && config.should_retry(trigger, &method, has_idempotency_key)
185 {
186 let retry_after = parse_retry_after(resp.headers())
190 .map(|d| d.min(config.backoff.max));
191 let backoff_duration = if config.ignore_retry_after {
192 calculate_backoff(&config.backoff, attempt)
193 } else {
194 retry_after
195 .unwrap_or_else(|| calculate_backoff(&config.backoff, attempt))
196 };
197
198 let drain_limit = config.retry_response_drain_limit;
200 let should_drain = if config.skip_drain_on_retry {
201 tracing::trace!("Skipping drain: skip_drain_on_retry enabled");
203 false
204 } else if let Some(content_length) = resp
205 .headers()
206 .get(http::header::CONTENT_LENGTH)
207 .and_then(|v| v.to_str().ok())
208 .and_then(|s| s.parse::<u64>().ok())
209 {
210 if content_length > drain_limit as u64 {
211 tracing::debug!(
214 content_length,
215 drain_limit,
216 "Skipping drain: Content-Length exceeds limit"
217 );
218 false
219 } else {
220 true
221 }
222 } else {
223 true
225 };
226
227 if should_drain
228 && let Err(e) = drain_response_body(resp, drain_limit).await
229 {
230 tracing::debug!(
232 error = %e,
233 "Failed to drain response body before retry; connection may not be reused"
234 );
235 }
236
237 let effective_backoff =
239 if let Some((deadline, timeout_duration)) = deadline_info {
240 let remaining = deadline
241 .saturating_duration_since(tokio::time::Instant::now());
242 if remaining.is_zero() {
243 return Err(HttpError::DeadlineExceeded(timeout_duration));
244 }
245 backoff_duration.min(remaining)
246 } else {
247 backoff_duration
248 };
249
250 tracing::debug!(
251 retry = attempt + 1,
252 max_retries = config.max_retries,
253 status = status_code,
254 trigger = ?trigger,
255 method = %method,
256 host = %url_host,
257 request_id = ?request_id,
258 backoff_ms = effective_backoff.as_millis(),
259 retry_after_used = retry_after.is_some() && !config.ignore_retry_after,
260 "Retrying request after status code"
261 );
262 tokio::time::sleep(effective_backoff).await;
263 attempt += 1;
264 continue;
265 }
266
267 return Ok(resp);
269 }
270 Err(err) => {
271 if config.max_retries == 0 || attempt >= config.max_retries {
272 return Err(err);
273 }
274
275 let trigger = get_retry_trigger(&err);
276 if !config.should_retry(trigger, &method, has_idempotency_key) {
277 return Err(err);
278 }
279
280 let backoff_duration = calculate_backoff(&config.backoff, attempt);
282
283 let effective_backoff =
285 if let Some((deadline, timeout_duration)) = deadline_info {
286 let remaining =
287 deadline.saturating_duration_since(tokio::time::Instant::now());
288 if remaining.is_zero() {
289 return Err(HttpError::DeadlineExceeded(timeout_duration));
290 }
291 backoff_duration.min(remaining)
292 } else {
293 backoff_duration
294 };
295
296 tracing::debug!(
297 retry = attempt + 1,
298 max_retries = config.max_retries,
299 error = %err,
300 trigger = ?trigger,
301 method = %method,
302 host = %url_host,
303 request_id = ?request_id,
304 backoff_ms = effective_backoff.as_millis(),
305 "Retrying request after error"
306 );
307 tokio::time::sleep(effective_backoff).await;
308 attempt += 1;
309 }
310 }
311 }
312 })
313 }
314}
315
316async fn drain_response_body(
340 response: Response<ResponseBody>,
341 limit: usize,
342) -> Result<(), HttpError> {
343 let (_parts, body) = response.into_parts();
344 let mut body = std::pin::pin!(body);
345 let mut drained = 0usize;
346
347 while let Some(frame) = body.frame().await {
348 let frame = frame.map_err(HttpError::Transport)?;
349 if let Some(chunk) = frame.data_ref() {
350 drained += chunk.len();
351 if drained >= limit {
352 break;
354 }
355 }
356 }
357
358 Ok(())
359}
360
361fn get_retry_trigger(err: &HttpError) -> RetryTrigger {
363 match err {
364 HttpError::Transport(_) => RetryTrigger::TransportError,
365 HttpError::Timeout(_) => RetryTrigger::Timeout,
366 _ => RetryTrigger::NonRetryable,
368 }
369}
370
371pub fn calculate_backoff(backoff: &ExponentialBackoff, attempt: usize) -> Duration {
375 const MAX_BACKOFF_SECS: f64 = 86400.0;
377
378 let attempt_i32 = i32::try_from(attempt).unwrap_or(i32::MAX);
381
382 let multiplier = if backoff.multiplier.is_finite() && backoff.multiplier >= 0.0 {
384 backoff.multiplier
385 } else {
386 1.0
387 };
388
389 let initial_secs = backoff.initial.as_secs_f64();
391 let initial_secs = if initial_secs.is_finite() && initial_secs >= 0.0 {
392 initial_secs
393 } else {
394 0.0
395 };
396
397 let max_secs = backoff.max.as_secs_f64();
399 let max_secs = if max_secs.is_finite() && max_secs >= 0.0 {
400 max_secs.min(MAX_BACKOFF_SECS)
401 } else {
402 MAX_BACKOFF_SECS
403 };
404
405 let base_duration = initial_secs * multiplier.powi(attempt_i32);
407
408 let clamped = if base_duration.is_finite() {
410 base_duration.min(max_secs).max(0.0)
411 } else {
412 max_secs
413 };
414 let duration = Duration::from_secs_f64(clamped);
415
416 let duration = if backoff.jitter {
418 let mut rng = rand::rng();
419 let jitter_factor = rng.random_range(0.0..=0.25);
420 let jitter = duration.mul_f64(jitter_factor);
421 duration + jitter
422 } else {
423 duration
424 };
425
426 let max_duration = Duration::from_secs_f64(max_secs);
428 duration.min(max_duration)
429}
430
431#[cfg(test)]
432#[cfg_attr(coverage_nightly, coverage(off))]
433mod tests {
434 use super::*;
435 use crate::config::IDEMPOTENCY_KEY_HEADER;
436 use bytes::Bytes;
437 use http::{Method, Request, Response, StatusCode};
438 use http_body_util::Full;
439
440 fn make_response_body(data: &[u8]) -> ResponseBody {
442 let body = Full::new(Bytes::from(data.to_vec()));
443 body.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })
444 .boxed()
445 }
446
447 #[tokio::test]
448 async fn test_retry_layer_successful_request() {
449 use std::sync::{Arc, Mutex};
450
451 #[derive(Clone)]
452 struct CountingService {
453 call_count: Arc<Mutex<usize>>,
454 }
455
456 impl Service<Request<Full<Bytes>>> for CountingService {
457 type Response = Response<ResponseBody>;
458 type Error = HttpError;
459 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
460
461 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
462 Poll::Ready(Ok(()))
463 }
464
465 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
466 let count = self.call_count.clone();
467 Box::pin(async move {
468 *count.lock().unwrap() += 1;
469 let response = Response::builder()
470 .status(StatusCode::OK)
471 .body(make_response_body(b""))
472 .unwrap();
473 Ok(response)
474 })
475 }
476 }
477
478 let call_count = Arc::new(Mutex::new(0));
479 let service = CountingService {
480 call_count: call_count.clone(),
481 };
482
483 let retry_config = RetryConfig::default();
484 let layer = RetryLayer::new(retry_config);
485 let mut retry_service = layer.layer(service);
486
487 let req = Request::builder()
488 .method(Method::GET)
489 .uri("http://example.com")
490 .body(Full::new(Bytes::new()))
491 .unwrap();
492
493 let result = retry_service.call(req).await;
494 assert!(result.is_ok());
495 assert_eq!(*call_count.lock().unwrap(), 1); }
497
498 #[tokio::test]
501 async fn test_retry_layer_post_not_retried_on_5xx() {
502 use std::sync::{Arc, Mutex};
503
504 #[derive(Clone)]
505 struct ServerErrorService {
506 call_count: Arc<Mutex<usize>>,
507 }
508
509 impl Service<Request<Full<Bytes>>> for ServerErrorService {
510 type Response = Response<ResponseBody>;
511 type Error = HttpError;
512 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
513
514 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
515 Poll::Ready(Ok(()))
516 }
517
518 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
519 let count = self.call_count.clone();
520 Box::pin(async move {
521 *count.lock().unwrap() += 1;
522 Ok(Response::builder()
524 .status(StatusCode::INTERNAL_SERVER_ERROR)
525 .body(make_response_body(b"Internal Server Error"))
526 .unwrap())
527 })
528 }
529 }
530
531 let call_count = Arc::new(Mutex::new(0));
532 let service = ServerErrorService {
533 call_count: call_count.clone(),
534 };
535
536 let retry_config = RetryConfig {
537 backoff: ExponentialBackoff::fast(),
538 ..RetryConfig::default()
539 };
540 let layer = RetryLayer::new(retry_config);
541 let mut retry_service = layer.layer(service);
542
543 let req = Request::builder()
544 .method(Method::POST)
545 .uri("http://example.com")
546 .body(Full::new(Bytes::new()))
547 .unwrap();
548
549 let result = retry_service.call(req).await;
550 assert!(result.is_ok());
552 let resp = result.unwrap();
553 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
554 assert_eq!(*call_count.lock().unwrap(), 1); }
556
557 #[tokio::test]
560 async fn test_retry_layer_get_retried_on_5xx() {
561 use std::sync::{Arc, Mutex};
562
563 #[derive(Clone)]
564 struct FailThenSucceedService {
565 call_count: Arc<Mutex<usize>>,
566 }
567
568 impl Service<Request<Full<Bytes>>> for FailThenSucceedService {
569 type Response = Response<ResponseBody>;
570 type Error = HttpError;
571 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
572
573 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
574 Poll::Ready(Ok(()))
575 }
576
577 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
578 let count = self.call_count.clone();
579 Box::pin(async move {
580 let mut c = count.lock().unwrap();
581 *c += 1;
582 if *c < 3 {
583 Ok(Response::builder()
585 .status(StatusCode::INTERNAL_SERVER_ERROR)
586 .body(make_response_body(b"Internal Server Error"))
587 .unwrap())
588 } else {
589 Ok(Response::builder()
590 .status(StatusCode::OK)
591 .body(make_response_body(b""))
592 .unwrap())
593 }
594 })
595 }
596 }
597
598 let call_count = Arc::new(Mutex::new(0));
599 let service = FailThenSucceedService {
600 call_count: call_count.clone(),
601 };
602
603 let retry_config = RetryConfig {
604 backoff: ExponentialBackoff::fast(),
605 ..RetryConfig::default()
606 };
607 let layer = RetryLayer::new(retry_config);
608 let mut retry_service = layer.layer(service);
609
610 let req = Request::builder()
611 .method(Method::GET)
612 .uri("http://example.com")
613 .body(Full::new(Bytes::new()))
614 .unwrap();
615
616 let result = retry_service.call(req).await;
617 assert!(result.is_ok());
618 assert_eq!(result.unwrap().status(), StatusCode::OK);
619 assert_eq!(*call_count.lock().unwrap(), 3); }
621
622 #[tokio::test]
624 async fn test_retry_layer_always_retries_429() {
625 use std::sync::{Arc, Mutex};
626
627 #[derive(Clone)]
628 struct RateLimitThenSucceedService {
629 call_count: Arc<Mutex<usize>>,
630 }
631
632 impl Service<Request<Full<Bytes>>> for RateLimitThenSucceedService {
633 type Response = Response<ResponseBody>;
634 type Error = HttpError;
635 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
636
637 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
638 Poll::Ready(Ok(()))
639 }
640
641 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
642 let count = self.call_count.clone();
643 Box::pin(async move {
644 let mut c = count.lock().unwrap();
645 *c += 1;
646 if *c < 2 {
647 Ok(Response::builder()
649 .status(StatusCode::TOO_MANY_REQUESTS)
650 .body(make_response_body(b"Rate limited"))
651 .unwrap())
652 } else {
653 Ok(Response::builder()
654 .status(StatusCode::OK)
655 .body(make_response_body(b""))
656 .unwrap())
657 }
658 })
659 }
660 }
661
662 let call_count = Arc::new(Mutex::new(0));
663 let service = RateLimitThenSucceedService {
664 call_count: call_count.clone(),
665 };
666
667 let retry_config = RetryConfig {
668 backoff: ExponentialBackoff::fast(),
669 ..RetryConfig::default()
670 };
671 let layer = RetryLayer::new(retry_config);
672 let mut retry_service = layer.layer(service);
673
674 let req = Request::builder()
676 .method(Method::POST)
677 .uri("http://example.com")
678 .body(Full::new(Bytes::new()))
679 .unwrap();
680
681 let result = retry_service.call(req).await;
682 assert!(result.is_ok());
683 assert_eq!(result.unwrap().status(), StatusCode::OK);
684 assert_eq!(*call_count.lock().unwrap(), 2); }
686
687 #[tokio::test]
688 async fn test_retry_layer_retries_transport_errors() {
689 use std::sync::{Arc, Mutex};
690
691 #[derive(Clone)]
692 struct FailThenSucceedService {
693 call_count: Arc<Mutex<usize>>,
694 }
695
696 impl Service<Request<Full<Bytes>>> for FailThenSucceedService {
697 type Response = Response<ResponseBody>;
698 type Error = HttpError;
699 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
700
701 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
702 Poll::Ready(Ok(()))
703 }
704
705 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
706 let count = self.call_count.clone();
707 Box::pin(async move {
708 let mut c = count.lock().unwrap();
709 *c += 1;
710 if *c < 3 {
711 Err(HttpError::Transport(Box::new(std::io::Error::new(
712 std::io::ErrorKind::ConnectionReset,
713 "connection reset",
714 ))))
715 } else {
716 Ok(Response::builder()
717 .status(StatusCode::OK)
718 .body(make_response_body(b""))
719 .unwrap())
720 }
721 })
722 }
723 }
724
725 let call_count = Arc::new(Mutex::new(0));
726 let service = FailThenSucceedService {
727 call_count: call_count.clone(),
728 };
729
730 let retry_config = RetryConfig {
731 backoff: ExponentialBackoff::fast(),
732 ..RetryConfig::default()
733 };
734 let layer = RetryLayer::new(retry_config);
735 let mut retry_service = layer.layer(service);
736
737 let req = Request::builder()
738 .method(Method::GET)
739 .uri("http://example.com")
740 .body(Full::new(Bytes::new()))
741 .unwrap();
742
743 let result = retry_service.call(req).await;
744 assert!(result.is_ok());
745 assert_eq!(*call_count.lock().unwrap(), 3); }
747
748 #[tokio::test]
750 async fn test_retry_layer_post_not_retried_on_transport_error() {
751 use std::sync::{Arc, Mutex};
752
753 #[derive(Clone)]
754 struct TransportErrorService {
755 call_count: Arc<Mutex<usize>>,
756 }
757
758 impl Service<Request<Full<Bytes>>> for TransportErrorService {
759 type Response = Response<ResponseBody>;
760 type Error = HttpError;
761 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
762
763 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
764 Poll::Ready(Ok(()))
765 }
766
767 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
768 let count = self.call_count.clone();
769 Box::pin(async move {
770 *count.lock().unwrap() += 1;
771 Err(HttpError::Transport(Box::new(std::io::Error::new(
772 std::io::ErrorKind::ConnectionReset,
773 "connection reset",
774 ))))
775 })
776 }
777 }
778
779 let call_count = Arc::new(Mutex::new(0));
780 let service = TransportErrorService {
781 call_count: call_count.clone(),
782 };
783
784 let retry_config = RetryConfig {
785 backoff: ExponentialBackoff::fast(),
786 ..RetryConfig::default()
787 };
788 let layer = RetryLayer::new(retry_config);
789 let mut retry_service = layer.layer(service);
790
791 let req = Request::builder()
793 .method(Method::POST)
794 .uri("http://example.com")
795 .body(Full::new(Bytes::new()))
796 .unwrap();
797
798 let result = retry_service.call(req).await;
799 assert!(result.is_err()); assert_eq!(*call_count.lock().unwrap(), 1); }
802
803 #[tokio::test]
805 async fn test_retry_layer_post_with_idempotency_key_retried() {
806 use std::sync::{Arc, Mutex};
807
808 #[derive(Clone)]
809 struct FailThenSucceedService {
810 call_count: Arc<Mutex<usize>>,
811 }
812
813 impl Service<Request<Full<Bytes>>> for FailThenSucceedService {
814 type Response = Response<ResponseBody>;
815 type Error = HttpError;
816 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
817
818 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
819 Poll::Ready(Ok(()))
820 }
821
822 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
823 let count = self.call_count.clone();
824 Box::pin(async move {
825 let mut c = count.lock().unwrap();
826 *c += 1;
827 if *c < 3 {
828 Err(HttpError::Transport(Box::new(std::io::Error::new(
829 std::io::ErrorKind::ConnectionReset,
830 "connection reset",
831 ))))
832 } else {
833 Ok(Response::builder()
834 .status(StatusCode::OK)
835 .body(make_response_body(b""))
836 .unwrap())
837 }
838 })
839 }
840 }
841
842 let call_count = Arc::new(Mutex::new(0));
843 let service = FailThenSucceedService {
844 call_count: call_count.clone(),
845 };
846
847 let retry_config = RetryConfig {
848 backoff: ExponentialBackoff::fast(),
849 ..RetryConfig::default()
850 };
851 let layer = RetryLayer::new(retry_config);
852 let mut retry_service = layer.layer(service);
853
854 let req = Request::builder()
856 .method(Method::POST)
857 .uri("http://example.com")
858 .header(IDEMPOTENCY_KEY_HEADER, "unique-key-123")
859 .body(Full::new(Bytes::new()))
860 .unwrap();
861
862 let result = retry_service.call(req).await;
863 assert!(result.is_ok()); assert_eq!(*call_count.lock().unwrap(), 3); }
866
867 #[tokio::test]
868 async fn test_retry_layer_does_not_retry_json_errors() {
869 use std::sync::{Arc, Mutex};
870
871 #[derive(Clone)]
872 struct JsonErrorService {
873 call_count: Arc<Mutex<usize>>,
874 }
875
876 impl Service<Request<Full<Bytes>>> for JsonErrorService {
877 type Response = Response<ResponseBody>;
878 type Error = HttpError;
879 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
880
881 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
882 Poll::Ready(Ok(()))
883 }
884
885 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
886 let count = self.call_count.clone();
887 Box::pin(async move {
888 *count.lock().unwrap() += 1;
889 let err: serde_json::Error =
891 serde_json::from_str::<serde_json::Value>("invalid").unwrap_err();
892 Err(HttpError::Json(err))
893 })
894 }
895 }
896
897 let call_count = Arc::new(Mutex::new(0));
898 let service = JsonErrorService {
899 call_count: call_count.clone(),
900 };
901
902 let retry_config = RetryConfig::default();
903 let layer = RetryLayer::new(retry_config);
904 let mut retry_service = layer.layer(service);
905
906 let req = Request::builder()
907 .method(Method::GET)
908 .uri("http://example.com")
909 .body(Full::new(Bytes::new()))
910 .unwrap();
911
912 let result = retry_service.call(req).await;
913 assert!(result.is_err());
914 assert_eq!(*call_count.lock().unwrap(), 1); }
916
917 #[test]
918 fn test_calculate_backoff_no_jitter() {
919 let backoff = ExponentialBackoff {
920 initial: Duration::from_millis(100),
921 max: Duration::from_secs(10),
922 multiplier: 2.0,
923 jitter: false,
924 };
925
926 let backoff0 = calculate_backoff(&backoff, 0);
927 assert_eq!(backoff0, Duration::from_millis(100));
928
929 let backoff1 = calculate_backoff(&backoff, 1);
930 assert_eq!(backoff1, Duration::from_millis(200));
931
932 let backoff2 = calculate_backoff(&backoff, 2);
933 assert_eq!(backoff2, Duration::from_millis(400));
934
935 let backoff_capped = calculate_backoff(&backoff, 10);
937 assert_eq!(backoff_capped, Duration::from_secs(10));
938 }
939
940 #[test]
941 fn test_calculate_backoff_with_jitter() {
942 let backoff = ExponentialBackoff {
943 initial: Duration::from_millis(100),
944 max: Duration::from_secs(10),
945 multiplier: 2.0,
946 jitter: true,
947 };
948
949 let backoff0 = calculate_backoff(&backoff, 0);
950 assert!(backoff0 >= Duration::from_millis(100));
952 assert!(backoff0 <= Duration::from_millis(125));
953 }
954
955 #[test]
956 fn test_calculate_backoff_with_nan_multiplier() {
957 let backoff = ExponentialBackoff {
959 initial: Duration::from_millis(100),
960 max: Duration::from_secs(10),
961 multiplier: f64::NAN,
962 jitter: false,
963 };
964
965 let result = calculate_backoff(&backoff, 0);
967 assert_eq!(result, Duration::from_millis(100));
968
969 let result1 = calculate_backoff(&backoff, 1);
970 assert_eq!(result1, Duration::from_millis(100));
972 }
973
974 #[test]
975 fn test_calculate_backoff_with_infinity_multiplier() {
976 let backoff = ExponentialBackoff {
978 initial: Duration::from_millis(100),
979 max: Duration::from_secs(10),
980 multiplier: f64::INFINITY,
981 jitter: false,
982 };
983
984 let result = calculate_backoff(&backoff, 0);
986 assert_eq!(result, Duration::from_millis(100));
987 }
988
989 #[test]
990 fn test_calculate_backoff_with_negative_multiplier() {
991 let backoff = ExponentialBackoff {
993 initial: Duration::from_millis(100),
994 max: Duration::from_secs(10),
995 multiplier: -2.0,
996 jitter: false,
997 };
998
999 let result = calculate_backoff(&backoff, 0);
1001 assert_eq!(result, Duration::from_millis(100));
1002 }
1003
1004 #[test]
1005 fn test_calculate_backoff_with_huge_attempt() {
1006 let backoff = ExponentialBackoff {
1008 initial: Duration::from_millis(100),
1009 max: Duration::from_secs(10),
1010 multiplier: 2.0,
1011 jitter: false,
1012 };
1013
1014 let result = calculate_backoff(&backoff, usize::MAX);
1016 assert_eq!(result, Duration::from_secs(10));
1018 }
1019
1020 #[tokio::test]
1022 async fn test_retry_layer_uses_retry_after_header() {
1023 use std::sync::{Arc, Mutex};
1024
1025 #[derive(Clone)]
1026 struct RetryAfterService {
1027 call_count: Arc<Mutex<usize>>,
1028 }
1029
1030 impl Service<Request<Full<Bytes>>> for RetryAfterService {
1031 type Response = Response<ResponseBody>;
1032 type Error = HttpError;
1033 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1034
1035 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1036 Poll::Ready(Ok(()))
1037 }
1038
1039 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
1040 let count = self.call_count.clone();
1041 Box::pin(async move {
1042 let mut c = count.lock().unwrap();
1043 *c += 1;
1044 if *c < 2 {
1045 Ok(Response::builder()
1047 .status(StatusCode::TOO_MANY_REQUESTS)
1048 .header(http::header::RETRY_AFTER, "0")
1049 .body(make_response_body(b"Rate limited"))
1050 .unwrap())
1051 } else {
1052 Ok(Response::builder()
1053 .status(StatusCode::OK)
1054 .body(make_response_body(b""))
1055 .unwrap())
1056 }
1057 })
1058 }
1059 }
1060
1061 let call_count = Arc::new(Mutex::new(0));
1062 let service = RetryAfterService {
1063 call_count: call_count.clone(),
1064 };
1065
1066 let retry_config = RetryConfig {
1067 backoff: ExponentialBackoff {
1068 initial: Duration::from_secs(10), jitter: false,
1070 ..ExponentialBackoff::default()
1071 },
1072 ignore_retry_after: false, ..RetryConfig::default()
1074 };
1075 let layer = RetryLayer::new(retry_config);
1076 let mut retry_service = layer.layer(service);
1077
1078 let req = Request::builder()
1079 .method(Method::POST)
1080 .uri("http://example.com")
1081 .body(Full::new(Bytes::new()))
1082 .unwrap();
1083
1084 let start = std::time::Instant::now();
1085 let result = retry_service.call(req).await;
1086 let elapsed = start.elapsed();
1087
1088 assert!(result.is_ok());
1089 assert_eq!(*call_count.lock().unwrap(), 2);
1090
1091 assert!(
1093 elapsed < Duration::from_secs(1),
1094 "Expected quick retry using Retry-After, but took {elapsed:?}",
1095 );
1096 }
1097
1098 #[tokio::test]
1100 async fn test_retry_layer_ignores_retry_after_when_configured() {
1101 use std::sync::{Arc, Mutex};
1102
1103 #[derive(Clone)]
1104 struct RetryAfterService {
1105 call_count: Arc<Mutex<usize>>,
1106 }
1107
1108 impl Service<Request<Full<Bytes>>> for RetryAfterService {
1109 type Response = Response<ResponseBody>;
1110 type Error = HttpError;
1111 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1112
1113 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1114 Poll::Ready(Ok(()))
1115 }
1116
1117 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
1118 let count = self.call_count.clone();
1119 Box::pin(async move {
1120 let mut c = count.lock().unwrap();
1121 *c += 1;
1122 if *c < 2 {
1123 Ok(Response::builder()
1125 .status(StatusCode::TOO_MANY_REQUESTS)
1126 .header(http::header::RETRY_AFTER, "10")
1127 .body(make_response_body(b"Rate limited"))
1128 .unwrap())
1129 } else {
1130 Ok(Response::builder()
1131 .status(StatusCode::OK)
1132 .body(make_response_body(b""))
1133 .unwrap())
1134 }
1135 })
1136 }
1137 }
1138
1139 let call_count = Arc::new(Mutex::new(0));
1140 let service = RetryAfterService {
1141 call_count: call_count.clone(),
1142 };
1143
1144 let retry_config = RetryConfig {
1145 backoff: ExponentialBackoff::fast(), ignore_retry_after: true, ..RetryConfig::default()
1148 };
1149 let layer = RetryLayer::new(retry_config);
1150 let mut retry_service = layer.layer(service);
1151
1152 let req = Request::builder()
1153 .method(Method::POST)
1154 .uri("http://example.com")
1155 .body(Full::new(Bytes::new()))
1156 .unwrap();
1157
1158 let start = std::time::Instant::now();
1159 let result = retry_service.call(req).await;
1160 let elapsed = start.elapsed();
1161
1162 assert!(result.is_ok());
1163 assert_eq!(*call_count.lock().unwrap(), 2);
1164
1165 assert!(
1167 elapsed < Duration::from_secs(1),
1168 "Expected quick retry using backoff policy (1ms), but took {elapsed:?}",
1169 );
1170 }
1171
1172 #[tokio::test]
1175 async fn test_retry_after_clamped_to_backoff_max() {
1176 use std::sync::{Arc, Mutex};
1177
1178 #[derive(Clone)]
1179 struct LargeRetryAfterService {
1180 call_count: Arc<Mutex<usize>>,
1181 }
1182
1183 impl Service<Request<Full<Bytes>>> for LargeRetryAfterService {
1184 type Response = Response<ResponseBody>;
1185 type Error = HttpError;
1186 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1187
1188 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1189 Poll::Ready(Ok(()))
1190 }
1191
1192 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
1193 let count = self.call_count.clone();
1194 Box::pin(async move {
1195 let mut c = count.lock().unwrap();
1196 *c += 1;
1197 if *c < 2 {
1198 Ok(Response::builder()
1200 .status(StatusCode::TOO_MANY_REQUESTS)
1201 .header(http::header::RETRY_AFTER, "3600")
1202 .body(make_response_body(b"Rate limited"))
1203 .unwrap())
1204 } else {
1205 Ok(Response::builder()
1206 .status(StatusCode::OK)
1207 .body(make_response_body(b""))
1208 .unwrap())
1209 }
1210 })
1211 }
1212 }
1213
1214 let call_count = Arc::new(Mutex::new(0));
1215 let service = LargeRetryAfterService {
1216 call_count: call_count.clone(),
1217 };
1218
1219 let retry_config = RetryConfig {
1220 backoff: ExponentialBackoff {
1221 initial: Duration::from_millis(1),
1222 max: Duration::from_millis(50), jitter: false,
1224 ..ExponentialBackoff::default()
1225 },
1226 ignore_retry_after: false, ..RetryConfig::default()
1228 };
1229 let layer = RetryLayer::new(retry_config);
1230 let mut retry_service = layer.layer(service);
1231
1232 let req = Request::builder()
1233 .method(Method::POST)
1234 .uri("http://example.com")
1235 .body(Full::new(Bytes::new()))
1236 .unwrap();
1237
1238 let start = std::time::Instant::now();
1239 let result = retry_service.call(req).await;
1240 let elapsed = start.elapsed();
1241
1242 assert!(result.is_ok());
1243 assert_eq!(*call_count.lock().unwrap(), 2);
1244
1245 assert!(
1248 elapsed < Duration::from_secs(1),
1249 "Retry-After should be clamped to backoff.max (50ms), but took {elapsed:?}",
1250 );
1251 }
1252
1253 #[tokio::test]
1254 async fn test_retry_attempt_header_added_on_retry() {
1255 use std::sync::{Arc, Mutex};
1256
1257 #[derive(Clone)]
1258 struct HeaderCapturingService {
1259 call_count: Arc<Mutex<usize>>,
1260 captured_headers: Arc<Mutex<Vec<Option<String>>>>,
1261 }
1262
1263 impl Service<Request<Full<Bytes>>> for HeaderCapturingService {
1264 type Response = Response<ResponseBody>;
1265 type Error = HttpError;
1266 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1267
1268 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1269 Poll::Ready(Ok(()))
1270 }
1271
1272 fn call(&mut self, req: Request<Full<Bytes>>) -> Self::Future {
1273 let count = self.call_count.clone();
1274 let captured_headers = self.captured_headers.clone();
1275
1276 let retry_header = req
1278 .headers()
1279 .get(RETRY_ATTEMPT_HEADER)
1280 .map(|v| v.to_str().unwrap_or("invalid").to_owned());
1281
1282 Box::pin(async move {
1283 let mut c = count.lock().unwrap();
1284 *c += 1;
1285 captured_headers.lock().unwrap().push(retry_header);
1286
1287 if *c < 3 {
1288 Err(HttpError::Transport(Box::new(std::io::Error::new(
1290 std::io::ErrorKind::ConnectionReset,
1291 "connection reset",
1292 ))))
1293 } else {
1294 Ok(Response::builder()
1295 .status(StatusCode::OK)
1296 .body(make_response_body(b""))
1297 .unwrap())
1298 }
1299 })
1300 }
1301 }
1302
1303 let call_count = Arc::new(Mutex::new(0));
1304 let captured_headers = Arc::new(Mutex::new(Vec::new()));
1305 let service = HeaderCapturingService {
1306 call_count: call_count.clone(),
1307 captured_headers: captured_headers.clone(),
1308 };
1309
1310 let retry_config = RetryConfig {
1311 backoff: ExponentialBackoff::fast(),
1312 ..RetryConfig::default()
1313 };
1314 let layer = RetryLayer::new(retry_config);
1315 let mut retry_service = layer.layer(service);
1316
1317 let req = Request::builder()
1318 .method(Method::GET)
1319 .uri("http://example.com")
1320 .body(Full::new(Bytes::new()))
1321 .unwrap();
1322
1323 let result = retry_service.call(req).await;
1324 assert!(result.is_ok());
1325 assert_eq!(*call_count.lock().unwrap(), 3);
1326
1327 let headers = captured_headers.lock().unwrap();
1329 assert_eq!(headers.len(), 3);
1330 assert_eq!(headers[0], None);
1332 assert_eq!(headers[1], Some("1".to_owned()));
1334 assert_eq!(headers[2], Some("2".to_owned()));
1336 }
1337
1338 #[tokio::test]
1340 async fn test_retry_layer_exhausted_returns_ok_with_status() {
1341 use std::sync::{Arc, Mutex};
1342
1343 #[derive(Clone)]
1344 struct AlwaysFailService {
1345 call_count: Arc<Mutex<usize>>,
1346 }
1347
1348 impl Service<Request<Full<Bytes>>> for AlwaysFailService {
1349 type Response = Response<ResponseBody>;
1350 type Error = HttpError;
1351 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1352
1353 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1354 Poll::Ready(Ok(()))
1355 }
1356
1357 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
1358 let count = self.call_count.clone();
1359 Box::pin(async move {
1360 *count.lock().unwrap() += 1;
1361 Ok(Response::builder()
1363 .status(StatusCode::INTERNAL_SERVER_ERROR)
1364 .body(make_response_body(b"error"))
1365 .unwrap())
1366 })
1367 }
1368 }
1369
1370 let call_count = Arc::new(Mutex::new(0));
1371 let service = AlwaysFailService {
1372 call_count: call_count.clone(),
1373 };
1374
1375 let retry_config = RetryConfig {
1376 max_retries: 2,
1377 backoff: ExponentialBackoff::fast(),
1378 ..RetryConfig::default()
1379 };
1380 let layer = RetryLayer::new(retry_config);
1381 let mut retry_service = layer.layer(service);
1382
1383 let req = Request::builder()
1384 .method(Method::GET)
1385 .uri("http://example.com")
1386 .body(Full::new(Bytes::new()))
1387 .unwrap();
1388
1389 let result = retry_service.call(req).await;
1390
1391 assert!(result.is_ok());
1393 let resp = result.unwrap();
1394 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
1395
1396 assert_eq!(*call_count.lock().unwrap(), 3);
1398 }
1399
1400 #[tokio::test]
1402 async fn test_retry_layer_non_retryable_status_passes_through() {
1403 use std::sync::{Arc, Mutex};
1404
1405 #[derive(Clone)]
1406 struct NotFoundService {
1407 call_count: Arc<Mutex<usize>>,
1408 }
1409
1410 impl Service<Request<Full<Bytes>>> for NotFoundService {
1411 type Response = Response<ResponseBody>;
1412 type Error = HttpError;
1413 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1414
1415 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1416 Poll::Ready(Ok(()))
1417 }
1418
1419 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
1420 let count = self.call_count.clone();
1421 Box::pin(async move {
1422 *count.lock().unwrap() += 1;
1423 Ok(Response::builder()
1424 .status(StatusCode::NOT_FOUND)
1425 .body(make_response_body(b"not found"))
1426 .unwrap())
1427 })
1428 }
1429 }
1430
1431 let call_count = Arc::new(Mutex::new(0));
1432 let service = NotFoundService {
1433 call_count: call_count.clone(),
1434 };
1435
1436 let retry_config = RetryConfig {
1437 max_retries: 3,
1438 backoff: ExponentialBackoff::fast(),
1439 ..RetryConfig::default()
1440 };
1441 let layer = RetryLayer::new(retry_config);
1442 let mut retry_service = layer.layer(service);
1443
1444 let req = Request::builder()
1445 .method(Method::GET)
1446 .uri("http://example.com")
1447 .body(Full::new(Bytes::new()))
1448 .unwrap();
1449
1450 let result = retry_service.call(req).await;
1451
1452 assert!(result.is_ok());
1454 let resp = result.unwrap();
1455 assert_eq!(resp.status(), StatusCode::NOT_FOUND);
1456
1457 assert_eq!(*call_count.lock().unwrap(), 1);
1459 }
1460}