1use crate::errors::KodeBridgeError;
2use rand::{random_range, rngs::StdRng, SeedableRng};
3use std::time::{Duration, Instant};
4use tracing::{debug, warn};
5
6pub type RetryFn = Box<dyn Fn(&KodeBridgeError, usize) -> bool + Send + Sync>;
8
9pub struct RetryConfig {
11 pub max_attempts: usize,
13 pub base_delay: Duration,
15 pub max_delay: Duration,
17 pub backoff_strategy: BackoffStrategy,
19 pub jitter_strategy: JitterStrategy,
21 pub should_retry_fn: Option<RetryFn>,
23}
24
25impl Clone for RetryConfig {
26 fn clone(&self) -> Self {
27 Self {
28 max_attempts: self.max_attempts,
29 base_delay: self.base_delay,
30 max_delay: self.max_delay,
31 backoff_strategy: self.backoff_strategy,
32 jitter_strategy: self.jitter_strategy,
33 should_retry_fn: None, }
35 }
36}
37
38#[derive(Debug, Clone, Copy)]
39pub enum BackoffStrategy {
40 Fixed,
42 Exponential { multiplier: f64 },
44 Linear { increment: Duration },
46}
47
48#[derive(Debug, Clone, Copy)]
49pub enum JitterStrategy {
50 None,
52 Full,
54 Partial,
56 Decorrelated,
58}
59
60impl Default for RetryConfig {
61 fn default() -> Self {
62 Self {
63 max_attempts: 3,
64 base_delay: Duration::from_millis(100),
65 max_delay: Duration::from_secs(30),
66 backoff_strategy: BackoffStrategy::Exponential { multiplier: 2.0 },
67 jitter_strategy: JitterStrategy::Partial,
68 should_retry_fn: None,
69 }
70 }
71}
72
73impl RetryConfig {
74 pub fn new() -> Self {
76 Self::default()
77 }
78
79 pub fn max_attempts(mut self, max_attempts: usize) -> Self {
81 self.max_attempts = max_attempts;
82 self
83 }
84
85 pub fn base_delay(mut self, delay: Duration) -> Self {
87 self.base_delay = delay;
88 self
89 }
90
91 pub fn max_delay(mut self, delay: Duration) -> Self {
93 self.max_delay = delay;
94 self
95 }
96
97 pub fn exponential_backoff(mut self, multiplier: f64) -> Self {
99 self.backoff_strategy = BackoffStrategy::Exponential { multiplier };
100 self
101 }
102
103 pub fn fixed_backoff(mut self) -> Self {
105 self.backoff_strategy = BackoffStrategy::Fixed;
106 self
107 }
108
109 pub fn linear_backoff(mut self, increment: Duration) -> Self {
111 self.backoff_strategy = BackoffStrategy::Linear { increment };
112 self
113 }
114
115 pub fn jitter(mut self, strategy: JitterStrategy) -> Self {
117 self.jitter_strategy = strategy;
118 self
119 }
120
121 pub fn should_retry<F>(mut self, f: F) -> Self
123 where
124 F: Fn(&KodeBridgeError, usize) -> bool + Send + Sync + 'static,
125 {
126 self.should_retry_fn = Some(Box::new(f));
127 self
128 }
129
130 pub fn for_network_operations() -> Self {
132 Self::new()
133 .max_attempts(5)
134 .base_delay(Duration::from_millis(50))
135 .max_delay(Duration::from_secs(10))
136 .exponential_backoff(2.0)
137 .jitter(JitterStrategy::Full)
138 }
139
140 pub fn for_rate_limited_apis() -> Self {
141 Self::new()
142 .max_attempts(10)
143 .base_delay(Duration::from_secs(1))
144 .max_delay(Duration::from_secs(60))
145 .exponential_backoff(1.5)
146 .jitter(JitterStrategy::Decorrelated)
147 }
148
149 pub fn for_quick_operations() -> Self {
150 Self::new()
151 .max_attempts(3)
152 .base_delay(Duration::from_millis(10))
153 .max_delay(Duration::from_millis(500))
154 .linear_backoff(Duration::from_millis(50))
155 .jitter(JitterStrategy::Partial)
156 }
157}
158
159#[derive(Debug)]
161pub struct RetryState {
162 attempt: usize,
163 total_elapsed: Duration,
164 last_delay: Duration,
165}
166
167impl Default for RetryState {
168 fn default() -> Self {
169 Self {
170 attempt: 0,
171 total_elapsed: Duration::ZERO,
172 last_delay: Duration::ZERO,
173 }
174 }
175}
176
177impl RetryState {
178 pub fn new() -> Self {
179 Self::default()
180 }
181
182 pub fn attempt(&self) -> usize {
183 self.attempt
184 }
185
186 pub fn total_elapsed(&self) -> Duration {
187 self.total_elapsed
188 }
189
190 pub fn last_delay(&self) -> Duration {
191 self.last_delay
192 }
193}
194
195pub struct RetryExecutor {
197 config: RetryConfig,
198}
199
200impl RetryExecutor {
201 pub fn new(config: RetryConfig) -> Self {
202 Self { config }
203 }
204
205 pub async fn execute<F, Fut, T>(&self, mut operation: F) -> Result<T, KodeBridgeError>
207 where
208 F: FnMut() -> Fut,
209 Fut: std::future::Future<Output = Result<T, KodeBridgeError>>,
210 {
211 let mut state = RetryState::new();
212 let mut rng = StdRng::from_seed([0u8; 32]); loop {
215 state.attempt += 1;
216 let attempt_start = Instant::now();
217
218 debug!("Retry attempt {} starting", state.attempt);
219
220 match operation().await {
221 Ok(result) => {
222 if state.attempt > 1 {
223 debug!(
224 "Operation succeeded on attempt {} after {}ms",
225 state.attempt,
226 state.total_elapsed.as_millis()
227 );
228 }
229 return Ok(result);
230 }
231 Err(error) => {
232 let attempt_duration = attempt_start.elapsed();
233 state.total_elapsed += attempt_duration;
234
235 let should_retry = if let Some(ref custom_fn) = self.config.should_retry_fn {
237 custom_fn(&error, state.attempt)
238 } else {
239 self.default_should_retry(&error, state.attempt)
240 };
241
242 if !should_retry || state.attempt >= self.config.max_attempts {
243 warn!(
244 "Operation failed after {} attempts in {}ms: {}",
245 state.attempt,
246 state.total_elapsed.as_millis(),
247 error
248 );
249 return Err(error);
250 }
251
252 let next_delay = self.calculate_delay(&mut state, &mut rng);
254
255 debug!(
256 "Retrying after {}ms (attempt {}/{}, error: {})",
257 next_delay.as_millis(),
258 state.attempt,
259 self.config.max_attempts,
260 error
261 );
262
263 tokio::time::sleep(next_delay).await;
264 }
265 }
266 }
267 }
268
269 pub async fn execute_with_context<F, Fut, T>(
271 &self,
272 operation_name: &str,
273 operation: F,
274 ) -> Result<T, KodeBridgeError>
275 where
276 F: FnMut() -> Fut,
277 Fut: std::future::Future<Output = Result<T, KodeBridgeError>>,
278 {
279 debug!("Starting retry execution for operation: {}", operation_name);
280
281 match self.execute(operation).await {
282 Ok(result) => {
283 debug!("Operation '{}' completed successfully", operation_name);
284 Ok(result)
285 }
286 Err(error) => {
287 warn!(
288 "Operation '{}' failed with error: {}",
289 operation_name, error
290 );
291 Err(KodeBridgeError::custom(format!(
292 "Operation '{}' failed after retries: {}",
293 operation_name, error
294 )))
295 }
296 }
297 }
298
299 fn default_should_retry(&self, error: &KodeBridgeError, attempt: usize) -> bool {
301 use KodeBridgeError::*;
302
303 match error {
304 Io(_) | Connection { .. } | Timeout { .. } | StreamClosed => true,
306
307 ServerError { status } => *status >= 500,
309 ClientError { .. } | InvalidRequest { .. } => false,
310
311 HttpParse(_) | Http(_) | Protocol { .. } => false,
313
314 Configuration { .. } => false,
316
317 Json(_) | JsonSerialize { .. } => false,
319
320 Utf8(_) | FromUtf8(_) => false,
322
323 PoolExhausted => attempt <= 5, Custom { .. } => false,
328
329 InvalidStatusCode(_) => false,
331 }
332 }
333
334 fn calculate_delay(&self, state: &mut RetryState, _rng: &mut impl rand::Rng) -> Duration {
336 let base_delay = match self.config.backoff_strategy {
337 BackoffStrategy::Fixed => self.config.base_delay,
338 BackoffStrategy::Exponential { multiplier } => {
339 if state.attempt == 1 {
340 self.config.base_delay
341 } else {
342 let exponential = (self.config.base_delay.as_millis() as f64
343 * multiplier.powi((state.attempt - 1) as i32))
344 as u64;
345 Duration::from_millis(exponential)
346 }
347 }
348 BackoffStrategy::Linear { increment } => {
349 self.config.base_delay + increment * (state.attempt as u32 - 1)
350 }
351 };
352
353 let capped_delay = std::cmp::min(base_delay, self.config.max_delay);
355
356 let final_delay = match self.config.jitter_strategy {
358 JitterStrategy::None => capped_delay,
359 JitterStrategy::Full => {
360 let jitter = random_range(0..=capped_delay.as_millis() / 2) as u64;
361 capped_delay + Duration::from_millis(jitter)
362 }
363 JitterStrategy::Partial => {
364 let jitter = random_range(0..=capped_delay.as_millis() / 4) as u64;
365 capped_delay + Duration::from_millis(jitter)
366 }
367 JitterStrategy::Decorrelated => {
368 let min_delay = self.config.base_delay.as_millis() as u64;
370 let max_delay = std::cmp::min(
371 (state.last_delay.as_millis() as u64 * 3).max(min_delay),
372 self.config.max_delay.as_millis() as u64,
373 );
374 Duration::from_millis(random_range(min_delay..=max_delay))
375 }
376 };
377
378 state.last_delay = final_delay;
379 final_delay
380 }
381}
382
383pub async fn retry<F, Fut, T>(config: RetryConfig, operation: F) -> Result<T, KodeBridgeError>
385where
386 F: FnMut() -> Fut,
387 Fut: std::future::Future<Output = Result<T, KodeBridgeError>>,
388{
389 RetryExecutor::new(config).execute(operation).await
390}
391
392pub async fn retry_default<F, Fut, T>(operation: F) -> Result<T, KodeBridgeError>
394where
395 F: FnMut() -> Fut,
396 Fut: std::future::Future<Output = Result<T, KodeBridgeError>>,
397{
398 retry(RetryConfig::default(), operation).await
399}
400
401#[derive(Debug)]
403pub struct CircuitBreaker {
404 failure_threshold: usize,
405 recovery_timeout: Duration,
406 consecutive_failures: usize,
407 last_failure_time: Option<Instant>,
408 state: CircuitState,
409}
410
411#[derive(Debug, Clone, PartialEq)]
412enum CircuitState {
413 Closed, Open, HalfOpen, }
417
418impl CircuitBreaker {
419 pub fn new(failure_threshold: usize, recovery_timeout: Duration) -> Self {
420 Self {
421 failure_threshold,
422 recovery_timeout,
423 consecutive_failures: 0,
424 last_failure_time: None,
425 state: CircuitState::Closed,
426 }
427 }
428
429 pub async fn execute<F, Fut, T>(&mut self, operation: F) -> Result<T, KodeBridgeError>
430 where
431 F: FnOnce() -> Fut,
432 Fut: std::future::Future<Output = Result<T, KodeBridgeError>>,
433 {
434 if self.state == CircuitState::Open {
435 if let Some(last_failure) = self.last_failure_time {
436 if last_failure.elapsed() >= self.recovery_timeout {
437 debug!("Circuit breaker entering half-open state");
438 self.state = CircuitState::HalfOpen;
439 } else {
440 return Err(KodeBridgeError::custom("Circuit breaker is open"));
441 }
442 } else {
443 return Err(KodeBridgeError::custom("Circuit breaker is open"));
444 }
445 }
446
447 match operation().await {
448 Ok(result) => {
449 if self.state == CircuitState::HalfOpen {
451 debug!("Circuit breaker closing after successful operation");
452 }
453 self.consecutive_failures = 0;
454 self.last_failure_time = None;
455 self.state = CircuitState::Closed;
456 Ok(result)
457 }
458 Err(error) => {
459 self.consecutive_failures += 1;
461 self.last_failure_time = Some(Instant::now());
462
463 if self.consecutive_failures >= self.failure_threshold {
464 debug!(
465 "Circuit breaker opening after {} consecutive failures",
466 self.consecutive_failures
467 );
468 self.state = CircuitState::Open;
469 }
470
471 Err(error)
472 }
473 }
474 }
475
476 pub fn is_open(&self) -> bool {
477 matches!(self.state, CircuitState::Open)
478 }
479
480 pub fn reset(&mut self) {
481 self.consecutive_failures = 0;
482 self.last_failure_time = None;
483 self.state = CircuitState::Closed;
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490 use std::sync::atomic::{AtomicUsize, Ordering};
491 use std::sync::Arc;
492
493 #[tokio::test]
494 async fn test_retry_success_on_first_attempt() {
495 let config = RetryConfig::new().max_attempts(3);
496 let executor = RetryExecutor::new(config);
497
498 let result = executor
499 .execute(|| async { Ok::<i32, KodeBridgeError>(42) })
500 .await;
501
502 assert_eq!(result.unwrap(), 42);
503 }
504
505 #[tokio::test]
506 async fn test_retry_success_after_failures() {
507 let config = RetryConfig::new()
508 .max_attempts(3)
509 .base_delay(Duration::from_millis(1));
510 let executor = RetryExecutor::new(config);
511 let attempt_count = Arc::new(AtomicUsize::new(0));
512
513 let result = executor
514 .execute(|| {
515 let count = attempt_count.clone();
516 async move {
517 let current = count.fetch_add(1, Ordering::SeqCst);
518 if current < 2 {
519 Err(KodeBridgeError::connection("Temporary failure"))
520 } else {
521 Ok(42)
522 }
523 }
524 })
525 .await;
526
527 assert_eq!(result.unwrap(), 42);
528 assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
529 }
530
531 #[tokio::test]
532 async fn test_retry_max_attempts_exceeded() {
533 let config = RetryConfig::new()
534 .max_attempts(2)
535 .base_delay(Duration::from_millis(1));
536 let executor = RetryExecutor::new(config);
537 let attempt_count = Arc::new(AtomicUsize::new(0));
538
539 let result = executor
540 .execute(|| {
541 let count = attempt_count.clone();
542 async move {
543 count.fetch_add(1, Ordering::SeqCst);
544 Err::<i32, _>(KodeBridgeError::connection("Always fails"))
545 }
546 })
547 .await;
548
549 assert!(result.is_err());
550 assert_eq!(attempt_count.load(Ordering::SeqCst), 2);
551 }
552
553 #[tokio::test]
554 async fn test_retry_non_retriable_error() {
555 let config = RetryConfig::new()
556 .max_attempts(3)
557 .base_delay(Duration::from_millis(1));
558 let executor = RetryExecutor::new(config);
559 let attempt_count = Arc::new(AtomicUsize::new(0));
560
561 let result = executor
562 .execute(|| {
563 let count = attempt_count.clone();
564 async move {
565 count.fetch_add(1, Ordering::SeqCst);
566 Err::<i32, _>(KodeBridgeError::ClientError { status: 400 })
567 }
568 })
569 .await;
570
571 assert!(result.is_err());
572 assert_eq!(attempt_count.load(Ordering::SeqCst), 1); }
574
575 #[tokio::test]
576 async fn test_circuit_breaker() {
577 let mut breaker = CircuitBreaker::new(2, Duration::from_millis(100));
578
579 let result = breaker
581 .execute(|| async { Err::<i32, _>(KodeBridgeError::connection("Failure 1")) })
582 .await;
583 assert!(result.is_err());
584 assert!(!breaker.is_open());
585
586 let result = breaker
588 .execute(|| async { Err::<i32, _>(KodeBridgeError::connection("Failure 2")) })
589 .await;
590 assert!(result.is_err());
591 assert!(breaker.is_open());
592
593 let result = breaker
595 .execute(|| async { Ok::<i32, KodeBridgeError>(42) })
596 .await;
597 assert!(result.is_err());
598 assert!(result
599 .unwrap_err()
600 .to_string()
601 .contains("Circuit breaker is open"));
602 }
603
604 #[test]
605 fn test_backoff_strategies() {
606 let mut state = RetryState::new();
607 let mut rng = StdRng::from_seed([0u8; 32]); let config = RetryConfig::new()
611 .exponential_backoff(2.0)
612 .base_delay(Duration::from_millis(100))
613 .jitter(JitterStrategy::None);
614 let executor = RetryExecutor::new(config);
615
616 state.attempt = 1;
617 let delay1 = executor.calculate_delay(&mut state, &mut rng);
618 assert_eq!(delay1, Duration::from_millis(100));
619
620 state.attempt = 2;
621 let delay2 = executor.calculate_delay(&mut state, &mut rng);
622 assert_eq!(delay2, Duration::from_millis(200));
623
624 state.attempt = 3;
625 let delay3 = executor.calculate_delay(&mut state, &mut rng);
626 assert_eq!(delay3, Duration::from_millis(400));
627 }
628
629 #[test]
630 fn test_retry_config_builder() {
631 let config = RetryConfig::for_network_operations();
632 assert_eq!(config.max_attempts, 5);
633 assert_eq!(config.base_delay, Duration::from_millis(50));
634
635 let config = RetryConfig::for_rate_limited_apis();
636 assert_eq!(config.max_attempts, 10);
637 assert_eq!(config.base_delay, Duration::from_secs(1));
638 }
639}