1use crate::errors::KodeBridgeError;
2use rand::{random_range, rngs::StdRng, SeedableRng as _};
3use std::time::{Duration, Instant};
4use tracing::{debug, warn};
5
6pub type RetryFn = Box<dyn Fn(&KodeBridgeError, usize) -> bool + Send + Sync>;
8
9pub struct RetryConfig {
11 pub max_attempts: usize,
13 pub base_delay: Duration,
15 pub max_delay: Duration,
17 pub backoff_strategy: BackoffStrategy,
19 pub jitter_strategy: JitterStrategy,
21 pub should_retry_fn: Option<RetryFn>,
23}
24
25impl Clone for RetryConfig {
26 fn clone(&self) -> Self {
27 Self {
28 max_attempts: self.max_attempts,
29 base_delay: self.base_delay,
30 max_delay: self.max_delay,
31 backoff_strategy: self.backoff_strategy,
32 jitter_strategy: self.jitter_strategy,
33 should_retry_fn: None, }
35 }
36}
37
38#[derive(Debug, Clone, Copy)]
39pub enum BackoffStrategy {
40 Fixed,
42 Exponential { multiplier: f64 },
44 Linear { increment: Duration },
46}
47
48#[derive(Debug, Clone, Copy)]
49pub enum JitterStrategy {
50 None,
52 Full,
54 Partial,
56 Decorrelated,
58}
59
60impl Default for RetryConfig {
61 fn default() -> Self {
62 Self {
63 max_attempts: 3,
64 base_delay: Duration::from_millis(100),
65 max_delay: Duration::from_secs(30),
66 backoff_strategy: BackoffStrategy::Exponential { multiplier: 2.0 },
67 jitter_strategy: JitterStrategy::Partial,
68 should_retry_fn: None,
69 }
70 }
71}
72
73impl RetryConfig {
74 pub fn new() -> Self {
76 Self::default()
77 }
78
79 pub const fn max_attempts(mut self, max_attempts: usize) -> Self {
81 self.max_attempts = max_attempts;
82 self
83 }
84
85 pub const fn base_delay(mut self, delay: Duration) -> Self {
87 self.base_delay = delay;
88 self
89 }
90
91 pub const fn max_delay(mut self, delay: Duration) -> Self {
93 self.max_delay = delay;
94 self
95 }
96
97 pub const fn exponential_backoff(mut self, multiplier: f64) -> Self {
99 self.backoff_strategy = BackoffStrategy::Exponential { multiplier };
100 self
101 }
102
103 pub const fn fixed_backoff(mut self) -> Self {
105 self.backoff_strategy = BackoffStrategy::Fixed;
106 self
107 }
108
109 pub const fn linear_backoff(mut self, increment: Duration) -> Self {
111 self.backoff_strategy = BackoffStrategy::Linear { increment };
112 self
113 }
114
115 pub const fn jitter(mut self, strategy: JitterStrategy) -> Self {
117 self.jitter_strategy = strategy;
118 self
119 }
120
121 pub fn should_retry<F>(mut self, f: F) -> Self
123 where
124 F: Fn(&KodeBridgeError, usize) -> bool + Send + Sync + 'static,
125 {
126 self.should_retry_fn = Some(Box::new(f));
127 self
128 }
129
130 pub fn for_network_operations() -> Self {
132 Self::new()
133 .max_attempts(5)
134 .base_delay(Duration::from_millis(50))
135 .max_delay(Duration::from_secs(10))
136 .exponential_backoff(2.0)
137 .jitter(JitterStrategy::Full)
138 }
139
140 pub fn for_rate_limited_apis() -> Self {
141 Self::new()
142 .max_attempts(10)
143 .base_delay(Duration::from_secs(1))
144 .max_delay(Duration::from_secs(60))
145 .exponential_backoff(1.5)
146 .jitter(JitterStrategy::Decorrelated)
147 }
148
149 pub fn for_quick_operations() -> Self {
150 Self::new()
151 .max_attempts(2)
152 .base_delay(Duration::from_millis(10))
153 .max_delay(Duration::from_millis(100))
154 .fixed_backoff()
155 .jitter(JitterStrategy::None)
156 }
157
158 pub fn for_put_requests() -> Self {
160 Self::new()
161 .max_attempts(2) .base_delay(Duration::from_millis(25))
163 .max_delay(Duration::from_millis(200))
164 .exponential_backoff(1.5) .jitter(JitterStrategy::Partial)
166 }
167
168 pub fn for_large_put_requests() -> Self {
170 Self::new()
171 .max_attempts(3) .base_delay(Duration::from_millis(50))
173 .max_delay(Duration::from_millis(500))
174 .linear_backoff(Duration::from_millis(50))
175 .jitter(JitterStrategy::Partial)
176 }
177}
178
179#[derive(Debug)]
181pub struct RetryState {
182 attempt: usize,
183 total_elapsed: Duration,
184 last_delay: Duration,
185}
186
187impl Default for RetryState {
188 fn default() -> Self {
189 Self {
190 attempt: 0,
191 total_elapsed: Duration::ZERO,
192 last_delay: Duration::ZERO,
193 }
194 }
195}
196
197impl RetryState {
198 pub fn new() -> Self {
199 Self::default()
200 }
201
202 pub const fn attempt(&self) -> usize {
203 self.attempt
204 }
205
206 pub const fn total_elapsed(&self) -> Duration {
207 self.total_elapsed
208 }
209
210 pub const fn last_delay(&self) -> Duration {
211 self.last_delay
212 }
213}
214
215pub struct RetryExecutor {
217 config: RetryConfig,
218}
219
220impl RetryExecutor {
221 pub const fn new(config: RetryConfig) -> Self {
222 Self { config }
223 }
224
225 pub async fn execute<F, Fut, T>(&self, mut operation: F) -> Result<T, KodeBridgeError>
227 where
228 F: FnMut() -> Fut + Send,
229 Fut: std::future::Future<Output = Result<T, KodeBridgeError>> + Send,
230 T: Send,
231 {
232 let mut state = RetryState::new();
233 let mut rng = StdRng::from_seed([0u8; 32]); loop {
236 state.attempt += 1;
237 let attempt_start = Instant::now();
238
239 debug!("Retry attempt {} starting", state.attempt);
240
241 match operation().await {
242 Ok(result) => {
243 if state.attempt > 1 {
244 debug!(
245 "Operation succeeded on attempt {} after {}ms",
246 state.attempt,
247 state.total_elapsed.as_millis()
248 );
249 }
250 return Ok(result);
251 }
252 Err(error) => {
253 let attempt_duration = attempt_start.elapsed();
254 state.total_elapsed += attempt_duration;
255
256 let should_retry = if let Some(ref custom_fn) = self.config.should_retry_fn {
258 custom_fn(&error, state.attempt)
259 } else {
260 self.default_should_retry(&error, state.attempt)
261 };
262
263 if !should_retry || state.attempt >= self.config.max_attempts {
264 warn!(
265 "Operation failed after {} attempts in {}ms: {}",
266 state.attempt,
267 state.total_elapsed.as_millis(),
268 error
269 );
270 return Err(error);
271 }
272
273 let next_delay = self.calculate_delay(&mut state, &mut rng);
275
276 debug!(
277 "Retrying after {}ms (attempt {}/{}, error: {})",
278 next_delay.as_millis(),
279 state.attempt,
280 self.config.max_attempts,
281 error
282 );
283
284 tokio::time::sleep(next_delay).await;
285 }
286 }
287 }
288 }
289
290 pub async fn execute_with_context<F, Fut, T>(
292 &self,
293 operation_name: &str,
294 operation: F,
295 ) -> Result<T, KodeBridgeError>
296 where
297 F: FnMut() -> Fut + Send,
298 Fut: std::future::Future<Output = Result<T, KodeBridgeError>> + Send,
299 T: Send,
300 {
301 debug!("Starting retry execution for operation: {}", operation_name);
302
303 match self.execute(operation).await {
304 Ok(result) => {
305 debug!("Operation '{}' completed successfully", operation_name);
306 Ok(result)
307 }
308 Err(error) => {
309 warn!("Operation '{}' failed with error: {}", operation_name, error);
310 Err(KodeBridgeError::custom(format!(
311 "Operation '{}' failed after retries: {}",
312 operation_name, error
313 )))
314 }
315 }
316 }
317
318 const fn default_should_retry(&self, error: &KodeBridgeError, attempt: usize) -> bool {
320 use KodeBridgeError::*;
321
322 match error {
323 Io(_) | Connection { .. } | Timeout { .. } | StreamClosed => true,
325
326 ServerError { status } => *status >= 500,
328 ClientError { .. } | InvalidRequest { .. } => false,
329
330 HttpParse(_) | Http(_) | Protocol { .. } => false,
332
333 Configuration { .. } => false,
335
336 Json(_) | JsonSerialize { .. } => false,
338
339 Utf8(_) | FromUtf8(_) => false,
341
342 PoolExhausted => attempt <= 5, Custom { .. } => false,
347
348 InvalidStatusCode(_) => false,
350 }
351 }
352
353 fn calculate_delay(&self, state: &mut RetryState, _rng: &mut impl rand::Rng) -> Duration {
355 let base_delay = match self.config.backoff_strategy {
356 BackoffStrategy::Fixed => self.config.base_delay,
357 BackoffStrategy::Exponential { multiplier } => {
358 if state.attempt == 1 {
359 self.config.base_delay
360 } else {
361 let exponential = (self.config.base_delay.as_millis() as f64
362 * multiplier.powi((state.attempt - 1) as i32)) as u64;
363 Duration::from_millis(exponential)
364 }
365 }
366 BackoffStrategy::Linear { increment } => self.config.base_delay + increment * (state.attempt as u32 - 1),
367 };
368
369 let capped_delay = std::cmp::min(base_delay, self.config.max_delay);
371
372 let final_delay = match self.config.jitter_strategy {
374 JitterStrategy::None => capped_delay,
375 JitterStrategy::Full => {
376 let jitter = random_range(0..=capped_delay.as_millis() / 2) as u64;
377 capped_delay + Duration::from_millis(jitter)
378 }
379 JitterStrategy::Partial => {
380 let jitter = random_range(0..=capped_delay.as_millis() / 4) as u64;
381 capped_delay + Duration::from_millis(jitter)
382 }
383 JitterStrategy::Decorrelated => {
384 let min_delay = self.config.base_delay.as_millis() as u64;
386 let max_delay = std::cmp::min(
387 (state.last_delay.as_millis() as u64 * 3).max(min_delay),
388 self.config.max_delay.as_millis() as u64,
389 );
390 Duration::from_millis(random_range(min_delay..=max_delay))
391 }
392 };
393
394 state.last_delay = final_delay;
395 final_delay
396 }
397}
398
399pub async fn retry<F, Fut, T>(config: RetryConfig, operation: F) -> Result<T, KodeBridgeError>
401where
402 F: FnMut() -> Fut + Send,
403 Fut: std::future::Future<Output = Result<T, KodeBridgeError>> + Send,
404 T: Send,
405{
406 RetryExecutor::new(config).execute(operation).await
407}
408
409pub async fn retry_default<F, Fut, T>(operation: F) -> Result<T, KodeBridgeError>
411where
412 F: FnMut() -> Fut + Send,
413 Fut: std::future::Future<Output = Result<T, KodeBridgeError>> + Send,
414 T: Send,
415{
416 retry(RetryConfig::default(), operation).await
417}
418
419#[derive(Debug)]
421pub struct CircuitBreaker {
422 failure_threshold: usize,
423 recovery_timeout: Duration,
424 consecutive_failures: usize,
425 last_failure_time: Option<Instant>,
426 state: CircuitState,
427}
428
429#[derive(Debug, Clone, PartialEq)]
430enum CircuitState {
431 Closed, Open, HalfOpen, }
435
436impl CircuitBreaker {
437 pub const fn new(failure_threshold: usize, recovery_timeout: Duration) -> Self {
438 Self {
439 failure_threshold,
440 recovery_timeout,
441 consecutive_failures: 0,
442 last_failure_time: None,
443 state: CircuitState::Closed,
444 }
445 }
446
447 pub async fn execute<F, Fut, T>(&mut self, operation: F) -> Result<T, KodeBridgeError>
448 where
449 F: FnOnce() -> Fut + Send,
450 Fut: std::future::Future<Output = Result<T, KodeBridgeError>> + Send,
451 T: Send,
452 {
453 if self.state == CircuitState::Open {
454 if let Some(last_failure) = self.last_failure_time {
455 if last_failure.elapsed() >= self.recovery_timeout {
456 debug!("Circuit breaker entering half-open state");
457 self.state = CircuitState::HalfOpen;
458 } else {
459 return Err(KodeBridgeError::custom("Circuit breaker is open"));
460 }
461 } else {
462 return Err(KodeBridgeError::custom("Circuit breaker is open"));
463 }
464 }
465
466 match operation().await {
467 Ok(result) => {
468 if self.state == CircuitState::HalfOpen {
470 debug!("Circuit breaker closing after successful operation");
471 }
472 self.consecutive_failures = 0;
473 self.last_failure_time = None;
474 self.state = CircuitState::Closed;
475 Ok(result)
476 }
477 Err(error) => {
478 self.consecutive_failures += 1;
480 self.last_failure_time = Some(Instant::now());
481
482 if self.consecutive_failures >= self.failure_threshold {
483 debug!(
484 "Circuit breaker opening after {} consecutive failures",
485 self.consecutive_failures
486 );
487 self.state = CircuitState::Open;
488 }
489
490 Err(error)
491 }
492 }
493 }
494
495 pub const fn is_open(&self) -> bool {
496 matches!(self.state, CircuitState::Open)
497 }
498
499 pub const fn reset(&mut self) {
500 self.consecutive_failures = 0;
501 self.last_failure_time = None;
502 self.state = CircuitState::Closed;
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509 use std::sync::atomic::{AtomicUsize, Ordering};
510 use std::sync::Arc;
511
512 #[tokio::test]
513 async fn test_retry_success_on_first_attempt() {
514 let config = RetryConfig::new().max_attempts(3);
515 let executor = RetryExecutor::new(config);
516
517 let result = executor
518 .execute(|| async { Ok::<i32, KodeBridgeError>(42) })
519 .await;
520
521 assert_eq!(result.unwrap(), 42);
522 }
523
524 #[tokio::test]
525 async fn test_retry_success_after_failures() {
526 let config = RetryConfig::new()
527 .max_attempts(3)
528 .base_delay(Duration::from_millis(1));
529 let executor = RetryExecutor::new(config);
530 let attempt_count = Arc::new(AtomicUsize::new(0));
531
532 let result = executor
533 .execute(|| {
534 let count = Arc::clone(&attempt_count);
535 async move {
536 let current = count.fetch_add(1, Ordering::SeqCst);
537 if current < 2 {
538 Err(KodeBridgeError::connection("Temporary failure"))
539 } else {
540 Ok(42)
541 }
542 }
543 })
544 .await;
545
546 assert_eq!(result.unwrap(), 42);
547 assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
548 }
549
550 #[tokio::test]
551 async fn test_retry_max_attempts_exceeded() {
552 let config = RetryConfig::new()
553 .max_attempts(2)
554 .base_delay(Duration::from_millis(1));
555 let executor = RetryExecutor::new(config);
556 let attempt_count = Arc::new(AtomicUsize::new(0));
557
558 let result = executor
559 .execute(|| {
560 let count = Arc::clone(&attempt_count);
561 async move {
562 count.fetch_add(1, Ordering::SeqCst);
563 Err::<i32, _>(KodeBridgeError::connection("Always fails"))
564 }
565 })
566 .await;
567
568 assert!(result.is_err());
569 assert_eq!(attempt_count.load(Ordering::SeqCst), 2);
570 }
571
572 #[tokio::test]
573 async fn test_retry_non_retriable_error() {
574 let config = RetryConfig::new()
575 .max_attempts(3)
576 .base_delay(Duration::from_millis(1));
577 let executor = RetryExecutor::new(config);
578 let attempt_count = Arc::new(AtomicUsize::new(0));
579
580 let result = executor
581 .execute(|| {
582 let count = Arc::clone(&attempt_count);
583 async move {
584 count.fetch_add(1, Ordering::SeqCst);
585 Err::<i32, _>(KodeBridgeError::ClientError { status: 400 })
586 }
587 })
588 .await;
589
590 assert!(result.is_err());
591 assert_eq!(attempt_count.load(Ordering::SeqCst), 1); }
593
594 #[tokio::test]
595 async fn test_circuit_breaker() {
596 let mut breaker = CircuitBreaker::new(2, Duration::from_millis(100));
597
598 let result = breaker
600 .execute(|| async { Err::<i32, _>(KodeBridgeError::connection("Failure 1")) })
601 .await;
602 assert!(result.is_err());
603 assert!(!breaker.is_open());
604
605 let result = breaker
607 .execute(|| async { Err::<i32, _>(KodeBridgeError::connection("Failure 2")) })
608 .await;
609 assert!(result.is_err());
610 assert!(breaker.is_open());
611
612 let result = breaker
614 .execute(|| async { Ok::<i32, KodeBridgeError>(42) })
615 .await;
616 assert!(result.is_err());
617 assert!(result
618 .unwrap_err()
619 .to_string()
620 .contains("Circuit breaker is open"));
621 }
622
623 #[test]
624 fn test_backoff_strategies() {
625 let mut state = RetryState::new();
626 let mut rng = StdRng::from_seed([0u8; 32]); let config = RetryConfig::new()
630 .exponential_backoff(2.0)
631 .base_delay(Duration::from_millis(100))
632 .jitter(JitterStrategy::None);
633 let executor = RetryExecutor::new(config);
634
635 state.attempt = 1;
636 let delay1 = executor.calculate_delay(&mut state, &mut rng);
637 assert_eq!(delay1, Duration::from_millis(100));
638
639 state.attempt = 2;
640 let delay2 = executor.calculate_delay(&mut state, &mut rng);
641 assert_eq!(delay2, Duration::from_millis(200));
642
643 state.attempt = 3;
644 let delay3 = executor.calculate_delay(&mut state, &mut rng);
645 assert_eq!(delay3, Duration::from_millis(400));
646 }
647
648 #[test]
649 fn test_retry_config_builder() {
650 let config = RetryConfig::for_network_operations();
651 assert_eq!(config.max_attempts, 5);
652 assert_eq!(config.base_delay, Duration::from_millis(50));
653
654 let config = RetryConfig::for_rate_limited_apis();
655 assert_eq!(config.max_attempts, 10);
656 assert_eq!(config.base_delay, Duration::from_secs(1));
657 }
658}