use std::sync::Arc;
use std::time::Duration;
use crate::CamelError;
pub const HEADER_REDELIVERED: &str = "CamelRedelivered";
pub const HEADER_REDELIVERY_COUNTER: &str = "CamelRedeliveryCounter";
pub const HEADER_REDELIVERY_MAX_COUNTER: &str = "CamelRedeliveryMaxCounter";
#[derive(Debug, Clone)]
pub struct RedeliveryPolicy {
pub max_attempts: u32,
pub initial_delay: Duration,
pub multiplier: f64,
pub max_delay: Duration,
pub jitter_factor: f64,
}
impl RedeliveryPolicy {
pub fn new(max_attempts: u32) -> Self {
Self {
max_attempts,
initial_delay: Duration::from_millis(100),
multiplier: 2.0,
max_delay: Duration::from_secs(10),
jitter_factor: 0.0,
}
}
pub fn with_initial_delay(mut self, d: Duration) -> Self {
self.initial_delay = d;
self
}
pub fn with_multiplier(mut self, m: f64) -> Self {
self.multiplier = m;
self
}
pub fn with_max_delay(mut self, d: Duration) -> Self {
self.max_delay = d;
self
}
pub fn with_jitter(mut self, j: f64) -> Self {
self.jitter_factor = j.clamp(0.0, 1.0);
self
}
pub fn delay_for(&self, attempt: u32) -> Duration {
let base_ms = self.initial_delay.as_millis() as f64 * self.multiplier.powi(attempt as i32);
let capped_ms = base_ms.min(self.max_delay.as_millis() as f64);
if self.jitter_factor > 0.0 {
let jitter = capped_ms * self.jitter_factor * (rand::random::<f64>() * 2.0 - 1.0);
Duration::from_millis((capped_ms + jitter).max(0.0) as u64)
} else {
Duration::from_millis(capped_ms as u64)
}
}
}
pub struct ExceptionPolicy {
pub matches: Arc<dyn Fn(&CamelError) -> bool + Send + Sync>,
pub retry: Option<RedeliveryPolicy>,
pub handled_by: Option<String>,
}
impl ExceptionPolicy {
pub fn new(matches: impl Fn(&CamelError) -> bool + Send + Sync + 'static) -> Self {
Self {
matches: Arc::new(matches),
retry: None,
handled_by: None,
}
}
}
impl Clone for ExceptionPolicy {
fn clone(&self) -> Self {
Self {
matches: Arc::clone(&self.matches),
retry: self.retry.clone(),
handled_by: self.handled_by.clone(),
}
}
}
#[derive(Clone)]
pub struct ErrorHandlerConfig {
pub dlc_uri: Option<String>,
pub policies: Vec<ExceptionPolicy>,
}
impl ErrorHandlerConfig {
pub fn log_only() -> Self {
Self {
dlc_uri: None,
policies: Vec::new(),
}
}
pub fn dead_letter_channel(uri: impl Into<String>) -> Self {
Self {
dlc_uri: Some(uri.into()),
policies: Vec::new(),
}
}
pub fn on_exception(
self,
matches: impl Fn(&CamelError) -> bool + Send + Sync + 'static,
) -> ExceptionPolicyBuilder {
ExceptionPolicyBuilder {
config: self,
policy: ExceptionPolicy::new(matches),
}
}
}
pub struct ExceptionPolicyBuilder {
config: ErrorHandlerConfig,
policy: ExceptionPolicy,
}
impl ExceptionPolicyBuilder {
pub fn retry(mut self, max_attempts: u32) -> Self {
self.policy.retry = Some(RedeliveryPolicy::new(max_attempts));
self
}
pub fn with_backoff(mut self, initial: Duration, multiplier: f64, max: Duration) -> Self {
if let Some(ref mut p) = self.policy.retry {
p.initial_delay = initial;
p.multiplier = multiplier;
p.max_delay = max;
}
self
}
pub fn with_jitter(mut self, jitter_factor: f64) -> Self {
if let Some(ref mut p) = self.policy.retry {
p.jitter_factor = jitter_factor.clamp(0.0, 1.0);
}
self
}
pub fn handled_by(mut self, uri: impl Into<String>) -> Self {
self.policy.handled_by = Some(uri.into());
self
}
pub fn build(mut self) -> ErrorHandlerConfig {
self.config.policies.push(self.policy);
self.config
}
}
#[deprecated(since = "0.1.0", note = "Use `RedeliveryPolicy` instead")]
pub type ExponentialBackoff = RedeliveryPolicy;
#[cfg(test)]
mod tests {
use super::*;
use crate::CamelError;
use std::time::Duration;
#[test]
fn test_redelivery_policy_defaults() {
let p = RedeliveryPolicy::new(3);
assert_eq!(p.max_attempts, 3);
assert_eq!(p.initial_delay, Duration::from_millis(100));
assert_eq!(p.multiplier, 2.0);
assert_eq!(p.max_delay, Duration::from_secs(10));
assert_eq!(p.jitter_factor, 0.0);
}
#[test]
fn test_exception_policy_matches() {
let policy = ExceptionPolicy::new(|e| matches!(e, CamelError::ProcessorError(_)));
assert!((policy.matches)(&CamelError::ProcessorError("oops".into())));
assert!(!(policy.matches)(&CamelError::Io("io".into())));
}
#[test]
fn test_error_handler_config_log_only() {
let config = ErrorHandlerConfig::log_only();
assert!(config.dlc_uri.is_none());
assert!(config.policies.is_empty());
}
#[test]
fn test_error_handler_config_dlc() {
let config = ErrorHandlerConfig::dead_letter_channel("log:dlc");
assert_eq!(config.dlc_uri.as_deref(), Some("log:dlc"));
}
#[test]
fn test_error_handler_config_with_policy() {
let config = ErrorHandlerConfig::dead_letter_channel("log:dlc")
.on_exception(|e| matches!(e, CamelError::Io(_)))
.retry(2)
.handled_by("log:io-errors")
.build();
assert_eq!(config.policies.len(), 1);
let p = &config.policies[0];
assert!(p.retry.is_some());
assert_eq!(p.retry.as_ref().unwrap().max_attempts, 2);
assert_eq!(p.handled_by.as_deref(), Some("log:io-errors"));
}
#[test]
fn test_jitter_applies_randomness() {
let policy = RedeliveryPolicy::new(3)
.with_initial_delay(Duration::from_millis(100))
.with_jitter(0.5);
let mut delays = std::collections::HashSet::new();
for _ in 0..10 {
delays.insert(policy.delay_for(0));
}
assert!(delays.len() > 1, "jitter should produce varying delays");
}
#[test]
fn test_jitter_stays_within_bounds() {
let policy = RedeliveryPolicy::new(3)
.with_initial_delay(Duration::from_millis(100))
.with_jitter(0.5);
for _ in 0..100 {
let delay = policy.delay_for(0);
assert!(
delay >= Duration::from_millis(50),
"delay too low: {:?}",
delay
);
assert!(
delay <= Duration::from_millis(150),
"delay too high: {:?}",
delay
);
}
}
#[test]
fn test_max_attempts_zero_means_no_retries() {
let policy = RedeliveryPolicy::new(0);
assert_eq!(policy.max_attempts, 0);
}
#[test]
fn test_jitter_zero_produces_exact_delay() {
let policy = RedeliveryPolicy::new(3)
.with_initial_delay(Duration::from_millis(100))
.with_jitter(0.0);
for _ in 0..10 {
let delay = policy.delay_for(0);
assert_eq!(delay, Duration::from_millis(100));
}
}
#[test]
fn test_jitter_one_produces_wide_range() {
let policy = RedeliveryPolicy::new(3)
.with_initial_delay(Duration::from_millis(100))
.with_jitter(1.0);
for _ in 0..100 {
let delay = policy.delay_for(0);
assert!(
delay >= Duration::from_millis(0),
"delay should be >= 0, got {:?}",
delay
);
assert!(
delay <= Duration::from_millis(200),
"delay should be <= 200ms, got {:?}",
delay
);
}
}
}