forge_reasoning/verification/
retry.rs1use std::time::Duration;
7use serde::{Deserialize, Serialize};
8
9#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct RetryConfig {
12 pub max_retries: u32,
14 pub initial_delay: Duration,
16 pub max_delay: Duration,
18 pub backoff_factor: f64,
20 pub jitter: bool,
22}
23
24impl Default for RetryConfig {
25 fn default() -> Self {
26 Self {
27 max_retries: 3,
28 initial_delay: Duration::from_millis(100),
29 max_delay: Duration::from_secs(30),
30 backoff_factor: 2.0,
31 jitter: true,
32 }
33 }
34}
35
36pub async fn execute_with_retry<F, Fut, T, E>(
55 mut operation: F,
56 config: RetryConfig,
57) -> Result<T, E>
58where
59 F: FnMut() -> Fut,
60 Fut: std::future::Future<Output = Result<T, E>>,
61{
62 let mut attempt = 0;
63
64 loop {
65 match operation().await {
67 Ok(result) => return Ok(result),
68 Err(error) if attempt < config.max_retries && is_retryable_internal(&error) => {
69 let delay_ms = config.initial_delay.as_millis() as f64
71 * config.backoff_factor.powi(attempt as i32);
72
73 let mut delay = Duration::from_millis(delay_ms as u64);
74
75 if config.jitter {
77 let jitter_factor = 0.5 + rand::random::<f64>(); delay = Duration::from_millis((delay.as_millis() as f64 * jitter_factor) as u64);
79 }
80
81 if delay > config.max_delay {
83 delay = config.max_delay;
84 }
85
86 tokio::time::sleep(delay).await;
88 attempt += 1;
89 }
90 Err(error) => return Err(error),
91 }
92 }
93}
94
95pub fn is_retryable<E>(_error: &E) -> bool {
100 true
103}
104
105fn is_retryable_internal<E>(_error: &E) -> bool {
107 true
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use std::sync::atomic::{AtomicU32, Ordering};
116 use std::sync::Arc;
117
118 #[tokio::test]
119 async fn test_retry_success_on_first_attempt() {
120 let config = RetryConfig::default();
121 let attempt_count = Arc::new(AtomicU32::new(0));
122
123 let result = execute_with_retry(
124 || {
125 let attempt_count = attempt_count.clone();
126 async move {
127 attempt_count.fetch_add(1, Ordering::SeqCst);
128 Ok::<_, String>("success")
129 }
130 },
131 config,
132 )
133 .await;
134
135 assert!(result.is_ok());
136 assert_eq!(result.unwrap(), "success");
137 assert_eq!(attempt_count.load(Ordering::SeqCst), 1);
138 }
139
140 #[tokio::test]
141 async fn test_retry_success_after_two_attempts() {
142 let config = RetryConfig::default();
143 let attempt_count = Arc::new(AtomicU32::new(0));
144
145 let result = execute_with_retry(
146 || {
147 let attempt_count = attempt_count.clone();
148 async move {
149 let count = attempt_count.fetch_add(1, Ordering::SeqCst);
150 if count < 1 {
151 Err::<(), _>("error")
152 } else {
153 Ok(())
154 }
155 }
156 },
157 config,
158 )
159 .await;
160
161 assert!(result.is_ok());
162 assert_eq!(attempt_count.load(Ordering::SeqCst), 2);
163 }
164
165 #[tokio::test]
166 async fn test_retry_failure_after_max_retries() {
167 let config = RetryConfig {
168 max_retries: 2,
169 ..Default::default()
170 };
171 let attempt_count = Arc::new(AtomicU32::new(0));
172
173 let result = execute_with_retry(
174 || {
175 let attempt_count = attempt_count.clone();
176 async move {
177 attempt_count.fetch_add(1, Ordering::SeqCst);
178 Err::<(), _>("persistent error")
179 }
180 },
181 config,
182 )
183 .await;
184
185 assert!(result.is_err());
186 assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
188 }
189
190 #[tokio::test]
191 async fn test_exponential_backoff() {
192 let config = RetryConfig {
194 max_retries: 3,
195 initial_delay: Duration::from_millis(10), max_delay: Duration::from_millis(100),
197 backoff_factor: 2.0,
198 jitter: false,
199 };
200
201 let attempt_count = Arc::new(AtomicU32::new(0));
202
203 let result = execute_with_retry(
204 || {
205 let attempt_count = attempt_count.clone();
206 async move {
207 let count = attempt_count.fetch_add(1, Ordering::SeqCst);
208 if count < 2 {
209 Err::<(), _>("retry me")
210 } else {
211 Ok(())
212 }
213 }
214 },
215 config,
216 )
217 .await;
218
219 assert!(result.is_ok());
220 assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
222 }
223
224 #[test]
225 fn test_retry_config_default() {
226 let config = RetryConfig::default();
227 assert_eq!(config.max_retries, 3);
228 assert_eq!(config.initial_delay, Duration::from_millis(100));
229 assert_eq!(config.max_delay, Duration::from_secs(30));
230 assert_eq!(config.backoff_factor, 2.0);
231 assert!(config.jitter);
232 }
233
234 #[tokio::test]
235 async fn test_max_delay_capping() {
236 let config = RetryConfig {
238 max_retries: 5,
239 initial_delay: Duration::from_millis(10),
240 max_delay: Duration::from_millis(50), backoff_factor: 10.0, jitter: false,
243 };
244
245 let attempt_count = Arc::new(AtomicU32::new(0));
246
247 let _ = execute_with_retry(
248 || {
249 let attempt_count = attempt_count.clone();
250 async move {
251 attempt_count.fetch_add(1, Ordering::SeqCst);
252 Err::<(), _>("error")
253 }
254 },
255 config,
256 )
257 .await;
258
259 assert_eq!(attempt_count.load(Ordering::SeqCst), 6);
261 }
262
263 #[tokio::test]
264 async fn test_jitter_randomization() {
265 let config = RetryConfig {
267 max_retries: 5,
268 initial_delay: Duration::from_millis(100),
269 backoff_factor: 1.0,
270 jitter: true,
271 ..Default::default()
272 };
273
274 let attempt_count = Arc::new(AtomicU32::new(0));
275
276 let _ = execute_with_retry(
277 || {
278 let attempt_count = attempt_count.clone();
279 async move {
280 attempt_count.fetch_add(1, Ordering::SeqCst);
281 Err::<(), _>("error")
282 }
283 },
284 config,
285 )
286 .await;
287
288 assert_eq!(attempt_count.load(Ordering::SeqCst), 6); }
291}