use std::{future::Future, time::Duration};
use tracing::{debug, warn};
use crate::BotError;
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub base_delay: Duration,
pub max_delay: Duration,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
base_delay: Duration::from_millis(200),
max_delay: Duration::from_secs(60),
}
}
}
impl RetryPolicy {
pub fn new() -> Self {
Self::default()
}
pub fn max_attempts(mut self, n: u32) -> Self {
self.max_attempts = n.max(1);
self
}
pub fn base_delay(mut self, d: Duration) -> Self {
self.base_delay = d;
self
}
pub fn max_delay(mut self, d: Duration) -> Self {
self.max_delay = d;
self
}
pub async fn run<T, F, Fut>(&self, mut f: F) -> Result<T, BotError>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, BotError>> + Send,
{
let mut attempt = 0u32;
loop {
attempt += 1;
match f().await {
Ok(v) => return Ok(v),
Err(e) => {
if attempt >= self.max_attempts {
return Err(e);
}
let delay = match &e {
BotError::Api {
code: 429,
retry_after: Some(secs),
..
} => {
let d = Duration::from_secs(*secs as u64);
warn!(
attempt,
retry_after = secs,
"flood-wait; sleeping before retry"
);
d.min(self.max_delay)
}
BotError::Http(_) => {
let d = backoff(self.base_delay, attempt, self.max_delay);
warn!(attempt, delay_ms = d.as_millis(), "HTTP error; retrying");
d
}
_ => return Err(e),
};
debug!(attempt, delay_ms = delay.as_millis(), "retry sleep");
tokio::time::sleep(delay).await;
}
}
}
}
}
fn backoff(base: Duration, attempt: u32, max: Duration) -> Duration {
let factor = 1u64
.checked_shl(attempt.saturating_sub(1))
.unwrap_or(u64::MAX);
let ms = base.as_millis().saturating_mul(factor as u128);
let ms = ms.min(max.as_millis()) as u64;
Duration::from_millis(ms)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn backoff_grows_exponentially() {
let base = Duration::from_millis(100);
let max = Duration::from_secs(10);
assert_eq!(backoff(base, 1, max), Duration::from_millis(100));
assert_eq!(backoff(base, 2, max), Duration::from_millis(200));
assert_eq!(backoff(base, 3, max), Duration::from_millis(400));
assert_eq!(backoff(base, 4, max), Duration::from_millis(800));
}
#[test]
fn backoff_caps_at_max() {
let base = Duration::from_millis(100);
let max = Duration::from_millis(300);
assert_eq!(backoff(base, 3, max), Duration::from_millis(300));
assert_eq!(backoff(base, 4, max), Duration::from_millis(300));
}
#[tokio::test]
async fn retries_on_http_error() {
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
let calls = Arc::new(AtomicU32::new(0));
let calls2 = Arc::clone(&calls);
let policy = RetryPolicy::new()
.max_attempts(3)
.base_delay(Duration::from_millis(1));
let result: Result<i32, BotError> = policy
.run(|| {
let c = Arc::clone(&calls2);
async move {
let n = c.fetch_add(1, Ordering::SeqCst);
if n < 2 {
Err(BotError::Other("simulated http".into()))
} else {
Ok(42)
}
}
})
.await;
assert!(result.is_err()); assert_eq!(calls.load(Ordering::SeqCst), 1); }
}