use http_extensions::HttpError;
use seatbelt::Recovery;
use seatbelt::retry::{Retry, RetryLayer};
use seatbelt::typestates::Set;
use crate::http_recovery::detect_recovery;
use crate::{HttpClone, HttpRecovery, HttpRequest, HttpResponse};
pub type HttpRetryLayer<S1 = Set, S2 = Set> = RetryLayer<HttpRequest, http_extensions::Result<HttpResponse>, S1, S2>;
pub type HttpRetry<S> = Retry<HttpRequest, http_extensions::Result<HttpResponse>, S>;
pub trait HttpRetryLayerExt<S1, S2>: sealed::Sealed {
fn http_configure_defaults(self) -> HttpRetryLayer;
fn http_clone(self, clone_strategy: HttpClone) -> HttpRetryLayer<Set, S2>;
fn http_recovery(self, recovery: impl Into<HttpRecovery>) -> HttpRetryLayer<S1, Set>;
fn http_restore_request(self) -> HttpRetryLayer<S1, S2>;
}
impl<S1, S2> HttpRetryLayerExt<S1, S2> for HttpRetryLayer<S1, S2> {
fn http_configure_defaults(self) -> HttpRetryLayer {
self.http_clone(HttpClone::default())
.http_recovery(HttpRecovery::default())
.http_restore_request()
}
fn http_clone(self, clone_strategy: HttpClone) -> HttpRetryLayer<Set, S2> {
self.clone_input_with(move |request, args| clone_strategy.try_clone(request, args.attempt(), args.previous_recovery()))
}
fn http_recovery(self, recovery: impl Into<HttpRecovery>) -> HttpRetryLayer<S1, Set> {
let recovery = recovery.into();
self.recovery_with(move |out, args| detect_recovery(out, &recovery, args.clock()))
}
fn http_restore_request(self) -> Self {
self.restore_input_from_error(|error, _args| extract_http_request(error))
}
}
pub(crate) mod sealed {
use super::*;
#[expect(unnameable_types, reason = "intentional, sealed trait pattern")]
pub trait Sealed {}
impl<S1, S2> Sealed for HttpRetryLayer<S1, S2> {}
}
fn extract_http_request(error: &mut HttpError) -> Option<HttpRequest> {
if error.recovery().kind() != seatbelt::RecoveryKind::Unavailable {
return None;
}
error.take_request()
}
#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Mutex};
use futures::executor::block_on;
use http::{Method, StatusCode};
use http_extensions::routing::{BaseUriConflict, Router};
use http_extensions::{FakeHandler, HttpRequestBuilder, HttpResponseBuilder};
use layered::{Service, Stack};
use seatbelt::Attempt;
use templated_uri::BaseUri;
use tick::ClockControl;
use super::*;
#[test]
fn retry_recovers_with_safe_methods() {
let clock = ClockControl::default().auto_advance_timers(true).to_clock();
let context = crate::HttpResilienceContext::new(&clock);
let service = (
HttpRetry::layer("test", &context).http_configure_defaults(),
FakeHandler::from_status_codes([StatusCode::INTERNAL_SERVER_ERROR, StatusCode::OK]),
)
.into_service();
let request = HttpRequestBuilder::new_fake().uri("https://example.com").build().unwrap();
let response = block_on(service.execute(request)).unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[test]
fn retry_fails_with_unsafe_methods() {
let clock = ClockControl::default().auto_advance_timers(true).to_clock();
let context = crate::HttpResilienceContext::new(&clock);
let service = (
HttpRetry::layer("test", &context).http_configure_defaults(),
FakeHandler::from_status_codes([StatusCode::INTERNAL_SERVER_ERROR, StatusCode::OK]),
)
.into_service();
let request = HttpRequestBuilder::new_fake()
.uri("https://example.com")
.method(Method::POST)
.build()
.unwrap();
let response = block_on(service.execute(request)).unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn restore_request_from_unavailable_error() {
let call_count = Arc::new(AtomicU32::new(0));
let counter = Arc::clone(&call_count);
let handler = FakeHandler::from_fn(move |req| {
let n = counter.fetch_add(1, Ordering::Relaxed);
if n < 2 {
Err(HttpError::unavailable("service down").with_request(req))
} else {
HttpResponseBuilder::new_fake().status(StatusCode::OK).build()
}
});
let clock = ClockControl::default().auto_advance_timers(true).to_clock();
let context = crate::HttpResilienceContext::new(&clock);
let service = (
HttpRetry::layer("test", &context)
.http_configure_defaults()
.handle_unavailable(true)
.max_retry_attempts(2),
handler,
)
.into_service();
let request = HttpRequestBuilder::new_fake()
.method(Method::POST)
.uri("https://example.com")
.build()
.unwrap();
let response = block_on(service.execute(request)).unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(call_count.load(Ordering::Relaxed), 3);
}
#[test]
fn extract_http_request_returns_none_for_non_unavailable_error() {
let request = HttpRequestBuilder::new_fake().uri("https://example.com").build().unwrap();
let mut error = HttpError::other("transient failure", seatbelt::RecoveryInfo::retry(), "test").with_request(request);
assert!(extract_http_request(&mut error).is_none());
assert!(error.take_request().is_some());
}
#[test]
fn retry_routes_attempts_with_custom_router() {
let clock = ClockControl::default().auto_advance_timers(true).to_clock();
let context = crate::HttpResilienceContext::new(&clock);
let captured_uris: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let captured_uris_for_handler = Arc::clone(&captured_uris);
let handler = FakeHandler::from_fn(move |request: HttpRequest| {
captured_uris_for_handler
.lock()
.expect("mutex is only accessed in single-threaded test")
.push(request.uri().to_string());
let attempt = request.extensions().get::<Attempt>().unwrap();
let status = if attempt.is_last() {
StatusCode::OK
} else {
StatusCode::INTERNAL_SERVER_ERROR
};
HttpResponseBuilder::new_fake().status(status).build()
});
let router = Router::custom(
|ctx| {
Some(match ctx.attempt() {
1 => BaseUri::from_static("https://retry-1.example.com"),
_ => BaseUri::from_static("https://retry-2.example.com"),
})
},
true,
)
.conflict_policy(BaseUriConflict::UseRouted);
let service = (
HttpRetry::layer("test", &context).http_configure_defaults().max_retry_attempts(2),
handler,
)
.into_service();
let request = HttpRequestBuilder::new_fake()
.uri("https://primary.example.com/items")
.extension(router)
.build()
.unwrap();
let response = block_on(service.execute(request)).unwrap();
assert_eq!(response.status(), StatusCode::OK);
let uris = captured_uris
.lock()
.expect("mutex is only accessed in single-threaded test")
.clone();
assert_eq!(
uris,
vec![
"https://primary.example.com/items".to_string(),
"https://retry-1.example.com/items".to_string(),
"https://retry-2.example.com/items".to_string(),
],
);
}
}