atomcode_core/provider/
retry.rs1use std::time::Duration;
9
10#[derive(Debug, Clone)]
12pub struct RetryPolicy {
13 pub max_attempts: u32,
14 pub base_delay: Duration,
15 pub max_delay: Duration,
16}
17
18impl RetryPolicy {
19 pub fn default_policy() -> Self {
21 Self {
22 max_attempts: 3,
23 base_delay: Duration::from_millis(500),
24 max_delay: Duration::from_secs(8),
25 }
26 }
27
28 #[cfg(test)]
30 pub fn testing() -> Self {
31 Self {
32 max_attempts: 3,
33 base_delay: Duration::from_millis(1),
34 max_delay: Duration::from_millis(10),
35 }
36 }
37}
38
39impl Default for RetryPolicy {
40 fn default() -> Self {
41 Self::default_policy()
42 }
43}
44
45fn is_retryable_status(status: reqwest::StatusCode) -> bool {
47 matches!(status.as_u16(), 408 | 425 | 429 | 500 | 502 | 503 | 504)
48}
49
50fn is_retryable_error(err: &reqwest::Error) -> bool {
52 err.is_timeout() || err.is_connect()
53}
54
55fn parse_retry_after(headers: &reqwest::header::HeaderMap) -> Option<Duration> {
58 let value = headers.get(reqwest::header::RETRY_AFTER)?.to_str().ok()?;
59 let secs: u64 = value.trim().parse().ok()?;
60 Some(Duration::from_secs(secs))
61}
62
63fn compute_backoff(attempt: u32, policy: &RetryPolicy) -> Duration {
65 let exp = policy
66 .base_delay
67 .saturating_mul(1u32 << attempt.saturating_sub(1).min(16));
68 let capped = exp.min(policy.max_delay);
69
70 let nanos = std::time::SystemTime::now()
72 .duration_since(std::time::UNIX_EPOCH)
73 .map(|d| d.subsec_nanos())
74 .unwrap_or(0);
75 let range = (capped.as_millis() / 2) as u64; let jitter_ms = if range > 0 { (nanos as u64) % range } else { 0 };
77 let jitter = Duration::from_millis(jitter_ms);
78 let floor = capped.saturating_sub(Duration::from_millis(range / 2));
80 floor + jitter
81}
82
83pub async fn send_with_retry(
93 builder: reqwest::RequestBuilder,
94 policy: &RetryPolicy,
95) -> Result<reqwest::Response, reqwest::Error> {
96 let (client, built) = builder.build_split();
101 let req = built?;
102 let mut last_err: Option<reqwest::Error> = None;
103 for attempt in 1..=policy.max_attempts {
104 let this_req = match req.try_clone() {
111 Some(c) => c,
112 None => {
113 return match last_err {
114 Some(e) => Err(e),
115 None => client.execute(req).await,
116 };
117 }
118 };
119 match client.execute(this_req).await {
120 Ok(resp) => {
121 if is_retryable_status(resp.status()) && attempt < policy.max_attempts {
122 let wait = parse_retry_after(resp.headers())
123 .unwrap_or_else(|| compute_backoff(attempt, policy));
124 tokio::time::sleep(wait).await;
125 continue;
126 }
127 return Ok(resp);
128 }
129 Err(e) => {
130 if is_retryable_error(&e) && attempt < policy.max_attempts {
131 let wait = compute_backoff(attempt, policy);
132 last_err = Some(e);
133 tokio::time::sleep(wait).await;
134 continue;
135 }
136 return Err(e);
137 }
138 }
139 }
140 Err(last_err.expect("send_with_retry: loop terminated without error or response"))
143}
144
145pub fn send_with_retry_blocking(
149 builder: reqwest::blocking::RequestBuilder,
150 policy: &RetryPolicy,
151) -> Result<reqwest::blocking::Response, reqwest::Error> {
152 let (client, built) = builder.build_split();
153 let req = built?;
154 let mut last_err: Option<reqwest::Error> = None;
155 for attempt in 1..=policy.max_attempts {
156 let this_req = match req.try_clone() {
157 Some(c) => c,
158 None => {
159 return match last_err {
160 Some(e) => Err(e),
161 None => client.execute(req),
162 };
163 }
164 };
165 match client.execute(this_req) {
166 Ok(resp) => {
167 if is_retryable_status(resp.status()) && attempt < policy.max_attempts {
168 let wait = parse_retry_after(resp.headers())
169 .unwrap_or_else(|| compute_backoff(attempt, policy));
170 std::thread::sleep(wait);
171 continue;
172 }
173 return Ok(resp);
174 }
175 Err(e) => {
176 if is_retryable_error(&e) && attempt < policy.max_attempts {
177 let wait = compute_backoff(attempt, policy);
178 last_err = Some(e);
179 std::thread::sleep(wait);
180 continue;
181 }
182 return Err(e);
183 }
184 }
185 }
186 Err(last_err.expect("send_with_retry_blocking: loop terminated without error or response"))
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192 use reqwest::header::{HeaderMap, HeaderValue, RETRY_AFTER};
193
194 #[test]
195 fn parse_retry_after_seconds() {
196 let mut h = HeaderMap::new();
197 h.insert(RETRY_AFTER, HeaderValue::from_static("3"));
198 assert_eq!(parse_retry_after(&h), Some(Duration::from_secs(3)));
199 }
200
201 #[test]
202 fn parse_retry_after_missing_returns_none() {
203 let h = HeaderMap::new();
204 assert_eq!(parse_retry_after(&h), None);
205 }
206
207 #[test]
208 fn parse_retry_after_http_date_returns_none() {
209 let mut h = HeaderMap::new();
210 h.insert(
211 RETRY_AFTER,
212 HeaderValue::from_static("Wed, 21 Oct 2015 07:28:00 GMT"),
213 );
214 assert_eq!(parse_retry_after(&h), None);
215 }
216
217 #[test]
218 fn retryable_status_includes_429_and_5xx() {
219 assert!(is_retryable_status(reqwest::StatusCode::TOO_MANY_REQUESTS));
220 assert!(is_retryable_status(
221 reqwest::StatusCode::INTERNAL_SERVER_ERROR
222 ));
223 assert!(is_retryable_status(reqwest::StatusCode::BAD_GATEWAY));
224 assert!(is_retryable_status(
225 reqwest::StatusCode::SERVICE_UNAVAILABLE
226 ));
227 assert!(is_retryable_status(reqwest::StatusCode::GATEWAY_TIMEOUT));
228 assert!(is_retryable_status(reqwest::StatusCode::REQUEST_TIMEOUT));
229 }
230
231 #[test]
232 fn retryable_status_excludes_auth_and_validation() {
233 assert!(!is_retryable_status(reqwest::StatusCode::UNAUTHORIZED));
234 assert!(!is_retryable_status(reqwest::StatusCode::FORBIDDEN));
235 assert!(!is_retryable_status(reqwest::StatusCode::BAD_REQUEST));
236 assert!(!is_retryable_status(reqwest::StatusCode::NOT_FOUND));
237 }
238
239 #[test]
240 fn backoff_respects_max_delay() {
241 let policy = RetryPolicy {
242 max_attempts: 10,
243 base_delay: Duration::from_millis(500),
244 max_delay: Duration::from_secs(1),
245 };
246 let d = compute_backoff(10, &policy);
248 assert!(d <= Duration::from_millis(1500), "got {:?}", d);
249 }
250
251 use wiremock::matchers::{method, path};
252 use wiremock::{Mock, MockServer, ResponseTemplate};
253
254 fn client() -> reqwest::Client {
255 reqwest::Client::builder()
256 .connect_timeout(Duration::from_secs(2))
257 .timeout(Duration::from_secs(2))
258 .build()
259 .unwrap()
260 }
261
262 #[tokio::test]
263 async fn succeeds_on_first_try() {
264 let server = MockServer::start().await;
265 Mock::given(method("POST"))
266 .and(path("/chat"))
267 .respond_with(ResponseTemplate::new(200).set_body_string("ok"))
268 .expect(1)
269 .mount(&server)
270 .await;
271
272 let builder = client().post(format!("{}/chat", server.uri())).body("req");
273 let resp = send_with_retry(builder, &RetryPolicy::testing())
274 .await
275 .unwrap();
276 assert_eq!(resp.status(), 200);
277 }
278
279 #[tokio::test]
280 async fn retries_on_500_then_succeeds() {
281 let server = MockServer::start().await;
282 Mock::given(method("POST"))
284 .and(path("/chat"))
285 .respond_with(ResponseTemplate::new(500))
286 .up_to_n_times(1)
287 .mount(&server)
288 .await;
289 Mock::given(method("POST"))
290 .and(path("/chat"))
291 .respond_with(ResponseTemplate::new(200).set_body_string("ok"))
292 .mount(&server)
293 .await;
294
295 let builder = client().post(format!("{}/chat", server.uri())).body("req");
296 let resp = send_with_retry(builder, &RetryPolicy::testing())
297 .await
298 .unwrap();
299 assert_eq!(resp.status(), 200);
300 }
301
302 #[tokio::test]
303 async fn exhausts_on_persistent_500() {
304 let server = MockServer::start().await;
305 Mock::given(method("POST"))
306 .and(path("/chat"))
307 .respond_with(ResponseTemplate::new(500))
308 .expect(3) .mount(&server)
310 .await;
311
312 let builder = client().post(format!("{}/chat", server.uri())).body("req");
313 let resp = send_with_retry(builder, &RetryPolicy::testing())
314 .await
315 .unwrap();
316 assert_eq!(resp.status(), 500);
317 }
318
319 #[tokio::test]
320 async fn does_not_retry_on_401() {
321 let server = MockServer::start().await;
322 Mock::given(method("POST"))
323 .and(path("/chat"))
324 .respond_with(ResponseTemplate::new(401))
325 .expect(1) .mount(&server)
327 .await;
328
329 let builder = client().post(format!("{}/chat", server.uri())).body("req");
330 let resp = send_with_retry(builder, &RetryPolicy::testing())
331 .await
332 .unwrap();
333 assert_eq!(resp.status(), 401);
334 }
335
336 #[tokio::test]
337 async fn honors_retry_after_on_429() {
338 let server = MockServer::start().await;
339 Mock::given(method("POST"))
340 .and(path("/chat"))
341 .respond_with(ResponseTemplate::new(429).insert_header("Retry-After", "1"))
342 .up_to_n_times(1)
343 .mount(&server)
344 .await;
345 Mock::given(method("POST"))
346 .and(path("/chat"))
347 .respond_with(ResponseTemplate::new(200).set_body_string("ok"))
348 .mount(&server)
349 .await;
350
351 let start = std::time::Instant::now();
352 let builder = client().post(format!("{}/chat", server.uri())).body("req");
353 let resp = send_with_retry(builder, &RetryPolicy::testing())
354 .await
355 .unwrap();
356 let elapsed = start.elapsed();
357 assert_eq!(resp.status(), 200);
358 assert!(
359 elapsed >= Duration::from_millis(900),
360 "expected ~1s wait from Retry-After, got {:?}",
361 elapsed
362 );
363 }
364
365 #[tokio::test]
366 async fn retries_on_connect_error() {
367 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
369 let addr = listener.local_addr().unwrap();
370 drop(listener);
371
372 let builder = client().post(format!("http://{}/chat", addr)).body("req");
373 let err = send_with_retry(builder, &RetryPolicy::testing())
374 .await
375 .unwrap_err();
376 assert!(err.is_connect() || err.is_request(), "got {:?}", err);
377 }
378
379 #[tokio::test]
395 async fn send_with_retry_returns_builder_error_instead_of_panicking() {
396 let result = std::panic::AssertUnwindSafe(async {
397 let builder = client()
398 .post("http://127.0.0.1:1/")
399 .header("Authorization", "Bearer token-with\n-newline");
405 send_with_retry(builder, &RetryPolicy::testing()).await
406 });
407 let outcome = futures::FutureExt::catch_unwind(result).await;
410 let inner = match outcome {
411 Ok(r) => r,
412 Err(_) => panic!(
413 "send_with_retry panicked on builder-error input \
414 (regression of the user's reported crash)"
415 ),
416 };
417 let err = inner.expect_err(
418 "builder with illegal header value must produce Err, \
419 not Ok",
420 );
421 assert!(
428 err.is_builder(),
429 "expected is_builder() error, got {:?}",
430 err
431 );
432 }
433}