use serde::{Deserialize, Serialize};
use std::hash::{Hash, Hasher};
use crate::agent_loop::{ChatMessage, LoopState};
pub const CHECKPOINT_VERSION: u32 = 1;
const CHECKPOINT_MAX_AGE_MS: i64 = 24 * 60 * 60 * 1000;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoopCheckpoint {
pub version: u32,
pub agent_name: String,
pub conversation_id: String,
pub correlation_id: String,
pub trace_id: String,
pub step: usize,
pub state: LoopState,
pub messages: Vec<ChatMessage>,
pub tool_call_log: Vec<String>,
pub last_action_fingerprint: Option<String>,
pub same_action_repeat_count: usize,
pub consecutive_error_count: usize,
pub progress_window: Vec<i32>,
pub updated_at_unix_ms: i64,
}
fn checkpoint_path(
agent_name: &str,
conversation_id: &str,
correlation_id: &str,
) -> std::path::PathBuf {
let mut dir = std::env::temp_dir();
dir.push("clawgarden-agent-loop");
dir.push(sanitize(agent_name));
let file = format!(
"{:x}__{:x}.json",
hash_str(conversation_id),
hash_str(correlation_id)
);
dir.push(file);
dir
}
fn sanitize(input: &str) -> String {
input
.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || c == '-' || c == '_' {
c
} else {
'_'
}
})
.collect()
}
fn hash_str(input: &str) -> u64 {
let mut h = std::collections::hash_map::DefaultHasher::new();
input.hash(&mut h);
h.finish()
}
pub async fn save_checkpoint(checkpoint: &LoopCheckpoint) -> anyhow::Result<()> {
let path = checkpoint_path(
&checkpoint.agent_name,
&checkpoint.conversation_id,
&checkpoint.correlation_id,
);
if let Some(parent) = path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
let json = serde_json::to_string_pretty(checkpoint)?;
tokio::fs::write(path, json).await?;
Ok(())
}
pub async fn load_checkpoint(
agent_name: &str,
conversation_id: &str,
correlation_id: &str,
expected_trace_id: &str,
) -> anyhow::Result<Option<LoopCheckpoint>> {
let path = checkpoint_path(agent_name, conversation_id, correlation_id);
if !path.exists() {
return Ok(None);
}
let content = tokio::fs::read_to_string(&path).await?;
let checkpoint: LoopCheckpoint = serde_json::from_str(&content)?;
if checkpoint.version != CHECKPOINT_VERSION {
let _ = tokio::fs::remove_file(&path).await;
return Ok(None);
}
if checkpoint.trace_id != expected_trace_id {
let _ = tokio::fs::remove_file(&path).await;
return Ok(None);
}
let age_ms = chrono::Utc::now().timestamp_millis() - checkpoint.updated_at_unix_ms;
if age_ms > CHECKPOINT_MAX_AGE_MS {
let _ = tokio::fs::remove_file(&path).await;
return Ok(None);
}
Ok(Some(checkpoint))
}
pub async fn clear_checkpoint(
agent_name: &str,
conversation_id: &str,
correlation_id: &str,
) -> anyhow::Result<()> {
let path = checkpoint_path(agent_name, conversation_id, correlation_id);
if path.exists() {
tokio::fs::remove_file(path).await?;
}
Ok(())
}
pub async fn gc_stale_checkpoints() -> anyhow::Result<usize> {
let mut removed = 0usize;
let mut base = std::env::temp_dir();
base.push("clawgarden-agent-loop");
if !base.exists() {
return Ok(0);
}
let mut agents = tokio::fs::read_dir(&base).await?;
while let Some(agent_entry) = agents.next_entry().await? {
let path = agent_entry.path();
if !path.is_dir() {
continue;
}
let mut files = tokio::fs::read_dir(&path).await?;
while let Some(file_entry) = files.next_entry().await? {
let file_path = file_entry.path();
if !file_path.is_file() {
continue;
}
let content = match tokio::fs::read_to_string(&file_path).await {
Ok(c) => c,
Err(_) => {
let _ = tokio::fs::remove_file(&file_path).await;
removed += 1;
continue;
}
};
let cp: LoopCheckpoint = match serde_json::from_str(&content) {
Ok(cp) => cp,
Err(_) => {
let _ = tokio::fs::remove_file(&file_path).await;
removed += 1;
continue;
}
};
let age_ms = chrono::Utc::now().timestamp_millis() - cp.updated_at_unix_ms;
if cp.version != CHECKPOINT_VERSION || age_ms > CHECKPOINT_MAX_AGE_MS {
let _ = tokio::fs::remove_file(&file_path).await;
removed += 1;
}
}
}
Ok(removed)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent_loop::{ChatMessage, LoopState};
#[tokio::test]
async fn test_checkpoint_roundtrip_and_clear() {
let agent = "test_agent_cp";
let conv = format!("conv_{}", uuid::Uuid::new_v4());
let corr = format!("corr_{}", uuid::Uuid::new_v4());
let cp = LoopCheckpoint {
version: CHECKPOINT_VERSION,
agent_name: agent.to_string(),
conversation_id: conv.clone(),
correlation_id: corr.clone(),
trace_id: "trace-1".to_string(),
step: 3,
state: LoopState::Reasoning,
messages: vec![ChatMessage::user("hello")],
tool_call_log: vec!["exec:ls".to_string()],
last_action_fingerprint: Some("abc".to_string()),
same_action_repeat_count: 1,
consecutive_error_count: 0,
progress_window: vec![1, -1],
updated_at_unix_ms: chrono::Utc::now().timestamp_millis(),
};
save_checkpoint(&cp).await.unwrap();
let loaded = load_checkpoint(agent, &conv, &corr, "trace-1")
.await
.unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.step, 3);
assert_eq!(loaded.state, LoopState::Reasoning);
clear_checkpoint(agent, &conv, &corr).await.unwrap();
let loaded2 = load_checkpoint(agent, &conv, &corr, "trace-1")
.await
.unwrap();
assert!(loaded2.is_none());
}
#[tokio::test]
async fn test_trace_mismatch_checkpoint_is_ignored() {
let agent = "test_agent_cp_trace_mismatch";
let conv = format!("conv_{}", uuid::Uuid::new_v4());
let corr = format!("corr_{}", uuid::Uuid::new_v4());
let cp = LoopCheckpoint {
version: CHECKPOINT_VERSION,
agent_name: agent.to_string(),
conversation_id: conv.clone(),
correlation_id: corr.clone(),
trace_id: "trace-saved".to_string(),
step: 1,
state: LoopState::Reasoning,
messages: vec![ChatMessage::user("hello")],
tool_call_log: vec![],
last_action_fingerprint: None,
same_action_repeat_count: 0,
consecutive_error_count: 0,
progress_window: vec![],
updated_at_unix_ms: chrono::Utc::now().timestamp_millis(),
};
save_checkpoint(&cp).await.unwrap();
let loaded = load_checkpoint(agent, &conv, &corr, "trace-current")
.await
.unwrap();
assert!(loaded.is_none());
}
#[tokio::test]
async fn test_stale_checkpoint_is_ignored() {
let agent = "test_agent_cp_stale";
let conv = format!("conv_{}", uuid::Uuid::new_v4());
let corr = format!("corr_{}", uuid::Uuid::new_v4());
let cp = LoopCheckpoint {
version: CHECKPOINT_VERSION,
agent_name: agent.to_string(),
conversation_id: conv.clone(),
correlation_id: corr.clone(),
trace_id: "trace-2".to_string(),
step: 1,
state: LoopState::Reasoning,
messages: vec![ChatMessage::user("hello")],
tool_call_log: vec![],
last_action_fingerprint: None,
same_action_repeat_count: 0,
consecutive_error_count: 0,
progress_window: vec![],
updated_at_unix_ms: chrono::Utc::now().timestamp_millis() - (CHECKPOINT_MAX_AGE_MS + 1000),
};
save_checkpoint(&cp).await.unwrap();
let loaded = load_checkpoint(agent, &conv, &corr, "trace-2")
.await
.unwrap();
assert!(loaded.is_none());
}
#[tokio::test]
async fn test_gc_stale_checkpoints_removes_old_files() {
let agent = "test_agent_cp_gc";
let conv = format!("conv_{}", uuid::Uuid::new_v4());
let corr = format!("corr_{}", uuid::Uuid::new_v4());
let cp = LoopCheckpoint {
version: CHECKPOINT_VERSION,
agent_name: agent.to_string(),
conversation_id: conv.clone(),
correlation_id: corr.clone(),
trace_id: "trace-gc".to_string(),
step: 1,
state: LoopState::Reasoning,
messages: vec![ChatMessage::user("hello")],
tool_call_log: vec![],
last_action_fingerprint: None,
same_action_repeat_count: 0,
consecutive_error_count: 0,
progress_window: vec![],
updated_at_unix_ms: chrono::Utc::now().timestamp_millis() - (CHECKPOINT_MAX_AGE_MS + 1000),
};
save_checkpoint(&cp).await.unwrap();
let removed = gc_stale_checkpoints().await.unwrap();
assert!(removed >= 1);
let loaded = load_checkpoint(agent, &conv, &corr, "trace-gc")
.await
.unwrap();
assert!(loaded.is_none());
}
}