use anyhow::{anyhow, Result};
use chrono::{DateTime, Duration as ChronoDuration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use crate::cook::retry_v2::{BackoffStrategy, RetryConfig};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryCheckpointState {
pub command_retry_states: HashMap<String, CommandRetryState>,
pub global_retry_config: Option<RetryConfig>,
pub retry_execution_history: Vec<RetryExecution>,
pub circuit_breaker_states: HashMap<String, CircuitBreakerState>,
pub retry_correlation_map: HashMap<String, String>,
pub checkpointed_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommandRetryState {
pub command_id: String,
pub attempt_count: u32,
pub max_attempts: u32,
pub last_attempt_at: Option<DateTime<Utc>>,
pub next_retry_at: Option<DateTime<Utc>>,
pub backoff_state: BackoffState,
pub retry_history: Vec<RetryAttempt>,
pub retry_config: Option<RetryConfig>,
pub is_circuit_broken: bool,
pub retry_budget_expires_at: Option<DateTime<Utc>>,
pub total_retry_duration: Duration,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackoffState {
pub strategy: BackoffStrategy,
#[serde(with = "humantime_serde")]
pub current_delay: Duration,
#[serde(with = "humantime_serde")]
pub base_delay: Duration,
#[serde(with = "humantime_serde")]
pub max_delay: Duration,
pub multiplier: f64,
pub jitter_enabled: bool,
pub jitter_factor: f64,
pub fibonacci_prev: Option<u64>,
pub fibonacci_curr: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryAttempt {
pub attempt_number: u32,
pub executed_at: DateTime<Utc>,
#[serde(with = "humantime_serde")]
pub duration: Duration,
pub success: bool,
pub error: Option<String>,
#[serde(with = "humantime_serde")]
pub backoff_applied: Duration,
pub exit_code: Option<i32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryExecution {
pub command_id: String,
pub correlation_id: String,
pub started_at: DateTime<Utc>,
pub completed_at: Option<DateTime<Utc>>,
pub total_attempts: u32,
pub succeeded: bool,
pub final_error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CircuitBreakerState {
pub state: CircuitState,
pub failure_count: u32,
pub failure_threshold: u32,
pub last_failure_at: Option<DateTime<Utc>>,
#[serde(with = "humantime_serde")]
pub recovery_timeout: Duration,
pub half_open_max_calls: u32,
pub half_open_success_count: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
pub struct RetryStateManager {
checkpoint_state: Arc<RwLock<Option<RetryCheckpointState>>>,
command_states: Arc<RwLock<HashMap<String, CommandRetryState>>>,
circuit_breakers: Arc<RwLock<HashMap<String, CircuitBreakerState>>>,
}
impl Default for RetryStateManager {
fn default() -> Self {
Self::new()
}
}
impl RetryStateManager {
#[cfg(test)]
pub(crate) fn get_command_states(&self) -> Arc<RwLock<HashMap<String, CommandRetryState>>> {
self.command_states.clone()
}
#[cfg(test)]
pub(crate) fn get_circuit_breakers(&self) -> Arc<RwLock<HashMap<String, CircuitBreakerState>>> {
self.circuit_breakers.clone()
}
pub fn new() -> Self {
Self {
checkpoint_state: Arc::new(RwLock::new(None)),
command_states: Arc::new(RwLock::new(HashMap::new())),
circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn create_checkpoint_state(&self) -> Result<RetryCheckpointState> {
let command_states = self.command_states.read().await;
let circuit_breakers = self.circuit_breakers.read().await;
let checkpoint = RetryCheckpointState {
command_retry_states: command_states.clone(),
global_retry_config: None, retry_execution_history: Vec::new(), circuit_breaker_states: circuit_breakers.clone(),
retry_correlation_map: HashMap::new(), checkpointed_at: Utc::now(),
};
Ok(checkpoint)
}
pub async fn restore_from_checkpoint(&self, checkpoint: &RetryCheckpointState) -> Result<()> {
info!(
"Restoring retry state from checkpoint at {}",
checkpoint.checkpointed_at
);
self.validate_checkpoint_consistency(checkpoint)?;
let mut command_states = self.command_states.write().await;
for (command_id, state) in &checkpoint.command_retry_states {
debug!(
"Restoring retry state for command {}: {} attempts",
command_id, state.attempt_count
);
let mut restored_state = state.clone();
if let Some(next_retry) = state.next_retry_at {
let elapsed = Utc::now() - checkpoint.checkpointed_at;
restored_state.next_retry_at = Some(next_retry + elapsed);
}
command_states.insert(command_id.clone(), restored_state);
}
let mut circuit_breakers = self.circuit_breakers.write().await;
for (command_id, cb_state) in &checkpoint.circuit_breaker_states {
debug!(
"Restoring circuit breaker for {}: {:?}",
command_id, cb_state.state
);
let mut restored_cb = cb_state.clone();
if cb_state.state == CircuitState::Open {
if let Some(last_failure) = cb_state.last_failure_at {
let elapsed = Utc::now() - last_failure;
if elapsed.num_seconds() as u64 >= cb_state.recovery_timeout.as_secs() {
restored_cb.state = CircuitState::HalfOpen;
restored_cb.half_open_success_count = 0;
info!(
"Circuit breaker for {} transitioned to half-open",
command_id
);
}
}
}
circuit_breakers.insert(command_id.clone(), restored_cb);
}
let mut checkpoint_state = self.checkpoint_state.write().await;
*checkpoint_state = Some(checkpoint.clone());
Ok(())
}
pub async fn get_command_retry_state(&self, command_id: &str) -> Option<CommandRetryState> {
let states = self.command_states.read().await;
states.get(command_id).cloned()
}
pub async fn update_retry_state(
&self,
command_id: &str,
attempt: RetryAttempt,
config: &RetryConfig,
) -> Result<()> {
let mut states = self.command_states.write().await;
let state = states
.entry(command_id.to_string())
.or_insert_with(|| CommandRetryState {
command_id: command_id.to_string(),
attempt_count: 0,
max_attempts: config.attempts,
last_attempt_at: None,
next_retry_at: None,
backoff_state: self.create_initial_backoff_state(config),
retry_history: Vec::new(),
retry_config: Some(config.clone()),
is_circuit_broken: false,
retry_budget_expires_at: config
.retry_budget
.map(|budget| Utc::now() + ChronoDuration::from_std(budget).unwrap()),
total_retry_duration: Duration::from_secs(0),
});
state.attempt_count += 1;
state.last_attempt_at = Some(attempt.executed_at);
state.retry_history.push(attempt.clone());
state.total_retry_duration += attempt.duration;
if !attempt.success && state.attempt_count < state.max_attempts {
let next_delay = self.calculate_next_delay(&mut state.backoff_state)?;
state.next_retry_at = Some(Utc::now() + ChronoDuration::from_std(next_delay)?);
state.backoff_state.current_delay = next_delay;
}
if !attempt.success {
self.update_circuit_breaker(command_id, false).await?;
} else {
self.update_circuit_breaker(command_id, true).await?;
}
Ok(())
}
pub async fn can_retry(&self, command_id: &str) -> Result<bool> {
let states = self.command_states.read().await;
let circuit_breakers = self.circuit_breakers.read().await;
if let Some(state) = states.get(command_id) {
if state.attempt_count >= state.max_attempts {
debug!("Command {} exceeded max attempts", command_id);
return Ok(false);
}
if let Some(expires_at) = state.retry_budget_expires_at {
if Utc::now() >= expires_at {
debug!("Command {} retry budget expired", command_id);
return Ok(false);
}
}
if let Some(cb) = circuit_breakers.get(command_id) {
if cb.state == CircuitState::Open {
debug!("Circuit breaker open for command {}", command_id);
return Ok(false);
}
}
Ok(true)
} else {
Ok(true)
}
}
fn calculate_next_delay(&self, backoff_state: &mut BackoffState) -> Result<Duration> {
let base_delay = match &backoff_state.strategy {
BackoffStrategy::Fixed => backoff_state.base_delay,
BackoffStrategy::Linear { increment } => backoff_state.current_delay + *increment,
BackoffStrategy::Exponential { base } => {
let millis = (backoff_state.current_delay.as_millis() as f64 * base) as u64;
Duration::from_millis(millis)
}
BackoffStrategy::Fibonacci => {
let (prev, curr) = if let (Some(p), Some(c)) =
(backoff_state.fibonacci_prev, backoff_state.fibonacci_curr)
{
(p, c)
} else {
(1, 1)
};
let next = prev + curr;
backoff_state.fibonacci_prev = Some(curr);
backoff_state.fibonacci_curr = Some(next);
Duration::from_secs(next)
}
BackoffStrategy::Custom { delays } => {
let index = backoff_state.current_delay.as_secs() as usize;
delays
.get(index)
.or_else(|| delays.last())
.copied()
.unwrap_or(backoff_state.base_delay)
}
};
let delay = if backoff_state.jitter_enabled {
let jitter_range = base_delay.as_millis() as f64 * backoff_state.jitter_factor;
let jitter = rand::random::<f64>() * jitter_range - (jitter_range / 2.0);
let millis = (base_delay.as_millis() as f64 + jitter).max(0.0) as u64;
Duration::from_millis(millis)
} else {
base_delay
};
Ok(delay.min(backoff_state.max_delay))
}
fn create_initial_backoff_state(&self, config: &RetryConfig) -> BackoffState {
BackoffState {
strategy: config.backoff.clone(),
current_delay: config.initial_delay,
base_delay: config.initial_delay,
max_delay: config.max_delay,
multiplier: match &config.backoff {
BackoffStrategy::Exponential { base } => *base,
_ => 2.0,
},
jitter_enabled: config.jitter,
jitter_factor: config.jitter_factor,
fibonacci_prev: None,
fibonacci_curr: None,
}
}
async fn update_circuit_breaker(&self, command_id: &str, success: bool) -> Result<()> {
let mut breakers = self.circuit_breakers.write().await;
let breaker = breakers.entry(command_id.to_string()).or_insert_with(|| {
CircuitBreakerState {
state: CircuitState::Closed,
failure_count: 0,
failure_threshold: 5, last_failure_at: None,
recovery_timeout: Duration::from_secs(60),
half_open_max_calls: 3,
half_open_success_count: 0,
}
});
match breaker.state {
CircuitState::Closed => {
if !success {
breaker.failure_count += 1;
breaker.last_failure_at = Some(Utc::now());
if breaker.failure_count >= breaker.failure_threshold {
breaker.state = CircuitState::Open;
warn!("Circuit breaker opened for command {}", command_id);
}
} else {
breaker.failure_count = 0;
}
}
CircuitState::Open => {
if let Some(last_failure) = breaker.last_failure_at {
let elapsed = Utc::now() - last_failure;
if elapsed.num_seconds() as u64 >= breaker.recovery_timeout.as_secs() {
breaker.state = CircuitState::HalfOpen;
breaker.half_open_success_count = 0;
info!("Circuit breaker half-open for command {}", command_id);
}
}
}
CircuitState::HalfOpen => {
if success {
breaker.half_open_success_count += 1;
if breaker.half_open_success_count >= breaker.half_open_max_calls {
breaker.state = CircuitState::Closed;
breaker.failure_count = 0;
info!("Circuit breaker closed for command {}", command_id);
}
} else {
breaker.state = CircuitState::Open;
breaker.last_failure_at = Some(Utc::now());
warn!("Circuit breaker re-opened for command {}", command_id);
}
}
}
Ok(())
}
fn validate_checkpoint_consistency(&self, checkpoint: &RetryCheckpointState) -> Result<()> {
for (command_id, state) in &checkpoint.command_retry_states {
if state.attempt_count > state.max_attempts + 1 {
return Err(anyhow!(
"Inconsistent retry state for {}: attempts {} > max {}",
command_id,
state.attempt_count,
state.max_attempts
));
}
if state.retry_history.len() as u32 != state.attempt_count {
warn!(
"Retry history mismatch for {}: {} history entries vs {} attempts",
command_id,
state.retry_history.len(),
state.attempt_count
);
}
}
Ok(())
}
pub async fn clear_command_state(&self, command_id: &str) {
let mut states = self.command_states.write().await;
let mut breakers = self.circuit_breakers.write().await;
states.remove(command_id);
breakers.remove(command_id);
debug!("Cleared retry state for command {}", command_id);
}
pub async fn get_retry_summary(&self) -> HashMap<String, (u32, u32, bool)> {
let states = self.command_states.read().await;
let breakers = self.circuit_breakers.read().await;
let mut summary = HashMap::new();
for (command_id, state) in states.iter() {
let is_open = breakers
.get(command_id)
.map(|b| b.state == CircuitState::Open)
.unwrap_or(false);
summary.insert(
command_id.clone(),
(state.attempt_count, state.max_attempts, is_open),
);
}
summary
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoordinatedRetryState {
pub work_item_retries: HashMap<String, WorkItemRetryState>,
pub dlq_retries: Vec<DlqRetryState>,
pub consistency_valid: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkItemRetryState {
pub work_item_id: String,
pub agent_id: String,
pub attempt_count: u32,
pub last_attempt_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DlqRetryState {
pub work_item_id: String,
pub dlq_retry_count: u32,
pub entered_dlq_at: DateTime<Utc>,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_retry_state_persistence_and_restoration() {
let manager = RetryStateManager::new();
let attempt = RetryAttempt {
attempt_number: 1,
executed_at: Utc::now(),
duration: Duration::from_secs(2),
success: false,
error: Some("Test error".to_string()),
backoff_applied: Duration::from_secs(0),
exit_code: Some(1),
};
let config = RetryConfig::default();
manager
.update_retry_state("test_cmd", attempt, &config)
.await
.unwrap();
let checkpoint = manager.create_checkpoint_state().await.unwrap();
assert_eq!(checkpoint.command_retry_states.len(), 1);
let new_manager = RetryStateManager::new();
new_manager
.restore_from_checkpoint(&checkpoint)
.await
.unwrap();
let restored = new_manager.get_command_retry_state("test_cmd").await;
assert!(restored.is_some());
assert_eq!(restored.unwrap().attempt_count, 1);
}
#[tokio::test]
async fn test_circuit_breaker_state_transitions() {
let manager = RetryStateManager::new();
for i in 0..5 {
let attempt = RetryAttempt {
attempt_number: i + 1,
executed_at: Utc::now(),
duration: Duration::from_secs(1),
success: false,
error: Some("Failed".to_string()),
backoff_applied: Duration::from_secs(i as u64),
exit_code: Some(1),
};
let config = RetryConfig::default();
manager
.update_retry_state("test_cmd", attempt, &config)
.await
.unwrap();
}
let can_retry = manager.can_retry("test_cmd").await.unwrap();
assert!(!can_retry, "Circuit should be open after failures");
let checkpoint = manager.create_checkpoint_state().await.unwrap();
let cb_state = checkpoint.circuit_breaker_states.get("test_cmd").unwrap();
assert_eq!(cb_state.state, CircuitState::Open);
}
#[tokio::test]
async fn test_retry_budget_enforcement() {
let manager = RetryStateManager::new();
let config = RetryConfig {
retry_budget: Some(Duration::from_secs(5)),
attempts: 100, ..Default::default()
};
let attempt = RetryAttempt {
attempt_number: 1,
executed_at: Utc::now() - ChronoDuration::seconds(10), duration: Duration::from_secs(1),
success: false,
error: Some("Failed".to_string()),
backoff_applied: Duration::from_secs(0),
exit_code: Some(1),
};
manager
.update_retry_state("budget_cmd", attempt, &config)
.await
.unwrap();
{
let mut states = manager.command_states.write().await;
if let Some(state) = states.get_mut("budget_cmd") {
state.retry_budget_expires_at = Some(Utc::now() - ChronoDuration::seconds(1));
}
}
let can_retry = manager.can_retry("budget_cmd").await.unwrap();
assert!(!can_retry, "Should not retry after budget expired");
}
}