use faucet_core::FaucetError;
use std::future::Future;
use std::time::Duration;
const MAX_CONSECUTIVE_RATE_LIMITS: u32 = 10;
pub async fn execute_with_retry<F, Fut, T>(
max_retries: u32,
base_backoff: Duration,
mut operation: F,
) -> Result<T, FaucetError>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, FaucetError>>,
{
let mut attempt = 0u32;
let mut rate_limited = 0u32;
loop {
match operation().await {
Ok(val) => return Ok(val),
Err(FaucetError::RateLimited(wait)) => {
if rate_limited >= MAX_CONSECUTIVE_RATE_LIMITS {
tracing::error!("rate limited {rate_limited} consecutive times; giving up");
return Err(FaucetError::RateLimited(wait));
}
rate_limited += 1;
tracing::warn!(
"rate limited; retrying after {wait:?} ({rate_limited}/{MAX_CONSECUTIVE_RATE_LIMITS})"
);
tokio::time::sleep(wait).await;
}
Err(e) if e.is_retriable() && attempt < max_retries => {
rate_limited = 0;
tracing::warn!(
"request failed (attempt {}/{}): {e}",
attempt + 1,
max_retries + 1
);
let wait = faucet_core::retry::backoff_with_jitter(base_backoff, attempt);
tokio::time::sleep(wait).await;
attempt += 1;
}
Err(e) => {
if !e.is_retriable() {
tracing::error!("non-retriable error: {e}");
} else {
tracing::error!("request failed after {} attempts: {e}", attempt + 1);
}
return Err(e);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use faucet_core::FaucetError;
#[tokio::test]
async fn execute_with_retry_success_on_first_try() {
let result = execute_with_retry(3, Duration::from_millis(1), || async {
Ok::<_, FaucetError>(42)
})
.await;
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn execute_with_retry_non_retriable_fails_immediately() {
let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let cc = call_count.clone();
let result = execute_with_retry(3, Duration::from_millis(1), move || {
cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
async { Err::<i32, _>(FaucetError::Auth("bad credentials".into())) }
})
.await;
assert!(result.is_err());
assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
}
#[tokio::test]
async fn execute_with_retry_retriable_exhausts_retries() {
let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let cc = call_count.clone();
let result = execute_with_retry(2, Duration::from_millis(1), move || {
cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
async {
Err::<i32, _>(FaucetError::HttpStatus {
status: 500,
url: "http://test".into(),
body: "error".into(),
})
}
})
.await;
assert!(result.is_err());
assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
}
#[tokio::test]
async fn execute_with_retry_succeeds_after_transient_failure() {
let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let cc = call_count.clone();
let result = execute_with_retry(3, Duration::from_millis(1), move || {
let count = cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
async move {
if count < 2 {
Err::<i32, _>(FaucetError::HttpStatus {
status: 502,
url: "http://test".into(),
body: "bad gateway".into(),
})
} else {
Ok(99)
}
}
})
.await;
assert_eq!(result.unwrap(), 99);
assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
}
#[tokio::test]
async fn execute_with_retry_rate_limited_does_not_count_as_attempt() {
let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let cc = call_count.clone();
let result = execute_with_retry(
0, Duration::from_millis(1),
move || {
let count = cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
async move {
if count == 0 {
Err::<i32, _>(FaucetError::RateLimited(Duration::from_millis(1)))
} else {
Ok(42)
}
}
},
)
.await;
assert_eq!(result.unwrap(), 42);
assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 2);
}
#[tokio::test]
async fn execute_with_retry_perpetual_rate_limit_is_bounded() {
let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let cc = call_count.clone();
let result = execute_with_retry(3, Duration::from_millis(1), move || {
cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
async { Err::<i32, _>(FaucetError::RateLimited(Duration::from_millis(1))) }
})
.await;
assert!(matches!(result, Err(FaucetError::RateLimited(_))));
assert_eq!(
call_count.load(std::sync::atomic::Ordering::SeqCst),
MAX_CONSECUTIVE_RATE_LIMITS + 1
);
}
}