use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use tokio::time::sleep;
use tower::{Layer, Service};
use crate::error::WechatError;
use crate::utils::jittered_delay;
#[derive(Clone)]
pub struct RetryMiddleware {
max_retries: usize,
delay_ms: u64,
retry_post: bool,
}
impl RetryMiddleware {
pub fn new() -> Self {
Self {
max_retries: 3,
delay_ms: 100,
retry_post: false,
}
}
pub fn with_max_retries(mut self, max: usize) -> Self {
self.max_retries = max;
self
}
pub fn with_delay_ms(mut self, delay: u64) -> Self {
self.delay_ms = delay;
self
}
pub fn with_retry_post(mut self, retry: bool) -> Self {
self.retry_post = retry;
self
}
pub fn is_retryable_error(error: &WechatError) -> bool {
error.is_transient()
}
}
impl Default for RetryMiddleware {
fn default() -> Self {
Self::new()
}
}
impl<S> Layer<S> for RetryMiddleware {
type Service = RetryMiddlewareService<S>;
fn layer(&self, inner: S) -> Self::Service {
RetryMiddlewareService {
inner,
max_retries: self.max_retries,
delay_ms: self.delay_ms,
retry_post: self.retry_post,
}
}
}
#[derive(Clone)]
pub struct RetryMiddlewareService<S> {
inner: S,
pub(crate) max_retries: usize,
pub(crate) delay_ms: u64,
pub(crate) retry_post: bool,
}
pub trait RetryableRequest {
fn is_idempotent(&self) -> bool;
}
impl RetryableRequest for reqwest::Request {
fn is_idempotent(&self) -> bool {
!matches!(
self.method(),
&reqwest::Method::POST | &reqwest::Method::PUT | &reqwest::Method::PATCH
)
}
}
impl<S, R> Service<R> for RetryMiddlewareService<S>
where
S: Service<R> + Send + Clone + 'static,
S::Future: Send,
S::Error: std::fmt::Debug + Send + 'static,
S::Response: Send,
R: Send + Clone + RetryableRequest + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: R) -> Self::Future {
let mut inner = self.inner.clone();
let max_retries = self.max_retries;
let delay_ms = self.delay_ms;
let retry_post = self.retry_post;
Box::pin(async move {
let attempts = max_retries.max(1);
let mut last_error: Option<S::Error> = None;
if !req.is_idempotent() && !retry_post {
return inner.call(req).await;
}
for attempt in 0..attempts {
let req_clone = req.clone();
match inner.call(req_clone).await {
Ok(response) => return Ok(response),
Err(e) => {
let is_retryable = check_error_retryable(&e);
if is_retryable {
last_error = Some(e);
if attempt < attempts - 1 {
sleep(jittered_delay(
delay_ms,
u32::try_from(attempt).unwrap_or(u32::MAX),
))
.await;
}
} else {
return Err(e);
}
}
}
}
Err(last_error.expect("retry loop completed without capturing an error"))
})
}
}
fn check_error_retryable<E: std::fmt::Debug + 'static>(error: &E) -> bool {
if let Some(wechat_err) = (error as &dyn std::any::Any).downcast_ref::<WechatError>() {
return RetryMiddleware::is_retryable_error(wechat_err);
}
if let Some(reqwest_err) = (error as &dyn std::any::Any).downcast_ref::<reqwest::Error>() {
return is_retryable_reqwest_error(reqwest_err);
}
false
}
fn is_retryable_reqwest_error(error: &reqwest::Error) -> bool {
match error.status() {
Some(status) => status.is_server_error() || status.as_u16() == 429,
None => true,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::WechatError;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[test]
fn test_retry_middleware_exists() {
let _ = RetryMiddleware::new();
}
#[test]
fn test_is_retryable_error_exhaustive_variants() {
use crate::error::HttpError;
let reqwest_error = reqwest::Client::new().get("http://").build().unwrap_err();
let reqwest_err = WechatError::Http(HttpError::Reqwest(std::sync::Arc::new(reqwest_error)));
assert!(RetryMiddleware::is_retryable_error(&reqwest_err));
let decode_err = WechatError::Http(HttpError::Decode("bad".into()));
assert!(!RetryMiddleware::is_retryable_error(&decode_err));
assert!(RetryMiddleware::is_retryable_error(&WechatError::Api {
code: -1,
message: "busy".into(),
}));
assert!(!RetryMiddleware::is_retryable_error(&WechatError::Api {
code: 40001,
message: "invalid".into(),
}));
let non_retryable: Vec<WechatError> = vec![
WechatError::Json(serde_json::from_str::<String>("bad").unwrap_err()),
WechatError::Token("t".into()),
WechatError::Config("c".into()),
WechatError::Signature("s".into()),
WechatError::Crypto("cr".into()),
WechatError::InvalidAppId("a".into()),
WechatError::InvalidOpenId("o".into()),
WechatError::InvalidAccessToken("at".into()),
WechatError::InvalidAppSecret("as".into()),
WechatError::InvalidSessionKey("sk".into()),
WechatError::InvalidUnionId("u".into()),
];
for err in &non_retryable {
assert!(
!RetryMiddleware::is_retryable_error(err),
"Expected non-retryable: {:?}",
err,
);
}
}
#[test]
fn test_retryable_error_codes() {
let err = WechatError::Api {
code: -1,
message: "System busy".to_string(),
};
assert!(RetryMiddleware::is_retryable_error(&err));
let err = WechatError::Api {
code: 45009,
message: "API limit".to_string(),
};
assert!(RetryMiddleware::is_retryable_error(&err));
let err = WechatError::Api {
code: 40001,
message: "Invalid credential".to_string(),
};
assert!(!RetryMiddleware::is_retryable_error(&err));
}
#[test]
fn test_decode_error_not_retryable() {
use crate::error::HttpError;
let decode_err = WechatError::Http(HttpError::Decode("response decode error".to_string()));
assert!(
!RetryMiddleware::is_retryable_error(&decode_err),
"Decode errors should not be retried",
);
}
#[tokio::test]
async fn test_check_error_retryable_for_reqwest_503_status_error() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/status-503"))
.respond_with(ResponseTemplate::new(503))
.mount(&mock_server)
.await;
let response = reqwest::Client::new()
.get(format!("{}/status-503", mock_server.uri()))
.send()
.await
.unwrap();
let err = response.error_for_status().unwrap_err();
assert!(
check_error_retryable(&err),
"503 status errors should be considered retryable",
);
}
#[tokio::test]
async fn test_check_error_retryable_for_reqwest_400_status_error() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/status-400"))
.respond_with(ResponseTemplate::new(400))
.mount(&mock_server)
.await;
let response = reqwest::Client::new()
.get(format!("{}/status-400", mock_server.uri()))
.send()
.await
.unwrap();
let err = response.error_for_status().unwrap_err();
assert!(
!check_error_retryable(&err),
"400 status errors should not be considered retryable",
);
}
#[test]
fn test_middleware_configuration() {
let middleware = RetryMiddleware::new()
.with_max_retries(5)
.with_delay_ms(200)
.with_retry_post(true);
assert_eq!(middleware.max_retries, 5);
assert_eq!(middleware.delay_ms, 200);
assert!(middleware.retry_post);
}
#[derive(Clone)]
struct MockIdempotentRequest;
impl RetryableRequest for MockIdempotentRequest {
fn is_idempotent(&self) -> bool {
true
}
}
#[derive(Clone)]
struct AlwaysRetryableErrorService;
impl Service<MockIdempotentRequest> for AlwaysRetryableErrorService {
type Response = String;
type Error = WechatError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: MockIdempotentRequest) -> Self::Future {
Box::pin(async {
Err(WechatError::Api {
code: -1,
message: "system busy".to_string(),
})
})
}
}
#[derive(Clone)]
struct AlwaysNonRetryableErrorService;
impl Service<MockIdempotentRequest> for AlwaysNonRetryableErrorService {
type Response = String;
type Error = WechatError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: MockIdempotentRequest) -> Self::Future {
Box::pin(async {
Err(WechatError::Api {
code: 40001,
message: "invalid credential".to_string(),
})
})
}
}
#[derive(Clone)]
struct SuccessService;
impl Service<MockIdempotentRequest> for SuccessService {
type Response = String;
type Error = WechatError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: MockIdempotentRequest) -> Self::Future {
Box::pin(async { Ok("success".to_string()) })
}
}
#[tokio::test]
async fn test_max_retries_zero_should_still_attempt_once() {
let middleware = RetryMiddleware::new().with_max_retries(0);
let mut service = middleware.layer(SuccessService);
let result: Result<String, WechatError> = service.call(MockIdempotentRequest).await;
assert!(result.is_ok(), "max_retries=0 should still execute once");
}
#[tokio::test]
async fn test_non_retryable_error_returns_immediately() {
let middleware = RetryMiddleware::new().with_max_retries(3);
let mut service = middleware.layer(AlwaysNonRetryableErrorService);
let result: Result<String, WechatError> = service.call(MockIdempotentRequest).await;
assert!(result.is_err());
if let Err(WechatError::Api { code, .. }) = &result {
assert_eq!(*code, 40001, "Should return the non-retryable error code");
}
}
#[tokio::test]
async fn test_retryable_error_with_max_retries_one() {
let middleware = RetryMiddleware::new().with_max_retries(1).with_delay_ms(1);
let mut service = middleware.layer(AlwaysRetryableErrorService);
let result: Result<String, WechatError> = service.call(MockIdempotentRequest).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_terminal_failure_all_retries_exhausted() {
let middleware = RetryMiddleware::new().with_max_retries(2).with_delay_ms(1);
let mut service = middleware.layer(AlwaysRetryableErrorService);
let result: Result<String, WechatError> = service.call(MockIdempotentRequest).await;
assert!(result.is_err());
if let Err(e) = &result {
assert!(matches!(e, WechatError::Api { code: -1, .. }));
}
}
#[tokio::test]
async fn test_success_case_no_retry() {
let middleware = RetryMiddleware::new()
.with_max_retries(3)
.with_delay_ms(100);
let mut service = middleware.layer(SuccessService);
let result: Result<String, WechatError> = service.call(MockIdempotentRequest).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
}
#[derive(Clone)]
struct NonIdempotentRequest;
impl RetryableRequest for NonIdempotentRequest {
fn is_idempotent(&self) -> bool {
false
}
}
#[derive(Clone)]
struct NonIdempotentErrorService;
impl Service<NonIdempotentRequest> for NonIdempotentErrorService {
type Response = String;
type Error = WechatError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: NonIdempotentRequest) -> Self::Future {
Box::pin(async {
Err(WechatError::Api {
code: -1,
message: "system busy".to_string(),
})
})
}
}
#[tokio::test]
async fn test_non_idempotent_no_retry() {
let middleware = RetryMiddleware::new()
.with_max_retries(3)
.with_retry_post(false);
let mut service = middleware.layer(NonIdempotentErrorService);
let result: Result<String, WechatError> = service.call(NonIdempotentRequest).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_non_idempotent_with_retry_enabled() {
let middleware = RetryMiddleware::new()
.with_max_retries(2)
.with_delay_ms(1)
.with_retry_post(true);
let mut service = middleware.layer(NonIdempotentErrorService);
let result: Result<String, WechatError> = service.call(NonIdempotentRequest).await;
assert!(result.is_err());
}
}