use std::{future::Future, time::Duration};
use rand::RngExt;
use rand::distr::Uniform;
use serde::{Deserialize, Serialize};
use tokio::time::sleep;
use tokio_util::sync::CancellationToken;
use crate::CamelError;
fn default_enabled() -> bool {
true
}
fn default_max_attempts() -> u32 {
10
}
fn default_initial_delay() -> Duration {
Duration::from_millis(100)
}
fn default_multiplier() -> f64 {
2.0
}
fn default_max_delay() -> Duration {
Duration::from_millis(30_000)
}
fn default_jitter_factor() -> f64 {
0.2
}
fn deserialize_duration_ms<'de, D>(d: D) -> Result<Duration, D::Error>
where
D: serde::Deserializer<'de>,
{
let ms = u64::deserialize(d)?;
Ok(Duration::from_millis(ms))
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[must_use]
pub struct NetworkRetryPolicy {
#[serde(default = "default_enabled")]
pub enabled: bool,
#[serde(default = "default_max_attempts")]
pub max_attempts: u32,
#[serde(
default = "default_initial_delay",
rename = "initial_delay_ms",
deserialize_with = "deserialize_duration_ms"
)]
pub initial_delay: Duration,
#[serde(default = "default_multiplier")]
pub multiplier: f64,
#[serde(
default = "default_max_delay",
rename = "max_delay_ms",
deserialize_with = "deserialize_duration_ms"
)]
pub max_delay: Duration,
#[serde(default = "default_jitter_factor")]
pub jitter_factor: f64,
}
impl Default for NetworkRetryPolicy {
fn default() -> Self {
Self {
enabled: default_enabled(),
max_attempts: default_max_attempts(),
initial_delay: default_initial_delay(),
multiplier: default_multiplier(),
max_delay: default_max_delay(),
jitter_factor: default_jitter_factor(),
}
}
}
impl NetworkRetryPolicy {
pub fn disabled() -> Self {
Self {
enabled: false,
..Self::default()
}
}
#[must_use]
pub fn delay_for(&self, attempt: u32) -> Duration {
let base_ms = self.initial_delay.as_millis() as f64;
let exp = self.multiplier.powi(attempt as i32);
let computed_ms = (base_ms * exp).min(self.max_delay.as_millis() as f64);
let jitter_range = computed_ms * self.jitter_factor;
let jitter = if jitter_range > 0.0 {
let mut rng = rand::rng();
let lo = -jitter_range / 2.0;
let hi = jitter_range / 2.0;
debug_assert!(lo < hi, "jitter bounds are valid when jitter_range > 0");
let dist = Uniform::new(lo, hi).unwrap(); rng.sample(dist)
} else {
0.0
};
let final_ms = (computed_ms + jitter).max(0.0) as u64;
let max_delay_ms = u64::try_from(self.max_delay.as_millis()).unwrap_or(u64::MAX);
Duration::from_millis(final_ms.min(max_delay_ms))
}
#[must_use]
pub fn should_retry(&self, attempt: u32) -> bool {
self.enabled && (self.max_attempts == 0 || attempt < self.max_attempts)
}
}
pub async fn retry_async<T, Op, Fut, IsRetryable, E>(
policy: &NetworkRetryPolicy,
label: Option<&'static str>,
op: Op,
is_retryable: IsRetryable,
) -> Result<T, E>
where
Op: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
IsRetryable: Fn(&E) -> bool,
E: std::fmt::Display,
{
retry_async_inner(policy, op, is_retryable, None, label).await
}
async fn retry_async_inner<T, Op, Fut, IsRetryable, E>(
policy: &NetworkRetryPolicy,
mut op: Op,
is_retryable: IsRetryable,
cancel: Option<&CancellationToken>,
label: Option<&'static str>,
) -> Result<T, E>
where
Op: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
IsRetryable: Fn(&E) -> bool,
E: std::fmt::Display,
{
let mut attempt = 0u32;
loop {
match op().await {
Ok(val) => return Ok(val),
Err(err) => {
if !is_retryable(&err) || !policy.should_retry(attempt + 1) {
return Err(err);
}
let delay = policy.delay_for(attempt);
if let Some(component) = label {
tracing::warn!(
component,
attempt,
delay_ms = delay.as_millis(),
error = %err,
"{component}: transient error — retrying"
);
} else {
tracing::warn!(
attempt,
delay_ms = delay.as_millis(),
error = %err,
"transient error — retrying"
);
}
if let Some(token) = cancel {
tokio::select! {
biased;
_ = token.cancelled() => return Err(err),
_ = sleep(delay) => {}
}
} else {
sleep(delay).await;
}
attempt += 1;
}
}
}
}
pub async fn retry_async_cancelable<T, Op, Fut, IsRetryable, E>(
policy: &NetworkRetryPolicy,
label: Option<&'static str>,
op: Op,
is_retryable: IsRetryable,
cancel: &CancellationToken,
) -> Result<T, E>
where
Op: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
IsRetryable: Fn(&E) -> bool,
E: std::fmt::Display,
{
retry_async_inner(policy, op, is_retryable, Some(cancel), label).await
}
pub fn is_retryable_camel_error(err: &CamelError) -> bool {
matches!(err, CamelError::Io(_))
|| matches!(err, CamelError::ProcessorError(s) if s.contains("[TRANSIENT]"))
|| matches!(err, CamelError::ProcessorErrorWithSource(s, _) if s.contains("[TRANSIENT]"))
}
#[cfg(test)]
#[path = "network_retry_tests.rs"]
mod tests;