1use std::time::{Duration, Instant};
4use std::collections::HashMap;
5use serde::{Deserialize, Serialize};
6use thiserror::Error;
7use tokio::time::sleep;
8
9#[derive(Error, Debug)]
10pub enum RetryError {
11 #[error("Max retries exceeded: {0}")]
12 MaxRetriesExceeded(usize),
13 #[error("Circuit breaker open: {0}")]
14 CircuitBreakerOpen(String),
15 #[error("Operation timeout: {0}")]
16 Timeout(String),
17 #[error("Fatal error: {0}")]
18 Fatal(String),
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub enum RetryStrategy {
24 None,
26 Fixed { delay: Duration, max_attempts: usize },
28 Exponential {
30 initial_delay: Duration,
31 max_delay: Duration,
32 multiplier: f64,
33 max_attempts: usize,
34 },
35 Fibonacci {
37 initial_delay: Duration,
38 max_delay: Duration,
39 max_attempts: usize,
40 },
41 Custom { max_attempts: usize, custom_logic: String },
43}
44
45impl Default for RetryStrategy {
46 fn default() -> Self {
47 Self::Exponential {
48 initial_delay: Duration::from_millis(100),
49 max_delay: Duration::from_secs(30),
50 multiplier: 2.0,
51 max_attempts: 5,
52 }
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub enum CircuitBreakerState {
59 Closed, Open, HalfOpen, }
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct CircuitBreakerConfig {
67 pub failure_threshold: usize,
68 pub recovery_timeout: Duration,
69 pub expected_duration: Duration,
70 pub min_calls: usize,
71}
72
73impl Default for CircuitBreakerConfig {
74 fn default() -> Self {
75 Self {
76 failure_threshold: 5,
77 recovery_timeout: Duration::from_secs(60),
78 expected_duration: Duration::from_secs(1),
79 min_calls: 3,
80 }
81 }
82}
83
84pub struct CircuitBreaker {
86 config: CircuitBreakerConfig,
87 state: CircuitBreakerState,
88 failure_count: usize,
89 success_count: usize,
90 last_failure_time: Option<Instant>,
91 last_state_change: Instant,
92}
93
94impl CircuitBreaker {
95 pub fn new(config: CircuitBreakerConfig) -> Self {
96 Self {
97 config,
98 state: CircuitBreakerState::Closed,
99 failure_count: 0,
100 success_count: 0,
101 last_failure_time: None,
102 last_state_change: Instant::now(),
103 }
104 }
105
106 pub fn can_execute(&mut self) -> bool {
107 match self.state {
108 CircuitBreakerState::Closed => true,
109 CircuitBreakerState::Open => {
110 if self.should_attempt_reset() {
111 self.transition_to_half_open();
112 true
113 } else {
114 false
115 }
116 }
117 CircuitBreakerState::HalfOpen => true,
118 }
119 }
120
121 pub fn on_success(&mut self) {
122 match self.state {
123 CircuitBreakerState::Closed => {
124 self.failure_count = 0;
125 }
126 CircuitBreakerState::HalfOpen => {
127 self.transition_to_closed();
128 }
129 CircuitBreakerState::Open => {}
130 }
131 }
132
133 pub fn on_failure(&mut self) {
134 self.failure_count += 1;
135 self.last_failure_time = Some(Instant::now());
136
137 match self.state {
138 CircuitBreakerState::Closed => {
139 if self.failure_count >= self.config.failure_threshold {
140 self.transition_to_open();
141 }
142 }
143 CircuitBreakerState::HalfOpen => {
144 self.transition_to_open();
145 }
146 CircuitBreakerState::Open => {}
147 }
148 }
149
150 fn should_attempt_reset(&self) -> bool {
151 if let Some(last_failure) = self.last_failure_time {
152 last_failure.elapsed() >= self.config.recovery_timeout
153 } else {
154 false
155 }
156 }
157
158 fn transition_to_open(&mut self) {
159 self.state = CircuitBreakerState::Open;
160 self.last_state_change = Instant::now();
161 }
162
163 fn transition_to_half_open(&mut self) {
164 self.state = CircuitBreakerState::HalfOpen;
165 self.last_state_change = Instant::now();
166 self.failure_count = 0;
167 self.success_count = 0;
168 }
169
170 fn transition_to_closed(&mut self) {
171 self.state = CircuitBreakerState::Closed;
172 self.last_state_change = Instant::now();
173 self.failure_count = 0;
174 self.success_count = 0;
175 }
176
177 pub fn state(&self) -> &CircuitBreakerState {
178 &self.state
179 }
180
181 pub fn failure_count(&self) -> usize {
182 self.failure_count
183 }
184
185 pub fn success_count(&self) -> usize {
186 self.success_count
187 }
188}
189
190pub struct RetryManager {
192 strategy: RetryStrategy,
193 circuit_breaker: CircuitBreaker,
194 operation_history: HashMap<String, OperationStats>,
195}
196
197#[derive(Debug, Clone)]
199pub struct OperationStats {
200 pub total_attempts: usize,
201 pub successful_attempts: usize,
202 pub failed_attempts: usize,
203 pub total_duration: Duration,
204 pub last_attempt: Option<Instant>,
205}
206
207impl RetryManager {
208 pub fn new(strategy: RetryStrategy, circuit_breaker_config: CircuitBreakerConfig) -> Self {
209 Self {
210 strategy,
211 circuit_breaker: CircuitBreaker::new(circuit_breaker_config),
212 operation_history: HashMap::new(),
213 }
214 }
215
216 pub async fn execute_with_retry<F, T, E>(
218 &mut self,
219 operation_name: &str,
220 operation: F,
221 ) -> Result<T, RetryError>
222 where
223 F: Fn() -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<T, E>> + Send>> + Send + Sync,
224 E: std::error::Error + Send + Sync + 'static,
225 {
226 let start_time = Instant::now();
227 let mut attempt = 0;
228
229 if !self.circuit_breaker.can_execute() {
231 return Err(RetryError::CircuitBreakerOpen(
232 format!("Circuit breaker is {:?}", self.circuit_breaker.state())
233 ));
234 }
235
236 loop {
237 attempt += 1;
238 let attempt_start = Instant::now();
239
240 let stats = self.operation_history.entry(operation_name.to_string()).or_insert_with(|| OperationStats {
242 total_attempts: 0,
243 successful_attempts: 0,
244 failed_attempts: 0,
245 total_duration: Duration::ZERO,
246 last_attempt: None,
247 });
248
249 stats.total_attempts += 1;
250 stats.last_attempt = Some(Instant::now());
251
252 let result = operation().await;
254
255 match result {
256 Ok(value) => {
257 self.circuit_breaker.on_success();
259 stats.successful_attempts += 1;
260 stats.total_duration += attempt_start.elapsed();
261
262 return Ok(value);
263 }
264 Err(error) => {
265 self.circuit_breaker.on_failure();
267 stats.failed_attempts += 1;
268 stats.total_duration += attempt_start.elapsed();
269
270 if !self.should_retry(attempt, &error) {
272 return Err(RetryError::MaxRetriesExceeded(attempt));
273 }
274
275 let delay = self.calculate_delay(attempt);
277
278 sleep(delay).await;
280 }
281 }
282 }
283 }
284
285 fn should_retry(&self, attempt: usize, error: &dyn std::error::Error) -> bool {
286 let max_attempts = match &self.strategy {
287 RetryStrategy::None => return false,
288 RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
289 RetryStrategy::Exponential { max_attempts, .. } => *max_attempts,
290 RetryStrategy::Fibonacci { max_attempts, .. } => *max_attempts,
291 RetryStrategy::Custom { max_attempts, .. } => *max_attempts,
292 };
293
294 attempt < max_attempts
295 }
296
297 fn calculate_delay(&self, attempt: usize) -> Duration {
298 match &self.strategy {
299 RetryStrategy::None => Duration::ZERO,
300 RetryStrategy::Fixed { delay, .. } => *delay,
301 RetryStrategy::Exponential { initial_delay, max_delay, multiplier, .. } => {
302 let delay = initial_delay.as_millis() as f64 * multiplier.powi(attempt as i32 - 1);
303 let delay_ms = delay.min(max_delay.as_millis() as f64) as u64;
304 Duration::from_millis(delay_ms)
305 }
306 RetryStrategy::Fibonacci { initial_delay, max_delay, .. } => {
307 let delay = initial_delay.as_millis() as u64 * fibonacci(attempt);
308 Duration::from_millis(delay.min(max_delay.as_millis() as u64))
309 }
310 RetryStrategy::Custom { .. } => Duration::from_millis(100), }
312 }
313
314 pub fn get_operation_stats(&self, operation_name: &str) -> Option<&OperationStats> {
316 self.operation_history.get(operation_name)
317 }
318
319 pub fn get_all_stats(&self) -> &HashMap<String, OperationStats> {
321 &self.operation_history
322 }
323
324 pub fn reset_circuit_breaker(&mut self) {
326 self.circuit_breaker = CircuitBreaker::new(self.circuit_breaker.config.clone());
327 }
328
329 pub fn update_strategy(&mut self, strategy: RetryStrategy) {
331 self.strategy = strategy;
332 }
333}
334
335fn fibonacci(n: usize) -> u64 {
337 if n <= 1 {
338 1
339 } else {
340 let mut a = 1;
341 let mut b = 1;
342 for _ in 2..n {
343 let temp = a + b;
344 a = b;
345 b = temp;
346 }
347 b
348 }
349}
350
351pub trait RetryableError {
353 fn is_retryable(&self) -> bool;
354 fn is_fatal(&self) -> bool;
355 fn retry_after(&self) -> Option<Duration>;
356}
357
358impl<T: std::error::Error> RetryableError for T {
360 fn is_retryable(&self) -> bool {
361 true
363 }
364
365 fn is_fatal(&self) -> bool {
366 false
368 }
369
370 fn retry_after(&self) -> Option<Duration> {
371 None
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use std::time::Duration;
380
381 #[tokio::test]
382 async fn test_circuit_breaker_creation() {
383 let config = CircuitBreakerConfig::default();
384 let breaker = CircuitBreaker::new(config);
385 assert!(matches!(breaker.state(), CircuitBreakerState::Closed));
386 }
387
388 #[tokio::test]
389 async fn test_retry_manager_creation() {
390 let strategy = RetryStrategy::default();
391 let config = CircuitBreakerConfig::default();
392 let manager = RetryManager::new(strategy, config);
393 assert!(manager.operation_history.is_empty());
394 }
395
396 #[tokio::test]
397 async fn test_fibonacci_calculation() {
398 assert_eq!(fibonacci(1), 1);
399 assert_eq!(fibonacci(2), 1);
400 assert_eq!(fibonacci(3), 2);
401 assert_eq!(fibonacci(4), 3);
402 assert_eq!(fibonacci(5), 5);
403 }
404
405 #[tokio::test]
406 async fn test_retry_strategy_default() {
407 let strategy = RetryStrategy::default();
408 match strategy {
409 RetryStrategy::Exponential { initial_delay, max_delay, multiplier, max_attempts } => {
410 assert_eq!(initial_delay, Duration::from_millis(100));
411 assert_eq!(max_delay, Duration::from_secs(30));
412 assert_eq!(multiplier, 2.0);
413 assert_eq!(max_attempts, 5);
414 }
415 _ => panic!("Expected exponential strategy"),
416 }
417 }
418}