use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use std::collections::HashMap;
use std::sync::Arc;
use tasker_shared::config::tasker::TaskerConfig;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackoffCalculatorConfig {
pub base_delay_seconds: u32,
pub max_delay_seconds: u32,
pub multiplier: f64,
pub jitter_enabled: bool,
pub max_jitter: f64,
}
impl Default for BackoffCalculatorConfig {
fn default() -> Self {
Self {
base_delay_seconds: 1,
max_delay_seconds: 300, multiplier: 2.0,
jitter_enabled: true,
max_jitter: 0.1, }
}
}
impl From<Arc<TaskerConfig>> for BackoffCalculatorConfig {
fn from(config: Arc<TaskerConfig>) -> BackoffCalculatorConfig {
let base_delay_seconds = config
.common
.backoff
.default_backoff_seconds
.first()
.copied()
.unwrap_or(1);
BackoffCalculatorConfig {
base_delay_seconds,
max_delay_seconds: config.common.backoff.max_backoff_seconds,
multiplier: config.common.backoff.backoff_multiplier,
jitter_enabled: config.common.backoff.jitter_enabled,
max_jitter: config.common.backoff.jitter_max_percentage,
}
}
}
impl From<&TaskerConfig> for BackoffCalculatorConfig {
fn from(config: &TaskerConfig) -> BackoffCalculatorConfig {
let base_delay_seconds = config
.common
.backoff
.default_backoff_seconds
.first()
.copied()
.unwrap_or(1);
BackoffCalculatorConfig {
base_delay_seconds,
max_delay_seconds: config.common.backoff.max_backoff_seconds,
multiplier: config.common.backoff.backoff_multiplier,
jitter_enabled: config.common.backoff.jitter_enabled,
max_jitter: config.common.backoff.jitter_max_percentage,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackoffContext {
pub headers: HashMap<String, String>,
pub error_context: Option<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl Default for BackoffContext {
fn default() -> Self {
Self::new()
}
}
impl BackoffContext {
pub fn new() -> Self {
Self {
headers: HashMap::new(),
error_context: None,
metadata: HashMap::new(),
}
}
pub fn with_header(mut self, key: String, value: String) -> Self {
self.headers.insert(key, value);
self
}
pub fn with_error(mut self, error: String) -> Self {
self.error_context = Some(error);
self
}
pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
self.metadata.insert(key, value);
self
}
}
#[derive(Clone, Debug)]
pub struct BackoffCalculator {
config: BackoffCalculatorConfig,
pool: PgPool,
}
impl BackoffCalculator {
pub fn new(config: BackoffCalculatorConfig, pool: PgPool) -> Self {
Self { config, pool }
}
pub fn with_defaults(pool: PgPool) -> Self {
Self::new(BackoffCalculatorConfig::default(), pool)
}
pub async fn calculate_and_apply_backoff(
&self,
step_uuid: &Uuid,
context: BackoffContext,
) -> Result<BackoffResult, BackoffError> {
if let Some(retry_after) = self.extract_retry_after_header(&context) {
self.apply_server_requested_backoff(step_uuid, retry_after)
.await
} else {
self.apply_exponential_backoff(step_uuid, &context).await
}
}
fn extract_retry_after_header(&self, context: &BackoffContext) -> Option<u32> {
let retry_after_value = context
.headers
.iter()
.find(|(key, _)| key.to_lowercase() == "retry-after")
.map(|(_, value)| value)?;
if let Ok(seconds) = retry_after_value.parse::<u32>() {
return Some(seconds);
}
if let Ok(date) = DateTime::parse_from_rfc2822(retry_after_value) {
let now = Utc::now();
let diff = date.signed_duration_since(now);
if diff.num_seconds() > 0 {
return Some(diff.num_seconds() as u32);
}
}
None
}
async fn update_backoff_atomic(
&self,
step_uuid: &Uuid,
delay_seconds: u32,
) -> Result<(), BackoffError> {
let mut tx = self.pool.begin().await.map_err(BackoffError::Database)?;
sqlx::query!(
"SELECT workflow_step_uuid
FROM tasker.workflow_steps
WHERE workflow_step_uuid = $1
FOR UPDATE",
step_uuid
)
.fetch_one(&mut *tx)
.await
.map_err(BackoffError::Database)?;
sqlx::query!(
"UPDATE tasker.workflow_steps
SET backoff_request_seconds = $1,
last_attempted_at = NOW(),
updated_at = NOW()
WHERE workflow_step_uuid = $2",
delay_seconds as i32,
step_uuid
)
.execute(&mut *tx)
.await
.map_err(BackoffError::Database)?;
tx.commit().await.map_err(BackoffError::Database)?;
Ok(())
}
async fn apply_server_requested_backoff(
&self,
step_uuid: &Uuid,
retry_after_seconds: u32,
) -> Result<BackoffResult, BackoffError> {
let delay_seconds = retry_after_seconds.min(self.config.max_delay_seconds);
self.update_backoff_atomic(step_uuid, delay_seconds).await?;
Ok(BackoffResult {
delay_seconds,
backoff_type: BackoffType::ServerRequested,
next_retry_at: Utc::now() + Duration::seconds(delay_seconds as i64),
})
}
async fn apply_exponential_backoff(
&self,
step_uuid: &Uuid,
_context: &BackoffContext,
) -> Result<BackoffResult, BackoffError> {
let step = sqlx::query!(
"SELECT attempts FROM tasker.workflow_steps WHERE workflow_step_uuid = $1",
step_uuid
)
.fetch_one(&self.pool)
.await
.map_err(BackoffError::Database)?;
let attempts = step.attempts.unwrap_or(0) as u32;
let base_delay = self.config.base_delay_seconds as f64;
let multiplier = self.config.multiplier;
let exponential_delay = base_delay * multiplier.powi(attempts as i32);
let mut delay_seconds = exponential_delay.min(self.config.max_delay_seconds as f64) as u32;
if self.config.jitter_enabled {
delay_seconds = self.apply_jitter(delay_seconds);
}
self.update_backoff_atomic(step_uuid, delay_seconds).await?;
Ok(BackoffResult {
delay_seconds,
backoff_type: BackoffType::Exponential,
next_retry_at: Utc::now() + Duration::seconds(delay_seconds as i64),
})
}
fn apply_jitter(&self, delay_seconds: u32) -> u32 {
use rand::Rng;
let jitter_range = (delay_seconds as f64 * self.config.max_jitter) as u32;
if jitter_range == 0 {
return delay_seconds;
}
let mut rng = rand::rng();
let jitter = rng.random_range(0..=jitter_range);
if rng.random_bool(0.5) {
delay_seconds.saturating_add(jitter)
} else {
delay_seconds.saturating_sub(jitter)
}
}
pub async fn is_ready_to_retry(&self, step_uuid: Uuid) -> Result<bool, BackoffError> {
let step = sqlx::query!(
r#"
SELECT backoff_request_seconds, last_attempted_at
FROM tasker.workflow_steps
WHERE workflow_step_uuid = $1
"#,
step_uuid
)
.fetch_one(&self.pool)
.await
.map_err(BackoffError::Database)?;
match (step.backoff_request_seconds, step.last_attempted_at) {
(Some(backoff_seconds), Some(last_attempt)) => {
let retry_available_at = last_attempt + Duration::seconds(backoff_seconds as i64);
Ok(Utc::now().naive_utc() >= retry_available_at)
}
_ => Ok(true), }
}
pub async fn clear_backoff(&self, step_uuid: Uuid) -> Result<(), BackoffError> {
sqlx::query!(
"UPDATE tasker.workflow_steps SET backoff_request_seconds = NULL WHERE workflow_step_uuid = $1",
step_uuid
)
.execute(&self.pool)
.await
.map_err(BackoffError::Database)?;
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackoffResult {
pub delay_seconds: u32,
pub backoff_type: BackoffType,
pub next_retry_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum BackoffType {
ServerRequested,
Exponential,
}
#[derive(Debug, thiserror::Error)]
pub enum BackoffError {
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
#[error("Step not found: {0}")]
StepNotFound(i64),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backoff_config_default() {
let config = BackoffCalculatorConfig::default();
assert_eq!(config.base_delay_seconds, 1);
assert_eq!(config.max_delay_seconds, 300);
assert_eq!(config.multiplier, 2.0);
assert!(config.jitter_enabled);
assert_eq!(config.max_jitter, 0.1);
}
#[test]
fn test_backoff_context_builder() {
let context = BackoffContext::new()
.with_header("retry-after".to_string(), "60".to_string())
.with_error("Rate limited".to_string());
assert_eq!(context.headers.get("retry-after"), Some(&"60".to_string()));
assert_eq!(context.error_context, Some("Rate limited".to_string()));
}
#[test]
fn test_extract_retry_after_seconds() {
let context =
BackoffContext::new().with_header("retry-after".to_string(), "120".to_string());
let headers = &context.headers;
let retry_after = headers
.get("retry-after")
.and_then(|value| value.parse::<u32>().ok());
assert_eq!(retry_after, Some(120));
}
#[test]
fn test_backoff_context_default() {
let context = BackoffContext::default();
assert!(context.headers.is_empty());
assert!(context.error_context.is_none());
assert!(context.metadata.is_empty());
}
#[test]
fn test_backoff_context_with_metadata() {
let context = BackoffContext::new()
.with_metadata("retry_count".to_string(), serde_json::json!(3))
.with_metadata("source".to_string(), serde_json::json!("api_gateway"));
assert_eq!(context.metadata.len(), 2);
assert_eq!(context.metadata["retry_count"], serde_json::json!(3));
assert_eq!(context.metadata["source"], serde_json::json!("api_gateway"));
}
#[test]
fn test_backoff_context_full_builder_chain() {
let context = BackoffContext::new()
.with_header("Retry-After".to_string(), "60".to_string())
.with_header("X-RateLimit-Reset".to_string(), "1700000000".to_string())
.with_error("429 Too Many Requests".to_string())
.with_metadata("endpoint".to_string(), serde_json::json!("/api/v1/tasks"));
assert_eq!(context.headers.len(), 2);
assert!(context.error_context.is_some());
assert_eq!(context.metadata.len(), 1);
}
#[test]
fn test_backoff_config_custom_values() {
let config = BackoffCalculatorConfig {
base_delay_seconds: 5,
max_delay_seconds: 600,
multiplier: 3.0,
jitter_enabled: false,
max_jitter: 0.2,
};
assert_eq!(config.base_delay_seconds, 5);
assert_eq!(config.max_delay_seconds, 600);
assert_eq!(config.multiplier, 3.0);
assert!(!config.jitter_enabled);
assert_eq!(config.max_jitter, 0.2);
}
#[test]
fn test_backoff_config_serialization_roundtrip() {
let config = BackoffCalculatorConfig::default();
let json = serde_json::to_string(&config).expect("serialize");
let deserialized: BackoffCalculatorConfig =
serde_json::from_str(&json).expect("deserialize");
assert_eq!(deserialized.base_delay_seconds, config.base_delay_seconds);
assert_eq!(deserialized.max_delay_seconds, config.max_delay_seconds);
assert_eq!(deserialized.multiplier, config.multiplier);
assert_eq!(deserialized.jitter_enabled, config.jitter_enabled);
assert_eq!(deserialized.max_jitter, config.max_jitter);
}
#[test]
fn test_backoff_result_construction() {
let now = Utc::now();
let result = BackoffResult {
delay_seconds: 30,
backoff_type: BackoffType::Exponential,
next_retry_at: now + Duration::seconds(30),
};
assert_eq!(result.delay_seconds, 30);
assert!(matches!(result.backoff_type, BackoffType::Exponential));
assert!(result.next_retry_at > now);
}
#[test]
fn test_backoff_result_server_requested_type() {
let result = BackoffResult {
delay_seconds: 120,
backoff_type: BackoffType::ServerRequested,
next_retry_at: Utc::now() + Duration::seconds(120),
};
assert_eq!(result.delay_seconds, 120);
assert!(matches!(result.backoff_type, BackoffType::ServerRequested));
}
#[test]
fn test_backoff_error_display_messages() {
let db_err = BackoffError::Database(sqlx::Error::ColumnNotFound("col".to_string()));
assert!(db_err.to_string().contains("Database error"));
let config_err = BackoffError::InvalidConfig("negative delay".to_string());
assert_eq!(
config_err.to_string(),
"Invalid configuration: negative delay"
);
let step_err = BackoffError::StepNotFound(99);
assert_eq!(step_err.to_string(), "Step not found: 99");
}
#[test]
fn test_backoff_result_serialization() {
let result = BackoffResult {
delay_seconds: 60,
backoff_type: BackoffType::Exponential,
next_retry_at: Utc::now(),
};
let json = serde_json::to_value(&result).expect("serialize");
assert_eq!(json["delay_seconds"], 60);
assert_eq!(json["backoff_type"], "Exponential");
}
fn make_calculator() -> BackoffCalculator {
let pool = PgPool::connect_lazy("postgresql://test").expect("lazy pool");
BackoffCalculator::with_defaults(pool)
}
fn make_calculator_with_config(config: BackoffCalculatorConfig) -> BackoffCalculator {
let pool = PgPool::connect_lazy("postgresql://test").expect("lazy pool");
BackoffCalculator::new(config, pool)
}
#[tokio::test]
async fn test_extract_retry_after_case_insensitive() {
let calculator = make_calculator();
let ctx = BackoffContext::new().with_header("retry-after".to_string(), "30".to_string());
assert_eq!(calculator.extract_retry_after_header(&ctx), Some(30));
let ctx = BackoffContext::new().with_header("Retry-After".to_string(), "60".to_string());
assert_eq!(calculator.extract_retry_after_header(&ctx), Some(60));
let ctx = BackoffContext::new().with_header("RETRY-AFTER".to_string(), "90".to_string());
assert_eq!(calculator.extract_retry_after_header(&ctx), Some(90));
}
#[tokio::test]
async fn test_extract_retry_after_missing_header() {
let calculator = make_calculator();
let ctx = BackoffContext::new().with_header("X-Custom".to_string(), "value".to_string());
assert_eq!(calculator.extract_retry_after_header(&ctx), None);
}
#[tokio::test]
async fn test_extract_retry_after_invalid_value() {
let calculator = make_calculator();
let ctx = BackoffContext::new()
.with_header("retry-after".to_string(), "not-a-number".to_string());
assert_eq!(calculator.extract_retry_after_header(&ctx), None);
}
#[tokio::test]
async fn test_extract_retry_after_empty_headers() {
let calculator = make_calculator();
let ctx = BackoffContext::new();
assert_eq!(calculator.extract_retry_after_header(&ctx), None);
}
#[tokio::test]
async fn test_extract_retry_after_rfc2822_date() {
let calculator = make_calculator();
let future = Utc::now() + Duration::seconds(120);
let rfc2822 = future.to_rfc2822();
let ctx = BackoffContext::new().with_header("Retry-After".to_string(), rfc2822);
let result = calculator.extract_retry_after_header(&ctx);
assert!(result.is_some());
let seconds = result.unwrap();
assert!(
(118..=122).contains(&seconds),
"Expected ~120, got {seconds}"
);
}
#[tokio::test]
async fn test_extract_retry_after_rfc2822_past_date() {
let calculator = make_calculator();
let past = Utc::now() - Duration::seconds(60);
let rfc2822 = past.to_rfc2822();
let ctx = BackoffContext::new().with_header("Retry-After".to_string(), rfc2822);
assert_eq!(calculator.extract_retry_after_header(&ctx), None);
}
#[tokio::test]
async fn test_extract_retry_after_zero_seconds() {
let calculator = make_calculator();
let ctx = BackoffContext::new().with_header("retry-after".to_string(), "0".to_string());
assert_eq!(calculator.extract_retry_after_header(&ctx), Some(0));
}
#[tokio::test]
async fn test_extract_retry_after_large_value() {
let calculator = make_calculator();
let ctx = BackoffContext::new().with_header("retry-after".to_string(), "86400".to_string());
assert_eq!(calculator.extract_retry_after_header(&ctx), Some(86400));
}
#[tokio::test]
async fn test_extract_retry_after_negative_value() {
let calculator = make_calculator();
let ctx = BackoffContext::new().with_header("retry-after".to_string(), "-30".to_string());
assert_eq!(calculator.extract_retry_after_header(&ctx), None);
}
#[tokio::test]
async fn test_extract_retry_after_empty_value() {
let calculator = make_calculator();
let ctx = BackoffContext::new().with_header("retry-after".to_string(), String::new());
assert_eq!(calculator.extract_retry_after_header(&ctx), None);
}
#[tokio::test]
async fn test_apply_jitter_within_bounds() {
let config = BackoffCalculatorConfig {
jitter_enabled: true,
max_jitter: 0.1, ..Default::default()
};
let calculator = make_calculator_with_config(config);
for _ in 0..50 {
let jittered = calculator.apply_jitter(100);
assert!(jittered >= 90, "Jitter too low: {jittered}");
assert!(jittered <= 110, "Jitter too high: {jittered}");
}
}
#[tokio::test]
async fn test_apply_jitter_zero_delay() {
let calculator = make_calculator();
let jittered = calculator.apply_jitter(0);
assert_eq!(jittered, 0, "Zero delay should remain zero");
}
#[tokio::test]
async fn test_apply_jitter_small_delay_no_underflow() {
let calculator = make_calculator();
let jittered = calculator.apply_jitter(1);
assert_eq!(jittered, 1);
}
#[tokio::test]
async fn test_apply_jitter_large_delay() {
let config = BackoffCalculatorConfig {
max_jitter: 0.1,
..Default::default()
};
let calculator = make_calculator_with_config(config);
for _ in 0..50 {
let jittered = calculator.apply_jitter(10_000);
assert!(jittered >= 9000, "Jitter too low: {jittered}");
assert!(jittered <= 11000, "Jitter too high: {jittered}");
}
}
#[tokio::test]
async fn test_apply_jitter_zero_max_jitter() {
let config = BackoffCalculatorConfig {
max_jitter: 0.0,
..Default::default()
};
let calculator = make_calculator_with_config(config);
let jittered = calculator.apply_jitter(100);
assert_eq!(jittered, 100);
}
#[tokio::test]
async fn test_apply_jitter_saturating_behavior() {
let config = BackoffCalculatorConfig {
max_jitter: 0.5, ..Default::default()
};
let calculator = make_calculator_with_config(config);
for _ in 0..20 {
let jittered = calculator.apply_jitter(u32::MAX);
assert!(jittered > 0);
}
}
#[test]
fn test_backoff_type_serialization_roundtrip() {
let types = [BackoffType::ServerRequested, BackoffType::Exponential];
for bt in &types {
let json = serde_json::to_string(bt).expect("serialize");
let deserialized: BackoffType = serde_json::from_str(&json).expect("deserialize");
assert_eq!(format!("{deserialized:?}"), format!("{bt:?}"));
}
}
#[test]
fn test_backoff_type_debug() {
assert_eq!(
format!("{:?}", BackoffType::ServerRequested),
"ServerRequested"
);
assert_eq!(format!("{:?}", BackoffType::Exponential), "Exponential");
}
#[test]
fn test_backoff_error_invalid_config_display() {
let err = BackoffError::InvalidConfig("multiplier must be > 0".to_string());
assert_eq!(
err.to_string(),
"Invalid configuration: multiplier must be > 0"
);
}
#[test]
fn test_backoff_error_step_not_found_display() {
let err = BackoffError::StepNotFound(42);
assert_eq!(err.to_string(), "Step not found: 42");
}
#[test]
fn test_backoff_context_header_overwrite() {
let context = BackoffContext::new()
.with_header("retry-after".to_string(), "60".to_string())
.with_header("retry-after".to_string(), "120".to_string());
assert_eq!(context.headers.get("retry-after"), Some(&"120".to_string()));
assert_eq!(context.headers.len(), 1);
}
#[test]
fn test_backoff_context_metadata_overwrite() {
let context = BackoffContext::new()
.with_metadata("count".to_string(), serde_json::json!(1))
.with_metadata("count".to_string(), serde_json::json!(2));
assert_eq!(context.metadata["count"], serde_json::json!(2));
assert_eq!(context.metadata.len(), 1);
}
#[test]
fn test_backoff_result_clone() {
let result = BackoffResult {
delay_seconds: 45,
backoff_type: BackoffType::Exponential,
next_retry_at: Utc::now(),
};
let cloned = result.clone();
assert_eq!(cloned.delay_seconds, 45);
assert!(matches!(cloned.backoff_type, BackoffType::Exponential));
}
#[tokio::test]
async fn test_backoff_calculator_debug() {
let calculator = make_calculator();
let debug = format!("{calculator:?}");
assert!(debug.contains("BackoffCalculator"));
assert!(debug.contains("config"));
}
#[tokio::test]
async fn test_backoff_calculator_clone() {
let calculator = make_calculator();
let cloned = calculator.clone();
assert_eq!(
cloned.config.base_delay_seconds,
calculator.config.base_delay_seconds
);
}
}