prax_query/middleware/
retry.rs

1//! Retry middleware for automatic query retry on transient failures.
2
3use super::context::QueryContext;
4use super::types::{BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse};
5use crate::QueryError;
6use std::time::Duration;
7
8/// Configuration for retry behavior.
9#[derive(Debug, Clone)]
10pub struct RetryConfig {
11    /// Maximum number of retry attempts.
12    pub max_retries: u32,
13    /// Initial delay between retries.
14    pub initial_delay: Duration,
15    /// Maximum delay between retries.
16    pub max_delay: Duration,
17    /// Multiplier for exponential backoff.
18    pub backoff_multiplier: f64,
19    /// Whether to add jitter to delays.
20    pub jitter: bool,
21    /// Predicate to determine if error is retryable.
22    pub retry_on: RetryPredicate,
23}
24
25impl Default for RetryConfig {
26    fn default() -> Self {
27        Self {
28            max_retries: 3,
29            initial_delay: Duration::from_millis(100),
30            max_delay: Duration::from_secs(10),
31            backoff_multiplier: 2.0,
32            jitter: true,
33            retry_on: RetryPredicate::Default,
34        }
35    }
36}
37
38impl RetryConfig {
39    /// Create a new retry config.
40    pub fn new() -> Self {
41        Self::default()
42    }
43
44    /// Set maximum retries.
45    pub fn max_retries(mut self, n: u32) -> Self {
46        self.max_retries = n;
47        self
48    }
49
50    /// Set initial delay.
51    pub fn initial_delay(mut self, delay: Duration) -> Self {
52        self.initial_delay = delay;
53        self
54    }
55
56    /// Set maximum delay.
57    pub fn max_delay(mut self, delay: Duration) -> Self {
58        self.max_delay = delay;
59        self
60    }
61
62    /// Set backoff multiplier.
63    pub fn backoff_multiplier(mut self, multiplier: f64) -> Self {
64        self.backoff_multiplier = multiplier;
65        self
66    }
67
68    /// Enable or disable jitter.
69    pub fn jitter(mut self, enabled: bool) -> Self {
70        self.jitter = enabled;
71        self
72    }
73
74    /// Set retry predicate.
75    pub fn retry_on(mut self, predicate: RetryPredicate) -> Self {
76        self.retry_on = predicate;
77        self
78    }
79
80    /// Calculate delay for a given attempt.
81    pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
82        let base_delay =
83            self.initial_delay.as_millis() as f64 * self.backoff_multiplier.powi(attempt as i32);
84
85        let delay_ms = base_delay.min(self.max_delay.as_millis() as f64);
86
87        let final_delay = if self.jitter {
88            // Add up to 25% jitter
89            let jitter = delay_ms * 0.25 * rand_jitter();
90            delay_ms + jitter
91        } else {
92            delay_ms
93        };
94
95        Duration::from_millis(final_delay as u64)
96    }
97}
98
99/// Predicate for determining if an error should trigger a retry.
100#[derive(Debug, Clone)]
101pub enum RetryPredicate {
102    /// Default: retry on connection and timeout errors.
103    Default,
104    /// Retry on any error.
105    Always,
106    /// Never retry.
107    Never,
108    /// Retry only on connection errors.
109    ConnectionOnly,
110    /// Retry only on timeout errors.
111    TimeoutOnly,
112    /// Custom list of error types to retry.
113    Custom(Vec<RetryableError>),
114}
115
116impl RetryPredicate {
117    /// Check if an error should be retried.
118    pub fn should_retry(&self, error: &QueryError) -> bool {
119        match self {
120            Self::Default => error.is_connection_error() || error.is_timeout(),
121            Self::Always => true,
122            Self::Never => false,
123            Self::ConnectionOnly => error.is_connection_error(),
124            Self::TimeoutOnly => error.is_timeout(),
125            Self::Custom(errors) => errors.iter().any(|e| e.matches(error)),
126        }
127    }
128}
129
130/// Types of errors that can be configured for retry.
131#[derive(Debug, Clone, Copy)]
132pub enum RetryableError {
133    /// Connection errors.
134    Connection,
135    /// Timeout errors.
136    Timeout,
137    /// Database errors.
138    Database,
139    /// Transaction errors.
140    Transaction,
141}
142
143impl RetryableError {
144    /// Check if this error type matches the given error.
145    pub fn matches(&self, error: &QueryError) -> bool {
146        match self {
147            Self::Connection => error.is_connection_error(),
148            Self::Timeout => error.is_timeout(),
149            Self::Database => matches!(
150                error.code,
151                crate::error::ErrorCode::SqlSyntax
152                    | crate::error::ErrorCode::InvalidParameter
153                    | crate::error::ErrorCode::QueryTooComplex
154            ),
155            Self::Transaction => matches!(
156                error.code,
157                crate::error::ErrorCode::TransactionFailed
158                    | crate::error::ErrorCode::Deadlock
159                    | crate::error::ErrorCode::SerializationFailure
160                    | crate::error::ErrorCode::TransactionClosed
161            ),
162        }
163    }
164}
165
166/// Simple pseudo-random jitter generator (no external dependencies).
167fn rand_jitter() -> f64 {
168    use std::collections::hash_map::RandomState;
169    use std::hash::{BuildHasher, Hasher};
170
171    let hasher = RandomState::new().build_hasher();
172    let hash = hasher.finish();
173    (hash % 1000) as f64 / 1000.0
174}
175
176/// Middleware that automatically retries failed queries.
177///
178/// # Example
179///
180/// ```rust,ignore
181/// use prax_query::middleware::{RetryMiddleware, RetryConfig};
182/// use std::time::Duration;
183///
184/// let retry = RetryMiddleware::new(
185///     RetryConfig::new()
186///         .max_retries(5)
187///         .initial_delay(Duration::from_millis(50))
188///         .backoff_multiplier(2.0)
189/// );
190/// ```
191pub struct RetryMiddleware {
192    config: RetryConfig,
193}
194
195impl RetryMiddleware {
196    /// Create a new retry middleware with the given config.
197    pub fn new(config: RetryConfig) -> Self {
198        Self { config }
199    }
200
201    /// Create with default configuration.
202    pub fn default_config() -> Self {
203        Self::new(RetryConfig::default())
204    }
205
206    /// Get the retry configuration.
207    pub fn config(&self) -> &RetryConfig {
208        &self.config
209    }
210}
211
212impl Default for RetryMiddleware {
213    fn default() -> Self {
214        Self::default_config()
215    }
216}
217
218impl Middleware for RetryMiddleware {
219    fn handle<'a>(
220        &'a self,
221        ctx: QueryContext,
222        next: Next<'a>,
223    ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
224        Box::pin(async move {
225            // For now, just pass through - actual retry logic would need
226            // to be able to re-execute the query which requires different design
227            // This is a placeholder that shows the structure
228            
229
230            // In a real implementation, we would:
231            // 1. Execute the query
232            // 2. If it fails with a retryable error, wait and retry
233            // 3. Track retry attempts
234            // 4. Eventually return success or final failure
235
236            next.run(ctx).await
237        })
238    }
239
240    fn name(&self) -> &'static str {
241        "RetryMiddleware"
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_retry_config_default() {
251        let config = RetryConfig::default();
252        assert_eq!(config.max_retries, 3);
253        assert!(config.jitter);
254    }
255
256    #[test]
257    fn test_retry_config_builder() {
258        let config = RetryConfig::new()
259            .max_retries(5)
260            .initial_delay(Duration::from_millis(50))
261            .jitter(false);
262
263        assert_eq!(config.max_retries, 5);
264        assert_eq!(config.initial_delay, Duration::from_millis(50));
265        assert!(!config.jitter);
266    }
267
268    #[test]
269    fn test_delay_calculation() {
270        let config = RetryConfig::new()
271            .initial_delay(Duration::from_millis(100))
272            .backoff_multiplier(2.0)
273            .jitter(false);
274
275        assert_eq!(config.delay_for_attempt(0), Duration::from_millis(100));
276        assert_eq!(config.delay_for_attempt(1), Duration::from_millis(200));
277        assert_eq!(config.delay_for_attempt(2), Duration::from_millis(400));
278    }
279
280    #[test]
281    fn test_delay_max_cap() {
282        let config = RetryConfig::new()
283            .initial_delay(Duration::from_secs(1))
284            .max_delay(Duration::from_secs(5))
285            .backoff_multiplier(10.0)
286            .jitter(false);
287
288        // Should be capped at 5 seconds
289        assert_eq!(config.delay_for_attempt(2), Duration::from_secs(5));
290    }
291
292    #[test]
293    fn test_retry_predicate_default() {
294        let predicate = RetryPredicate::Default;
295
296        assert!(predicate.should_retry(&QueryError::connection("test")));
297        assert!(predicate.should_retry(&QueryError::timeout(1000)));
298        assert!(!predicate.should_retry(&QueryError::not_found("User")));
299    }
300
301    #[test]
302    fn test_retry_predicate_custom() {
303        let predicate =
304            RetryPredicate::Custom(vec![RetryableError::Connection, RetryableError::Database]);
305
306        assert!(predicate.should_retry(&QueryError::connection("test")));
307        assert!(predicate.should_retry(&QueryError::sql_syntax("error", "SELECT")));
308        assert!(!predicate.should_retry(&QueryError::timeout(1000)));
309    }
310}