use super::context::QueryContext;
use super::types::{BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse};
use crate::QueryError;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f64,
pub jitter: bool,
pub retry_on: RetryPredicate,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(10),
backoff_multiplier: 2.0,
jitter: true,
retry_on: RetryPredicate::Default,
}
}
}
impl RetryConfig {
pub fn new() -> Self {
Self::default()
}
pub fn max_retries(mut self, n: u32) -> Self {
self.max_retries = n;
self
}
pub fn initial_delay(mut self, delay: Duration) -> Self {
self.initial_delay = delay;
self
}
pub fn max_delay(mut self, delay: Duration) -> Self {
self.max_delay = delay;
self
}
pub fn backoff_multiplier(mut self, multiplier: f64) -> Self {
self.backoff_multiplier = multiplier;
self
}
pub fn jitter(mut self, enabled: bool) -> Self {
self.jitter = enabled;
self
}
pub fn retry_on(mut self, predicate: RetryPredicate) -> Self {
self.retry_on = predicate;
self
}
pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
let base_delay =
self.initial_delay.as_millis() as f64 * self.backoff_multiplier.powi(attempt as i32);
let delay_ms = base_delay.min(self.max_delay.as_millis() as f64);
let final_delay = if self.jitter {
let jitter = delay_ms * 0.25 * rand_jitter();
delay_ms + jitter
} else {
delay_ms
};
Duration::from_millis(final_delay as u64)
}
}
#[derive(Debug, Clone)]
pub enum RetryPredicate {
Default,
Always,
Never,
ConnectionOnly,
TimeoutOnly,
Custom(Vec<RetryableError>),
}
impl RetryPredicate {
pub fn should_retry(&self, error: &QueryError) -> bool {
match self {
Self::Default => error.is_connection_error() || error.is_timeout(),
Self::Always => true,
Self::Never => false,
Self::ConnectionOnly => error.is_connection_error(),
Self::TimeoutOnly => error.is_timeout(),
Self::Custom(errors) => errors.iter().any(|e| e.matches(error)),
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum RetryableError {
Connection,
Timeout,
Database,
Transaction,
}
impl RetryableError {
pub fn matches(&self, error: &QueryError) -> bool {
match self {
Self::Connection => error.is_connection_error(),
Self::Timeout => error.is_timeout(),
Self::Database => matches!(
error.code,
crate::error::ErrorCode::SqlSyntax
| crate::error::ErrorCode::InvalidParameter
| crate::error::ErrorCode::QueryTooComplex
),
Self::Transaction => matches!(
error.code,
crate::error::ErrorCode::TransactionFailed
| crate::error::ErrorCode::Deadlock
| crate::error::ErrorCode::SerializationFailure
| crate::error::ErrorCode::TransactionClosed
),
}
}
}
fn rand_jitter() -> f64 {
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hasher};
let hasher = RandomState::new().build_hasher();
let hash = hasher.finish();
(hash % 1000) as f64 / 1000.0
}
pub struct RetryMiddleware {
config: RetryConfig,
}
impl RetryMiddleware {
pub fn new(config: RetryConfig) -> Self {
Self { config }
}
pub fn default_config() -> Self {
Self::new(RetryConfig::default())
}
pub fn config(&self) -> &RetryConfig {
&self.config
}
}
impl Default for RetryMiddleware {
fn default() -> Self {
Self::default_config()
}
}
impl Middleware for RetryMiddleware {
fn handle<'a>(
&'a self,
ctx: QueryContext,
next: Next<'a>,
) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
Box::pin(async move {
next.run(ctx).await
})
}
fn name(&self) -> &'static str {
"RetryMiddleware"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retry_config_default() {
let config = RetryConfig::default();
assert_eq!(config.max_retries, 3);
assert!(config.jitter);
}
#[test]
fn test_retry_config_builder() {
let config = RetryConfig::new()
.max_retries(5)
.initial_delay(Duration::from_millis(50))
.jitter(false);
assert_eq!(config.max_retries, 5);
assert_eq!(config.initial_delay, Duration::from_millis(50));
assert!(!config.jitter);
}
#[test]
fn test_delay_calculation() {
let config = RetryConfig::new()
.initial_delay(Duration::from_millis(100))
.backoff_multiplier(2.0)
.jitter(false);
assert_eq!(config.delay_for_attempt(0), Duration::from_millis(100));
assert_eq!(config.delay_for_attempt(1), Duration::from_millis(200));
assert_eq!(config.delay_for_attempt(2), Duration::from_millis(400));
}
#[test]
fn test_delay_max_cap() {
let config = RetryConfig::new()
.initial_delay(Duration::from_secs(1))
.max_delay(Duration::from_secs(5))
.backoff_multiplier(10.0)
.jitter(false);
assert_eq!(config.delay_for_attempt(2), Duration::from_secs(5));
}
#[test]
fn test_retry_predicate_default() {
let predicate = RetryPredicate::Default;
assert!(predicate.should_retry(&QueryError::connection("test")));
assert!(predicate.should_retry(&QueryError::timeout(1000)));
assert!(!predicate.should_retry(&QueryError::not_found("User")));
}
#[test]
fn test_retry_predicate_custom() {
let predicate =
RetryPredicate::Custom(vec![RetryableError::Connection, RetryableError::Database]);
assert!(predicate.should_retry(&QueryError::connection("test")));
assert!(predicate.should_retry(&QueryError::sql_syntax("error", "SELECT")));
assert!(!predicate.should_retry(&QueryError::timeout(1000)));
}
}