1use crate::errors::KodeBridgeError;
2use rand::{random_range, rngs::StdRng, SeedableRng};
3use std::time::{Duration, Instant};
4use tracing::{debug, warn};
5
6pub type RetryFn = Box<dyn Fn(&KodeBridgeError, usize) -> bool + Send + Sync>;
8
9pub struct RetryConfig {
11 pub max_attempts: usize,
13 pub base_delay: Duration,
15 pub max_delay: Duration,
17 pub backoff_strategy: BackoffStrategy,
19 pub jitter_strategy: JitterStrategy,
21 pub should_retry_fn: Option<RetryFn>,
23}
24
25impl Clone for RetryConfig {
26 fn clone(&self) -> Self {
27 Self {
28 max_attempts: self.max_attempts,
29 base_delay: self.base_delay,
30 max_delay: self.max_delay,
31 backoff_strategy: self.backoff_strategy,
32 jitter_strategy: self.jitter_strategy,
33 should_retry_fn: None, }
35 }
36}
37
38#[derive(Debug, Clone, Copy)]
39pub enum BackoffStrategy {
40 Fixed,
42 Exponential { multiplier: f64 },
44 Linear { increment: Duration },
46}
47
48#[derive(Debug, Clone, Copy)]
49pub enum JitterStrategy {
50 None,
52 Full,
54 Partial,
56 Decorrelated,
58}
59
60impl Default for RetryConfig {
61 fn default() -> Self {
62 Self {
63 max_attempts: 3,
64 base_delay: Duration::from_millis(100),
65 max_delay: Duration::from_secs(30),
66 backoff_strategy: BackoffStrategy::Exponential { multiplier: 2.0 },
67 jitter_strategy: JitterStrategy::Partial,
68 should_retry_fn: None,
69 }
70 }
71}
72
73impl RetryConfig {
74 pub fn new() -> Self {
76 Self::default()
77 }
78
79 pub fn max_attempts(mut self, max_attempts: usize) -> Self {
81 self.max_attempts = max_attempts;
82 self
83 }
84
85 pub fn base_delay(mut self, delay: Duration) -> Self {
87 self.base_delay = delay;
88 self
89 }
90
91 pub fn max_delay(mut self, delay: Duration) -> Self {
93 self.max_delay = delay;
94 self
95 }
96
97 pub fn exponential_backoff(mut self, multiplier: f64) -> Self {
99 self.backoff_strategy = BackoffStrategy::Exponential { multiplier };
100 self
101 }
102
103 pub fn fixed_backoff(mut self) -> Self {
105 self.backoff_strategy = BackoffStrategy::Fixed;
106 self
107 }
108
109 pub fn linear_backoff(mut self, increment: Duration) -> Self {
111 self.backoff_strategy = BackoffStrategy::Linear { increment };
112 self
113 }
114
115 pub fn jitter(mut self, strategy: JitterStrategy) -> Self {
117 self.jitter_strategy = strategy;
118 self
119 }
120
121 pub fn should_retry<F>(mut self, f: F) -> Self
123 where
124 F: Fn(&KodeBridgeError, usize) -> bool + Send + Sync + 'static,
125 {
126 self.should_retry_fn = Some(Box::new(f));
127 self
128 }
129
130 pub fn for_network_operations() -> Self {
132 Self::new()
133 .max_attempts(5)
134 .base_delay(Duration::from_millis(50))
135 .max_delay(Duration::from_secs(10))
136 .exponential_backoff(2.0)
137 .jitter(JitterStrategy::Full)
138 }
139
140 pub fn for_rate_limited_apis() -> Self {
141 Self::new()
142 .max_attempts(10)
143 .base_delay(Duration::from_secs(1))
144 .max_delay(Duration::from_secs(60))
145 .exponential_backoff(1.5)
146 .jitter(JitterStrategy::Decorrelated)
147 }
148
149 pub fn for_quick_operations() -> Self {
150 Self::new()
151 .max_attempts(2)
152 .base_delay(Duration::from_millis(10))
153 .max_delay(Duration::from_millis(100))
154 .fixed_backoff()
155 .jitter(JitterStrategy::None)
156 }
157
158 pub fn for_put_requests() -> Self {
160 Self::new()
161 .max_attempts(2) .base_delay(Duration::from_millis(25))
163 .max_delay(Duration::from_millis(200))
164 .exponential_backoff(1.5) .jitter(JitterStrategy::Partial)
166 }
167
168 pub fn for_large_put_requests() -> Self {
170 Self::new()
171 .max_attempts(3) .base_delay(Duration::from_millis(50))
173 .max_delay(Duration::from_millis(500))
174 .linear_backoff(Duration::from_millis(50))
175 .jitter(JitterStrategy::Partial)
176 }
177}
178
179#[derive(Debug)]
181pub struct RetryState {
182 attempt: usize,
183 total_elapsed: Duration,
184 last_delay: Duration,
185}
186
187impl Default for RetryState {
188 fn default() -> Self {
189 Self {
190 attempt: 0,
191 total_elapsed: Duration::ZERO,
192 last_delay: Duration::ZERO,
193 }
194 }
195}
196
197impl RetryState {
198 pub fn new() -> Self {
199 Self::default()
200 }
201
202 pub fn attempt(&self) -> usize {
203 self.attempt
204 }
205
206 pub fn total_elapsed(&self) -> Duration {
207 self.total_elapsed
208 }
209
210 pub fn last_delay(&self) -> Duration {
211 self.last_delay
212 }
213}
214
215pub struct RetryExecutor {
217 config: RetryConfig,
218}
219
220impl RetryExecutor {
221 pub fn new(config: RetryConfig) -> Self {
222 Self { config }
223 }
224
225 pub async fn execute<F, Fut, T>(&self, mut operation: F) -> Result<T, KodeBridgeError>
227 where
228 F: FnMut() -> Fut,
229 Fut: std::future::Future<Output = Result<T, KodeBridgeError>>,
230 {
231 let mut state = RetryState::new();
232 let mut rng = StdRng::from_seed([0u8; 32]); loop {
235 state.attempt += 1;
236 let attempt_start = Instant::now();
237
238 debug!("Retry attempt {} starting", state.attempt);
239
240 match operation().await {
241 Ok(result) => {
242 if state.attempt > 1 {
243 debug!(
244 "Operation succeeded on attempt {} after {}ms",
245 state.attempt,
246 state.total_elapsed.as_millis()
247 );
248 }
249 return Ok(result);
250 }
251 Err(error) => {
252 let attempt_duration = attempt_start.elapsed();
253 state.total_elapsed += attempt_duration;
254
255 let should_retry = if let Some(ref custom_fn) = self.config.should_retry_fn {
257 custom_fn(&error, state.attempt)
258 } else {
259 self.default_should_retry(&error, state.attempt)
260 };
261
262 if !should_retry || state.attempt >= self.config.max_attempts {
263 warn!(
264 "Operation failed after {} attempts in {}ms: {}",
265 state.attempt,
266 state.total_elapsed.as_millis(),
267 error
268 );
269 return Err(error);
270 }
271
272 let next_delay = self.calculate_delay(&mut state, &mut rng);
274
275 debug!(
276 "Retrying after {}ms (attempt {}/{}, error: {})",
277 next_delay.as_millis(),
278 state.attempt,
279 self.config.max_attempts,
280 error
281 );
282
283 tokio::time::sleep(next_delay).await;
284 }
285 }
286 }
287 }
288
289 pub async fn execute_with_context<F, Fut, T>(
291 &self,
292 operation_name: &str,
293 operation: F,
294 ) -> Result<T, KodeBridgeError>
295 where
296 F: FnMut() -> Fut,
297 Fut: std::future::Future<Output = Result<T, KodeBridgeError>>,
298 {
299 debug!("Starting retry execution for operation: {}", operation_name);
300
301 match self.execute(operation).await {
302 Ok(result) => {
303 debug!("Operation '{}' completed successfully", operation_name);
304 Ok(result)
305 }
306 Err(error) => {
307 warn!(
308 "Operation '{}' failed with error: {}",
309 operation_name, error
310 );
311 Err(KodeBridgeError::custom(format!(
312 "Operation '{}' failed after retries: {}",
313 operation_name, error
314 )))
315 }
316 }
317 }
318
319 fn default_should_retry(&self, error: &KodeBridgeError, attempt: usize) -> bool {
321 use KodeBridgeError::*;
322
323 match error {
324 Io(_) | Connection { .. } | Timeout { .. } | StreamClosed => true,
326
327 ServerError { status } => *status >= 500,
329 ClientError { .. } | InvalidRequest { .. } => false,
330
331 HttpParse(_) | Http(_) | Protocol { .. } => false,
333
334 Configuration { .. } => false,
336
337 Json(_) | JsonSerialize { .. } => false,
339
340 Utf8(_) | FromUtf8(_) => false,
342
343 PoolExhausted => attempt <= 5, Custom { .. } => false,
348
349 InvalidStatusCode(_) => false,
351 }
352 }
353
354 fn calculate_delay(&self, state: &mut RetryState, _rng: &mut impl rand::Rng) -> Duration {
356 let base_delay = match self.config.backoff_strategy {
357 BackoffStrategy::Fixed => self.config.base_delay,
358 BackoffStrategy::Exponential { multiplier } => {
359 if state.attempt == 1 {
360 self.config.base_delay
361 } else {
362 let exponential = (self.config.base_delay.as_millis() as f64
363 * multiplier.powi((state.attempt - 1) as i32))
364 as u64;
365 Duration::from_millis(exponential)
366 }
367 }
368 BackoffStrategy::Linear { increment } => {
369 self.config.base_delay + increment * (state.attempt as u32 - 1)
370 }
371 };
372
373 let capped_delay = std::cmp::min(base_delay, self.config.max_delay);
375
376 let final_delay = match self.config.jitter_strategy {
378 JitterStrategy::None => capped_delay,
379 JitterStrategy::Full => {
380 let jitter = random_range(0..=capped_delay.as_millis() / 2) as u64;
381 capped_delay + Duration::from_millis(jitter)
382 }
383 JitterStrategy::Partial => {
384 let jitter = random_range(0..=capped_delay.as_millis() / 4) as u64;
385 capped_delay + Duration::from_millis(jitter)
386 }
387 JitterStrategy::Decorrelated => {
388 let min_delay = self.config.base_delay.as_millis() as u64;
390 let max_delay = std::cmp::min(
391 (state.last_delay.as_millis() as u64 * 3).max(min_delay),
392 self.config.max_delay.as_millis() as u64,
393 );
394 Duration::from_millis(random_range(min_delay..=max_delay))
395 }
396 };
397
398 state.last_delay = final_delay;
399 final_delay
400 }
401}
402
403pub async fn retry<F, Fut, T>(config: RetryConfig, operation: F) -> Result<T, KodeBridgeError>
405where
406 F: FnMut() -> Fut,
407 Fut: std::future::Future<Output = Result<T, KodeBridgeError>>,
408{
409 RetryExecutor::new(config).execute(operation).await
410}
411
412pub async fn retry_default<F, Fut, T>(operation: F) -> Result<T, KodeBridgeError>
414where
415 F: FnMut() -> Fut,
416 Fut: std::future::Future<Output = Result<T, KodeBridgeError>>,
417{
418 retry(RetryConfig::default(), operation).await
419}
420
421#[derive(Debug)]
423pub struct CircuitBreaker {
424 failure_threshold: usize,
425 recovery_timeout: Duration,
426 consecutive_failures: usize,
427 last_failure_time: Option<Instant>,
428 state: CircuitState,
429}
430
431#[derive(Debug, Clone, PartialEq)]
432enum CircuitState {
433 Closed, Open, HalfOpen, }
437
438impl CircuitBreaker {
439 pub fn new(failure_threshold: usize, recovery_timeout: Duration) -> Self {
440 Self {
441 failure_threshold,
442 recovery_timeout,
443 consecutive_failures: 0,
444 last_failure_time: None,
445 state: CircuitState::Closed,
446 }
447 }
448
449 pub async fn execute<F, Fut, T>(&mut self, operation: F) -> Result<T, KodeBridgeError>
450 where
451 F: FnOnce() -> Fut,
452 Fut: std::future::Future<Output = Result<T, KodeBridgeError>>,
453 {
454 if self.state == CircuitState::Open {
455 if let Some(last_failure) = self.last_failure_time {
456 if last_failure.elapsed() >= self.recovery_timeout {
457 debug!("Circuit breaker entering half-open state");
458 self.state = CircuitState::HalfOpen;
459 } else {
460 return Err(KodeBridgeError::custom("Circuit breaker is open"));
461 }
462 } else {
463 return Err(KodeBridgeError::custom("Circuit breaker is open"));
464 }
465 }
466
467 match operation().await {
468 Ok(result) => {
469 if self.state == CircuitState::HalfOpen {
471 debug!("Circuit breaker closing after successful operation");
472 }
473 self.consecutive_failures = 0;
474 self.last_failure_time = None;
475 self.state = CircuitState::Closed;
476 Ok(result)
477 }
478 Err(error) => {
479 self.consecutive_failures += 1;
481 self.last_failure_time = Some(Instant::now());
482
483 if self.consecutive_failures >= self.failure_threshold {
484 debug!(
485 "Circuit breaker opening after {} consecutive failures",
486 self.consecutive_failures
487 );
488 self.state = CircuitState::Open;
489 }
490
491 Err(error)
492 }
493 }
494 }
495
496 pub fn is_open(&self) -> bool {
497 matches!(self.state, CircuitState::Open)
498 }
499
500 pub fn reset(&mut self) {
501 self.consecutive_failures = 0;
502 self.last_failure_time = None;
503 self.state = CircuitState::Closed;
504 }
505}
506
507#[cfg(test)]
508mod tests {
509 use super::*;
510 use std::sync::atomic::{AtomicUsize, Ordering};
511 use std::sync::Arc;
512
513 #[tokio::test]
514 async fn test_retry_success_on_first_attempt() {
515 let config = RetryConfig::new().max_attempts(3);
516 let executor = RetryExecutor::new(config);
517
518 let result = executor
519 .execute(|| async { Ok::<i32, KodeBridgeError>(42) })
520 .await;
521
522 assert_eq!(result.unwrap(), 42);
523 }
524
525 #[tokio::test]
526 async fn test_retry_success_after_failures() {
527 let config = RetryConfig::new()
528 .max_attempts(3)
529 .base_delay(Duration::from_millis(1));
530 let executor = RetryExecutor::new(config);
531 let attempt_count = Arc::new(AtomicUsize::new(0));
532
533 let result = executor
534 .execute(|| {
535 let count = attempt_count.clone();
536 async move {
537 let current = count.fetch_add(1, Ordering::SeqCst);
538 if current < 2 {
539 Err(KodeBridgeError::connection("Temporary failure"))
540 } else {
541 Ok(42)
542 }
543 }
544 })
545 .await;
546
547 assert_eq!(result.unwrap(), 42);
548 assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
549 }
550
551 #[tokio::test]
552 async fn test_retry_max_attempts_exceeded() {
553 let config = RetryConfig::new()
554 .max_attempts(2)
555 .base_delay(Duration::from_millis(1));
556 let executor = RetryExecutor::new(config);
557 let attempt_count = Arc::new(AtomicUsize::new(0));
558
559 let result = executor
560 .execute(|| {
561 let count = attempt_count.clone();
562 async move {
563 count.fetch_add(1, Ordering::SeqCst);
564 Err::<i32, _>(KodeBridgeError::connection("Always fails"))
565 }
566 })
567 .await;
568
569 assert!(result.is_err());
570 assert_eq!(attempt_count.load(Ordering::SeqCst), 2);
571 }
572
573 #[tokio::test]
574 async fn test_retry_non_retriable_error() {
575 let config = RetryConfig::new()
576 .max_attempts(3)
577 .base_delay(Duration::from_millis(1));
578 let executor = RetryExecutor::new(config);
579 let attempt_count = Arc::new(AtomicUsize::new(0));
580
581 let result = executor
582 .execute(|| {
583 let count = attempt_count.clone();
584 async move {
585 count.fetch_add(1, Ordering::SeqCst);
586 Err::<i32, _>(KodeBridgeError::ClientError { status: 400 })
587 }
588 })
589 .await;
590
591 assert!(result.is_err());
592 assert_eq!(attempt_count.load(Ordering::SeqCst), 1); }
594
595 #[tokio::test]
596 async fn test_circuit_breaker() {
597 let mut breaker = CircuitBreaker::new(2, Duration::from_millis(100));
598
599 let result = breaker
601 .execute(|| async { Err::<i32, _>(KodeBridgeError::connection("Failure 1")) })
602 .await;
603 assert!(result.is_err());
604 assert!(!breaker.is_open());
605
606 let result = breaker
608 .execute(|| async { Err::<i32, _>(KodeBridgeError::connection("Failure 2")) })
609 .await;
610 assert!(result.is_err());
611 assert!(breaker.is_open());
612
613 let result = breaker
615 .execute(|| async { Ok::<i32, KodeBridgeError>(42) })
616 .await;
617 assert!(result.is_err());
618 assert!(result
619 .unwrap_err()
620 .to_string()
621 .contains("Circuit breaker is open"));
622 }
623
624 #[test]
625 fn test_backoff_strategies() {
626 let mut state = RetryState::new();
627 let mut rng = StdRng::from_seed([0u8; 32]); let config = RetryConfig::new()
631 .exponential_backoff(2.0)
632 .base_delay(Duration::from_millis(100))
633 .jitter(JitterStrategy::None);
634 let executor = RetryExecutor::new(config);
635
636 state.attempt = 1;
637 let delay1 = executor.calculate_delay(&mut state, &mut rng);
638 assert_eq!(delay1, Duration::from_millis(100));
639
640 state.attempt = 2;
641 let delay2 = executor.calculate_delay(&mut state, &mut rng);
642 assert_eq!(delay2, Duration::from_millis(200));
643
644 state.attempt = 3;
645 let delay3 = executor.calculate_delay(&mut state, &mut rng);
646 assert_eq!(delay3, Duration::from_millis(400));
647 }
648
649 #[test]
650 fn test_retry_config_builder() {
651 let config = RetryConfig::for_network_operations();
652 assert_eq!(config.max_attempts, 5);
653 assert_eq!(config.base_delay, Duration::from_millis(50));
654
655 let config = RetryConfig::for_rate_limited_apis();
656 assert_eq!(config.max_attempts, 10);
657 assert_eq!(config.base_delay, Duration::from_secs(1));
658 }
659}