mod memory;
mod runner;
#[cfg(feature = "persistence")]
mod surreal;
pub use memory::MemoryCheckpointer;
pub use runner::{CheckpointingRunner, RunResult};
#[cfg(feature = "persistence")]
pub use surreal::SurrealCheckpointer;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::error::RuntimeError;
use crate::state::AgentState;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Checkpoint {
pub id: String,
pub thread_id: String,
pub node_id: String,
pub state: AgentState,
pub created_at: DateTime<Utc>,
pub parent_id: Option<String>,
#[serde(default)]
pub metadata: serde_json::Value,
}
impl Checkpoint {
pub fn new(thread_id: impl Into<String>, node_id: impl Into<String>, state: AgentState) -> Self {
Self {
id: Uuid::new_v4().to_string(),
thread_id: thread_id.into(),
node_id: node_id.into(),
state,
created_at: Utc::now(),
parent_id: None,
metadata: serde_json::Value::Null,
}
}
pub fn with_parent(mut self, parent_id: impl Into<String>) -> Self {
self.parent_id = Some(parent_id.into());
self
}
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = metadata;
self
}
}
#[derive(Clone, Debug, Default)]
pub struct CheckpointConfig {
pub checkpoint_every_node: bool,
pub max_checkpoints_per_thread: Option<usize>,
pub enable_branching: bool,
}
impl CheckpointConfig {
pub fn every_node() -> Self {
Self {
checkpoint_every_node: true,
..Default::default()
}
}
pub fn max_per_thread(mut self, max: usize) -> Self {
self.max_checkpoints_per_thread = Some(max);
self
}
pub fn with_branching(mut self) -> Self {
self.enable_branching = true;
self
}
}
#[async_trait]
pub trait Checkpointer: Send + Sync {
async fn save(&self, checkpoint: Checkpoint) -> Result<(), RuntimeError>;
async fn load(&self, thread_id: &str) -> Result<Option<Checkpoint>, RuntimeError>;
async fn load_by_id(&self, checkpoint_id: &str) -> Result<Option<Checkpoint>, RuntimeError>;
async fn list(&self, thread_id: &str) -> Result<Vec<Checkpoint>, RuntimeError>;
async fn delete(&self, checkpoint_id: &str) -> Result<(), RuntimeError>;
async fn delete_thread(&self, thread_id: &str) -> Result<(), RuntimeError>;
async fn history(
&self,
thread_id: &str,
limit: usize,
) -> Result<Vec<Checkpoint>, RuntimeError> {
let mut checkpoints = self.list(thread_id).await?;
checkpoints.truncate(limit);
Ok(checkpoints)
}
}
#[async_trait]
pub trait Resumable {
async fn resume(&self, thread_id: &str) -> Result<AgentState, RuntimeError>;
async fn resume_from(&self, checkpoint_id: &str) -> Result<AgentState, RuntimeError>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_checkpoint_creation() {
let state = AgentState::new();
let checkpoint = Checkpoint::new("thread-1", "node-a", state);
assert!(!checkpoint.id.is_empty());
assert_eq!(checkpoint.thread_id, "thread-1");
assert_eq!(checkpoint.node_id, "node-a");
assert!(checkpoint.parent_id.is_none());
}
#[test]
fn test_checkpoint_with_parent() {
let state = AgentState::new();
let checkpoint = Checkpoint::new("thread-1", "node-a", state).with_parent("parent-123");
assert_eq!(checkpoint.parent_id, Some("parent-123".to_string()));
}
#[test]
fn test_checkpoint_config() {
let config = CheckpointConfig::every_node().max_per_thread(10).with_branching();
assert!(config.checkpoint_every_node);
assert_eq!(config.max_checkpoints_per_thread, Some(10));
assert!(config.enable_branching);
}
}