Skip to main content

agent_runtime/
retry.rs

1use crate::error::RuntimeError;
2use rand::Rng;
3use std::time::Duration;
4
5/// Policy for retrying failed operations with exponential backoff
6#[derive(Debug, Clone)]
7pub struct RetryPolicy {
8    /// Maximum number of retry attempts (0 = no retries)
9    pub max_attempts: u32,
10
11    /// Initial delay before first retry
12    pub initial_delay: Duration,
13
14    /// Maximum delay between retries
15    pub max_delay: Duration,
16
17    /// Multiplier for exponential backoff (typically 2.0)
18    pub backoff_multiplier: f64,
19
20    /// Add random jitter to prevent thundering herd (0.0 - 1.0)
21    /// 0.0 = no jitter, 1.0 = full jitter (delay * random(0-1))
22    pub jitter_factor: f64,
23
24    /// Maximum total duration for all retries
25    pub max_total_duration: Option<Duration>,
26}
27
28impl Default for RetryPolicy {
29    fn default() -> Self {
30        Self {
31            max_attempts: 3,
32            initial_delay: Duration::from_millis(100),
33            max_delay: Duration::from_secs(30),
34            backoff_multiplier: 2.0,
35            jitter_factor: 0.1,
36            max_total_duration: Some(Duration::from_secs(60)),
37        }
38    }
39}
40
41impl RetryPolicy {
42    /// Create a new retry policy with custom settings
43    pub fn new(max_attempts: u32, initial_delay: Duration) -> Self {
44        Self {
45            max_attempts,
46            initial_delay,
47            ..Default::default()
48        }
49    }
50
51    /// Disable retries
52    pub fn no_retry() -> Self {
53        Self {
54            max_attempts: 0,
55            ..Default::default()
56        }
57    }
58
59    /// Aggressive retry for critical operations
60    pub fn aggressive() -> Self {
61        Self {
62            max_attempts: 5,
63            initial_delay: Duration::from_millis(50),
64            max_delay: Duration::from_secs(10),
65            backoff_multiplier: 1.5,
66            jitter_factor: 0.2,
67            max_total_duration: Some(Duration::from_secs(30)),
68        }
69    }
70
71    /// Conservative retry for expensive operations
72    pub fn conservative() -> Self {
73        Self {
74            max_attempts: 2,
75            initial_delay: Duration::from_secs(1),
76            max_delay: Duration::from_secs(60),
77            backoff_multiplier: 3.0,
78            jitter_factor: 0.1,
79            max_total_duration: Some(Duration::from_secs(120)),
80        }
81    }
82
83    /// Calculate delay for a given attempt number (0-indexed)
84    pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
85        let base_delay =
86            self.initial_delay.as_millis() as f64 * self.backoff_multiplier.powi(attempt as i32);
87
88        let clamped = base_delay.min(self.max_delay.as_millis() as f64);
89
90        // Add jitter
91        let jittered = if self.jitter_factor > 0.0 {
92            let mut rng = rand::thread_rng();
93            let jitter = rng.gen::<f64>() * self.jitter_factor * clamped;
94            clamped + jitter
95        } else {
96            clamped
97        };
98
99        Duration::from_millis(jittered as u64)
100    }
101
102    /// Execute an async operation with retry logic
103    ///
104    /// # Example
105    /// ```no_run
106    /// use agent_runtime::retry::RetryPolicy;
107    /// use agent_runtime::{RuntimeError, LlmError};
108    ///
109    /// # async fn example() -> Result<String, RuntimeError> {
110    /// let policy = RetryPolicy::default();
111    /// let result = policy.execute(
112    ///     "fetch_data",
113    ///     || async {
114    ///         // Your operation here - returns Result<T, impl Into<RuntimeError>>
115    ///         Ok::<String, LlmError>("success".to_string())
116    ///     }
117    /// ).await?;
118    /// # Ok(result)
119    /// # }
120    /// ```
121    pub async fn execute<F, Fut, T, E>(
122        &self,
123        operation_name: &str,
124        mut operation: F,
125    ) -> Result<T, RuntimeError>
126    where
127        F: FnMut() -> Fut,
128        Fut: std::future::Future<Output = Result<T, E>>,
129        E: Into<RuntimeError> + Clone,
130    {
131        let start = std::time::Instant::now();
132        let mut last_error = None;
133
134        for attempt in 0..=self.max_attempts {
135            // Check if we've exceeded max total duration
136            if let Some(max_duration) = self.max_total_duration {
137                if start.elapsed() > max_duration {
138                    break;
139                }
140            }
141
142            // Execute the operation
143            match operation().await {
144                Ok(result) => return Ok(result),
145                Err(e) => {
146                    let runtime_error: RuntimeError = e.into();
147
148                    // Check if error is retryable
149                    let should_retry = match &runtime_error {
150                        RuntimeError::Llm(llm_err) => llm_err.is_retryable(),
151                        _ => false, // Only retry LLM errors for now
152                    };
153
154                    last_error = Some(runtime_error.clone());
155
156                    // Don't retry if:
157                    // - This was the last attempt
158                    // - Error is not retryable
159                    if attempt >= self.max_attempts || !should_retry {
160                        break;
161                    }
162
163                    // Calculate delay and sleep
164                    let delay = self.delay_for_attempt(attempt);
165                    tokio::time::sleep(delay).await;
166                }
167            }
168        }
169
170        // All attempts exhausted
171        Err(RuntimeError::RetryExhausted {
172            operation: operation_name.to_string(),
173            attempts: self.max_attempts + 1,
174            last_error: Box::new(last_error.unwrap()),
175        })
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::LlmError;
183
184    #[test]
185    fn test_delay_calculation() {
186        let policy = RetryPolicy {
187            max_attempts: 3,
188            initial_delay: Duration::from_millis(100),
189            max_delay: Duration::from_secs(10),
190            backoff_multiplier: 2.0,
191            jitter_factor: 0.0, // No jitter for predictable tests
192            max_total_duration: None,
193        };
194
195        assert_eq!(policy.delay_for_attempt(0).as_millis(), 100);
196        assert_eq!(policy.delay_for_attempt(1).as_millis(), 200);
197        assert_eq!(policy.delay_for_attempt(2).as_millis(), 400);
198    }
199
200    #[test]
201    fn test_max_delay_clamp() {
202        let policy = RetryPolicy {
203            max_attempts: 10,
204            initial_delay: Duration::from_secs(1),
205            max_delay: Duration::from_secs(5),
206            backoff_multiplier: 2.0,
207            jitter_factor: 0.0,
208            max_total_duration: None,
209        };
210
211        // After enough attempts, should clamp to max_delay
212        let delay = policy.delay_for_attempt(10);
213        assert_eq!(delay, Duration::from_secs(5));
214    }
215
216    #[tokio::test]
217    async fn test_retry_success_on_second_attempt() {
218        let policy = RetryPolicy::default();
219        let attempts = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
220        let attempts_clone = attempts.clone();
221
222        let result: Result<&str, RuntimeError> = policy
223            .execute("test_op", move || {
224                let attempts = attempts_clone.clone();
225                async move {
226                    let count = attempts.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + 1;
227                    if count == 1 {
228                        Err(LlmError::network("Network error"))
229                    } else {
230                        Ok("success")
231                    }
232                }
233            })
234            .await;
235
236        assert!(result.is_ok());
237        assert_eq!(attempts.load(std::sync::atomic::Ordering::SeqCst), 2);
238    }
239
240    #[tokio::test]
241    async fn test_retry_exhausted() {
242        let policy = RetryPolicy::new(2, Duration::from_millis(10));
243        let attempts = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
244        let attempts_clone = attempts.clone();
245
246        let result: Result<&str, RuntimeError> = policy
247            .execute("test_op", move || {
248                let attempts = attempts_clone.clone();
249                async move {
250                    attempts.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
251                    Err(LlmError::network("Network error"))
252                }
253            })
254            .await;
255
256        assert!(result.is_err());
257        assert_eq!(attempts.load(std::sync::atomic::Ordering::SeqCst), 3); // Initial + 2 retries
258
259        match result.unwrap_err() {
260            RuntimeError::RetryExhausted {
261                attempts: retry_attempts,
262                ..
263            } => {
264                assert_eq!(retry_attempts, 3);
265            }
266            _ => panic!("Expected RetryExhausted error"),
267        }
268    }
269
270    #[tokio::test]
271    async fn test_no_retry_on_non_retryable_error() {
272        let policy = RetryPolicy::default();
273        let attempts = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
274        let attempts_clone = attempts.clone();
275
276        let result: Result<&str, RuntimeError> = policy
277            .execute("test_op", move || {
278                let attempts = attempts_clone.clone();
279                async move {
280                    attempts.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
281                    Err(LlmError {
282                        code: crate::error::LlmErrorCode::InvalidRequest,
283                        message: "Bad request".to_string(),
284                        provider: None,
285                        model: None,
286                        retryable: false,
287                    })
288                }
289            })
290            .await;
291
292        assert!(result.is_err());
293        assert_eq!(attempts.load(std::sync::atomic::Ordering::SeqCst), 1); // Should not retry non-retryable errors
294    }
295}