use std::time::Duration;
use sea_orm::{ConnectionTrait, DbBackend, Statement, Value};
use uuid::Uuid;
#[derive(Debug, Clone)]
pub enum RetryPolicy {
Exponential {
max_retries: i32,
base_delay: Duration,
},
Fixed { max_retries: i32, delay: Duration },
None,
}
impl RetryPolicy {
pub fn exponential(max_retries: i32, base_delay: Duration) -> Self {
Self::Exponential {
max_retries,
base_delay,
}
}
pub fn fixed(max_retries: i32, delay: Duration) -> Self {
Self::Fixed { max_retries, delay }
}
pub fn none() -> Self {
Self::None
}
pub(crate) fn backoff_delay(&self, attempts: i32, job_id: Uuid) -> Duration {
match self {
Self::Exponential { base_delay, .. } => {
exponential_delay(*base_delay, attempts, job_id)
}
Self::Fixed { delay, .. } => {
if attempts <= 1 {
Duration::ZERO
} else {
*delay
}
}
Self::None => Duration::ZERO,
}
}
}
fn jitter(base: Duration, job_id: Uuid) -> Duration {
if base.is_zero() {
return Duration::ZERO;
}
let seed = job_id.as_u128();
let base_nanos = base.as_nanos().min(u128::from(u64::MAX)) as u64;
Duration::from_nanos((seed % u128::from(base_nanos)) as u64)
}
fn exponential_delay(base: Duration, attempts: i32, job_id: Uuid) -> Duration {
if attempts <= 1 {
return Duration::ZERO;
}
let exponent = (attempts - 2) as f64;
let multiplier = 4.0_f64.powf(exponent);
let secs = (base.as_secs_f64() * multiplier).min(7.0 * 24.0 * 3600.0);
Duration::from_secs_f64(secs) + jitter(base, job_id)
}
pub(crate) async fn apply_failure(
db: &impl ConnectionTrait,
job_id: Uuid,
error: &str,
attempts: i32,
max_retries: i32,
policy: &RetryPolicy,
) -> Result<(), sea_orm::DbErr> {
let new_attempts = attempts + 1;
if new_attempts < max_retries {
let delay_secs = policy.backoff_delay(new_attempts, job_id).as_secs_f64();
db.execute(Statement::from_sql_and_values(
DbBackend::Postgres,
r#"UPDATE rapina_jobs
SET attempts = attempts + 1,
last_error = $1,
status = 'pending',
run_at = NOW() + make_interval(secs => $2),
locked_until = NULL,
started_at = NULL
WHERE id = $3::uuid"#,
[
Value::String(Some(Box::new(error.to_owned()))),
Value::Double(Some(delay_secs)),
Value::String(Some(Box::new(job_id.to_string()))),
],
))
.await?;
} else {
db.execute(Statement::from_sql_and_values(
DbBackend::Postgres,
r#"UPDATE rapina_jobs
SET attempts = attempts + 1,
last_error = $1,
status = 'failed',
finished_at = NOW()
WHERE id = $2::uuid"#,
[
Value::String(Some(Box::new(error.to_owned()))),
Value::String(Some(Box::new(job_id.to_string()))),
],
))
.await?;
}
Ok(())
}
pub(crate) async fn apply_success(
db: &impl ConnectionTrait,
job_id: Uuid,
) -> Result<(), sea_orm::DbErr> {
db.execute(Statement::from_sql_and_values(
DbBackend::Postgres,
r#"UPDATE rapina_jobs
SET status = 'completed',
finished_at = NOW(),
locked_until = NULL
WHERE id = $1::uuid"#,
[Value::String(Some(Box::new(job_id.to_string())))],
))
.await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
const JOB_ID: Uuid = Uuid::from_u128(0xdeadbeef_cafe_babe_1234_56789abcdef0);
#[test]
fn exponential_attempt_1_is_immediate() {
assert_eq!(
exponential_delay(Duration::from_secs(1), 1, JOB_ID),
Duration::ZERO
);
}
#[test]
fn exponential_attempt_2_equals_base() {
let base = Duration::from_secs(1);
let delay = exponential_delay(base, 2, JOB_ID);
assert!(delay >= base);
assert!(delay < base * 2);
}
#[test]
fn exponential_attempt_3_is_4x_base() {
let base = Duration::from_secs(1);
let delay = exponential_delay(base, 3, JOB_ID);
assert!(delay >= base * 4);
assert!(delay < base * 5);
}
#[test]
fn exponential_attempt_4_is_16x_base() {
let base = Duration::from_secs(1);
let delay = exponential_delay(base, 4, JOB_ID);
assert!(delay >= base * 16);
assert!(delay < base * 17);
}
#[test]
fn exponential_caps_at_one_week() {
let base = Duration::from_secs(1);
let one_week = Duration::from_secs(7 * 24 * 3600);
let delay = exponential_delay(base, 50, JOB_ID);
assert!(delay <= one_week + base); }
#[test]
fn fixed_attempt_1_is_immediate() {
let policy = RetryPolicy::fixed(5, Duration::from_secs(10));
assert_eq!(policy.backoff_delay(1, JOB_ID), Duration::ZERO);
}
#[test]
fn fixed_attempt_2_returns_configured_delay() {
let d = Duration::from_secs(10);
let policy = RetryPolicy::fixed(5, d);
assert_eq!(policy.backoff_delay(2, JOB_ID), d);
}
#[test]
fn none_always_returns_zero() {
let policy = RetryPolicy::none();
for attempt in 1..=5 {
assert_eq!(policy.backoff_delay(attempt, JOB_ID), Duration::ZERO);
}
}
#[test]
fn jitter_is_within_range() {
let base = Duration::from_secs(10);
assert!(jitter(base, JOB_ID) < base);
}
#[test]
fn jitter_zero_base_returns_zero() {
assert_eq!(jitter(Duration::ZERO, JOB_ID), Duration::ZERO);
}
#[test]
fn jitter_is_deterministic() {
let base = Duration::from_secs(10);
assert_eq!(jitter(base, JOB_ID), jitter(base, JOB_ID));
}
#[test]
fn different_job_ids_produce_different_jitter() {
let base = Duration::from_secs(10);
let id1 = Uuid::from_u128(1);
let id2 = Uuid::from_u128(2);
assert_ne!(jitter(base, id1), jitter(base, id2));
}
}