1use crate::config::{ExponentialBackoff, RetryConfig, RetryTrigger};
2use crate::error::HttpError;
3use crate::response::{ResponseBody, parse_retry_after};
4use bytes::Bytes;
5use http::{HeaderValue, Request, Response};
6use http_body_util::{BodyExt, Full};
7use rand::Rng;
8use std::future::Future;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::time::Duration;
12use tower::{Layer, Service, ServiceExt};
13
14pub const RETRY_ATTEMPT_HEADER: &str = "X-Retry-Attempt";
17
18#[derive(Clone)]
23pub struct RetryLayer {
24 config: RetryConfig,
25 total_timeout: Option<Duration>,
26}
27
28impl RetryLayer {
29 #[must_use]
31 pub fn new(config: RetryConfig) -> Self {
32 Self {
33 config,
34 total_timeout: None,
35 }
36 }
37
38 #[must_use]
40 pub fn with_total_timeout(config: RetryConfig, total_timeout: Option<Duration>) -> Self {
41 Self {
42 config,
43 total_timeout,
44 }
45 }
46}
47
48impl<S> Layer<S> for RetryLayer {
49 type Service = RetryService<S>;
50
51 fn layer(&self, inner: S) -> Self::Service {
52 RetryService {
53 inner,
54 config: self.config.clone(),
55 total_timeout: self.total_timeout,
56 }
57 }
58}
59
60#[derive(Clone)]
75pub struct RetryService<S> {
76 inner: S,
77 config: RetryConfig,
78 total_timeout: Option<Duration>,
79}
80
81impl<S> Service<Request<Full<Bytes>>> for RetryService<S>
82where
83 S: Service<Request<Full<Bytes>>, Response = Response<ResponseBody>, Error = HttpError>
84 + Clone
85 + Send
86 + 'static,
87 S::Future: Send,
88{
89 type Response = S::Response;
90 type Error = HttpError;
91 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
92
93 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
94 self.inner.poll_ready(cx)
95 }
96
97 fn call(&mut self, req: Request<Full<Bytes>>) -> Self::Future {
98 let clone = self.inner.clone();
101 let inner = std::mem::replace(&mut self.inner, clone);
102 let config = self.config.clone();
103 let total_timeout = self.total_timeout;
104
105 let (parts, body_bytes) = req.into_parts();
106
107 let http_version = parts.version;
109
110 let extensions = parts.extensions.clone();
114
115 let has_idempotency_key = config
118 .idempotency_key_header
119 .as_ref()
120 .is_some_and(|name| parts.headers.contains_key(name));
121
122 let parts = std::sync::Arc::new(parts);
123
124 Box::pin(async move {
125 let method = parts.method.clone();
126
127 let url_host = parts
130 .uri
131 .authority()
132 .map(ToString::to_string)
133 .or_else(|| parts.uri.host().map(ToOwned::to_owned))
134 .unwrap_or_else(|| "unknown".to_owned());
135 let request_id = parts
136 .headers
137 .get("x-request-id")
138 .or_else(|| parts.headers.get("x-correlation-id"))
139 .and_then(|v| v.to_str().ok())
140 .map(String::from);
141
142 let deadline_info = total_timeout.map(|t| (tokio::time::Instant::now() + t, t));
145
146 let mut attempt = 0usize;
147 loop {
148 if let Some((deadline, timeout_duration)) = deadline_info
150 && tokio::time::Instant::now() >= deadline
151 {
152 return Err(HttpError::DeadlineExceeded(timeout_duration));
153 }
154
155 let mut req = Request::from_parts((*parts).clone(), body_bytes.clone());
157
158 *req.version_mut() = http_version;
160
161 *req.extensions_mut() = extensions.clone();
164
165 if attempt > 0 {
167 if let Ok(value) = HeaderValue::try_from(attempt.to_string()) {
169 req.headers_mut().insert(RETRY_ATTEMPT_HEADER, value);
170 }
171 }
172
173 let mut svc = inner.clone();
174 svc.ready().await?;
175
176 match svc.call(req).await {
177 Ok(resp) => {
178 let status_code = resp.status().as_u16();
180 let trigger = RetryTrigger::Status(status_code);
181
182 if config.max_retries > 0
183 && attempt < config.max_retries
184 && config.should_retry(trigger, &method, has_idempotency_key)
185 {
186 let retry_after = parse_retry_after(resp.headers());
188 let backoff_duration = if config.ignore_retry_after {
189 calculate_backoff(&config.backoff, attempt)
190 } else {
191 retry_after
192 .unwrap_or_else(|| calculate_backoff(&config.backoff, attempt))
193 };
194
195 let drain_limit = config.retry_response_drain_limit;
197 let should_drain = if config.skip_drain_on_retry {
198 tracing::trace!("Skipping drain: skip_drain_on_retry enabled");
200 false
201 } else if let Some(content_length) = resp
202 .headers()
203 .get(http::header::CONTENT_LENGTH)
204 .and_then(|v| v.to_str().ok())
205 .and_then(|s| s.parse::<u64>().ok())
206 {
207 if content_length > drain_limit as u64 {
208 tracing::debug!(
211 content_length,
212 drain_limit,
213 "Skipping drain: Content-Length exceeds limit"
214 );
215 false
216 } else {
217 true
218 }
219 } else {
220 true
222 };
223
224 if should_drain
225 && let Err(e) = drain_response_body(resp, drain_limit).await
226 {
227 tracing::debug!(
229 error = %e,
230 "Failed to drain response body before retry; connection may not be reused"
231 );
232 }
233
234 let effective_backoff =
236 if let Some((deadline, timeout_duration)) = deadline_info {
237 let remaining = deadline
238 .saturating_duration_since(tokio::time::Instant::now());
239 if remaining.is_zero() {
240 return Err(HttpError::DeadlineExceeded(timeout_duration));
241 }
242 backoff_duration.min(remaining)
243 } else {
244 backoff_duration
245 };
246
247 tracing::debug!(
248 retry = attempt + 1,
249 max_retries = config.max_retries,
250 status = status_code,
251 trigger = ?trigger,
252 method = %method,
253 host = %url_host,
254 request_id = ?request_id,
255 backoff_ms = effective_backoff.as_millis(),
256 retry_after_used = retry_after.is_some() && !config.ignore_retry_after,
257 "Retrying request after status code"
258 );
259 tokio::time::sleep(effective_backoff).await;
260 attempt += 1;
261 continue;
262 }
263
264 return Ok(resp);
266 }
267 Err(err) => {
268 if config.max_retries == 0 || attempt >= config.max_retries {
269 return Err(err);
270 }
271
272 let trigger = get_retry_trigger(&err);
273 if !config.should_retry(trigger, &method, has_idempotency_key) {
274 return Err(err);
275 }
276
277 let backoff_duration = calculate_backoff(&config.backoff, attempt);
279
280 let effective_backoff =
282 if let Some((deadline, timeout_duration)) = deadline_info {
283 let remaining =
284 deadline.saturating_duration_since(tokio::time::Instant::now());
285 if remaining.is_zero() {
286 return Err(HttpError::DeadlineExceeded(timeout_duration));
287 }
288 backoff_duration.min(remaining)
289 } else {
290 backoff_duration
291 };
292
293 tracing::debug!(
294 retry = attempt + 1,
295 max_retries = config.max_retries,
296 error = %err,
297 trigger = ?trigger,
298 method = %method,
299 host = %url_host,
300 request_id = ?request_id,
301 backoff_ms = effective_backoff.as_millis(),
302 "Retrying request after error"
303 );
304 tokio::time::sleep(effective_backoff).await;
305 attempt += 1;
306 }
307 }
308 }
309 })
310 }
311}
312
313async fn drain_response_body(
337 response: Response<ResponseBody>,
338 limit: usize,
339) -> Result<(), HttpError> {
340 let (_parts, body) = response.into_parts();
341 let mut body = std::pin::pin!(body);
342 let mut drained = 0usize;
343
344 while let Some(frame) = body.frame().await {
345 let frame = frame.map_err(HttpError::Transport)?;
346 if let Some(chunk) = frame.data_ref() {
347 drained += chunk.len();
348 if drained >= limit {
349 break;
351 }
352 }
353 }
354
355 Ok(())
356}
357
358fn get_retry_trigger(err: &HttpError) -> RetryTrigger {
360 match err {
361 HttpError::Transport(_) => RetryTrigger::TransportError,
362 HttpError::Timeout(_) => RetryTrigger::Timeout,
363 _ => RetryTrigger::NonRetryable,
365 }
366}
367
368pub fn calculate_backoff(backoff: &ExponentialBackoff, attempt: usize) -> Duration {
372 const MAX_BACKOFF_SECS: f64 = 86400.0;
374
375 let attempt_i32 = i32::try_from(attempt).unwrap_or(i32::MAX);
378
379 let multiplier = if backoff.multiplier.is_finite() && backoff.multiplier >= 0.0 {
381 backoff.multiplier
382 } else {
383 1.0
384 };
385
386 let initial_secs = backoff.initial.as_secs_f64();
388 let initial_secs = if initial_secs.is_finite() && initial_secs >= 0.0 {
389 initial_secs
390 } else {
391 0.0
392 };
393
394 let max_secs = backoff.max.as_secs_f64();
396 let max_secs = if max_secs.is_finite() && max_secs >= 0.0 {
397 max_secs.min(MAX_BACKOFF_SECS)
398 } else {
399 MAX_BACKOFF_SECS
400 };
401
402 let base_duration = initial_secs * multiplier.powi(attempt_i32);
404
405 let clamped = if base_duration.is_finite() {
407 base_duration.min(max_secs).max(0.0)
408 } else {
409 max_secs
410 };
411 let duration = Duration::from_secs_f64(clamped);
412
413 let duration = if backoff.jitter {
415 let mut rng = rand::rng();
416 let jitter_factor = rng.random_range(0.0..=0.25);
417 let jitter = duration.mul_f64(jitter_factor);
418 duration + jitter
419 } else {
420 duration
421 };
422
423 let max_duration = Duration::from_secs_f64(max_secs);
425 duration.min(max_duration)
426}
427
428#[cfg(test)]
429#[cfg_attr(coverage_nightly, coverage(off))]
430mod tests {
431 use super::*;
432 use crate::config::IDEMPOTENCY_KEY_HEADER;
433 use bytes::Bytes;
434 use http::{Method, Request, Response, StatusCode};
435 use http_body_util::Full;
436
437 fn make_response_body(data: &[u8]) -> ResponseBody {
439 let body = Full::new(Bytes::from(data.to_vec()));
440 body.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })
441 .boxed()
442 }
443
444 #[tokio::test]
445 async fn test_retry_layer_successful_request() {
446 use std::sync::{Arc, Mutex};
447
448 #[derive(Clone)]
449 struct CountingService {
450 call_count: Arc<Mutex<usize>>,
451 }
452
453 impl Service<Request<Full<Bytes>>> for CountingService {
454 type Response = Response<ResponseBody>;
455 type Error = HttpError;
456 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
457
458 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
459 Poll::Ready(Ok(()))
460 }
461
462 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
463 let count = self.call_count.clone();
464 Box::pin(async move {
465 *count.lock().unwrap() += 1;
466 let response = Response::builder()
467 .status(StatusCode::OK)
468 .body(make_response_body(b""))
469 .unwrap();
470 Ok(response)
471 })
472 }
473 }
474
475 let call_count = Arc::new(Mutex::new(0));
476 let service = CountingService {
477 call_count: call_count.clone(),
478 };
479
480 let retry_config = RetryConfig::default();
481 let layer = RetryLayer::new(retry_config);
482 let mut retry_service = layer.layer(service);
483
484 let req = Request::builder()
485 .method(Method::GET)
486 .uri("http://example.com")
487 .body(Full::new(Bytes::new()))
488 .unwrap();
489
490 let result = retry_service.call(req).await;
491 assert!(result.is_ok());
492 assert_eq!(*call_count.lock().unwrap(), 1); }
494
495 #[tokio::test]
498 async fn test_retry_layer_post_not_retried_on_5xx() {
499 use std::sync::{Arc, Mutex};
500
501 #[derive(Clone)]
502 struct ServerErrorService {
503 call_count: Arc<Mutex<usize>>,
504 }
505
506 impl Service<Request<Full<Bytes>>> for ServerErrorService {
507 type Response = Response<ResponseBody>;
508 type Error = HttpError;
509 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
510
511 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
512 Poll::Ready(Ok(()))
513 }
514
515 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
516 let count = self.call_count.clone();
517 Box::pin(async move {
518 *count.lock().unwrap() += 1;
519 Ok(Response::builder()
521 .status(StatusCode::INTERNAL_SERVER_ERROR)
522 .body(make_response_body(b"Internal Server Error"))
523 .unwrap())
524 })
525 }
526 }
527
528 let call_count = Arc::new(Mutex::new(0));
529 let service = ServerErrorService {
530 call_count: call_count.clone(),
531 };
532
533 let retry_config = RetryConfig {
534 backoff: ExponentialBackoff::fast(),
535 ..RetryConfig::default()
536 };
537 let layer = RetryLayer::new(retry_config);
538 let mut retry_service = layer.layer(service);
539
540 let req = Request::builder()
541 .method(Method::POST)
542 .uri("http://example.com")
543 .body(Full::new(Bytes::new()))
544 .unwrap();
545
546 let result = retry_service.call(req).await;
547 assert!(result.is_ok());
549 let resp = result.unwrap();
550 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
551 assert_eq!(*call_count.lock().unwrap(), 1); }
553
554 #[tokio::test]
557 async fn test_retry_layer_get_retried_on_5xx() {
558 use std::sync::{Arc, Mutex};
559
560 #[derive(Clone)]
561 struct FailThenSucceedService {
562 call_count: Arc<Mutex<usize>>,
563 }
564
565 impl Service<Request<Full<Bytes>>> for FailThenSucceedService {
566 type Response = Response<ResponseBody>;
567 type Error = HttpError;
568 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
569
570 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
571 Poll::Ready(Ok(()))
572 }
573
574 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
575 let count = self.call_count.clone();
576 Box::pin(async move {
577 let mut c = count.lock().unwrap();
578 *c += 1;
579 if *c < 3 {
580 Ok(Response::builder()
582 .status(StatusCode::INTERNAL_SERVER_ERROR)
583 .body(make_response_body(b"Internal Server Error"))
584 .unwrap())
585 } else {
586 Ok(Response::builder()
587 .status(StatusCode::OK)
588 .body(make_response_body(b""))
589 .unwrap())
590 }
591 })
592 }
593 }
594
595 let call_count = Arc::new(Mutex::new(0));
596 let service = FailThenSucceedService {
597 call_count: call_count.clone(),
598 };
599
600 let retry_config = RetryConfig {
601 backoff: ExponentialBackoff::fast(),
602 ..RetryConfig::default()
603 };
604 let layer = RetryLayer::new(retry_config);
605 let mut retry_service = layer.layer(service);
606
607 let req = Request::builder()
608 .method(Method::GET)
609 .uri("http://example.com")
610 .body(Full::new(Bytes::new()))
611 .unwrap();
612
613 let result = retry_service.call(req).await;
614 assert!(result.is_ok());
615 assert_eq!(result.unwrap().status(), StatusCode::OK);
616 assert_eq!(*call_count.lock().unwrap(), 3); }
618
619 #[tokio::test]
621 async fn test_retry_layer_always_retries_429() {
622 use std::sync::{Arc, Mutex};
623
624 #[derive(Clone)]
625 struct RateLimitThenSucceedService {
626 call_count: Arc<Mutex<usize>>,
627 }
628
629 impl Service<Request<Full<Bytes>>> for RateLimitThenSucceedService {
630 type Response = Response<ResponseBody>;
631 type Error = HttpError;
632 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
633
634 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
635 Poll::Ready(Ok(()))
636 }
637
638 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
639 let count = self.call_count.clone();
640 Box::pin(async move {
641 let mut c = count.lock().unwrap();
642 *c += 1;
643 if *c < 2 {
644 Ok(Response::builder()
646 .status(StatusCode::TOO_MANY_REQUESTS)
647 .body(make_response_body(b"Rate limited"))
648 .unwrap())
649 } else {
650 Ok(Response::builder()
651 .status(StatusCode::OK)
652 .body(make_response_body(b""))
653 .unwrap())
654 }
655 })
656 }
657 }
658
659 let call_count = Arc::new(Mutex::new(0));
660 let service = RateLimitThenSucceedService {
661 call_count: call_count.clone(),
662 };
663
664 let retry_config = RetryConfig {
665 backoff: ExponentialBackoff::fast(),
666 ..RetryConfig::default()
667 };
668 let layer = RetryLayer::new(retry_config);
669 let mut retry_service = layer.layer(service);
670
671 let req = Request::builder()
673 .method(Method::POST)
674 .uri("http://example.com")
675 .body(Full::new(Bytes::new()))
676 .unwrap();
677
678 let result = retry_service.call(req).await;
679 assert!(result.is_ok());
680 assert_eq!(result.unwrap().status(), StatusCode::OK);
681 assert_eq!(*call_count.lock().unwrap(), 2); }
683
684 #[tokio::test]
685 async fn test_retry_layer_retries_transport_errors() {
686 use std::sync::{Arc, Mutex};
687
688 #[derive(Clone)]
689 struct FailThenSucceedService {
690 call_count: Arc<Mutex<usize>>,
691 }
692
693 impl Service<Request<Full<Bytes>>> for FailThenSucceedService {
694 type Response = Response<ResponseBody>;
695 type Error = HttpError;
696 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
697
698 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
699 Poll::Ready(Ok(()))
700 }
701
702 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
703 let count = self.call_count.clone();
704 Box::pin(async move {
705 let mut c = count.lock().unwrap();
706 *c += 1;
707 if *c < 3 {
708 Err(HttpError::Transport(Box::new(std::io::Error::new(
709 std::io::ErrorKind::ConnectionReset,
710 "connection reset",
711 ))))
712 } else {
713 Ok(Response::builder()
714 .status(StatusCode::OK)
715 .body(make_response_body(b""))
716 .unwrap())
717 }
718 })
719 }
720 }
721
722 let call_count = Arc::new(Mutex::new(0));
723 let service = FailThenSucceedService {
724 call_count: call_count.clone(),
725 };
726
727 let retry_config = RetryConfig {
728 backoff: ExponentialBackoff::fast(),
729 ..RetryConfig::default()
730 };
731 let layer = RetryLayer::new(retry_config);
732 let mut retry_service = layer.layer(service);
733
734 let req = Request::builder()
735 .method(Method::GET)
736 .uri("http://example.com")
737 .body(Full::new(Bytes::new()))
738 .unwrap();
739
740 let result = retry_service.call(req).await;
741 assert!(result.is_ok());
742 assert_eq!(*call_count.lock().unwrap(), 3); }
744
745 #[tokio::test]
747 async fn test_retry_layer_post_not_retried_on_transport_error() {
748 use std::sync::{Arc, Mutex};
749
750 #[derive(Clone)]
751 struct TransportErrorService {
752 call_count: Arc<Mutex<usize>>,
753 }
754
755 impl Service<Request<Full<Bytes>>> for TransportErrorService {
756 type Response = Response<ResponseBody>;
757 type Error = HttpError;
758 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
759
760 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
761 Poll::Ready(Ok(()))
762 }
763
764 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
765 let count = self.call_count.clone();
766 Box::pin(async move {
767 *count.lock().unwrap() += 1;
768 Err(HttpError::Transport(Box::new(std::io::Error::new(
769 std::io::ErrorKind::ConnectionReset,
770 "connection reset",
771 ))))
772 })
773 }
774 }
775
776 let call_count = Arc::new(Mutex::new(0));
777 let service = TransportErrorService {
778 call_count: call_count.clone(),
779 };
780
781 let retry_config = RetryConfig {
782 backoff: ExponentialBackoff::fast(),
783 ..RetryConfig::default()
784 };
785 let layer = RetryLayer::new(retry_config);
786 let mut retry_service = layer.layer(service);
787
788 let req = Request::builder()
790 .method(Method::POST)
791 .uri("http://example.com")
792 .body(Full::new(Bytes::new()))
793 .unwrap();
794
795 let result = retry_service.call(req).await;
796 assert!(result.is_err()); assert_eq!(*call_count.lock().unwrap(), 1); }
799
800 #[tokio::test]
802 async fn test_retry_layer_post_with_idempotency_key_retried() {
803 use std::sync::{Arc, Mutex};
804
805 #[derive(Clone)]
806 struct FailThenSucceedService {
807 call_count: Arc<Mutex<usize>>,
808 }
809
810 impl Service<Request<Full<Bytes>>> for FailThenSucceedService {
811 type Response = Response<ResponseBody>;
812 type Error = HttpError;
813 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
814
815 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
816 Poll::Ready(Ok(()))
817 }
818
819 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
820 let count = self.call_count.clone();
821 Box::pin(async move {
822 let mut c = count.lock().unwrap();
823 *c += 1;
824 if *c < 3 {
825 Err(HttpError::Transport(Box::new(std::io::Error::new(
826 std::io::ErrorKind::ConnectionReset,
827 "connection reset",
828 ))))
829 } else {
830 Ok(Response::builder()
831 .status(StatusCode::OK)
832 .body(make_response_body(b""))
833 .unwrap())
834 }
835 })
836 }
837 }
838
839 let call_count = Arc::new(Mutex::new(0));
840 let service = FailThenSucceedService {
841 call_count: call_count.clone(),
842 };
843
844 let retry_config = RetryConfig {
845 backoff: ExponentialBackoff::fast(),
846 ..RetryConfig::default()
847 };
848 let layer = RetryLayer::new(retry_config);
849 let mut retry_service = layer.layer(service);
850
851 let req = Request::builder()
853 .method(Method::POST)
854 .uri("http://example.com")
855 .header(IDEMPOTENCY_KEY_HEADER, "unique-key-123")
856 .body(Full::new(Bytes::new()))
857 .unwrap();
858
859 let result = retry_service.call(req).await;
860 assert!(result.is_ok()); assert_eq!(*call_count.lock().unwrap(), 3); }
863
864 #[tokio::test]
865 async fn test_retry_layer_does_not_retry_json_errors() {
866 use std::sync::{Arc, Mutex};
867
868 #[derive(Clone)]
869 struct JsonErrorService {
870 call_count: Arc<Mutex<usize>>,
871 }
872
873 impl Service<Request<Full<Bytes>>> for JsonErrorService {
874 type Response = Response<ResponseBody>;
875 type Error = HttpError;
876 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
877
878 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
879 Poll::Ready(Ok(()))
880 }
881
882 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
883 let count = self.call_count.clone();
884 Box::pin(async move {
885 *count.lock().unwrap() += 1;
886 let err: serde_json::Error =
888 serde_json::from_str::<serde_json::Value>("invalid").unwrap_err();
889 Err(HttpError::Json(err))
890 })
891 }
892 }
893
894 let call_count = Arc::new(Mutex::new(0));
895 let service = JsonErrorService {
896 call_count: call_count.clone(),
897 };
898
899 let retry_config = RetryConfig::default();
900 let layer = RetryLayer::new(retry_config);
901 let mut retry_service = layer.layer(service);
902
903 let req = Request::builder()
904 .method(Method::GET)
905 .uri("http://example.com")
906 .body(Full::new(Bytes::new()))
907 .unwrap();
908
909 let result = retry_service.call(req).await;
910 assert!(result.is_err());
911 assert_eq!(*call_count.lock().unwrap(), 1); }
913
914 #[test]
915 fn test_calculate_backoff_no_jitter() {
916 let backoff = ExponentialBackoff {
917 initial: Duration::from_millis(100),
918 max: Duration::from_secs(10),
919 multiplier: 2.0,
920 jitter: false,
921 };
922
923 let backoff0 = calculate_backoff(&backoff, 0);
924 assert_eq!(backoff0, Duration::from_millis(100));
925
926 let backoff1 = calculate_backoff(&backoff, 1);
927 assert_eq!(backoff1, Duration::from_millis(200));
928
929 let backoff2 = calculate_backoff(&backoff, 2);
930 assert_eq!(backoff2, Duration::from_millis(400));
931
932 let backoff_capped = calculate_backoff(&backoff, 10);
934 assert_eq!(backoff_capped, Duration::from_secs(10));
935 }
936
937 #[test]
938 fn test_calculate_backoff_with_jitter() {
939 let backoff = ExponentialBackoff {
940 initial: Duration::from_millis(100),
941 max: Duration::from_secs(10),
942 multiplier: 2.0,
943 jitter: true,
944 };
945
946 let backoff0 = calculate_backoff(&backoff, 0);
947 assert!(backoff0 >= Duration::from_millis(100));
949 assert!(backoff0 <= Duration::from_millis(125));
950 }
951
952 #[test]
953 fn test_calculate_backoff_with_nan_multiplier() {
954 let backoff = ExponentialBackoff {
956 initial: Duration::from_millis(100),
957 max: Duration::from_secs(10),
958 multiplier: f64::NAN,
959 jitter: false,
960 };
961
962 let result = calculate_backoff(&backoff, 0);
964 assert_eq!(result, Duration::from_millis(100));
965
966 let result1 = calculate_backoff(&backoff, 1);
967 assert_eq!(result1, Duration::from_millis(100));
969 }
970
971 #[test]
972 fn test_calculate_backoff_with_infinity_multiplier() {
973 let backoff = ExponentialBackoff {
975 initial: Duration::from_millis(100),
976 max: Duration::from_secs(10),
977 multiplier: f64::INFINITY,
978 jitter: false,
979 };
980
981 let result = calculate_backoff(&backoff, 0);
983 assert_eq!(result, Duration::from_millis(100));
984 }
985
986 #[test]
987 fn test_calculate_backoff_with_negative_multiplier() {
988 let backoff = ExponentialBackoff {
990 initial: Duration::from_millis(100),
991 max: Duration::from_secs(10),
992 multiplier: -2.0,
993 jitter: false,
994 };
995
996 let result = calculate_backoff(&backoff, 0);
998 assert_eq!(result, Duration::from_millis(100));
999 }
1000
1001 #[test]
1002 fn test_calculate_backoff_with_huge_attempt() {
1003 let backoff = ExponentialBackoff {
1005 initial: Duration::from_millis(100),
1006 max: Duration::from_secs(10),
1007 multiplier: 2.0,
1008 jitter: false,
1009 };
1010
1011 let result = calculate_backoff(&backoff, usize::MAX);
1013 assert_eq!(result, Duration::from_secs(10));
1015 }
1016
1017 #[tokio::test]
1019 async fn test_retry_layer_uses_retry_after_header() {
1020 use std::sync::{Arc, Mutex};
1021
1022 #[derive(Clone)]
1023 struct RetryAfterService {
1024 call_count: Arc<Mutex<usize>>,
1025 }
1026
1027 impl Service<Request<Full<Bytes>>> for RetryAfterService {
1028 type Response = Response<ResponseBody>;
1029 type Error = HttpError;
1030 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1031
1032 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1033 Poll::Ready(Ok(()))
1034 }
1035
1036 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
1037 let count = self.call_count.clone();
1038 Box::pin(async move {
1039 let mut c = count.lock().unwrap();
1040 *c += 1;
1041 if *c < 2 {
1042 Ok(Response::builder()
1044 .status(StatusCode::TOO_MANY_REQUESTS)
1045 .header(http::header::RETRY_AFTER, "0")
1046 .body(make_response_body(b"Rate limited"))
1047 .unwrap())
1048 } else {
1049 Ok(Response::builder()
1050 .status(StatusCode::OK)
1051 .body(make_response_body(b""))
1052 .unwrap())
1053 }
1054 })
1055 }
1056 }
1057
1058 let call_count = Arc::new(Mutex::new(0));
1059 let service = RetryAfterService {
1060 call_count: call_count.clone(),
1061 };
1062
1063 let retry_config = RetryConfig {
1064 backoff: ExponentialBackoff {
1065 initial: Duration::from_secs(10), jitter: false,
1067 ..ExponentialBackoff::default()
1068 },
1069 ignore_retry_after: false, ..RetryConfig::default()
1071 };
1072 let layer = RetryLayer::new(retry_config);
1073 let mut retry_service = layer.layer(service);
1074
1075 let req = Request::builder()
1076 .method(Method::POST)
1077 .uri("http://example.com")
1078 .body(Full::new(Bytes::new()))
1079 .unwrap();
1080
1081 let start = std::time::Instant::now();
1082 let result = retry_service.call(req).await;
1083 let elapsed = start.elapsed();
1084
1085 assert!(result.is_ok());
1086 assert_eq!(*call_count.lock().unwrap(), 2);
1087
1088 assert!(
1090 elapsed < Duration::from_secs(1),
1091 "Expected quick retry using Retry-After, but took {elapsed:?}",
1092 );
1093 }
1094
1095 #[tokio::test]
1097 async fn test_retry_layer_ignores_retry_after_when_configured() {
1098 use std::sync::{Arc, Mutex};
1099
1100 #[derive(Clone)]
1101 struct RetryAfterService {
1102 call_count: Arc<Mutex<usize>>,
1103 }
1104
1105 impl Service<Request<Full<Bytes>>> for RetryAfterService {
1106 type Response = Response<ResponseBody>;
1107 type Error = HttpError;
1108 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1109
1110 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1111 Poll::Ready(Ok(()))
1112 }
1113
1114 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
1115 let count = self.call_count.clone();
1116 Box::pin(async move {
1117 let mut c = count.lock().unwrap();
1118 *c += 1;
1119 if *c < 2 {
1120 Ok(Response::builder()
1122 .status(StatusCode::TOO_MANY_REQUESTS)
1123 .header(http::header::RETRY_AFTER, "10")
1124 .body(make_response_body(b"Rate limited"))
1125 .unwrap())
1126 } else {
1127 Ok(Response::builder()
1128 .status(StatusCode::OK)
1129 .body(make_response_body(b""))
1130 .unwrap())
1131 }
1132 })
1133 }
1134 }
1135
1136 let call_count = Arc::new(Mutex::new(0));
1137 let service = RetryAfterService {
1138 call_count: call_count.clone(),
1139 };
1140
1141 let retry_config = RetryConfig {
1142 backoff: ExponentialBackoff::fast(), ignore_retry_after: true, ..RetryConfig::default()
1145 };
1146 let layer = RetryLayer::new(retry_config);
1147 let mut retry_service = layer.layer(service);
1148
1149 let req = Request::builder()
1150 .method(Method::POST)
1151 .uri("http://example.com")
1152 .body(Full::new(Bytes::new()))
1153 .unwrap();
1154
1155 let start = std::time::Instant::now();
1156 let result = retry_service.call(req).await;
1157 let elapsed = start.elapsed();
1158
1159 assert!(result.is_ok());
1160 assert_eq!(*call_count.lock().unwrap(), 2);
1161
1162 assert!(
1164 elapsed < Duration::from_secs(1),
1165 "Expected quick retry using backoff policy (1ms), but took {elapsed:?}",
1166 );
1167 }
1168
1169 #[tokio::test]
1170 async fn test_retry_attempt_header_added_on_retry() {
1171 use std::sync::{Arc, Mutex};
1172
1173 #[derive(Clone)]
1174 struct HeaderCapturingService {
1175 call_count: Arc<Mutex<usize>>,
1176 captured_headers: Arc<Mutex<Vec<Option<String>>>>,
1177 }
1178
1179 impl Service<Request<Full<Bytes>>> for HeaderCapturingService {
1180 type Response = Response<ResponseBody>;
1181 type Error = HttpError;
1182 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1183
1184 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1185 Poll::Ready(Ok(()))
1186 }
1187
1188 fn call(&mut self, req: Request<Full<Bytes>>) -> Self::Future {
1189 let count = self.call_count.clone();
1190 let captured_headers = self.captured_headers.clone();
1191
1192 let retry_header = req
1194 .headers()
1195 .get(RETRY_ATTEMPT_HEADER)
1196 .map(|v| v.to_str().unwrap_or("invalid").to_owned());
1197
1198 Box::pin(async move {
1199 let mut c = count.lock().unwrap();
1200 *c += 1;
1201 captured_headers.lock().unwrap().push(retry_header);
1202
1203 if *c < 3 {
1204 Err(HttpError::Transport(Box::new(std::io::Error::new(
1206 std::io::ErrorKind::ConnectionReset,
1207 "connection reset",
1208 ))))
1209 } else {
1210 Ok(Response::builder()
1211 .status(StatusCode::OK)
1212 .body(make_response_body(b""))
1213 .unwrap())
1214 }
1215 })
1216 }
1217 }
1218
1219 let call_count = Arc::new(Mutex::new(0));
1220 let captured_headers = Arc::new(Mutex::new(Vec::new()));
1221 let service = HeaderCapturingService {
1222 call_count: call_count.clone(),
1223 captured_headers: captured_headers.clone(),
1224 };
1225
1226 let retry_config = RetryConfig {
1227 backoff: ExponentialBackoff::fast(),
1228 ..RetryConfig::default()
1229 };
1230 let layer = RetryLayer::new(retry_config);
1231 let mut retry_service = layer.layer(service);
1232
1233 let req = Request::builder()
1234 .method(Method::GET)
1235 .uri("http://example.com")
1236 .body(Full::new(Bytes::new()))
1237 .unwrap();
1238
1239 let result = retry_service.call(req).await;
1240 assert!(result.is_ok());
1241 assert_eq!(*call_count.lock().unwrap(), 3);
1242
1243 let headers = captured_headers.lock().unwrap();
1245 assert_eq!(headers.len(), 3);
1246 assert_eq!(headers[0], None);
1248 assert_eq!(headers[1], Some("1".to_owned()));
1250 assert_eq!(headers[2], Some("2".to_owned()));
1252 }
1253
1254 #[tokio::test]
1256 async fn test_retry_layer_exhausted_returns_ok_with_status() {
1257 use std::sync::{Arc, Mutex};
1258
1259 #[derive(Clone)]
1260 struct AlwaysFailService {
1261 call_count: Arc<Mutex<usize>>,
1262 }
1263
1264 impl Service<Request<Full<Bytes>>> for AlwaysFailService {
1265 type Response = Response<ResponseBody>;
1266 type Error = HttpError;
1267 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1268
1269 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1270 Poll::Ready(Ok(()))
1271 }
1272
1273 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
1274 let count = self.call_count.clone();
1275 Box::pin(async move {
1276 *count.lock().unwrap() += 1;
1277 Ok(Response::builder()
1279 .status(StatusCode::INTERNAL_SERVER_ERROR)
1280 .body(make_response_body(b"error"))
1281 .unwrap())
1282 })
1283 }
1284 }
1285
1286 let call_count = Arc::new(Mutex::new(0));
1287 let service = AlwaysFailService {
1288 call_count: call_count.clone(),
1289 };
1290
1291 let retry_config = RetryConfig {
1292 max_retries: 2,
1293 backoff: ExponentialBackoff::fast(),
1294 ..RetryConfig::default()
1295 };
1296 let layer = RetryLayer::new(retry_config);
1297 let mut retry_service = layer.layer(service);
1298
1299 let req = Request::builder()
1300 .method(Method::GET)
1301 .uri("http://example.com")
1302 .body(Full::new(Bytes::new()))
1303 .unwrap();
1304
1305 let result = retry_service.call(req).await;
1306
1307 assert!(result.is_ok());
1309 let resp = result.unwrap();
1310 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
1311
1312 assert_eq!(*call_count.lock().unwrap(), 3);
1314 }
1315
1316 #[tokio::test]
1318 async fn test_retry_layer_non_retryable_status_passes_through() {
1319 use std::sync::{Arc, Mutex};
1320
1321 #[derive(Clone)]
1322 struct NotFoundService {
1323 call_count: Arc<Mutex<usize>>,
1324 }
1325
1326 impl Service<Request<Full<Bytes>>> for NotFoundService {
1327 type Response = Response<ResponseBody>;
1328 type Error = HttpError;
1329 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1330
1331 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1332 Poll::Ready(Ok(()))
1333 }
1334
1335 fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
1336 let count = self.call_count.clone();
1337 Box::pin(async move {
1338 *count.lock().unwrap() += 1;
1339 Ok(Response::builder()
1340 .status(StatusCode::NOT_FOUND)
1341 .body(make_response_body(b"not found"))
1342 .unwrap())
1343 })
1344 }
1345 }
1346
1347 let call_count = Arc::new(Mutex::new(0));
1348 let service = NotFoundService {
1349 call_count: call_count.clone(),
1350 };
1351
1352 let retry_config = RetryConfig {
1353 max_retries: 3,
1354 backoff: ExponentialBackoff::fast(),
1355 ..RetryConfig::default()
1356 };
1357 let layer = RetryLayer::new(retry_config);
1358 let mut retry_service = layer.layer(service);
1359
1360 let req = Request::builder()
1361 .method(Method::GET)
1362 .uri("http://example.com")
1363 .body(Full::new(Bytes::new()))
1364 .unwrap();
1365
1366 let result = retry_service.call(req).await;
1367
1368 assert!(result.is_ok());
1370 let resp = result.unwrap();
1371 assert_eq!(resp.status(), StatusCode::NOT_FOUND);
1372
1373 assert_eq!(*call_count.lock().unwrap(), 1);
1375 }
1376}