use std::time::{Duration, Instant};
use tracing::{debug, warn};
use crate::error::KrafkaError;
use crate::util::BackoffPolicy;
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub(crate) max_retries: u32,
pub(crate) backoff: BackoffPolicy,
pub(crate) delivery_timeout: Option<Duration>,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_retries: 3,
backoff: BackoffPolicy::default(),
delivery_timeout: Some(Duration::from_secs(120)),
}
}
}
impl RetryPolicy {
pub fn new() -> Self {
Self::default()
}
pub fn no_retries() -> Self {
Self {
max_retries: 0,
..Self::default()
}
}
pub fn max_retries(&self) -> u32 {
self.max_retries
}
pub fn initial_backoff(&self) -> Duration {
self.backoff.initial_backoff
}
pub fn max_backoff(&self) -> Duration {
self.backoff.max_backoff
}
pub fn backoff_multiplier(&self) -> f64 {
self.backoff.backoff_multiplier
}
pub fn jitter_factor(&self) -> f64 {
self.backoff.jitter_factor
}
pub fn delivery_timeout(&self) -> Option<Duration> {
self.delivery_timeout
}
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
pub fn with_initial_backoff(mut self, duration: Duration) -> Self {
self.backoff.initial_backoff = duration;
self
}
pub fn with_max_backoff(mut self, duration: Duration) -> Self {
self.backoff.max_backoff = duration;
self
}
pub fn with_backoff_multiplier(mut self, multiplier: f64) -> Self {
self.backoff.backoff_multiplier = if multiplier.is_finite() {
multiplier.max(1.0)
} else {
warn!("backoff_multiplier is not finite ({multiplier}); using default 2.0");
2.0
};
self
}
pub fn with_jitter_factor(mut self, factor: f64) -> Self {
self.backoff.jitter_factor = factor.clamp(0.0, 1.0);
self
}
pub fn with_delivery_timeout(mut self, timeout: Option<Duration>) -> Self {
self.delivery_timeout = timeout;
self
}
#[inline]
pub fn calculate_backoff(&self, attempt: u32) -> Duration {
self.backoff.calculate_backoff(attempt)
}
#[inline]
pub fn should_retry(&self, error: &KrafkaError, attempt: u32) -> bool {
attempt < self.max_retries && error.is_retriable()
}
#[inline]
pub fn max_retries_reached(&self, attempt: u32) -> bool {
attempt >= self.max_retries
}
}
#[derive(Debug)]
pub struct RetryContext {
policy: RetryPolicy,
attempt: u32,
operation: String,
started_at: Instant,
}
impl RetryContext {
pub fn new(policy: RetryPolicy, operation: impl Into<String>) -> Self {
Self {
policy,
attempt: 0,
operation: operation.into(),
started_at: Instant::now(),
}
}
pub fn new_with_start(
policy: RetryPolicy,
operation: impl Into<String>,
started_at: Instant,
) -> Self {
Self {
policy,
attempt: 0,
operation: operation.into(),
started_at,
}
}
#[inline]
pub fn attempt(&self) -> u32 {
self.attempt
}
#[inline]
pub fn operation(&self) -> &str {
&self.operation
}
pub fn record_failure(&mut self, error: &KrafkaError) -> Option<Duration> {
let elapsed = self.started_at.elapsed();
if let Some(deadline) = self.policy.delivery_timeout
&& elapsed >= deadline
{
warn!(
operation = %self.operation,
attempt = self.attempt,
elapsed_ms = elapsed.as_millis(),
error = %error,
"Delivery timeout exceeded, giving up"
);
return None;
}
if !self.policy.should_retry(error, self.attempt) {
if !error.is_retriable() {
debug!(
operation = %self.operation,
error = %error,
"Non-retriable error, not retrying"
);
} else {
warn!(
operation = %self.operation,
attempt = self.attempt,
max_retries = self.policy.max_retries,
error = %error,
"Max retries reached, giving up"
);
}
return None;
}
self.attempt += 1;
let backoff = self.policy.calculate_backoff(self.attempt);
let backoff = if let Some(deadline) = self.policy.delivery_timeout {
let remaining = deadline.saturating_sub(elapsed);
backoff.min(remaining)
} else {
backoff
};
debug!(
operation = %self.operation,
attempt = self.attempt,
max_retries = self.policy.max_retries,
backoff_ms = backoff.as_millis(),
error = %error,
"Retrying after failure"
);
Some(backoff)
}
pub fn record_success(&self) {
if self.attempt > 0 {
debug!(
operation = %self.operation,
attempt = self.attempt,
"Succeeded after retries"
);
}
}
pub async fn wait(&self, backoff: Duration) {
if !backoff.is_zero() {
tokio::time::sleep(backoff).await;
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_retry_policy_default() {
let policy = RetryPolicy::default();
assert_eq!(policy.max_retries, 3);
assert_eq!(policy.initial_backoff(), Duration::from_millis(100));
assert_eq!(policy.max_backoff(), Duration::from_secs(10));
assert_eq!(policy.backoff_multiplier(), 2.0);
}
#[test]
fn test_retry_policy_no_retries() {
let policy = RetryPolicy::no_retries();
assert_eq!(policy.max_retries, 0);
}
#[test]
fn test_retry_policy_builder() {
let policy = RetryPolicy::new()
.with_max_retries(5)
.with_initial_backoff(Duration::from_millis(50))
.with_max_backoff(Duration::from_secs(5))
.with_backoff_multiplier(3.0)
.with_jitter_factor(0.2);
assert_eq!(policy.max_retries, 5);
assert_eq!(policy.initial_backoff(), Duration::from_millis(50));
assert_eq!(policy.max_backoff(), Duration::from_secs(5));
assert_eq!(policy.backoff_multiplier(), 3.0);
assert_eq!(policy.jitter_factor(), 0.2);
}
#[test]
fn test_calculate_backoff_exponential() {
let policy = RetryPolicy::new()
.with_initial_backoff(Duration::from_millis(100))
.with_backoff_multiplier(2.0)
.with_jitter_factor(0.0);
assert_eq!(policy.calculate_backoff(0), Duration::ZERO);
assert_eq!(policy.calculate_backoff(1), Duration::from_millis(100));
assert_eq!(policy.calculate_backoff(2), Duration::from_millis(200));
assert_eq!(policy.calculate_backoff(3), Duration::from_millis(400));
}
#[test]
fn test_calculate_backoff_capped() {
let policy = RetryPolicy::new()
.with_initial_backoff(Duration::from_secs(1))
.with_max_backoff(Duration::from_secs(5))
.with_backoff_multiplier(10.0)
.with_jitter_factor(0.0);
assert_eq!(policy.calculate_backoff(2), Duration::from_secs(5));
}
#[test]
fn test_calculate_backoff_handles_max_attempt() {
let policy = RetryPolicy::new()
.with_initial_backoff(Duration::from_millis(100))
.with_max_backoff(Duration::from_secs(10))
.with_jitter_factor(0.0);
assert_eq!(policy.calculate_backoff(u32::MAX), Duration::from_secs(10));
}
#[test]
fn test_should_retry() {
let policy = RetryPolicy::new().with_max_retries(3);
let retriable_error = KrafkaError::timeout("test");
assert!(policy.should_retry(&retriable_error, 0));
assert!(policy.should_retry(&retriable_error, 1));
assert!(policy.should_retry(&retriable_error, 2));
assert!(!policy.should_retry(&retriable_error, 3)); assert!(!policy.should_retry(&retriable_error, 4));
let non_retriable = KrafkaError::config("test");
assert!(!policy.should_retry(&non_retriable, 0));
}
#[test]
fn test_retry_context() {
let policy = RetryPolicy::new().with_max_retries(3);
let mut ctx = RetryContext::new(policy, "test_operation");
assert_eq!(ctx.attempt(), 0);
assert_eq!(ctx.operation(), "test_operation");
let error = KrafkaError::timeout("test");
let backoff = ctx.record_failure(&error);
assert!(backoff.is_some());
assert_eq!(ctx.attempt(), 1);
let backoff = ctx.record_failure(&error);
assert!(backoff.is_some());
assert_eq!(ctx.attempt(), 2);
let backoff = ctx.record_failure(&error);
assert!(backoff.is_some());
assert_eq!(ctx.attempt(), 3);
let backoff = ctx.record_failure(&error);
assert!(backoff.is_none());
assert_eq!(ctx.attempt(), 3);
}
#[test]
fn test_retry_context_non_retriable() {
let policy = RetryPolicy::new().with_max_retries(5);
let mut ctx = RetryContext::new(policy, "test");
let error = KrafkaError::config("invalid config");
let backoff = ctx.record_failure(&error);
assert!(backoff.is_none());
}
#[test]
fn test_jitter_factor_clamped() {
let policy = RetryPolicy::new().with_jitter_factor(2.0); assert_eq!(policy.jitter_factor(), 1.0);
let policy = RetryPolicy::new().with_jitter_factor(-0.5); assert_eq!(policy.jitter_factor(), 0.0);
}
#[test]
fn test_calculate_backoff_never_below_initial_backoff() {
let policy = RetryPolicy::new()
.with_initial_backoff(Duration::from_millis(100))
.with_backoff_multiplier(2.0)
.with_jitter_factor(1.0);
let floor = policy.initial_backoff();
for attempt in 1..=5 {
let backoff = policy.calculate_backoff(attempt);
assert!(
backoff >= floor,
"attempt {attempt}: backoff {backoff:?} fell below initial_backoff {floor:?}"
);
}
}
#[test]
fn test_calculate_backoff_jitter_produces_varying_results() {
let policy = RetryPolicy::new()
.with_initial_backoff(Duration::from_millis(100))
.with_backoff_multiplier(2.0)
.with_jitter_factor(0.5);
let backoffs: Vec<Duration> = (0..50).map(|_| policy.calculate_backoff(2)).collect();
let unique_count = {
let mut unique: Vec<u128> = backoffs.iter().map(|d| d.as_nanos()).collect();
unique.sort();
unique.dedup();
unique.len()
};
assert!(
unique_count > 1,
"with jitter_factor > 0, calculate_backoff should produce varying results, but got {} unique values",
unique_count
);
}
#[test]
fn test_delivery_timeout_gives_up_when_exceeded() {
let policy = RetryPolicy::new()
.with_max_retries(10) .with_delivery_timeout(Some(Duration::from_millis(50)));
let started_at = Instant::now() - Duration::from_millis(100);
let mut ctx = RetryContext::new_with_start(policy, "test_timeout", started_at);
let error = KrafkaError::timeout("test");
let result = ctx.record_failure(&error);
assert!(
result.is_none(),
"should give up when delivery timeout exceeded"
);
assert_eq!(ctx.attempt(), 0);
}
#[test]
fn test_delivery_timeout_clamps_backoff_to_remaining_budget() {
let policy = RetryPolicy::new()
.with_max_retries(5)
.with_initial_backoff(Duration::from_secs(10)) .with_jitter_factor(0.0)
.with_delivery_timeout(Some(Duration::from_secs(1)));
let started_at = Instant::now() - Duration::from_millis(900);
let mut ctx = RetryContext::new_with_start(policy, "test_clamp", started_at);
let error = KrafkaError::timeout("test");
let backoff = ctx
.record_failure(&error)
.expect("should still retry within budget");
assert!(
backoff <= Duration::from_millis(150),
"backoff ({backoff:?}) should be clamped to remaining delivery budget"
);
}
#[test]
fn test_delivery_timeout_disabled_does_not_limit_retries() {
let policy = RetryPolicy::new()
.with_max_retries(2)
.with_delivery_timeout(None);
let started_at = Instant::now() - Duration::from_secs(3600);
let mut ctx = RetryContext::new_with_start(policy, "test_no_timeout", started_at);
let error = KrafkaError::timeout("test");
let backoff = ctx.record_failure(&error);
assert!(
backoff.is_some(),
"should retry when delivery_timeout is None"
);
}
}