use crate::error::{AppError, AppResult};
use std::time::Duration;
use tokio::time::sleep;
pub async fn retry_with_backoff<F, Fut, T>(mut op: F, max_attempts: u8) -> AppResult<T>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = AppResult<T>>,
{
let delays = [
Duration::from_secs(1),
Duration::from_secs(2),
Duration::from_secs(4),
];
for attempt in 0..max_attempts {
match op().await {
Ok(value) => return Ok(value),
Err(e) if !is_retryable(&e) => return Err(e),
Err(_) if attempt + 1 == max_attempts => {
return Err(AppError::ProviderUnavailable);
}
Err(AppError::RateLimited { retry_after_secs }) => {
let wait = retry_after_secs.unwrap_or(60).min(300);
tracing::warn!(
target: "events",
event = "retry",
attempt = attempt + 1,
next_delay_secs = wait,
"rate limited (HTTP 429); honouring Retry-After"
);
sleep(Duration::from_secs(wait)).await;
}
Err(_) => {
if let Some(delay) = delays.get(attempt as usize) {
tracing::warn!(
target: "events",
event = "retry",
attempt = attempt + 1,
next_delay_secs = delay.as_secs(),
"transient failure; backing off"
);
sleep(*delay).await;
}
}
}
}
Err(AppError::ProviderUnavailable)
}
fn is_retryable(err: &AppError) -> bool {
matches!(
err,
AppError::Timeout(_)
| AppError::ProviderUnavailable
| AppError::RateLimited { .. }
| AppError::Http(_)
)
}
#[derive(Debug)]
#[non_exhaustive]
pub struct CircuitBreaker {
threshold: u8,
failures: u8,
}
impl CircuitBreaker {
pub fn new(threshold: u8) -> Self {
Self {
threshold,
failures: 0,
}
}
pub fn record_failure(&mut self) -> bool {
self.failures = self.failures.saturating_add(1);
self.failures >= self.threshold
}
pub fn record_success(&mut self) {
self.failures = 0;
}
pub fn is_open(&self) -> bool {
self.failures >= self.threshold
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU8, Ordering};
#[tokio::test(start_paused = true)]
async fn rate_limited_waits_retry_after_seconds() {
let calls = AtomicU8::new(0);
let start = tokio::time::Instant::now();
let result = retry_with_backoff(
|| async {
if calls.fetch_add(1, Ordering::SeqCst) == 0 {
Err(AppError::RateLimited {
retry_after_secs: Some(2),
})
} else {
Ok(42u8)
}
},
3,
)
.await;
assert_eq!(result.expect("second attempt succeeds"), 42);
assert_eq!(calls.load(Ordering::SeqCst), 2);
let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_secs(2),
"virtual clock advanced only {elapsed:?}"
);
}
#[tokio::test(start_paused = true)]
async fn rate_limited_without_header_waits_60s_fallback() {
let calls = AtomicU8::new(0);
let start = tokio::time::Instant::now();
let result = retry_with_backoff(
|| async {
if calls.fetch_add(1, Ordering::SeqCst) == 0 {
Err(AppError::RateLimited {
retry_after_secs: None,
})
} else {
Ok(())
}
},
3,
)
.await;
assert!(result.is_ok());
let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_secs(60),
"virtual clock advanced only {elapsed:?}"
);
}
#[tokio::test(start_paused = true)]
async fn rate_limited_wait_is_capped_at_300s() {
let calls = AtomicU8::new(0);
let start = tokio::time::Instant::now();
let result = retry_with_backoff(
|| async {
if calls.fetch_add(1, Ordering::SeqCst) == 0 {
Err(AppError::RateLimited {
retry_after_secs: Some(9999),
})
} else {
Ok(())
}
},
3,
)
.await;
assert!(result.is_ok());
let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_secs(300) && elapsed < Duration::from_secs(360),
"expected ~300s of virtual wait, got {elapsed:?}"
);
}
}