use std::error::Error as StdError;
use std::future::Future;
use std::result::Result as StdResult;
use std::time::Duration;
use sift_error::prelude::*;
use tonic;
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_attempts: usize,
pub base_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 3,
base_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(5),
backoff_multiplier: 2.0,
}
}
}
impl RetryConfig {
pub fn backoff(&self, attempt: usize) -> Duration {
if attempt <= 1 {
return self.base_delay;
}
let exponent = (attempt - 1) as f64;
let delay_ms = self.base_delay.as_millis() as f64 * self.backoff_multiplier.powf(exponent);
let delay = Duration::from_millis(delay_ms as u64);
delay.min(self.max_delay)
}
}
pub trait RetryDecider<E> {
fn should_retry(&self, err: &E) -> bool;
}
pub struct DefaultGrpcRetry;
impl RetryDecider<sift_error::Error> for DefaultGrpcRetry {
fn should_retry(&self, err: &sift_error::Error) -> bool {
let mut source = err.source();
while let Some(err_ref) = source {
if let Some(status) = err_ref.downcast_ref::<tonic::Status>() {
return matches!(
status.code(),
tonic::Code::Unavailable
| tonic::Code::ResourceExhausted
| tonic::Code::DeadlineExceeded
);
}
source = err_ref.source();
}
matches!(
err.kind(),
ErrorKind::GrpcConnectError
| ErrorKind::RetrieveAssetError
| ErrorKind::RetrieveIngestionConfigError
| ErrorKind::RetrieveRunError
)
}
}
#[derive(Clone, Debug)]
pub struct Retrying<T, D = DefaultGrpcRetry> {
inner: T,
cfg: RetryConfig,
decider: D,
}
impl<T> Retrying<T> {
pub fn new(inner: T, cfg: RetryConfig) -> Self {
Self {
inner,
cfg,
decider: DefaultGrpcRetry,
}
}
}
impl<T, D> Retrying<T, D> {
pub fn with_decider<D2>(self, decider: D2) -> Retrying<T, D2> {
Retrying {
inner: self.inner,
cfg: self.cfg,
decider,
}
}
pub fn inner(&self) -> &T {
&self.inner
}
}
impl<T, D> Retrying<T, D>
where
T: Clone,
{
pub async fn call<F, Fut, R, E>(&self, mut f: F) -> StdResult<R, E>
where
F: FnMut(T) -> Fut,
Fut: Future<Output = StdResult<R, E>>,
D: RetryDecider<E>,
{
let mut last_err = None;
for attempt in 1..=self.cfg.max_attempts {
let wrapper = self.inner.clone();
match f(wrapper).await {
Ok(result) => return Ok(result),
Err(e) => {
last_err = Some(e);
if attempt < self.cfg.max_attempts
&& self.decider.should_retry(last_err.as_ref().unwrap())
{
let delay = self.cfg.backoff(attempt);
tokio::time::sleep(delay).await;
continue;
}
break;
}
}
}
Err(last_err.expect("retry loop invariant violated"))
}
}
pub trait RetryExt: Sized {
fn retrying(self, cfg: RetryConfig) -> Retrying<Self> {
Retrying::new(self, cfg)
}
}
impl<T> RetryExt for T {}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn test_backoff_calculation() {
let cfg = RetryConfig::default();
assert_eq!(cfg.backoff(1), Duration::from_millis(100));
assert_eq!(cfg.backoff(2), Duration::from_millis(200));
assert_eq!(cfg.backoff(3), Duration::from_millis(400));
assert_eq!(cfg.backoff(4), Duration::from_millis(800));
}
#[test]
fn test_backoff_caps_at_max() {
let cfg = RetryConfig {
max_attempts: 10,
base_delay: Duration::from_millis(100),
max_delay: Duration::from_millis(500),
backoff_multiplier: 2.0,
};
let delay = cfg.backoff(10);
assert_eq!(delay, Duration::from_millis(500));
}
#[tokio::test]
async fn test_retry_loop_succeeds_after_failures() {
let counter = Arc::new(AtomicUsize::new(0));
let cfg = RetryConfig {
max_attempts: 3,
base_delay: Duration::from_millis(10),
max_delay: Duration::from_secs(5),
backoff_multiplier: 2.0,
};
let retrying = Retrying::new((), cfg);
let result = retrying
.call(|_| {
let counter = counter.clone();
async move {
let attempts = counter.fetch_add(1, Ordering::SeqCst) + 1;
if attempts < 3 {
Err::<(), sift_error::Error>(Error::new_msg(
ErrorKind::RetrieveAssetError,
"temporary failure",
))
} else {
Ok(())
}
}
})
.await;
assert!(result.is_ok());
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_loop_exhausts_attempts() {
let counter = Arc::new(AtomicUsize::new(0));
let cfg = RetryConfig {
max_attempts: 3,
base_delay: Duration::from_millis(10),
max_delay: Duration::from_secs(5),
backoff_multiplier: 2.0,
};
let retrying = Retrying::new((), cfg);
let result = retrying
.call(|_| {
let counter = counter.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err::<(), sift_error::Error>(Error::new_msg(
ErrorKind::RetrieveAssetError,
"persistent failure",
))
}
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[test]
fn test_default_grpc_retry_with_tonic_status() {
let decider = DefaultGrpcRetry;
let unavailable = Error::new(
ErrorKind::RetrieveAssetError,
tonic::Status::unavailable("service unavailable"),
);
assert!(decider.should_retry(&unavailable));
let resource_exhausted = Error::new(
ErrorKind::RetrieveAssetError,
tonic::Status::resource_exhausted("resource exhausted"),
);
assert!(decider.should_retry(&resource_exhausted));
let deadline_exceeded = Error::new(
ErrorKind::RetrieveAssetError,
tonic::Status::deadline_exceeded("deadline exceeded"),
);
assert!(decider.should_retry(&deadline_exceeded));
let invalid_argument = Error::new(
ErrorKind::ArgumentValidationError,
tonic::Status::invalid_argument("invalid argument"),
);
assert!(!decider.should_retry(&invalid_argument));
let not_found = Error::new(
ErrorKind::NotFoundError,
tonic::Status::not_found("not found"),
);
assert!(!decider.should_retry(¬_found));
}
#[test]
fn test_default_grpc_retry_with_error_kind_fallback() {
let decider = DefaultGrpcRetry;
let grpc_connect_error = Error::new_msg(ErrorKind::GrpcConnectError, "connection failed");
assert!(decider.should_retry(&grpc_connect_error));
let retrieve_asset_error =
Error::new_msg(ErrorKind::RetrieveAssetError, "retrieval failed");
assert!(decider.should_retry(&retrieve_asset_error));
let retrieve_ingestion_config_error =
Error::new_msg(ErrorKind::RetrieveIngestionConfigError, "retrieval failed");
assert!(decider.should_retry(&retrieve_ingestion_config_error));
let retrieve_run_error = Error::new_msg(ErrorKind::RetrieveRunError, "retrieval failed");
assert!(decider.should_retry(&retrieve_run_error));
let argument_error = Error::new_msg(ErrorKind::ArgumentValidationError, "bad argument");
assert!(!decider.should_retry(&argument_error));
let not_found_error = Error::new_msg(ErrorKind::NotFoundError, "not found");
assert!(!decider.should_retry(¬_found_error));
}
#[tokio::test]
async fn test_no_retry_on_non_retryable_error() {
let counter = Arc::new(AtomicUsize::new(0));
let cfg = RetryConfig {
max_attempts: 3,
base_delay: Duration::from_millis(10),
max_delay: Duration::from_secs(5),
backoff_multiplier: 2.0,
};
let retrying = Retrying::new((), cfg);
let result = retrying
.call(|_| {
let counter = counter.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err::<(), sift_error::Error>(Error::new(
ErrorKind::ArgumentValidationError,
tonic::Status::invalid_argument("invalid argument"),
))
}
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
}