use std::time::Duration;
use crate::error::{CortexError, CortexResult};
#[derive(Debug, Clone)]
pub enum RetryPolicy {
None,
Backoff {
max_retries: u32,
base_delay: Duration,
max_delay: Duration,
},
}
impl RetryPolicy {
#[must_use]
pub fn none() -> Self {
Self::None
}
#[must_use]
pub fn query() -> Self {
Self::Backoff {
max_retries: 3,
base_delay: Duration::from_millis(500),
max_delay: Duration::from_secs(10),
}
}
#[must_use]
pub fn idempotent() -> Self {
Self::Backoff {
max_retries: 2,
base_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(15),
}
}
#[must_use]
pub fn stop() -> Self {
Self::Backoff {
max_retries: 2,
base_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(15),
}
}
#[must_use]
pub fn custom(max_retries: u32, base_delay: Duration, max_delay: Duration) -> Self {
Self::Backoff {
max_retries,
base_delay,
max_delay,
}
}
}
pub async fn with_retry<F, Fut, T>(policy: &RetryPolicy, mut operation: F) -> CortexResult<T>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = CortexResult<T>>,
{
match policy {
RetryPolicy::None => operation().await,
RetryPolicy::Backoff {
max_retries,
base_delay,
max_delay,
} => {
let mut delay = *base_delay;
for attempt in 0..=*max_retries {
match operation().await {
Ok(result) => return Ok(result),
Err(e) => {
if !e.is_retryable() {
return Err(e);
}
if attempt == *max_retries {
return Err(CortexError::RetriesExhausted {
attempts: attempt + 1,
last_error: Box::new(e),
});
}
tracing::warn!(
attempt = attempt + 1,
max = max_retries + 1,
error = %e,
delay_ms = delay.as_millis() as u64,
"Retrying after transient error"
);
tokio::time::sleep(delay).await;
delay = std::cmp::min(delay * 2, *max_delay);
}
}
}
operation().await
}
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use super::*;
#[tokio::test]
async fn test_policy_none_returns_immediately() {
let result: CortexResult<i32> = with_retry(&RetryPolicy::none(), || async {
Err(CortexError::Timeout { seconds: 1 })
})
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_retry_succeeds_on_second_attempt() {
let attempts = AtomicUsize::new(0);
let result = with_retry(&RetryPolicy::custom(3, Duration::from_millis(1), Duration::from_millis(10)), || {
let n = attempts.fetch_add(1, Ordering::SeqCst);
async move {
if n == 0 {
Err(CortexError::Timeout { seconds: 1 })
} else {
Ok::<i32, CortexError>(42)
}
}
})
.await;
assert_eq!(result.unwrap(), 42);
assert_eq!(attempts.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_non_retryable_fails_immediately() {
let attempts = AtomicUsize::new(0);
let result: CortexResult<i32> = with_retry(&RetryPolicy::query(), || {
attempts.fetch_add(1, Ordering::SeqCst);
async { Err(CortexError::NoHeadsetFound) }
})
.await;
assert!(matches!(result, Err(CortexError::NoHeadsetFound)));
assert_eq!(attempts.load(Ordering::SeqCst), 1); }
#[tokio::test]
async fn test_retries_exhausted() {
let attempts = AtomicUsize::new(0);
let result: CortexResult<i32> = with_retry(
&RetryPolicy::custom(2, Duration::from_millis(1), Duration::from_millis(10)),
|| {
attempts.fetch_add(1, Ordering::SeqCst);
async { Err(CortexError::Timeout { seconds: 1 }) }
},
)
.await;
assert!(matches!(result, Err(CortexError::RetriesExhausted { attempts: 3, .. })));
assert_eq!(attempts.load(Ordering::SeqCst), 3); }
#[test]
fn test_policy_constructors() {
assert!(matches!(RetryPolicy::none(), RetryPolicy::None));
assert!(matches!(RetryPolicy::query(), RetryPolicy::Backoff { max_retries: 3, .. }));
assert!(matches!(RetryPolicy::idempotent(), RetryPolicy::Backoff { max_retries: 2, .. }));
}
}