prax_query/middleware/
retry.rs1use super::context::QueryContext;
4use super::types::{BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse};
5use crate::QueryError;
6use std::time::Duration;
7
8#[derive(Debug, Clone)]
10pub struct RetryConfig {
11 pub max_retries: u32,
13 pub initial_delay: Duration,
15 pub max_delay: Duration,
17 pub backoff_multiplier: f64,
19 pub jitter: bool,
21 pub retry_on: RetryPredicate,
23}
24
25impl Default for RetryConfig {
26 fn default() -> Self {
27 Self {
28 max_retries: 3,
29 initial_delay: Duration::from_millis(100),
30 max_delay: Duration::from_secs(10),
31 backoff_multiplier: 2.0,
32 jitter: true,
33 retry_on: RetryPredicate::Default,
34 }
35 }
36}
37
38impl RetryConfig {
39 pub fn new() -> Self {
41 Self::default()
42 }
43
44 pub fn max_retries(mut self, n: u32) -> Self {
46 self.max_retries = n;
47 self
48 }
49
50 pub fn initial_delay(mut self, delay: Duration) -> Self {
52 self.initial_delay = delay;
53 self
54 }
55
56 pub fn max_delay(mut self, delay: Duration) -> Self {
58 self.max_delay = delay;
59 self
60 }
61
62 pub fn backoff_multiplier(mut self, multiplier: f64) -> Self {
64 self.backoff_multiplier = multiplier;
65 self
66 }
67
68 pub fn jitter(mut self, enabled: bool) -> Self {
70 self.jitter = enabled;
71 self
72 }
73
74 pub fn retry_on(mut self, predicate: RetryPredicate) -> Self {
76 self.retry_on = predicate;
77 self
78 }
79
80 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
82 let base_delay =
83 self.initial_delay.as_millis() as f64 * self.backoff_multiplier.powi(attempt as i32);
84
85 let delay_ms = base_delay.min(self.max_delay.as_millis() as f64);
86
87 let final_delay = if self.jitter {
88 let jitter = delay_ms * 0.25 * rand_jitter();
90 delay_ms + jitter
91 } else {
92 delay_ms
93 };
94
95 Duration::from_millis(final_delay as u64)
96 }
97}
98
99#[derive(Debug, Clone)]
101pub enum RetryPredicate {
102 Default,
104 Always,
106 Never,
108 ConnectionOnly,
110 TimeoutOnly,
112 Custom(Vec<RetryableError>),
114}
115
116impl RetryPredicate {
117 pub fn should_retry(&self, error: &QueryError) -> bool {
119 match self {
120 Self::Default => error.is_connection_error() || error.is_timeout(),
121 Self::Always => true,
122 Self::Never => false,
123 Self::ConnectionOnly => error.is_connection_error(),
124 Self::TimeoutOnly => error.is_timeout(),
125 Self::Custom(errors) => errors.iter().any(|e| e.matches(error)),
126 }
127 }
128}
129
130#[derive(Debug, Clone, Copy)]
132pub enum RetryableError {
133 Connection,
135 Timeout,
137 Database,
139 Transaction,
141}
142
143impl RetryableError {
144 pub fn matches(&self, error: &QueryError) -> bool {
146 match self {
147 Self::Connection => error.is_connection_error(),
148 Self::Timeout => error.is_timeout(),
149 Self::Database => matches!(
150 error.code,
151 crate::error::ErrorCode::SqlSyntax
152 | crate::error::ErrorCode::InvalidParameter
153 | crate::error::ErrorCode::QueryTooComplex
154 ),
155 Self::Transaction => matches!(
156 error.code,
157 crate::error::ErrorCode::TransactionFailed
158 | crate::error::ErrorCode::Deadlock
159 | crate::error::ErrorCode::SerializationFailure
160 | crate::error::ErrorCode::TransactionClosed
161 ),
162 }
163 }
164}
165
166fn rand_jitter() -> f64 {
168 use std::collections::hash_map::RandomState;
169 use std::hash::{BuildHasher, Hasher};
170
171 let hasher = RandomState::new().build_hasher();
172 let hash = hasher.finish();
173 (hash % 1000) as f64 / 1000.0
174}
175
176pub struct RetryMiddleware {
192 config: RetryConfig,
193}
194
195impl RetryMiddleware {
196 pub fn new(config: RetryConfig) -> Self {
198 Self { config }
199 }
200
201 pub fn default_config() -> Self {
203 Self::new(RetryConfig::default())
204 }
205
206 pub fn config(&self) -> &RetryConfig {
208 &self.config
209 }
210}
211
212impl Default for RetryMiddleware {
213 fn default() -> Self {
214 Self::default_config()
215 }
216}
217
218impl Middleware for RetryMiddleware {
219 fn handle<'a>(
220 &'a self,
221 ctx: QueryContext,
222 next: Next<'a>,
223 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
224 Box::pin(async move {
225 next.run(ctx).await
237 })
238 }
239
240 fn name(&self) -> &'static str {
241 "RetryMiddleware"
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn test_retry_config_default() {
251 let config = RetryConfig::default();
252 assert_eq!(config.max_retries, 3);
253 assert!(config.jitter);
254 }
255
256 #[test]
257 fn test_retry_config_builder() {
258 let config = RetryConfig::new()
259 .max_retries(5)
260 .initial_delay(Duration::from_millis(50))
261 .jitter(false);
262
263 assert_eq!(config.max_retries, 5);
264 assert_eq!(config.initial_delay, Duration::from_millis(50));
265 assert!(!config.jitter);
266 }
267
268 #[test]
269 fn test_delay_calculation() {
270 let config = RetryConfig::new()
271 .initial_delay(Duration::from_millis(100))
272 .backoff_multiplier(2.0)
273 .jitter(false);
274
275 assert_eq!(config.delay_for_attempt(0), Duration::from_millis(100));
276 assert_eq!(config.delay_for_attempt(1), Duration::from_millis(200));
277 assert_eq!(config.delay_for_attempt(2), Duration::from_millis(400));
278 }
279
280 #[test]
281 fn test_delay_max_cap() {
282 let config = RetryConfig::new()
283 .initial_delay(Duration::from_secs(1))
284 .max_delay(Duration::from_secs(5))
285 .backoff_multiplier(10.0)
286 .jitter(false);
287
288 assert_eq!(config.delay_for_attempt(2), Duration::from_secs(5));
290 }
291
292 #[test]
293 fn test_retry_predicate_default() {
294 let predicate = RetryPredicate::Default;
295
296 assert!(predicate.should_retry(&QueryError::connection("test")));
297 assert!(predicate.should_retry(&QueryError::timeout(1000)));
298 assert!(!predicate.should_retry(&QueryError::not_found("User")));
299 }
300
301 #[test]
302 fn test_retry_predicate_custom() {
303 let predicate =
304 RetryPredicate::Custom(vec![RetryableError::Connection, RetryableError::Database]);
305
306 assert!(predicate.should_retry(&QueryError::connection("test")));
307 assert!(predicate.should_retry(&QueryError::sql_syntax("error", "SELECT")));
308 assert!(!predicate.should_retry(&QueryError::timeout(1000)));
309 }
310}