1use parking_lot::RwLock;
22use std::future::Future;
23use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
24use std::time::{Duration, Instant};
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum RetryableError {
29 Timeout,
31 ConnectionFailed,
33 ServiceUnavailable,
35 RateLimited,
37}
38
39#[derive(Debug, Clone)]
41pub struct RetryPolicy {
42 pub max_attempts: u32,
44 pub initial_delay: Duration,
46 pub max_delay: Duration,
48 pub multiplier: f64,
50 pub jitter: bool,
52}
53
54impl Default for RetryPolicy {
55 fn default() -> Self {
56 Self {
57 max_attempts: 3,
58 initial_delay: Duration::from_millis(100),
59 max_delay: Duration::from_secs(5),
60 multiplier: 2.0,
61 jitter: true,
62 }
63 }
64}
65
66impl RetryPolicy {
67 pub fn new(max_attempts: u32, initial_delay: Duration) -> Self {
69 Self {
70 max_attempts,
71 initial_delay,
72 ..Default::default()
73 }
74 }
75
76 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
78 if attempt == 0 {
79 return Duration::ZERO;
80 }
81
82 let base_delay_ms = self.initial_delay.as_millis() as f64;
83 let delay_ms = base_delay_ms * self.multiplier.powi(attempt as i32 - 1);
84 let capped_delay =
85 Duration::from_millis(delay_ms.min(self.max_delay.as_millis() as f64) as u64);
86
87 if self.jitter {
88 let jitter_factor = 1.0 + (rand::random::<f64>() * 0.25);
90 Duration::from_millis((capped_delay.as_millis() as f64 * jitter_factor) as u64)
91 } else {
92 capped_delay
93 }
94 }
95
96 pub async fn execute<F, Fut, T, E>(&self, mut operation: F) -> Result<T, E>
98 where
99 F: FnMut() -> Fut,
100 Fut: Future<Output = Result<T, E>>,
101 E: std::fmt::Debug,
102 {
103 let mut attempt = 0;
104
105 loop {
106 attempt += 1;
107
108 match operation().await {
109 Ok(result) => return Ok(result),
110 Err(e) => {
111 if attempt >= self.max_attempts {
112 tracing::warn!(
113 attempt = attempt,
114 max_attempts = self.max_attempts,
115 error = ?e,
116 "Retry exhausted"
117 );
118 return Err(e);
119 }
120
121 let delay = self.delay_for_attempt(attempt);
122 tracing::debug!(
123 attempt = attempt,
124 delay_ms = delay.as_millis(),
125 error = ?e,
126 "Retrying after delay"
127 );
128
129 tokio::time::sleep(delay).await;
130 }
131 }
132 }
133 }
134}
135
136#[derive(Debug, Clone, Copy, PartialEq, Eq)]
138pub enum CircuitState {
139 Closed,
141 Open,
143 HalfOpen,
145}
146
147#[derive(Debug)]
149pub struct CircuitBreaker {
150 failure_threshold: u32,
152 success_threshold: u32,
154 timeout: Duration,
156 state: RwLock<CircuitState>,
158 failure_count: AtomicU32,
160 success_count: AtomicU32,
162 opened_at: RwLock<Option<Instant>>,
164}
165
166impl CircuitBreaker {
167 pub fn new(failure_threshold: u32, success_threshold: u32, timeout: Duration) -> Self {
169 Self {
170 failure_threshold,
171 success_threshold,
172 timeout,
173 state: RwLock::new(CircuitState::Closed),
174 failure_count: AtomicU32::new(0),
175 success_count: AtomicU32::new(0),
176 opened_at: RwLock::new(None),
177 }
178 }
179
180 pub fn state(&self) -> CircuitState {
182 self.maybe_transition_from_open();
183 *self.state.read()
184 }
185
186 pub fn allow_request(&self) -> bool {
188 self.maybe_transition_from_open();
189 let state = *self.state.read();
190 matches!(state, CircuitState::Closed | CircuitState::HalfOpen)
191 }
192
193 pub fn record_success(&self) {
195 let state = *self.state.read();
196 match state {
197 CircuitState::Closed => {
198 self.failure_count.store(0, Ordering::SeqCst);
200 }
201 CircuitState::HalfOpen => {
202 let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
203 if count >= self.success_threshold {
204 self.close();
205 }
206 }
207 CircuitState::Open => {}
208 }
209 }
210
211 pub fn record_failure(&self) {
213 let state = *self.state.read();
214 match state {
215 CircuitState::Closed => {
216 let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
217 if count >= self.failure_threshold {
218 self.open();
219 }
220 }
221 CircuitState::HalfOpen => {
222 self.open();
224 }
225 CircuitState::Open => {}
226 }
227 }
228
229 fn open(&self) {
231 tracing::warn!("Circuit breaker opened");
232 *self.state.write() = CircuitState::Open;
233 *self.opened_at.write() = Some(Instant::now());
234 self.success_count.store(0, Ordering::SeqCst);
235 }
236
237 fn close(&self) {
239 tracing::info!("Circuit breaker closed");
240 *self.state.write() = CircuitState::Closed;
241 self.failure_count.store(0, Ordering::SeqCst);
242 self.success_count.store(0, Ordering::SeqCst);
243 *self.opened_at.write() = None;
244 }
245
246 fn maybe_transition_from_open(&self) {
248 let state = *self.state.read();
249 if state != CircuitState::Open {
250 return;
251 }
252
253 if let Some(opened_at) = *self.opened_at.read() {
254 if opened_at.elapsed() >= self.timeout {
255 tracing::info!("Circuit breaker transitioning to half-open");
256 *self.state.write() = CircuitState::HalfOpen;
257 self.success_count.store(0, Ordering::SeqCst);
258 }
259 }
260 }
261
262 pub async fn execute<F, Fut, T, E>(&self, operation: F) -> Result<T, CircuitBreakerError<E>>
264 where
265 F: FnOnce() -> Fut,
266 Fut: Future<Output = Result<T, E>>,
267 {
268 if !self.allow_request() {
269 return Err(CircuitBreakerError::Open);
270 }
271
272 match operation().await {
273 Ok(result) => {
274 self.record_success();
275 Ok(result)
276 }
277 Err(e) => {
278 self.record_failure();
279 Err(CircuitBreakerError::Inner(e))
280 }
281 }
282 }
283}
284
285#[derive(Debug)]
287pub enum CircuitBreakerError<E> {
288 Open,
290 Inner(E),
292}
293
294impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> {
295 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
296 match self {
297 Self::Open => write!(f, "Circuit breaker is open"),
298 Self::Inner(e) => write!(f, "{}", e),
299 }
300 }
301}
302
303impl<E: std::error::Error> std::error::Error for CircuitBreakerError<E> {}
304
305#[derive(Debug, Clone)]
307pub struct TimeoutConfig {
308 pub connect: Duration,
310 pub read: Duration,
312 pub write: Duration,
314 pub total: Duration,
316}
317
318impl Default for TimeoutConfig {
319 fn default() -> Self {
320 Self {
321 connect: Duration::from_secs(5),
322 read: Duration::from_secs(30),
323 write: Duration::from_secs(30),
324 total: Duration::from_secs(60),
325 }
326 }
327}
328
329impl TimeoutConfig {
330 pub fn new(connect: Duration, total: Duration) -> Self {
332 Self {
333 connect,
334 read: total,
335 write: total,
336 total,
337 }
338 }
339
340 pub async fn execute<F, Fut, T>(&self, operation: F) -> Result<T, TimeoutError>
342 where
343 F: FnOnce() -> Fut,
344 Fut: Future<Output = T>,
345 {
346 match tokio::time::timeout(self.total, operation()).await {
347 Ok(result) => Ok(result),
348 Err(_) => Err(TimeoutError {
349 timeout: self.total,
350 }),
351 }
352 }
353}
354
355#[derive(Debug)]
357pub struct TimeoutError {
358 pub timeout: Duration,
360}
361
362impl std::fmt::Display for TimeoutError {
363 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364 write!(f, "Operation timed out after {:?}", self.timeout)
365 }
366}
367
368impl std::error::Error for TimeoutError {}
369
370#[derive(Debug)]
372pub struct RateLimiter {
373 max_tokens: u32,
375 tokens: AtomicU32,
377 last_refill: AtomicU64,
379 refill_rate: f64,
381}
382
383impl RateLimiter {
384 pub fn new(requests_per_second: f64) -> Self {
386 let max_tokens = (requests_per_second.ceil() as u32).max(1);
387 Self {
388 max_tokens,
389 tokens: AtomicU32::new(max_tokens),
390 last_refill: AtomicU64::new(
391 std::time::SystemTime::now()
392 .duration_since(std::time::UNIX_EPOCH)
393 .unwrap_or_default()
394 .as_millis() as u64,
395 ),
396 refill_rate: requests_per_second,
397 }
398 }
399
400 pub fn try_acquire(&self) -> bool {
402 self.refill();
403
404 loop {
405 let current = self.tokens.load(Ordering::SeqCst);
406 if current == 0 {
407 return false;
408 }
409 if self
410 .tokens
411 .compare_exchange(current, current - 1, Ordering::SeqCst, Ordering::SeqCst)
412 .is_ok()
413 {
414 return true;
415 }
416 }
417 }
418
419 fn refill(&self) {
421 let now = std::time::SystemTime::now()
422 .duration_since(std::time::UNIX_EPOCH)
423 .unwrap_or_default()
424 .as_millis() as u64;
425
426 let last = self.last_refill.load(Ordering::SeqCst);
427 let elapsed_ms = now.saturating_sub(last);
428 let elapsed_secs = elapsed_ms as f64 / 1000.0;
429 let tokens_to_add = (elapsed_secs * self.refill_rate) as u32;
430
431 if tokens_to_add > 0
432 && self
433 .last_refill
434 .compare_exchange(last, now, Ordering::SeqCst, Ordering::SeqCst)
435 .is_ok()
436 {
437 let current = self.tokens.load(Ordering::SeqCst);
438 let new_tokens = (current + tokens_to_add).min(self.max_tokens);
439 self.tokens.store(new_tokens, Ordering::SeqCst);
440 }
441 }
442
443 pub fn remaining(&self) -> u32 {
445 self.refill();
446 self.tokens.load(Ordering::SeqCst)
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453
454 #[test]
455 fn test_retry_policy_delay() {
456 let policy = RetryPolicy {
457 max_attempts: 3,
458 initial_delay: Duration::from_millis(100),
459 max_delay: Duration::from_secs(5),
460 multiplier: 2.0,
461 jitter: false,
462 };
463
464 assert_eq!(policy.delay_for_attempt(0), Duration::ZERO);
465 assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(100));
466 assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(200));
467 assert_eq!(policy.delay_for_attempt(3), Duration::from_millis(400));
468 }
469
470 #[test]
471 fn test_circuit_breaker_states() {
472 let cb = CircuitBreaker::new(2, 1, Duration::from_millis(100));
473
474 assert_eq!(cb.state(), CircuitState::Closed);
476 assert!(cb.allow_request());
477
478 cb.record_failure();
480 assert_eq!(cb.state(), CircuitState::Closed);
481
482 cb.record_failure();
483 assert_eq!(cb.state(), CircuitState::Open);
484 assert!(!cb.allow_request());
485
486 std::thread::sleep(Duration::from_millis(150));
488 assert_eq!(cb.state(), CircuitState::HalfOpen);
489 assert!(cb.allow_request());
490
491 cb.record_success();
493 assert_eq!(cb.state(), CircuitState::Closed);
494 }
495
496 #[test]
497 fn test_rate_limiter() {
498 let limiter = RateLimiter::new(10.0);
499
500 for _ in 0..10 {
502 assert!(limiter.try_acquire());
503 }
504
505 assert!(!limiter.try_acquire());
507 }
508}