use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint {
pub id: String,
pub agent_name: String,
pub session_id: String,
pub step: usize,
pub messages: Vec<CheckpointMessage>,
pub tool_calls: Vec<ToolCallRecord>,
pub partial_results: Vec<String>,
pub timestamp: u64,
pub status: CheckpointStatus,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointMessage {
pub role: String, pub content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRecord {
pub tool_name: String,
pub arguments: String,
pub result: Option<String>,
pub success: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum CheckpointStatus {
InProgress,
Completed,
Failed(String),
Halted(String),
}
pub struct CheckpointManager {
checkpoint_dir: PathBuf,
}
impl CheckpointManager {
pub fn new(checkpoint_dir: &Path) -> std::io::Result<Self> {
std::fs::create_dir_all(checkpoint_dir)?;
Ok(Self {
checkpoint_dir: checkpoint_dir.to_path_buf(),
})
}
pub fn default_dir() -> std::io::Result<Self> {
let dir = dirs_or_default().join("checkpoints");
Self::new(&dir)
}
pub fn save(&self, checkpoint: &Checkpoint) -> std::io::Result<()> {
let filename = format!("{}_{}.json", checkpoint.session_id, checkpoint.step);
let path = self.checkpoint_dir.join(&filename);
let json = serde_json::to_string_pretty(checkpoint)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
std::fs::write(&path, json)?;
let latest_path = self.checkpoint_dir.join(format!("{}_latest.json", checkpoint.session_id));
std::fs::write(&latest_path, &filename)?;
Ok(())
}
pub fn load_latest(&self, session_id: &str) -> std::io::Result<Option<Checkpoint>> {
let latest_path = self.checkpoint_dir.join(format!("{}_latest.json", session_id));
if !latest_path.exists() {
return Ok(None);
}
let filename = std::fs::read_to_string(&latest_path)?;
let checkpoint_path = self.checkpoint_dir.join(filename.trim());
if !checkpoint_path.exists() {
return Ok(None);
}
let json = std::fs::read_to_string(&checkpoint_path)?;
let checkpoint: Checkpoint = serde_json::from_str(&json)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(Some(checkpoint))
}
pub fn list_checkpoints(&self, session_id: &str) -> std::io::Result<Vec<Checkpoint>> {
let mut checkpoints = Vec::new();
let prefix = format!("{}_", session_id);
for entry in std::fs::read_dir(&self.checkpoint_dir)? {
let entry = entry?;
let name = entry.file_name().to_string_lossy().to_string();
if name.starts_with(&prefix) && name.ends_with(".json") && !name.contains("latest") {
let json = std::fs::read_to_string(entry.path())?;
if let Ok(cp) = serde_json::from_str::<Checkpoint>(&json) {
checkpoints.push(cp);
}
}
}
checkpoints.sort_by_key(|c| c.step);
Ok(checkpoints)
}
pub fn cleanup(&self, session_id: &str) -> std::io::Result<usize> {
let mut removed = 0;
let prefix = format!("{}_", session_id);
for entry in std::fs::read_dir(&self.checkpoint_dir)? {
let entry = entry?;
let name = entry.file_name().to_string_lossy().to_string();
if name.starts_with(&prefix) {
std::fs::remove_file(entry.path())?;
removed += 1;
}
}
Ok(removed)
}
pub fn has_checkpoint(&self, session_id: &str) -> bool {
let latest_path = self.checkpoint_dir.join(format!("{}_latest.json", session_id));
latest_path.exists()
}
}
fn dirs_or_default() -> PathBuf {
dirs::data_dir()
.unwrap_or_else(|| PathBuf::from("/tmp"))
.join("ares")
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::{SystemTime, UNIX_EPOCH};
fn temp_dir() -> tempfile::TempDir {
tempfile::tempdir().unwrap()
}
fn sample_checkpoint(session: &str, step: usize) -> Checkpoint {
Checkpoint {
id: format!("{}-{}", session, step),
agent_name: "test-agent".into(),
session_id: session.into(),
step,
messages: vec![
CheckpointMessage { role: "user".into(), content: "Hello".into() },
CheckpointMessage { role: "assistant".into(), content: "Hi there".into() },
],
tool_calls: vec![
ToolCallRecord {
tool_name: "search".into(),
arguments: "query".into(),
result: Some("found it".into()),
success: true,
},
],
partial_results: vec!["partial output".into()],
timestamp: SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(),
status: CheckpointStatus::InProgress,
}
}
#[test]
fn test_save_and_load() {
let dir = temp_dir();
let mgr = CheckpointManager::new(dir.path()).unwrap();
let cp = sample_checkpoint("sess1", 0);
mgr.save(&cp).unwrap();
let loaded = mgr.load_latest("sess1").unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.session_id, "sess1");
assert_eq!(loaded.step, 0);
assert_eq!(loaded.messages.len(), 2);
assert_eq!(loaded.tool_calls.len(), 1);
}
#[test]
fn test_load_nonexistent() {
let dir = temp_dir();
let mgr = CheckpointManager::new(dir.path()).unwrap();
let loaded = mgr.load_latest("nonexistent").unwrap();
assert!(loaded.is_none());
}
#[test]
fn test_multiple_steps() {
let dir = temp_dir();
let mgr = CheckpointManager::new(dir.path()).unwrap();
mgr.save(&sample_checkpoint("sess1", 0)).unwrap();
mgr.save(&sample_checkpoint("sess1", 1)).unwrap();
mgr.save(&sample_checkpoint("sess1", 2)).unwrap();
let latest = mgr.load_latest("sess1").unwrap().unwrap();
assert_eq!(latest.step, 2, "latest should be step 2");
let all = mgr.list_checkpoints("sess1").unwrap();
assert_eq!(all.len(), 3);
assert_eq!(all[0].step, 0);
assert_eq!(all[2].step, 2);
}
#[test]
fn test_cleanup() {
let dir = temp_dir();
let mgr = CheckpointManager::new(dir.path()).unwrap();
mgr.save(&sample_checkpoint("sess1", 0)).unwrap();
mgr.save(&sample_checkpoint("sess1", 1)).unwrap();
assert!(mgr.has_checkpoint("sess1"));
let removed = mgr.cleanup("sess1").unwrap();
assert!(removed >= 2);
assert!(!mgr.has_checkpoint("sess1"));
}
#[test]
fn test_separate_sessions() {
let dir = temp_dir();
let mgr = CheckpointManager::new(dir.path()).unwrap();
mgr.save(&sample_checkpoint("sess1", 0)).unwrap();
mgr.save(&sample_checkpoint("sess2", 0)).unwrap();
assert!(mgr.has_checkpoint("sess1"));
assert!(mgr.has_checkpoint("sess2"));
mgr.cleanup("sess1").unwrap();
assert!(!mgr.has_checkpoint("sess1"));
assert!(mgr.has_checkpoint("sess2"));
}
#[test]
fn test_checkpoint_status_serialization() {
let dir = temp_dir();
let mgr = CheckpointManager::new(dir.path()).unwrap();
let mut cp = sample_checkpoint("sess1", 0);
cp.status = CheckpointStatus::Failed("OOM".into());
mgr.save(&cp).unwrap();
let loaded = mgr.load_latest("sess1").unwrap().unwrap();
assert_eq!(loaded.status, CheckpointStatus::Failed("OOM".into()));
}
}