use std::{future::Future, time::Duration};
use anyhow::Result;
use tracing::warn;
pub async fn with_retry<F, Fut, T>(operation: F, max_attempts: u32, base_delay_ms: u64) -> Result<T>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<T>>,
{
let mut attempt = 0u32;
loop {
match operation().await {
Ok(result) => return Ok(result),
Err(e) if attempt < max_attempts && is_retryable(&e) => {
let delay = base_delay_ms.saturating_mul(2u64.pow(attempt));
warn!(
attempt = attempt + 1,
max_attempts,
delay_ms = delay,
error = %e,
"RPC call failed, retrying"
);
tokio::time::sleep(Duration::from_millis(delay)).await;
attempt += 1;
}
Err(e) => return Err(e),
}
}
}
fn is_retryable(e: &anyhow::Error) -> bool {
let msg = e.to_string().to_lowercase();
msg.contains("429")
|| msg.contains("too many requests")
|| msg.contains("timeout")
|| msg.contains("timed out")
|| msg.contains("connection")
|| msg.contains("temporarily unavailable")
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicU32, Ordering};
use super::*;
#[tokio::test]
async fn succeeds_on_first_try() {
let result = with_retry(|| async { Ok::<_, anyhow::Error>(42) }, 3, 10).await;
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn retries_on_transient_error_then_succeeds() {
let call_count = AtomicU32::new(0);
let result = with_retry(
|| {
let n = call_count.fetch_add(1, Ordering::SeqCst);
async move {
if n < 2 {
Err(anyhow::anyhow!("429 Too Many Requests"))
} else {
Ok(99)
}
}
},
3,
10,
)
.await;
assert_eq!(result.unwrap(), 99);
assert_eq!(call_count.load(Ordering::SeqCst), 3); }
#[tokio::test]
async fn fails_immediately_on_non_retryable_error() {
let call_count = AtomicU32::new(0);
let result = with_retry(
|| {
call_count.fetch_add(1, Ordering::SeqCst);
async { Err::<i32, _>(anyhow::anyhow!("invalid account data")) }
},
3,
10,
)
.await;
assert!(result.is_err());
assert_eq!(call_count.load(Ordering::SeqCst), 1); }
#[tokio::test]
async fn exhausts_retries_and_returns_last_error() {
let call_count = AtomicU32::new(0);
let result = with_retry(
|| {
call_count.fetch_add(1, Ordering::SeqCst);
async { Err::<i32, _>(anyhow::anyhow!("connection refused")) }
},
2,
10,
)
.await;
assert!(result.is_err());
assert_eq!(call_count.load(Ordering::SeqCst), 3); }
}