use super::{JobProgress, JobState, JobSummary, PhaseType, StateError, StateStore};
use crate::cook::execution::state::{DefaultJobStateManager, JobStateManager, MapReduceJobState};
use std::sync::Arc;
use tracing::{debug, error, info};
#[cfg(test)]
use std::collections::HashMap;
pub struct DefaultStateStore {
state_manager: Arc<dyn JobStateManager>,
}
impl DefaultStateStore {
pub fn new(repository_name: String) -> Self {
use std::path::PathBuf;
Self {
state_manager: Arc::new(DefaultJobStateManager::new(PathBuf::from(format!(
"~/.prodigy/state/{}",
repository_name
)))),
}
}
pub fn from_manager(state_manager: Arc<dyn JobStateManager>) -> Self {
Self { state_manager }
}
}
#[async_trait::async_trait]
impl StateStore for DefaultStateStore {
async fn save(&self, state: &JobState) -> Result<(), StateError> {
debug!("Save state for job {} (not fully implemented)", state.id);
Ok(())
}
async fn load(&self, job_id: &str) -> Result<Option<JobState>, StateError> {
match self.state_manager.get_job_state(job_id).await {
Ok(mapreduce_state) => {
let state = from_mapreduce_job_state(mapreduce_state);
debug!("Loaded state for job {}", job_id);
Ok(Some(state))
}
Err(e) if e.to_string().contains("not found") => {
debug!("No state found for job {}", job_id);
Ok(None)
}
Err(e) => {
error!("Failed to load state for job {}: {}", job_id, e);
Err(StateError::LoadError(e.to_string()))
}
}
}
async fn list(&self) -> Result<Vec<JobSummary>, StateError> {
let jobs = self
.state_manager
.list_resumable_jobs()
.await
.map_err(|e| StateError::LoadError(e.to_string()))?;
let summaries = jobs
.into_iter()
.map(|job| {
let progress = JobProgress {
total_items: job.total_items,
completed_items: job.completed_items,
failed_items: job.failed_items,
pending_items: job.total_items - job.completed_items - job.failed_items,
completion_percentage: if job.total_items > 0 {
(job.completed_items as f64 / job.total_items as f64) * 100.0
} else {
0.0
},
};
JobSummary {
job_id: job.job_id,
phase: PhaseType::Map, progress,
created_at: job.started_at,
updated_at: job.updated_at,
is_complete: job.is_complete,
}
})
.collect();
Ok(summaries)
}
async fn delete(&self, job_id: &str) -> Result<(), StateError> {
info!("Delete requested for job {} (not implemented)", job_id);
Ok(())
}
}
fn from_mapreduce_job_state(state: MapReduceJobState) -> JobState {
let phase = map_phase_from_state(&state);
JobState {
id: state.job_id,
phase,
checkpoint: None, processed_items: state.completed_agents,
failed_items: state.pending_items,
variables: state.variables,
created_at: state.started_at,
updated_at: state.updated_at,
config: state.config,
agent_results: state.agent_results,
is_complete: state.is_complete,
total_items: state.total_items,
}
}
fn map_phase_from_state(state: &MapReduceJobState) -> PhaseType {
if state.is_complete {
PhaseType::Completed
} else if state
.reduce_phase_state
.as_ref()
.map(|r| r.started)
.unwrap_or(false)
{
PhaseType::Reduce
} else if state.setup_completed {
PhaseType::Map
} else {
PhaseType::Setup
}
}
#[cfg(test)]
pub struct InMemoryStateStore {
states: Arc<tokio::sync::RwLock<HashMap<String, JobState>>>,
}
#[cfg(test)]
impl InMemoryStateStore {
pub fn new() -> Self {
Self {
states: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
}
}
}
#[cfg(test)]
impl Default for InMemoryStateStore {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[async_trait::async_trait]
impl StateStore for InMemoryStateStore {
async fn save(&self, state: &JobState) -> Result<(), StateError> {
let mut states = self.states.write().await;
states.insert(state.id.clone(), state.clone());
debug!("Saved state for job {}", state.id);
Ok(())
}
async fn load(&self, job_id: &str) -> Result<Option<JobState>, StateError> {
let states = self.states.read().await;
let result = states.get(job_id).cloned();
debug!("Loaded state for job {}: {}", job_id, result.is_some());
Ok(result)
}
async fn list(&self) -> Result<Vec<JobSummary>, StateError> {
let states = self.states.read().await;
let summaries: Vec<JobSummary> = states
.values()
.map(|state| {
let processed = state.processed_items.len();
let failed = state.failed_items.len();
let total = state.total_items;
let pending = total.saturating_sub(processed + failed);
JobSummary {
job_id: state.id.clone(),
phase: state.phase,
progress: JobProgress {
total_items: total,
completed_items: processed,
failed_items: failed,
pending_items: pending,
completion_percentage: if total > 0 {
(processed as f64 / total as f64) * 100.0
} else {
0.0
},
},
created_at: state.created_at,
updated_at: state.updated_at,
is_complete: state.is_complete,
}
})
.collect();
Ok(summaries)
}
async fn delete(&self, job_id: &str) -> Result<(), StateError> {
let mut states = self.states.write().await;
states.remove(job_id);
debug!("Deleted state for job {}", job_id);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cook::execution::mapreduce::MapReduceConfig;
use chrono::Utc;
use std::collections::HashMap;
#[tokio::test]
async fn test_state_persistence() {
let store = InMemoryStateStore::new();
let state = JobState {
id: "test-job-123".to_string(),
phase: PhaseType::Setup,
checkpoint: None,
processed_items: Default::default(),
failed_items: Vec::new(),
variables: HashMap::new(),
created_at: Utc::now(),
updated_at: Utc::now(),
config: MapReduceConfig::default(),
agent_results: HashMap::new(),
is_complete: false,
total_items: 10,
};
store.save(&state).await.unwrap();
let loaded = store.load(&state.id).await.unwrap().unwrap();
assert_eq!(loaded.id, state.id);
assert_eq!(loaded.phase, state.phase);
assert_eq!(loaded.total_items, state.total_items);
}
#[tokio::test]
async fn test_list_jobs() {
let store = DefaultStateStore::new("test-repo".to_string());
let jobs = store.list().await.unwrap();
assert!(jobs.is_empty() || !jobs.is_empty()); }
}