use crate::kernel::{GraphId, RunId};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use svix_ksuid::KsuidLike;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint {
pub id: String,
pub run_id: RunId,
pub graph_id: Option<GraphId>,
pub current_node: Option<String>,
pub state: Value,
pub messages: Vec<MessageRecord>,
pub tool_results: HashMap<String, Value>,
pub created_at: chrono::DateTime<chrono::Utc>,
pub metadata: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageRecord {
pub role: String,
pub content: String,
}
impl Checkpoint {
pub fn new(run_id: RunId) -> Self {
Self {
id: format!("ckpt_{}", svix_ksuid::Ksuid::new(None, None)),
run_id,
graph_id: None,
current_node: None,
state: Value::Null,
messages: Vec::new(),
tool_results: HashMap::new(),
created_at: chrono::Utc::now(),
metadata: HashMap::new(),
}
}
pub fn with_state(mut self, state: Value) -> Self {
self.state = state;
self
}
pub fn with_node(mut self, node: impl Into<String>) -> Self {
self.current_node = Some(node.into());
self
}
pub fn add_message(&mut self, role: impl Into<String>, content: impl Into<String>) {
self.messages.push(MessageRecord {
role: role.into(),
content: content.into(),
});
}
pub fn add_tool_result(&mut self, tool_name: impl Into<String>, result: Value) {
self.tool_results.insert(tool_name.into(), result);
}
pub fn with_agent_name(mut self, name: impl Into<String>) -> Self {
self.metadata
.insert("agent_name".to_string(), Value::String(name.into()));
self
}
pub fn agent_name(&self) -> Option<&str> {
self.metadata.get("agent_name").and_then(|v| v.as_str())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::ExecutionId;
#[test]
fn test_checkpoint_new() {
let run_id = ExecutionId::new();
let checkpoint = Checkpoint::new(run_id.clone());
assert!(checkpoint.id.starts_with("ckpt_"));
assert_eq!(checkpoint.run_id.as_str(), run_id.as_str());
assert!(checkpoint.graph_id.is_none());
assert!(checkpoint.current_node.is_none());
assert_eq!(checkpoint.state, Value::Null);
assert!(checkpoint.messages.is_empty());
assert!(checkpoint.tool_results.is_empty());
assert!(checkpoint.metadata.is_empty());
}
#[test]
fn test_checkpoint_with_agent_name() {
let run_id = ExecutionId::new();
let checkpoint = Checkpoint::new(run_id).with_agent_name("my_agent");
assert_eq!(checkpoint.agent_name(), Some("my_agent"));
}
#[test]
fn test_checkpoint_agent_name_none_when_not_set() {
let run_id = ExecutionId::new();
let checkpoint = Checkpoint::new(run_id);
assert!(checkpoint.agent_name().is_none());
}
#[test]
fn test_checkpoint_builder_chain() {
let run_id = ExecutionId::new();
let checkpoint = Checkpoint::new(run_id)
.with_state(Value::String("test_state".to_string()))
.with_node("node_1")
.with_agent_name("coder");
assert_eq!(checkpoint.state, Value::String("test_state".to_string()));
assert_eq!(checkpoint.current_node, Some("node_1".to_string()));
assert_eq!(checkpoint.agent_name(), Some("coder"));
}
#[test]
fn test_checkpoint_agent_name_serialization() {
let run_id = ExecutionId::new();
let checkpoint = Checkpoint::new(run_id).with_agent_name("assistant");
let json = serde_json::to_string(&checkpoint).unwrap();
let deserialized: Checkpoint = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.agent_name(), Some("assistant"));
}
#[test]
fn test_checkpoint_add_message() {
let run_id = ExecutionId::new();
let mut checkpoint = Checkpoint::new(run_id);
checkpoint.add_message("user", "Hello");
checkpoint.add_message("assistant", "Hi there!");
assert_eq!(checkpoint.messages.len(), 2);
assert_eq!(checkpoint.messages[0].role, "user");
assert_eq!(checkpoint.messages[0].content, "Hello");
assert_eq!(checkpoint.messages[1].role, "assistant");
assert_eq!(checkpoint.messages[1].content, "Hi there!");
}
#[test]
fn test_checkpoint_add_tool_result() {
let run_id = ExecutionId::new();
let mut checkpoint = Checkpoint::new(run_id);
checkpoint.add_tool_result("read_file", Value::String("file contents".to_string()));
assert_eq!(checkpoint.tool_results.len(), 1);
assert_eq!(
checkpoint.tool_results.get("read_file"),
Some(&Value::String("file contents".to_string()))
);
}
}