prax_query/middleware/
retry.rs

1//! Retry middleware for automatic query retry on transient failures.
2
3use super::context::QueryContext;
4use super::types::{BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse};
5use crate::QueryError;
6use std::time::Duration;
7
8/// Configuration for retry behavior.
9#[derive(Debug, Clone)]
10pub struct RetryConfig {
11    /// Maximum number of retry attempts.
12    pub max_retries: u32,
13    /// Initial delay between retries.
14    pub initial_delay: Duration,
15    /// Maximum delay between retries.
16    pub max_delay: Duration,
17    /// Multiplier for exponential backoff.
18    pub backoff_multiplier: f64,
19    /// Whether to add jitter to delays.
20    pub jitter: bool,
21    /// Predicate to determine if error is retryable.
22    pub retry_on: RetryPredicate,
23}
24
25impl Default for RetryConfig {
26    fn default() -> Self {
27        Self {
28            max_retries: 3,
29            initial_delay: Duration::from_millis(100),
30            max_delay: Duration::from_secs(10),
31            backoff_multiplier: 2.0,
32            jitter: true,
33            retry_on: RetryPredicate::Default,
34        }
35    }
36}
37
38impl RetryConfig {
39    /// Create a new retry config.
40    pub fn new() -> Self {
41        Self::default()
42    }
43
44    /// Set maximum retries.
45    pub fn max_retries(mut self, n: u32) -> Self {
46        self.max_retries = n;
47        self
48    }
49
50    /// Set initial delay.
51    pub fn initial_delay(mut self, delay: Duration) -> Self {
52        self.initial_delay = delay;
53        self
54    }
55
56    /// Set maximum delay.
57    pub fn max_delay(mut self, delay: Duration) -> Self {
58        self.max_delay = delay;
59        self
60    }
61
62    /// Set backoff multiplier.
63    pub fn backoff_multiplier(mut self, multiplier: f64) -> Self {
64        self.backoff_multiplier = multiplier;
65        self
66    }
67
68    /// Enable or disable jitter.
69    pub fn jitter(mut self, enabled: bool) -> Self {
70        self.jitter = enabled;
71        self
72    }
73
74    /// Set retry predicate.
75    pub fn retry_on(mut self, predicate: RetryPredicate) -> Self {
76        self.retry_on = predicate;
77        self
78    }
79
80    /// Calculate delay for a given attempt.
81    pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
82        let base_delay =
83            self.initial_delay.as_millis() as f64 * self.backoff_multiplier.powi(attempt as i32);
84
85        let delay_ms = base_delay.min(self.max_delay.as_millis() as f64);
86
87        let final_delay = if self.jitter {
88            // Add up to 25% jitter
89            let jitter = delay_ms * 0.25 * rand_jitter();
90            delay_ms + jitter
91        } else {
92            delay_ms
93        };
94
95        Duration::from_millis(final_delay as u64)
96    }
97}
98
99/// Predicate for determining if an error should trigger a retry.
100#[derive(Debug, Clone)]
101pub enum RetryPredicate {
102    /// Default: retry on connection and timeout errors.
103    Default,
104    /// Retry on any error.
105    Always,
106    /// Never retry.
107    Never,
108    /// Retry only on connection errors.
109    ConnectionOnly,
110    /// Retry only on timeout errors.
111    TimeoutOnly,
112    /// Custom list of error types to retry.
113    Custom(Vec<RetryableError>),
114}
115
116impl RetryPredicate {
117    /// Check if an error should be retried.
118    pub fn should_retry(&self, error: &QueryError) -> bool {
119        match self {
120            Self::Default => error.is_connection_error() || error.is_timeout(),
121            Self::Always => true,
122            Self::Never => false,
123            Self::ConnectionOnly => error.is_connection_error(),
124            Self::TimeoutOnly => error.is_timeout(),
125            Self::Custom(errors) => errors.iter().any(|e| e.matches(error)),
126        }
127    }
128}
129
130/// Types of errors that can be configured for retry.
131#[derive(Debug, Clone, Copy)]
132pub enum RetryableError {
133    /// Connection errors.
134    Connection,
135    /// Timeout errors.
136    Timeout,
137    /// Database errors.
138    Database,
139    /// Transaction errors.
140    Transaction,
141}
142
143impl RetryableError {
144    /// Check if this error type matches the given error.
145    pub fn matches(&self, error: &QueryError) -> bool {
146        match self {
147            Self::Connection => error.is_connection_error(),
148            Self::Timeout => error.is_timeout(),
149            Self::Database => matches!(
150                error.code,
151                crate::error::ErrorCode::SqlSyntax
152                    | crate::error::ErrorCode::InvalidParameter
153                    | crate::error::ErrorCode::QueryTooComplex
154            ),
155            Self::Transaction => matches!(
156                error.code,
157                crate::error::ErrorCode::TransactionFailed
158                    | crate::error::ErrorCode::Deadlock
159                    | crate::error::ErrorCode::SerializationFailure
160                    | crate::error::ErrorCode::TransactionClosed
161            ),
162        }
163    }
164}
165
166/// Simple pseudo-random jitter generator (no external dependencies).
167fn rand_jitter() -> f64 {
168    use std::collections::hash_map::RandomState;
169    use std::hash::{BuildHasher, Hasher};
170
171    let hasher = RandomState::new().build_hasher();
172    let hash = hasher.finish();
173    (hash % 1000) as f64 / 1000.0
174}
175
176/// Middleware that automatically retries failed queries.
177///
178/// # Example
179///
180/// ```rust,ignore
181/// use prax_query::middleware::{RetryMiddleware, RetryConfig};
182/// use std::time::Duration;
183///
184/// let retry = RetryMiddleware::new(
185///     RetryConfig::new()
186///         .max_retries(5)
187///         .initial_delay(Duration::from_millis(50))
188///         .backoff_multiplier(2.0)
189/// );
190/// ```
191pub struct RetryMiddleware {
192    config: RetryConfig,
193}
194
195impl RetryMiddleware {
196    /// Create a new retry middleware with the given config.
197    pub fn new(config: RetryConfig) -> Self {
198        Self { config }
199    }
200
201    /// Create with default configuration.
202    pub fn default_config() -> Self {
203        Self::new(RetryConfig::default())
204    }
205
206    /// Get the retry configuration.
207    pub fn config(&self) -> &RetryConfig {
208        &self.config
209    }
210}
211
212impl Default for RetryMiddleware {
213    fn default() -> Self {
214        Self::default_config()
215    }
216}
217
218impl Middleware for RetryMiddleware {
219    fn handle<'a>(
220        &'a self,
221        ctx: QueryContext,
222        next: Next<'a>,
223    ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
224        Box::pin(async move {
225            // For now, just pass through - actual retry logic would need
226            // to be able to re-execute the query which requires different design
227            // This is a placeholder that shows the structure
228
229            // In a real implementation, we would:
230            // 1. Execute the query
231            // 2. If it fails with a retryable error, wait and retry
232            // 3. Track retry attempts
233            // 4. Eventually return success or final failure
234
235            next.run(ctx).await
236        })
237    }
238
239    fn name(&self) -> &'static str {
240        "RetryMiddleware"
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    #[test]
249    fn test_retry_config_default() {
250        let config = RetryConfig::default();
251        assert_eq!(config.max_retries, 3);
252        assert!(config.jitter);
253    }
254
255    #[test]
256    fn test_retry_config_builder() {
257        let config = RetryConfig::new()
258            .max_retries(5)
259            .initial_delay(Duration::from_millis(50))
260            .jitter(false);
261
262        assert_eq!(config.max_retries, 5);
263        assert_eq!(config.initial_delay, Duration::from_millis(50));
264        assert!(!config.jitter);
265    }
266
267    #[test]
268    fn test_delay_calculation() {
269        let config = RetryConfig::new()
270            .initial_delay(Duration::from_millis(100))
271            .backoff_multiplier(2.0)
272            .jitter(false);
273
274        assert_eq!(config.delay_for_attempt(0), Duration::from_millis(100));
275        assert_eq!(config.delay_for_attempt(1), Duration::from_millis(200));
276        assert_eq!(config.delay_for_attempt(2), Duration::from_millis(400));
277    }
278
279    #[test]
280    fn test_delay_max_cap() {
281        let config = RetryConfig::new()
282            .initial_delay(Duration::from_secs(1))
283            .max_delay(Duration::from_secs(5))
284            .backoff_multiplier(10.0)
285            .jitter(false);
286
287        // Should be capped at 5 seconds
288        assert_eq!(config.delay_for_attempt(2), Duration::from_secs(5));
289    }
290
291    #[test]
292    fn test_retry_predicate_default() {
293        let predicate = RetryPredicate::Default;
294
295        assert!(predicate.should_retry(&QueryError::connection("test")));
296        assert!(predicate.should_retry(&QueryError::timeout(1000)));
297        assert!(!predicate.should_retry(&QueryError::not_found("User")));
298    }
299
300    #[test]
301    fn test_retry_predicate_custom() {
302        let predicate =
303            RetryPredicate::Custom(vec![RetryableError::Connection, RetryableError::Database]);
304
305        assert!(predicate.should_retry(&QueryError::connection("test")));
306        assert!(predicate.should_retry(&QueryError::sql_syntax("error", "SELECT")));
307        assert!(!predicate.should_retry(&QueryError::timeout(1000)));
308    }
309}