pub mod checkpoint;
pub mod persistence;
pub mod recovery;
pub mod transitions;
#[cfg(test)]
mod tests;
use crate::cook::execution::errors::MapReduceError;
use crate::cook::execution::mapreduce::{AgentResult, MapReduceConfig};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::error;
pub struct StateManager {
store: Arc<dyn StateStore + Send + Sync>,
transitions: StateMachine,
audit_log: Arc<RwLock<Vec<StateEvent>>>,
}
#[async_trait::async_trait]
pub trait StateStore: Send + Sync {
async fn save(&self, state: &JobState) -> Result<(), StateError>;
async fn load(&self, job_id: &str) -> Result<Option<JobState>, StateError>;
async fn list(&self) -> Result<Vec<JobSummary>, StateError>;
async fn delete(&self, job_id: &str) -> Result<(), StateError>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JobState {
pub id: String,
pub phase: PhaseType,
pub checkpoint: Option<Checkpoint>,
pub processed_items: HashSet<String>,
pub failed_items: Vec<String>,
pub variables: HashMap<String, Value>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub config: MapReduceConfig,
pub agent_results: HashMap<String, AgentResult>,
pub is_complete: bool,
pub total_items: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint {
pub phase: PhaseType,
pub items_processed: Vec<String>,
pub agent_results: Vec<AgentResult>,
pub timestamp: DateTime<Utc>,
pub checksum: String,
pub version: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JobSummary {
pub job_id: String,
pub phase: PhaseType,
pub progress: JobProgress,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub is_complete: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JobProgress {
pub total_items: usize,
pub completed_items: usize,
pub failed_items: usize,
pub pending_items: usize,
pub completion_percentage: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum PhaseType {
Setup,
Map,
Reduce,
Completed,
Failed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StateEvent {
pub timestamp: DateTime<Utc>,
pub event_type: StateEventType,
pub job_id: String,
pub details: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum StateEventType {
JobCreated,
PhaseTransition { from: PhaseType, to: PhaseType },
CheckpointCreated { version: u32 },
RecoveryStarted { checkpoint_version: u32 },
ItemsProcessed { count: usize },
ItemsFailed { count: usize },
JobCompleted,
JobFailed { reason: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RecoveryPlan {
pub resume_phase: PhaseType,
pub pending_items: Vec<Value>,
pub skip_items: HashSet<String>,
pub variables: HashMap<String, Value>,
pub agent_results: HashMap<String, AgentResult>,
}
pub struct StateMachine {
transitions: HashMap<PhaseType, Vec<PhaseType>>,
}
#[derive(Debug, thiserror::Error)]
pub enum StateError {
#[error("Failed to persist state: {0}")]
PersistenceError(String),
#[error("Failed to load state: {0}")]
LoadError(String),
#[error("Invalid transition from {from:?} to {to:?}")]
InvalidTransition { from: PhaseType, to: PhaseType },
#[error("Checkpoint validation failed: {0}")]
ValidationError(String),
#[error("State not found for job: {0}")]
NotFound(String),
#[error("State was modified concurrently")]
ConcurrentModification,
}
impl StateManager {
pub fn new(store: Arc<dyn StateStore + Send + Sync>) -> Self {
Self {
store,
transitions: StateMachine::new(),
audit_log: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn create_job(
&self,
config: &MapReduceConfig,
job_id: String,
) -> Result<JobState, StateError> {
let state = JobState {
id: job_id.clone(),
phase: PhaseType::Setup,
checkpoint: None,
processed_items: HashSet::new(),
failed_items: Vec::new(),
variables: HashMap::new(),
created_at: Utc::now(),
updated_at: Utc::now(),
config: config.clone(),
agent_results: HashMap::new(),
is_complete: false,
total_items: 0,
};
self.store.save(&state).await?;
self.log_event(StateEvent {
timestamp: Utc::now(),
event_type: StateEventType::JobCreated,
job_id,
details: None,
})
.await;
Ok(state)
}
pub async fn update_state<F>(&self, job_id: &str, updater: F) -> Result<JobState, StateError>
where
F: FnOnce(&mut JobState) -> Result<(), StateError>,
{
let mut state = self
.store
.load(job_id)
.await?
.ok_or_else(|| StateError::NotFound(job_id.to_string()))?;
let original_phase = state.phase;
updater(&mut state)?;
state.updated_at = Utc::now();
self.store.save(&state).await?;
if state.phase != original_phase {
self.log_event(StateEvent {
timestamp: Utc::now(),
event_type: StateEventType::PhaseTransition {
from: original_phase,
to: state.phase,
},
job_id: job_id.to_string(),
details: None,
})
.await;
}
Ok(state)
}
pub async fn get_state(&self, job_id: &str) -> Result<Option<JobState>, StateError> {
self.store.load(job_id).await
}
pub async fn list_jobs(&self) -> Result<Vec<JobSummary>, StateError> {
self.store.list().await
}
pub async fn get_state_history(&self, job_id: &str) -> Vec<StateEvent> {
let log = self.audit_log.read().await;
log.iter()
.filter(|event| event.job_id == job_id)
.cloned()
.collect()
}
async fn log_event(&self, event: StateEvent) {
let mut log = self.audit_log.write().await;
log.push(event);
}
}
impl Default for StateMachine {
fn default() -> Self {
Self::new()
}
}
impl StateMachine {
pub fn new() -> Self {
let mut transitions = HashMap::new();
transitions.insert(PhaseType::Setup, vec![PhaseType::Map, PhaseType::Failed]);
transitions.insert(
PhaseType::Map,
vec![PhaseType::Reduce, PhaseType::Completed, PhaseType::Failed],
);
transitions.insert(
PhaseType::Reduce,
vec![PhaseType::Completed, PhaseType::Failed],
);
transitions.insert(PhaseType::Failed, vec![]); transitions.insert(PhaseType::Completed, vec![]);
Self { transitions }
}
pub fn is_valid_transition(&self, from: PhaseType, to: PhaseType) -> bool {
self.transitions
.get(&from)
.map(|valid| valid.contains(&to))
.unwrap_or(false)
}
pub fn get_valid_transitions(&self, from: PhaseType) -> Vec<PhaseType> {
self.transitions.get(&from).cloned().unwrap_or_default()
}
}
impl From<StateError> for MapReduceError {
fn from(err: StateError) -> Self {
match err {
StateError::NotFound(job_id) => MapReduceError::JobNotFound { job_id },
StateError::ValidationError(details) => MapReduceError::CheckpointCorrupted {
job_id: String::new(),
version: 0,
details,
},
_ => MapReduceError::General {
message: err.to_string(),
source: None,
},
}
}
}