Skip to main content

hadb_io/
retry.rs

1//! Retry logic with exponential backoff for storage operations.
2//!
3//! Shared infrastructure for all hadb ecosystem crates (walrust-core, graphstream, etc.).
4//! Implements:
5//! - Exponential backoff with full jitter
6//! - Error classification (retryable vs non-retryable)
7//! - Circuit breaker pattern
8//!
9//! Based on AWS best practices:
10//! https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
11
12use anyhow::{anyhow, Result};
13use rand::Rng;
14use serde::{Deserialize, Serialize};
15use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
16use std::sync::Arc;
17use std::time::Duration;
18
19/// Configuration for retry behavior.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct RetryConfig {
22    /// Maximum number of retry attempts (default: 5).
23    #[serde(default = "default_max_retries")]
24    pub max_retries: u32,
25
26    /// Initial backoff delay in milliseconds (default: 100).
27    #[serde(default = "default_base_delay_ms")]
28    pub base_delay_ms: u64,
29
30    /// Maximum backoff delay in milliseconds (default: 30000 = 30s).
31    #[serde(default = "default_max_delay_ms")]
32    pub max_delay_ms: u64,
33
34    /// Enable circuit breaker (default: true).
35    #[serde(default = "default_circuit_breaker_enabled")]
36    pub circuit_breaker_enabled: bool,
37
38    /// Number of consecutive failures before circuit opens (default: 10).
39    #[serde(default = "default_circuit_breaker_threshold")]
40    pub circuit_breaker_threshold: u32,
41
42    /// Time to wait before attempting half-open state (milliseconds, default: 60000 = 1min).
43    #[serde(default = "default_circuit_breaker_cooldown_ms")]
44    pub circuit_breaker_cooldown_ms: u64,
45}
46
47fn default_max_retries() -> u32 {
48    5
49}
50fn default_base_delay_ms() -> u64 {
51    100
52}
53fn default_max_delay_ms() -> u64 {
54    30_000
55}
56fn default_circuit_breaker_enabled() -> bool {
57    true
58}
59fn default_circuit_breaker_threshold() -> u32 {
60    10
61}
62fn default_circuit_breaker_cooldown_ms() -> u64 {
63    60_000
64}
65
66impl Default for RetryConfig {
67    fn default() -> Self {
68        Self {
69            max_retries: default_max_retries(),
70            base_delay_ms: default_base_delay_ms(),
71            max_delay_ms: default_max_delay_ms(),
72            circuit_breaker_enabled: default_circuit_breaker_enabled(),
73            circuit_breaker_threshold: default_circuit_breaker_threshold(),
74            circuit_breaker_cooldown_ms: default_circuit_breaker_cooldown_ms(),
75        }
76    }
77}
78
79/// Error classification for retry decisions.
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum ErrorKind {
82    /// Transient error — should retry (500, 502, 503, 504, timeouts, network).
83    Transient,
84    /// Client error — don't retry, it's a bug (400).
85    ClientError,
86    /// Authentication error — don't retry without user intervention (401, 403).
87    AuthError,
88    /// Not found — context dependent (404).
89    NotFound,
90    /// Unknown error — may retry with caution.
91    Unknown,
92}
93
94/// Classify an error to determine retry behavior.
95pub fn classify_error(error: &anyhow::Error) -> ErrorKind {
96    let error_str = error.to_string().to_lowercase();
97
98    // HTTP 5xx server errors
99    if error_str.contains("500")
100        || error_str.contains("502")
101        || error_str.contains("503")
102        || error_str.contains("504")
103        || error_str.contains("internal server error")
104        || error_str.contains("bad gateway")
105        || error_str.contains("service unavailable")
106        || error_str.contains("gateway timeout")
107    {
108        return ErrorKind::Transient;
109    }
110
111    // Network and timeout errors
112    if error_str.contains("timeout")
113        || error_str.contains("timed out")
114        || error_str.contains("connection")
115        || error_str.contains("network")
116        || error_str.contains("socket")
117        || error_str.contains("reset")
118        || error_str.contains("broken pipe")
119        || error_str.contains("eof")
120        || error_str.contains("temporarily unavailable")
121    {
122        return ErrorKind::Transient;
123    }
124
125    // AWS SDK specific transient errors
126    if error_str.contains("throttl")
127        || error_str.contains("slowdown")
128        || error_str.contains("reduce your request rate")
129        || error_str.contains("request rate exceeded")
130    {
131        return ErrorKind::Transient;
132    }
133
134    // AWS SDK dispatch failures (from graphstream)
135    if error_str.contains("dispatch failure") {
136        return ErrorKind::Transient;
137    }
138
139    // Injected test errors (from MockStorage)
140    if error_str.contains("service unavailable (injected)") {
141        return ErrorKind::Transient;
142    }
143
144    // Client errors — don't retry
145    if error_str.contains("400") || error_str.contains("bad request") {
146        return ErrorKind::ClientError;
147    }
148
149    // Auth errors — don't retry
150    if error_str.contains("401")
151        || error_str.contains("403")
152        || error_str.contains("unauthorized")
153        || error_str.contains("forbidden")
154        || error_str.contains("access denied")
155        || error_str.contains("invalid credentials")
156        || error_str.contains("expired token")
157    {
158        return ErrorKind::AuthError;
159    }
160
161    // Not found
162    if error_str.contains("404")
163        || error_str.contains("not found")
164        || error_str.contains("no such key")
165    {
166        return ErrorKind::NotFound;
167    }
168
169    ErrorKind::Unknown
170}
171
172/// Check if an error is retryable.
173pub fn is_retryable(error: &anyhow::Error) -> bool {
174    matches!(
175        classify_error(error),
176        ErrorKind::Transient | ErrorKind::Unknown
177    )
178}
179
180/// Circuit breaker state.
181#[derive(Debug, Clone, Copy, PartialEq, Eq)]
182pub enum CircuitState {
183    /// Normal operation.
184    Closed,
185    /// Failing — rejecting requests.
186    Open,
187    /// Testing if service recovered.
188    HalfOpen,
189}
190
191/// Callback invoked when the circuit breaker opens.
192pub type OnCircuitOpen = Arc<dyn Fn(u32) + Send + Sync>;
193
194/// Circuit breaker for preventing cascading failures.
195pub struct CircuitBreaker {
196    consecutive_failures: AtomicU32,
197    threshold: u32,
198    /// Milliseconds since UNIX epoch when circuit opened (0 = not open).
199    opened_at_ms: AtomicU64,
200    cooldown_ms: u64,
201    /// Optional callback when circuit opens (for webhook notifications, logging, etc.).
202    on_open: Option<OnCircuitOpen>,
203}
204
205impl std::fmt::Debug for CircuitBreaker {
206    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207        f.debug_struct("CircuitBreaker")
208            .field("consecutive_failures", &self.consecutive_failures)
209            .field("threshold", &self.threshold)
210            .field("opened_at_ms", &self.opened_at_ms)
211            .field("cooldown_ms", &self.cooldown_ms)
212            .field("on_open", &self.on_open.as_ref().map(|_| "..."))
213            .finish()
214    }
215}
216
217impl CircuitBreaker {
218    /// Create a new circuit breaker.
219    pub fn new(threshold: u32, cooldown_ms: u64) -> Self {
220        Self {
221            consecutive_failures: AtomicU32::new(0),
222            threshold,
223            opened_at_ms: AtomicU64::new(0),
224            cooldown_ms,
225            on_open: None,
226        }
227    }
228
229    /// Create a circuit breaker with a callback on open.
230    pub fn with_on_open(threshold: u32, cooldown_ms: u64, on_open: OnCircuitOpen) -> Self {
231        Self {
232            consecutive_failures: AtomicU32::new(0),
233            threshold,
234            opened_at_ms: AtomicU64::new(0),
235            cooldown_ms,
236            on_open: Some(on_open),
237        }
238    }
239
240    /// Get current circuit state.
241    pub fn state(&self) -> CircuitState {
242        let failures = self.consecutive_failures.load(Ordering::Relaxed);
243        let opened_at = self.opened_at_ms.load(Ordering::Relaxed);
244
245        if failures < self.threshold {
246            return CircuitState::Closed;
247        }
248
249        if opened_at == 0 {
250            return CircuitState::Closed;
251        }
252
253        let now_ms = std::time::SystemTime::now()
254            .duration_since(std::time::UNIX_EPOCH)
255            .unwrap_or_default()
256            .as_millis() as u64;
257
258        if now_ms - opened_at >= self.cooldown_ms {
259            CircuitState::HalfOpen
260        } else {
261            CircuitState::Open
262        }
263    }
264
265    /// Record a successful operation.
266    pub fn record_success(&self) {
267        self.consecutive_failures.store(0, Ordering::Relaxed);
268        self.opened_at_ms.store(0, Ordering::Relaxed);
269    }
270
271    /// Record a failed operation.
272    pub fn record_failure(&self) {
273        let failures = self.consecutive_failures.fetch_add(1, Ordering::Relaxed) + 1;
274
275        if failures >= self.threshold && self.opened_at_ms.load(Ordering::Relaxed) == 0 {
276            let now_ms = std::time::SystemTime::now()
277                .duration_since(std::time::UNIX_EPOCH)
278                .unwrap_or_default()
279                .as_millis() as u64;
280            self.opened_at_ms.store(now_ms, Ordering::Relaxed);
281            tracing::warn!(
282                "Circuit breaker opened after {} consecutive failures",
283                failures
284            );
285            if let Some(ref callback) = self.on_open {
286                callback(failures);
287            }
288        }
289    }
290
291    /// Check if request should be allowed.
292    pub fn should_allow(&self) -> bool {
293        match self.state() {
294            CircuitState::Closed => true,
295            CircuitState::HalfOpen => true,
296            CircuitState::Open => false,
297        }
298    }
299
300    /// Get current consecutive failure count.
301    pub fn consecutive_failures(&self) -> u32 {
302        self.consecutive_failures.load(Ordering::Relaxed)
303    }
304}
305
306/// Retry policy with exponential backoff.
307#[derive(Debug, Clone)]
308pub struct RetryPolicy {
309    config: RetryConfig,
310    circuit_breaker: Option<Arc<CircuitBreaker>>,
311}
312
313impl RetryPolicy {
314    /// Create a new retry policy.
315    pub fn new(config: RetryConfig) -> Self {
316        let circuit_breaker = if config.circuit_breaker_enabled {
317            Some(Arc::new(CircuitBreaker::new(
318                config.circuit_breaker_threshold,
319                config.circuit_breaker_cooldown_ms,
320            )))
321        } else {
322            None
323        };
324
325        Self {
326            config,
327            circuit_breaker,
328        }
329    }
330
331    /// Create a retry policy with a custom circuit breaker (e.g., with on_open callback).
332    pub fn with_circuit_breaker(config: RetryConfig, cb: Arc<CircuitBreaker>) -> Self {
333        Self {
334            config,
335            circuit_breaker: Some(cb),
336        }
337    }
338
339    /// Create with default config.
340    pub fn default_policy() -> Self {
341        Self::new(RetryConfig::default())
342    }
343
344    /// Get the config.
345    pub fn config(&self) -> &RetryConfig {
346        &self.config
347    }
348
349    /// Get the circuit breaker.
350    pub fn circuit_breaker(&self) -> Option<&Arc<CircuitBreaker>> {
351        self.circuit_breaker.as_ref()
352    }
353
354    /// Calculate backoff delay for a given attempt.
355    ///
356    /// Uses full jitter: sleep = random(0, min(cap, base * 2^attempt))
357    pub fn calculate_delay(&self, attempt: u32) -> Duration {
358        let base = self.config.base_delay_ms;
359        let cap = self.config.max_delay_ms;
360
361        let exp_delay = base.saturating_mul(1u64 << attempt.min(20));
362        let capped_delay = exp_delay.min(cap);
363
364        let jittered = if capped_delay > 0 {
365            rand::thread_rng().gen_range(0..=capped_delay)
366        } else {
367            0
368        };
369
370        Duration::from_millis(jittered)
371    }
372
373    /// Execute an async operation with retry.
374    pub async fn execute<F, Fut, T>(&self, operation: F) -> Result<T>
375    where
376        F: Fn() -> Fut,
377        Fut: std::future::Future<Output = Result<T>>,
378    {
379        if let Some(cb) = &self.circuit_breaker {
380            if !cb.should_allow() {
381                return Err(anyhow!(
382                    "Circuit breaker open — refusing request after {} consecutive failures",
383                    self.config.circuit_breaker_threshold
384                ));
385            }
386        }
387
388        let mut last_error: Option<anyhow::Error> = None;
389
390        for attempt in 0..=self.config.max_retries {
391            match operation().await {
392                Ok(result) => {
393                    if let Some(cb) = &self.circuit_breaker {
394                        cb.record_success();
395                    }
396                    return Ok(result);
397                }
398                Err(e) => {
399                    let error_kind = classify_error(&e);
400                    let retryable =
401                        matches!(error_kind, ErrorKind::Transient | ErrorKind::Unknown);
402
403                    if let Some(cb) = &self.circuit_breaker {
404                        cb.record_failure();
405                    }
406
407                    if !retryable {
408                        tracing::warn!(
409                            "Non-retryable error (kind={:?}): {}",
410                            error_kind,
411                            e
412                        );
413                        return Err(e);
414                    }
415
416                    if attempt < self.config.max_retries {
417                        let delay = self.calculate_delay(attempt);
418                        tracing::debug!(
419                            "Attempt {}/{} failed (kind={:?}), retrying in {:?}: {}",
420                            attempt + 1,
421                            self.config.max_retries + 1,
422                            error_kind,
423                            delay,
424                            e
425                        );
426                        tokio::time::sleep(delay).await;
427                    }
428
429                    last_error = Some(e);
430                }
431            }
432        }
433
434        Err(last_error.unwrap_or_else(|| anyhow!("Retry failed with no error recorded")))
435    }
436
437    /// Execute with context for better error messages.
438    pub async fn execute_with_context<F, Fut, T>(&self, context: &str, operation: F) -> Result<T>
439    where
440        F: Fn() -> Fut,
441        Fut: std::future::Future<Output = Result<T>>,
442    {
443        self.execute(operation)
444            .await
445            .map_err(|e| anyhow!("{}: {}", context, e))
446    }
447}
448
449/// Outcome of a retry attempt (for webhooks/logging).
450#[derive(Debug, Clone)]
451pub struct RetryOutcome {
452    pub operation: String,
453    pub success: bool,
454    pub attempts: u32,
455    pub total_duration: Duration,
456    pub error: Option<String>,
457    pub error_kind: Option<ErrorKind>,
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463
464    #[test]
465    fn test_error_classification() {
466        assert_eq!(
467            classify_error(&anyhow!("500 Internal Server Error")),
468            ErrorKind::Transient
469        );
470        assert_eq!(
471            classify_error(&anyhow!("503 Service Unavailable")),
472            ErrorKind::Transient
473        );
474        assert_eq!(
475            classify_error(&anyhow!("Connection timeout")),
476            ErrorKind::Transient
477        );
478        assert_eq!(
479            classify_error(&anyhow!("dispatch failure")),
480            ErrorKind::Transient
481        );
482        assert_eq!(
483            classify_error(&anyhow!("Storage error: Service unavailable (injected)")),
484            ErrorKind::Transient
485        );
486        assert_eq!(
487            classify_error(&anyhow!("401 Unauthorized")),
488            ErrorKind::AuthError
489        );
490        assert_eq!(
491            classify_error(&anyhow!("403 Forbidden")),
492            ErrorKind::AuthError
493        );
494        assert_eq!(
495            classify_error(&anyhow!("Access Denied")),
496            ErrorKind::AuthError
497        );
498        assert_eq!(
499            classify_error(&anyhow!("400 Bad Request")),
500            ErrorKind::ClientError
501        );
502        assert_eq!(
503            classify_error(&anyhow!("404 Not Found")),
504            ErrorKind::NotFound
505        );
506        assert_eq!(
507            classify_error(&anyhow!("No such key")),
508            ErrorKind::NotFound
509        );
510    }
511
512    #[test]
513    fn test_error_classification_throttling() {
514        assert_eq!(
515            classify_error(&anyhow!("Request rate exceeded")),
516            ErrorKind::Transient
517        );
518        assert_eq!(
519            classify_error(&anyhow!("SlowDown: reduce your request rate")),
520            ErrorKind::Transient
521        );
522        assert_eq!(
523            classify_error(&anyhow!("throttling exception")),
524            ErrorKind::Transient
525        );
526    }
527
528    #[test]
529    fn test_error_classification_network() {
530        assert_eq!(
531            classify_error(&anyhow!("connection reset by peer")),
532            ErrorKind::Transient
533        );
534        assert_eq!(
535            classify_error(&anyhow!("broken pipe")),
536            ErrorKind::Transient
537        );
538        assert_eq!(
539            classify_error(&anyhow!("unexpected eof")),
540            ErrorKind::Transient
541        );
542        assert_eq!(
543            classify_error(&anyhow!("network unreachable")),
544            ErrorKind::Transient
545        );
546    }
547
548    #[test]
549    fn test_error_classification_unknown() {
550        assert_eq!(
551            classify_error(&anyhow!("some random error")),
552            ErrorKind::Unknown
553        );
554    }
555
556    #[test]
557    fn test_is_retryable() {
558        assert!(is_retryable(&anyhow!("500 Internal Server Error")));
559        assert!(is_retryable(&anyhow!("Connection timeout")));
560        assert!(is_retryable(&anyhow!("dispatch failure")));
561        assert!(is_retryable(&anyhow!("some unknown error")));
562        assert!(!is_retryable(&anyhow!("401 Unauthorized")));
563        assert!(!is_retryable(&anyhow!("403 Forbidden")));
564        assert!(!is_retryable(&anyhow!("400 Bad Request")));
565        assert!(!is_retryable(&anyhow!("404 Not Found")));
566    }
567
568    #[test]
569    fn test_backoff_calculation() {
570        let policy = RetryPolicy::new(RetryConfig {
571            base_delay_ms: 100,
572            max_delay_ms: 30_000,
573            ..Default::default()
574        });
575
576        for _ in 0..10 {
577            let delay = policy.calculate_delay(0);
578            assert!(delay <= Duration::from_millis(100));
579        }
580
581        for _ in 0..10 {
582            let delay = policy.calculate_delay(1);
583            assert!(delay <= Duration::from_millis(200));
584        }
585
586        for _ in 0..10 {
587            let delay = policy.calculate_delay(20);
588            assert!(delay <= Duration::from_millis(30_000));
589        }
590    }
591
592    #[test]
593    fn test_circuit_breaker_states() {
594        let cb = CircuitBreaker::new(3, 100);
595
596        assert_eq!(cb.state(), CircuitState::Closed);
597        assert!(cb.should_allow());
598        assert_eq!(cb.consecutive_failures(), 0);
599
600        cb.record_failure();
601        cb.record_failure();
602        assert_eq!(cb.state(), CircuitState::Closed);
603        assert!(cb.should_allow());
604        assert_eq!(cb.consecutive_failures(), 2);
605
606        cb.record_failure();
607        assert_eq!(cb.state(), CircuitState::Open);
608        assert!(!cb.should_allow());
609        assert_eq!(cb.consecutive_failures(), 3);
610
611        std::thread::sleep(Duration::from_millis(150));
612        assert_eq!(cb.state(), CircuitState::HalfOpen);
613        assert!(cb.should_allow());
614
615        cb.record_success();
616        assert_eq!(cb.state(), CircuitState::Closed);
617        assert!(cb.should_allow());
618        assert_eq!(cb.consecutive_failures(), 0);
619    }
620
621    #[test]
622    fn test_circuit_breaker_on_open_callback() {
623        let called = Arc::new(AtomicU32::new(0));
624        let called_clone = called.clone();
625        let on_open: OnCircuitOpen = Arc::new(move |failures| {
626            called_clone.store(failures, Ordering::Relaxed);
627        });
628
629        let cb = CircuitBreaker::with_on_open(2, 60_000, on_open);
630
631        cb.record_failure();
632        assert_eq!(called.load(Ordering::Relaxed), 0);
633
634        cb.record_failure();
635        assert_eq!(called.load(Ordering::Relaxed), 2);
636    }
637
638    #[test]
639    fn test_retry_config_defaults() {
640        let config = RetryConfig::default();
641        assert_eq!(config.max_retries, 5);
642        assert_eq!(config.base_delay_ms, 100);
643        assert_eq!(config.max_delay_ms, 30_000);
644        assert!(config.circuit_breaker_enabled);
645        assert_eq!(config.circuit_breaker_threshold, 10);
646        assert_eq!(config.circuit_breaker_cooldown_ms, 60_000);
647    }
648
649    #[test]
650    fn test_retry_config_serde() {
651        let json = r#"{"max_retries": 3, "base_delay_ms": 50}"#;
652        let config: RetryConfig = serde_json::from_str(json).unwrap();
653        assert_eq!(config.max_retries, 3);
654        assert_eq!(config.base_delay_ms, 50);
655        // Defaults for missing fields
656        assert_eq!(config.max_delay_ms, 30_000);
657        assert!(config.circuit_breaker_enabled);
658    }
659
660    #[test]
661    fn test_retry_policy_no_circuit_breaker() {
662        let policy = RetryPolicy::new(RetryConfig {
663            circuit_breaker_enabled: false,
664            ..Default::default()
665        });
666        assert!(policy.circuit_breaker().is_none());
667    }
668
669    #[test]
670    fn test_retry_policy_with_custom_circuit_breaker() {
671        let cb = Arc::new(CircuitBreaker::new(5, 30_000));
672        let policy = RetryPolicy::with_circuit_breaker(RetryConfig::default(), cb.clone());
673        assert!(policy.circuit_breaker().is_some());
674        assert_eq!(policy.circuit_breaker().unwrap().consecutive_failures(), 0);
675    }
676
677    #[tokio::test]
678    async fn test_retry_success() {
679        let policy = RetryPolicy::default_policy();
680        let result: Result<i32> = policy.execute(|| async { Ok(42) }).await;
681        assert_eq!(result.unwrap(), 42);
682    }
683
684    #[tokio::test]
685    async fn test_retry_transient_then_success() {
686        let policy = RetryPolicy::new(RetryConfig {
687            max_retries: 3,
688            base_delay_ms: 10,
689            ..Default::default()
690        });
691
692        let attempts = std::sync::atomic::AtomicU32::new(0);
693
694        let result: Result<i32> = policy
695            .execute(|| {
696                let attempt = attempts.fetch_add(1, Ordering::Relaxed);
697                async move {
698                    if attempt < 2 {
699                        Err(anyhow!("Service unavailable (injected)"))
700                    } else {
701                        Ok(42)
702                    }
703                }
704            })
705            .await;
706
707        assert_eq!(result.unwrap(), 42);
708        assert_eq!(attempts.load(Ordering::Relaxed), 3);
709    }
710
711    #[tokio::test]
712    async fn test_retry_auth_error_no_retry() {
713        let policy = RetryPolicy::new(RetryConfig {
714            max_retries: 5,
715            base_delay_ms: 10,
716            ..Default::default()
717        });
718
719        let attempts = std::sync::atomic::AtomicU32::new(0);
720
721        let result: Result<i32> = policy
722            .execute(|| {
723                attempts.fetch_add(1, Ordering::Relaxed);
724                async { Err(anyhow!("401 Unauthorized")) }
725            })
726            .await;
727
728        assert!(result.is_err());
729        assert_eq!(attempts.load(Ordering::Relaxed), 1);
730    }
731
732    #[tokio::test]
733    async fn test_retry_not_found_no_retry() {
734        let policy = RetryPolicy::new(RetryConfig {
735            max_retries: 5,
736            base_delay_ms: 10,
737            ..Default::default()
738        });
739
740        let attempts = std::sync::atomic::AtomicU32::new(0);
741
742        let result: Result<i32> = policy
743            .execute(|| {
744                attempts.fetch_add(1, Ordering::Relaxed);
745                async { Err(anyhow!("404 Not Found")) }
746            })
747            .await;
748
749        assert!(result.is_err());
750        assert_eq!(attempts.load(Ordering::Relaxed), 1);
751    }
752
753    #[tokio::test]
754    async fn test_retry_exhausted() {
755        let policy = RetryPolicy::new(RetryConfig {
756            max_retries: 2,
757            base_delay_ms: 10,
758            circuit_breaker_enabled: false,
759            ..Default::default()
760        });
761
762        let attempts = std::sync::atomic::AtomicU32::new(0);
763
764        let result: Result<i32> = policy
765            .execute(|| {
766                attempts.fetch_add(1, Ordering::Relaxed);
767                async { Err(anyhow!("Service unavailable (injected)")) }
768            })
769            .await;
770
771        assert!(result.is_err());
772        // Initial attempt + 2 retries = 3 attempts
773        assert_eq!(attempts.load(Ordering::Relaxed), 3);
774    }
775
776    #[tokio::test]
777    async fn test_retry_circuit_breaker_blocks() {
778        let policy = RetryPolicy::new(RetryConfig {
779            max_retries: 1,
780            base_delay_ms: 10,
781            circuit_breaker_enabled: true,
782            circuit_breaker_threshold: 2,
783            circuit_breaker_cooldown_ms: 60_000,
784            ..Default::default()
785        });
786
787        // Trip the circuit breaker
788        let cb = policy.circuit_breaker().unwrap();
789        cb.record_failure();
790        cb.record_failure();
791
792        let result: Result<i32> = policy.execute(|| async { Ok(42) }).await;
793        assert!(result.is_err());
794        assert!(result.unwrap_err().to_string().contains("Circuit breaker open"));
795    }
796
797    #[tokio::test]
798    async fn test_execute_with_context() {
799        let policy = RetryPolicy::new(RetryConfig {
800            max_retries: 0,
801            circuit_breaker_enabled: false,
802            ..Default::default()
803        });
804
805        let result: Result<i32> = policy
806            .execute_with_context("upload segment", || async {
807                Err(anyhow!("Service unavailable (injected)"))
808            })
809            .await;
810
811        assert!(result.is_err());
812        let err = result.unwrap_err().to_string();
813        assert!(err.contains("upload segment"));
814        assert!(err.contains("Service unavailable"));
815    }
816}