use crate::agent::swarm::knowledge::types::{
CheckpointResult, CheckpointTask, ExecutionCheckpoint,
};
impl ExecutionCheckpoint {
pub fn new(user_message: String, tasks: Vec<CheckpointTask>) -> Self {
Self {
user_message,
tasks,
completed_task_ids: Vec::new(),
completed_results: Vec::new(),
failed_task_ids: Vec::new(),
last_heartbeat_ms: current_unix_millis(),
finished: false,
}
}
pub fn record_completion(&mut self, result: CheckpointResult) {
if result.success {
self.completed_task_ids.push(result.id.clone());
} else {
self.failed_task_ids.push(result.id.clone());
}
self.completed_results.push(result);
self.last_heartbeat_ms = current_unix_millis();
}
pub fn heartbeat(&mut self) {
self.last_heartbeat_ms = current_unix_millis();
}
pub fn is_incomplete(&self) -> bool {
!self.finished && {
let done_count = self.completed_task_ids.len() + self.failed_task_ids.len();
done_count < self.tasks.len()
}
}
pub fn remaining_task_ids(&self) -> Vec<String> {
self.tasks
.iter()
.filter(|t| {
!self.completed_task_ids.contains(&t.id) && !self.failed_task_ids.contains(&t.id)
})
.map(|t| t.id.clone())
.collect()
}
pub fn is_stale(&self, timeout_ms: u64) -> bool {
let now = current_unix_millis();
now.saturating_sub(self.last_heartbeat_ms) > timeout_ms
}
pub fn mark_finished(&mut self) {
self.finished = true;
self.last_heartbeat_ms = current_unix_millis();
}
}
pub fn save_checkpoint(
checkpoint: &ExecutionCheckpoint,
path: &std::path::Path,
) -> std::io::Result<()> {
let json = serde_json::to_string_pretty(checkpoint).map_err(std::io::Error::other)?;
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
std::fs::write(path, json)
}
pub fn load_checkpoint(path: &std::path::Path) -> std::io::Result<ExecutionCheckpoint> {
let json = std::fs::read_to_string(path)?;
serde_json::from_str(&json).map_err(std::io::Error::other)
}
pub fn clear_checkpoint(path: &std::path::Path) -> std::io::Result<()> {
if path.exists() {
std::fs::remove_file(path)?;
}
Ok(())
}
fn current_unix_millis() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}