enact_core/graph/
checkpoint.rs1use crate::kernel::{GraphId, RunId};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7use svix_ksuid::KsuidLike;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Checkpoint {
12 pub id: String,
14 pub run_id: RunId,
16 pub graph_id: Option<GraphId>,
18 pub current_node: Option<String>,
20 pub state: Value,
22 pub messages: Vec<MessageRecord>,
24 pub tool_results: HashMap<String, Value>,
26 pub created_at: chrono::DateTime<chrono::Utc>,
28 pub metadata: HashMap<String, Value>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct MessageRecord {
35 pub role: String,
36 pub content: String,
37}
38
39impl Checkpoint {
40 pub fn new(run_id: RunId) -> Self {
42 Self {
43 id: format!("ckpt_{}", svix_ksuid::Ksuid::new(None, None)),
44 run_id,
45 graph_id: None,
46 current_node: None,
47 state: Value::Null,
48 messages: Vec::new(),
49 tool_results: HashMap::new(),
50 created_at: chrono::Utc::now(),
51 metadata: HashMap::new(),
52 }
53 }
54
55 pub fn with_state(mut self, state: Value) -> Self {
57 self.state = state;
58 self
59 }
60
61 pub fn with_node(mut self, node: impl Into<String>) -> Self {
63 self.current_node = Some(node.into());
64 self
65 }
66
67 pub fn add_message(&mut self, role: impl Into<String>, content: impl Into<String>) {
69 self.messages.push(MessageRecord {
70 role: role.into(),
71 content: content.into(),
72 });
73 }
74
75 pub fn add_tool_result(&mut self, tool_name: impl Into<String>, result: Value) {
77 self.tool_results.insert(tool_name.into(), result);
78 }
79
80 pub fn with_agent_name(mut self, name: impl Into<String>) -> Self {
82 self.metadata
83 .insert("agent_name".to_string(), Value::String(name.into()));
84 self
85 }
86
87 pub fn agent_name(&self) -> Option<&str> {
89 self.metadata.get("agent_name").and_then(|v| v.as_str())
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use crate::kernel::ExecutionId;
97
98 #[test]
99 fn test_checkpoint_new() {
100 let run_id = ExecutionId::new();
101 let checkpoint = Checkpoint::new(run_id.clone());
102
103 assert!(checkpoint.id.starts_with("ckpt_"));
104 assert_eq!(checkpoint.run_id.as_str(), run_id.as_str());
105 assert!(checkpoint.graph_id.is_none());
106 assert!(checkpoint.current_node.is_none());
107 assert_eq!(checkpoint.state, Value::Null);
108 assert!(checkpoint.messages.is_empty());
109 assert!(checkpoint.tool_results.is_empty());
110 assert!(checkpoint.metadata.is_empty());
111 }
112
113 #[test]
114 fn test_checkpoint_with_agent_name() {
115 let run_id = ExecutionId::new();
116 let checkpoint = Checkpoint::new(run_id).with_agent_name("my_agent");
117
118 assert_eq!(checkpoint.agent_name(), Some("my_agent"));
119 }
120
121 #[test]
122 fn test_checkpoint_agent_name_none_when_not_set() {
123 let run_id = ExecutionId::new();
124 let checkpoint = Checkpoint::new(run_id);
125
126 assert!(checkpoint.agent_name().is_none());
127 }
128
129 #[test]
130 fn test_checkpoint_builder_chain() {
131 let run_id = ExecutionId::new();
132 let checkpoint = Checkpoint::new(run_id)
133 .with_state(Value::String("test_state".to_string()))
134 .with_node("node_1")
135 .with_agent_name("coder");
136
137 assert_eq!(checkpoint.state, Value::String("test_state".to_string()));
138 assert_eq!(checkpoint.current_node, Some("node_1".to_string()));
139 assert_eq!(checkpoint.agent_name(), Some("coder"));
140 }
141
142 #[test]
143 fn test_checkpoint_agent_name_serialization() {
144 let run_id = ExecutionId::new();
145 let checkpoint = Checkpoint::new(run_id).with_agent_name("assistant");
146
147 let json = serde_json::to_string(&checkpoint).unwrap();
149 let deserialized: Checkpoint = serde_json::from_str(&json).unwrap();
150
151 assert_eq!(deserialized.agent_name(), Some("assistant"));
152 }
153
154 #[test]
155 fn test_checkpoint_add_message() {
156 let run_id = ExecutionId::new();
157 let mut checkpoint = Checkpoint::new(run_id);
158
159 checkpoint.add_message("user", "Hello");
160 checkpoint.add_message("assistant", "Hi there!");
161
162 assert_eq!(checkpoint.messages.len(), 2);
163 assert_eq!(checkpoint.messages[0].role, "user");
164 assert_eq!(checkpoint.messages[0].content, "Hello");
165 assert_eq!(checkpoint.messages[1].role, "assistant");
166 assert_eq!(checkpoint.messages[1].content, "Hi there!");
167 }
168
169 #[test]
170 fn test_checkpoint_add_tool_result() {
171 let run_id = ExecutionId::new();
172 let mut checkpoint = Checkpoint::new(run_id);
173
174 checkpoint.add_tool_result("read_file", Value::String("file contents".to_string()));
175
176 assert_eq!(checkpoint.tool_results.len(), 1);
177 assert_eq!(
178 checkpoint.tool_results.get("read_file"),
179 Some(&Value::String("file contents".to_string()))
180 );
181 }
182}