Skip to main content

forge_reasoning/verification/
retry.rs

1//! Retry logic with exponential backoff and jitter
2//!
3//! This module provides retry logic for verification checks that fail due to
4//! transient issues like timeouts or network errors.
5
6use std::time::Duration;
7use serde::{Deserialize, Serialize};
8
9/// Configuration for retry behavior
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct RetryConfig {
12    /// Maximum number of retry attempts
13    pub max_retries: u32,
14    /// Initial delay before first retry
15    pub initial_delay: Duration,
16    /// Maximum delay cap
17    pub max_delay: Duration,
18    /// Backoff multiplier (default: 2.0 for exponential)
19    pub backoff_factor: f64,
20    /// Whether to add jitter to delay
21    pub jitter: bool,
22}
23
24impl Default for RetryConfig {
25    fn default() -> Self {
26        Self {
27            max_retries: 3,
28            initial_delay: Duration::from_millis(100),
29            max_delay: Duration::from_secs(30),
30            backoff_factor: 2.0,
31            jitter: true,
32        }
33    }
34}
35
36/// Execute an async operation with retry logic
37///
38/// # Arguments
39/// * `operation` - Async operation that returns Result<T, E>
40/// * `config` - Retry configuration
41///
42/// # Returns
43/// * `Ok(T)` - Operation succeeded
44/// * `Err(E)` - Operation failed after all retries
45///
46/// # Behavior
47/// - On success: returns immediately
48/// - On error when attempt < max_retries:
49///   * Calculates delay: initial_delay * backoff_factor^attempt
50///   * Adds jitter if enabled: delay * (0.5 + random::<f32>())
51///   * Caps at max_delay
52///   * Sleeps and retries
53/// - On final error: returns error
54pub async fn execute_with_retry<F, Fut, T, E>(
55    mut operation: F,
56    config: RetryConfig,
57) -> Result<T, E>
58where
59    F: FnMut() -> Fut,
60    Fut: std::future::Future<Output = Result<T, E>>,
61{
62    let mut attempt = 0;
63
64    loop {
65        // Attempt the operation
66        match operation().await {
67            Ok(result) => return Ok(result),
68            Err(error) if attempt < config.max_retries && is_retryable_internal(&error) => {
69                // Calculate delay with exponential backoff
70                let delay_ms = config.initial_delay.as_millis() as f64
71                    * config.backoff_factor.powi(attempt as i32);
72
73                let mut delay = Duration::from_millis(delay_ms as u64);
74
75                // Add jitter if enabled
76                if config.jitter {
77                    let jitter_factor = 0.5 + rand::random::<f64>(); // 0.5 to 1.5
78                    delay = Duration::from_millis((delay.as_millis() as f64 * jitter_factor) as u64);
79                }
80
81                // Cap at max_delay
82                if delay > config.max_delay {
83                    delay = config.max_delay;
84                }
85
86                // Sleep before retry
87                tokio::time::sleep(delay).await;
88                attempt += 1;
89            }
90            Err(error) => return Err(error),
91        }
92    }
93}
94
95/// Check if an error should trigger a retry
96///
97/// Timeouts, panics, and IO errors are retryable.
98/// Validation errors are NOT retryable.
99pub fn is_retryable<E>(_error: &E) -> bool {
100    // In a real implementation, we'd check the error type
101    // For now, we'll rely on the internal check in execute_with_retry
102    true
103}
104
105// Internal check for retryable errors
106fn is_retryable_internal<E>(_error: &E) -> bool {
107    // For now, assume all errors are retryable
108    // In Task 3, we'll make this more specific based on actual error types
109    true
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use std::sync::atomic::{AtomicU32, Ordering};
116    use std::sync::Arc;
117
118    #[tokio::test]
119    async fn test_retry_success_on_first_attempt() {
120        let config = RetryConfig::default();
121        let attempt_count = Arc::new(AtomicU32::new(0));
122
123        let result = execute_with_retry(
124            || {
125                let attempt_count = attempt_count.clone();
126                async move {
127                    attempt_count.fetch_add(1, Ordering::SeqCst);
128                    Ok::<_, String>("success")
129                }
130            },
131            config,
132        )
133        .await;
134
135        assert!(result.is_ok());
136        assert_eq!(result.unwrap(), "success");
137        assert_eq!(attempt_count.load(Ordering::SeqCst), 1);
138    }
139
140    #[tokio::test]
141    async fn test_retry_success_after_two_attempts() {
142        let config = RetryConfig::default();
143        let attempt_count = Arc::new(AtomicU32::new(0));
144
145        let result = execute_with_retry(
146            || {
147                let attempt_count = attempt_count.clone();
148                async move {
149                    let count = attempt_count.fetch_add(1, Ordering::SeqCst);
150                    if count < 1 {
151                        Err::<(), _>("error")
152                    } else {
153                        Ok(())
154                    }
155                }
156            },
157            config,
158        )
159        .await;
160
161        assert!(result.is_ok());
162        assert_eq!(attempt_count.load(Ordering::SeqCst), 2);
163    }
164
165    #[tokio::test]
166    async fn test_retry_failure_after_max_retries() {
167        let config = RetryConfig {
168            max_retries: 2,
169            ..Default::default()
170        };
171        let attempt_count = Arc::new(AtomicU32::new(0));
172
173        let result = execute_with_retry(
174            || {
175                let attempt_count = attempt_count.clone();
176                async move {
177                    attempt_count.fetch_add(1, Ordering::SeqCst);
178                    Err::<(), _>("persistent error")
179                }
180            },
181            config,
182        )
183        .await;
184
185        assert!(result.is_err());
186        // Should have initial attempt + 2 retries = 3 total
187        assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
188    }
189
190    #[tokio::test]
191    async fn test_exponential_backoff() {
192        // Test that retries happen with expected backoff
193        let config = RetryConfig {
194            max_retries: 3,
195            initial_delay: Duration::from_millis(10), // Short delay for testing
196            max_delay: Duration::from_millis(100),
197            backoff_factor: 2.0,
198            jitter: false,
199        };
200
201        let attempt_count = Arc::new(AtomicU32::new(0));
202
203        let result = execute_with_retry(
204            || {
205                let attempt_count = attempt_count.clone();
206                async move {
207                    let count = attempt_count.fetch_add(1, Ordering::SeqCst);
208                    if count < 2 {
209                        Err::<(), _>("retry me")
210                    } else {
211                        Ok(())
212                    }
213                }
214            },
215            config,
216        )
217        .await;
218
219        assert!(result.is_ok());
220        // Should have initial attempt + 2 retries = 3 total
221        assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
222    }
223
224    #[test]
225    fn test_retry_config_default() {
226        let config = RetryConfig::default();
227        assert_eq!(config.max_retries, 3);
228        assert_eq!(config.initial_delay, Duration::from_millis(100));
229        assert_eq!(config.max_delay, Duration::from_secs(30));
230        assert_eq!(config.backoff_factor, 2.0);
231        assert!(config.jitter);
232    }
233
234    #[tokio::test]
235    async fn test_max_delay_capping() {
236        // Test that max_delay caps the exponential backoff
237        let config = RetryConfig {
238            max_retries: 5,
239            initial_delay: Duration::from_millis(10),
240            max_delay: Duration::from_millis(50), // Low cap
241            backoff_factor: 10.0, // High growth factor
242            jitter: false,
243        };
244
245        let attempt_count = Arc::new(AtomicU32::new(0));
246
247        let _ = execute_with_retry(
248            || {
249                let attempt_count = attempt_count.clone();
250                async move {
251                    attempt_count.fetch_add(1, Ordering::SeqCst);
252                    Err::<(), _>("error")
253                }
254            },
255            config,
256        )
257        .await;
258
259        // Should have attempted max_retries + 1 times
260        assert_eq!(attempt_count.load(Ordering::SeqCst), 6);
261    }
262
263    #[tokio::test]
264    async fn test_jitter_randomization() {
265        // Run multiple retries and verify delays are different
266        let config = RetryConfig {
267            max_retries: 5,
268            initial_delay: Duration::from_millis(100),
269            backoff_factor: 1.0,
270            jitter: true,
271            ..Default::default()
272        };
273
274        let attempt_count = Arc::new(AtomicU32::new(0));
275
276        let _ = execute_with_retry(
277            || {
278                let attempt_count = attempt_count.clone();
279                async move {
280                    attempt_count.fetch_add(1, Ordering::SeqCst);
281                    Err::<(), _>("error")
282                }
283            },
284            config,
285        )
286        .await;
287
288        // With jitter, we should have completed all retries
289        assert_eq!(attempt_count.load(Ordering::SeqCst), 6); // initial + 5 retries
290    }
291}