use std::future::Future;
use std::time::Duration;
use osproxy_spi::SpiError;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub base_backoff: Duration,
pub max_backoff: Duration,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
base_backoff: Duration::from_millis(50),
max_backoff: Duration::from_secs(1),
}
}
}
impl RetryPolicy {
fn backoff(self, attempt: u32) -> Duration {
let factor = 1u32.checked_shl(attempt).unwrap_or(u32::MAX);
self.base_backoff
.saturating_mul(factor)
.min(self.max_backoff)
}
}
pub(crate) async fn with_retry<T, F, Fut>(policy: RetryPolicy, mut op: F) -> Result<T, SpiError>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, SpiError>>,
{
let mut attempt = 0;
loop {
match op().await {
Ok(value) => return Ok(value),
Err(err) if err.retryable() && attempt + 1 < policy.max_attempts => {
let backoff = policy.backoff(attempt);
if !backoff.is_zero() {
tokio::time::sleep(backoff).await;
}
attempt += 1;
}
Err(err) => return Err(err),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::cell::Cell;
fn backend_unavailable() -> SpiError {
SpiError::PlacementBackend { retryable: true }
}
#[tokio::test]
async fn retries_a_transient_backend_then_succeeds() {
let policy = RetryPolicy {
max_attempts: 3,
base_backoff: Duration::ZERO,
max_backoff: Duration::ZERO,
};
let calls = Cell::new(0);
let out: Result<u8, SpiError> = with_retry(policy, || {
let n = calls.get() + 1;
calls.set(n);
async move {
if n < 3 {
Err(backend_unavailable())
} else {
Ok(7)
}
}
})
.await;
assert_eq!(out.unwrap(), 7);
assert_eq!(calls.get(), 3);
}
#[tokio::test]
async fn gives_up_after_max_attempts_with_the_retryable_error() {
let policy = RetryPolicy {
max_attempts: 2,
base_backoff: Duration::ZERO,
max_backoff: Duration::ZERO,
};
let calls = Cell::new(0);
let out: Result<u8, SpiError> = with_retry(policy, || {
calls.set(calls.get() + 1);
async { Err(backend_unavailable()) }
})
.await;
assert!(out.is_err());
assert_eq!(calls.get(), 2, "exactly max_attempts tries");
}
#[tokio::test]
async fn does_not_retry_a_non_retryable_error() {
let policy = RetryPolicy::default();
let calls = Cell::new(0);
let out: Result<u8, SpiError> = with_retry(policy, || {
calls.set(calls.get() + 1);
async {
Err(SpiError::PlacementMissing {
partition: osproxy_core::PartitionId::from("p"),
})
}
})
.await;
assert!(out.is_err());
assert_eq!(calls.get(), 1, "a definitive error is not retried");
}
}