1use std::sync::atomic::{AtomicU32, Ordering};
6use std::sync::Mutex;
7use std::time::{Duration, Instant};
8
9use tracing::{debug, warn};
10
11use crate::error::ApiError;
12
13pub const MAX_RATE_LIMIT_RETRIES: u32 = 5;
18pub const MAX_5XX_RETRIES: u32 = 3;
19pub const RATE_LIMIT_BASE_DELAY: Duration = Duration::from_secs(1);
20pub const SERVER_ERROR_RETRY_DELAY: Duration = Duration::from_secs(2);
21pub const CIRCUIT_BREAKER_THRESHOLD: u32 = 5;
22pub const CIRCUIT_BREAKER_RESET_DURATION: Duration = Duration::from_secs(60);
23
24pub struct CircuitBreaker {
31 failures: AtomicU32,
32 last_failure: Mutex<Option<Instant>>,
33 threshold: u32,
34 reset_duration: Duration,
35}
36
37impl CircuitBreaker {
38 pub fn new() -> Self {
39 Self {
40 failures: AtomicU32::new(0),
41 last_failure: Mutex::new(None),
42 threshold: CIRCUIT_BREAKER_THRESHOLD,
43 reset_duration: CIRCUIT_BREAKER_RESET_DURATION,
44 }
45 }
46
47 pub fn is_open(&self) -> bool {
50 let failures = self.failures.load(Ordering::SeqCst);
51 if failures < self.threshold {
52 return false;
53 }
54 let guard = self.last_failure.lock().unwrap();
56 match *guard {
57 None => false,
58 Some(last) => {
59 if last.elapsed() > self.reset_duration {
60 false
63 } else {
64 true
65 }
66 }
67 }
68 }
69
70 pub fn record_success(&self) {
72 self.failures.store(0, Ordering::SeqCst);
73 let mut guard = self.last_failure.lock().unwrap();
74 *guard = None;
75 }
76
77 pub fn record_failure(&self) {
79 let count = self.failures.fetch_add(1, Ordering::SeqCst) + 1;
80 let mut guard = self.last_failure.lock().unwrap();
81 *guard = Some(Instant::now());
82 if count >= self.threshold {
83 warn!("circuit breaker opened after {} failures", count);
84 }
85 }
86}
87
88impl Default for CircuitBreaker {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94pub struct RetryConfig {
100 pub max_retries_429: u32,
101 pub max_retries_5xx: u32,
102 pub base_delay: Duration,
103 pub circuit_breaker: Option<CircuitBreaker>,
104}
105
106impl Default for RetryConfig {
107 fn default() -> Self {
108 Self {
109 max_retries_429: MAX_RATE_LIMIT_RETRIES,
110 max_retries_5xx: MAX_5XX_RETRIES,
111 base_delay: RATE_LIMIT_BASE_DELAY,
112 circuit_breaker: Some(CircuitBreaker::new()),
113 }
114 }
115}
116
117pub fn calculate_backoff(attempt: u32, base_delay: Duration, retry_after: Option<&str>) -> Duration {
126 if let Some(header) = retry_after {
127 let trimmed = header.trim();
128 if let Ok(secs) = trimmed.parse::<i64>() {
130 if secs <= 0 {
131 return Duration::ZERO;
132 }
133 return Duration::from_secs(secs as u64);
134 }
135 if let Ok(dt) = httpdate::parse_http_date(trimmed) {
137 let now = std::time::SystemTime::now();
138 match dt.duration_since(now) {
139 Ok(d) => return d,
140 Err(_) => return Duration::ZERO, }
142 }
143 }
145
146 if base_delay.is_zero() {
148 return Duration::ZERO;
149 }
150
151 let factor = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
152 let base_ms = base_delay.as_millis() as u64;
153 let exp_ms = base_ms.saturating_mul(factor);
154 if exp_ms == 0 {
155 return Duration::ZERO;
156 }
157
158 let jitter_range_ms = base_ms / 2;
160 let jitter_ms = if jitter_range_ms > 0 {
161 let nonce = std::time::SystemTime::now()
164 .duration_since(std::time::UNIX_EPOCH)
165 .map(|d| d.subsec_nanos())
166 .unwrap_or(0) as u64;
167 (nonce.wrapping_add(attempt as u64 * 6364136223846793005)) % jitter_range_ms
168 } else {
169 0
170 };
171
172 Duration::from_millis(exp_ms + jitter_ms)
173}
174
175pub async fn execute_with_retry(
188 client: &reqwest::Client,
189 request_builder: reqwest::RequestBuilder,
190 config: &RetryConfig,
191) -> Result<reqwest::Response, ApiError> {
192 if let Some(cb) = &config.circuit_breaker {
194 if cb.is_open() {
195 return Err(ApiError::CircuitBreakerOpen);
196 }
197 }
198
199 let request = request_builder.build()?;
201
202 let mut retries_429: u32 = 0;
203 let mut retries_5xx: u32 = 0;
204
205 loop {
206 let attempt_req = request
208 .try_clone()
209 .ok_or_else(|| ApiError::GoogleApi {
210 status: 0,
211 message: "request body is not cloneable (streaming body)".to_string(),
212 })?;
213
214 let resp = client.execute(attempt_req).await?;
215 let status = resp.status().as_u16();
216
217 if status < 400 {
219 if let Some(cb) = &config.circuit_breaker {
220 cb.record_success();
221 }
222 return Ok(resp);
223 }
224
225 if status == 429 {
227 if retries_429 >= config.max_retries_429 {
228 return Err(ApiError::RateLimitExhausted {
229 retries: retries_429,
230 });
231 }
232
233 let retry_after_header = resp
234 .headers()
235 .get("Retry-After")
236 .and_then(|v| v.to_str().ok())
237 .map(|s| s.to_string());
238
239 let delay = calculate_backoff(
240 retries_429,
241 config.base_delay,
242 retry_after_header.as_deref(),
243 );
244
245 debug!(
246 delay_ms = delay.as_millis(),
247 attempt = retries_429 + 1,
248 max = config.max_retries_429,
249 "rate limited, retrying"
250 );
251
252 drop(resp);
254
255 tokio::time::sleep(delay).await;
256 retries_429 += 1;
257 continue;
258 }
259
260 if status >= 500 {
262 if let Some(cb) = &config.circuit_breaker {
263 cb.record_failure();
264 }
265
266 if retries_5xx >= config.max_retries_5xx {
267 return Ok(resp);
268 }
269
270 debug!(
271 status,
272 attempt = retries_5xx + 1,
273 "server error, retrying"
274 );
275
276 drop(resp);
277 tokio::time::sleep(SERVER_ERROR_RETRY_DELAY).await;
278 retries_5xx += 1;
279 continue;
280 }
281
282 return Ok(resp);
284 }
285}
286
287#[cfg(test)]
292mod tests {
293 use super::*;
294 use std::time::Duration;
295
296 #[test]
301 fn test_circuit_breaker_starts_closed() {
302 let cb = CircuitBreaker::new();
303 assert!(!cb.is_open(), "new circuit breaker should be closed");
304 }
305
306 #[test]
307 fn test_circuit_breaker_opens_after_threshold() {
308 let cb = CircuitBreaker::new();
309 for _ in 0..CIRCUIT_BREAKER_THRESHOLD {
310 cb.record_failure();
311 }
312 assert!(cb.is_open(), "circuit breaker should be open after {} failures", CIRCUIT_BREAKER_THRESHOLD);
313 }
314
315 #[test]
316 fn test_circuit_breaker_resets_on_success() {
317 let cb = CircuitBreaker::new();
318 for _ in 0..CIRCUIT_BREAKER_THRESHOLD {
319 cb.record_failure();
320 }
321 assert!(cb.is_open(), "should be open before reset");
322 cb.record_success();
323 assert!(!cb.is_open(), "should be closed after record_success");
324 }
325
326 #[test]
327 fn test_circuit_breaker_auto_resets_after_duration() {
328 let cb = CircuitBreaker {
330 failures: AtomicU32::new(CIRCUIT_BREAKER_THRESHOLD),
331 last_failure: Mutex::new(Some(
332 Instant::now() - Duration::from_secs(3600),
334 )),
335 threshold: CIRCUIT_BREAKER_THRESHOLD,
336 reset_duration: Duration::from_secs(1),
337 };
338 assert!(
340 !cb.is_open(),
341 "circuit breaker should auto-reset after reset_duration"
342 );
343 }
344
345 #[test]
346 fn test_circuit_breaker_stays_closed_below_threshold() {
347 let cb = CircuitBreaker::new();
348 for _ in 0..(CIRCUIT_BREAKER_THRESHOLD - 1) {
349 cb.record_failure();
350 }
351 assert!(
352 !cb.is_open(),
353 "circuit breaker should stay closed below threshold"
354 );
355 }
356
357 #[test]
362 fn test_calculate_backoff_base() {
363 let base = Duration::from_millis(100);
364 let d = calculate_backoff(0, base, None);
366 assert!(
367 d >= base && d < base + base / 2 + Duration::from_millis(1),
368 "attempt 0 should be roughly base_delay, got {:?}",
369 d
370 );
371 }
372
373 #[test]
374 fn test_calculate_backoff_exponential() {
375 let base = Duration::from_millis(100);
376 let d = calculate_backoff(2, base, None);
378 let expected_base = base * 4; assert!(
380 d >= expected_base,
381 "attempt 2 should be >= 400ms, got {:?}",
382 d
383 );
384 }
385
386 #[test]
387 fn test_calculate_backoff_retry_after_seconds() {
388 let base = Duration::from_secs(1);
389 let d = calculate_backoff(0, base, Some("5"));
390 assert_eq!(d, Duration::from_secs(5), "Retry-After: 5 should give 5s delay");
391 }
392
393 #[test]
394 fn test_calculate_backoff_zero_base() {
395 let d = calculate_backoff(0, Duration::ZERO, None);
396 assert_eq!(d, Duration::ZERO, "zero base_delay should give zero backoff");
397 }
398
399 #[test]
400 fn test_calculate_backoff_negative_retry_after() {
401 let d = calculate_backoff(0, Duration::from_secs(1), Some("-1"));
402 assert_eq!(d, Duration::ZERO, "negative Retry-After should give zero delay");
403 }
404
405 #[test]
410 fn test_retry_config_default() {
411 let cfg = RetryConfig::default();
412 assert_eq!(cfg.max_retries_429, MAX_RATE_LIMIT_RETRIES);
413 assert_eq!(cfg.max_retries_5xx, MAX_5XX_RETRIES);
414 assert_eq!(cfg.base_delay, RATE_LIMIT_BASE_DELAY);
415 assert!(cfg.circuit_breaker.is_some(), "default should have circuit breaker");
416 }
417}