Skip to main content

enact_core/graph/
checkpoint.rs

1//! Checkpoint types for save/resume execution
2
3use crate::kernel::{GraphId, RunId};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7use svix_ksuid::KsuidLike;
8
9/// Checkpoint - saved state of execution
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Checkpoint {
12    /// Unique checkpoint ID
13    pub id: String,
14    /// Run this checkpoint belongs to
15    pub run_id: RunId,
16    /// Graph being executed
17    pub graph_id: Option<GraphId>,
18    /// Current node in execution
19    pub current_node: Option<String>,
20    /// State at this checkpoint
21    pub state: Value,
22    /// Messages history (for LLM agents)
23    pub messages: Vec<MessageRecord>,
24    /// Tool results collected so far
25    pub tool_results: HashMap<String, Value>,
26    /// Created timestamp
27    pub created_at: chrono::DateTime<chrono::Utc>,
28    /// Metadata
29    pub metadata: HashMap<String, Value>,
30}
31
32/// Message record for checkpoint
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct MessageRecord {
35    pub role: String,
36    pub content: String,
37}
38
39impl Checkpoint {
40    /// Create a new checkpoint
41    pub fn new(run_id: RunId) -> Self {
42        Self {
43            id: format!("ckpt_{}", svix_ksuid::Ksuid::new(None, None)),
44            run_id,
45            graph_id: None,
46            current_node: None,
47            state: Value::Null,
48            messages: Vec::new(),
49            tool_results: HashMap::new(),
50            created_at: chrono::Utc::now(),
51            metadata: HashMap::new(),
52        }
53    }
54
55    /// Set the current state
56    pub fn with_state(mut self, state: Value) -> Self {
57        self.state = state;
58        self
59    }
60
61    /// Set the current node
62    pub fn with_node(mut self, node: impl Into<String>) -> Self {
63        self.current_node = Some(node.into());
64        self
65    }
66
67    /// Add a message to history
68    pub fn add_message(&mut self, role: impl Into<String>, content: impl Into<String>) {
69        self.messages.push(MessageRecord {
70            role: role.into(),
71            content: content.into(),
72        });
73    }
74
75    /// Add a tool result
76    pub fn add_tool_result(&mut self, tool_name: impl Into<String>, result: Value) {
77        self.tool_results.insert(tool_name.into(), result);
78    }
79
80    /// Set the agent name in metadata
81    pub fn with_agent_name(mut self, name: impl Into<String>) -> Self {
82        self.metadata
83            .insert("agent_name".to_string(), Value::String(name.into()));
84        self
85    }
86
87    /// Get the agent name from metadata
88    pub fn agent_name(&self) -> Option<&str> {
89        self.metadata.get("agent_name").and_then(|v| v.as_str())
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use crate::kernel::ExecutionId;
97
98    #[test]
99    fn test_checkpoint_new() {
100        let run_id = ExecutionId::new();
101        let checkpoint = Checkpoint::new(run_id.clone());
102
103        assert!(checkpoint.id.starts_with("ckpt_"));
104        assert_eq!(checkpoint.run_id.as_str(), run_id.as_str());
105        assert!(checkpoint.graph_id.is_none());
106        assert!(checkpoint.current_node.is_none());
107        assert_eq!(checkpoint.state, Value::Null);
108        assert!(checkpoint.messages.is_empty());
109        assert!(checkpoint.tool_results.is_empty());
110        assert!(checkpoint.metadata.is_empty());
111    }
112
113    #[test]
114    fn test_checkpoint_with_agent_name() {
115        let run_id = ExecutionId::new();
116        let checkpoint = Checkpoint::new(run_id).with_agent_name("my_agent");
117
118        assert_eq!(checkpoint.agent_name(), Some("my_agent"));
119    }
120
121    #[test]
122    fn test_checkpoint_agent_name_none_when_not_set() {
123        let run_id = ExecutionId::new();
124        let checkpoint = Checkpoint::new(run_id);
125
126        assert!(checkpoint.agent_name().is_none());
127    }
128
129    #[test]
130    fn test_checkpoint_builder_chain() {
131        let run_id = ExecutionId::new();
132        let checkpoint = Checkpoint::new(run_id)
133            .with_state(Value::String("test_state".to_string()))
134            .with_node("node_1")
135            .with_agent_name("coder");
136
137        assert_eq!(checkpoint.state, Value::String("test_state".to_string()));
138        assert_eq!(checkpoint.current_node, Some("node_1".to_string()));
139        assert_eq!(checkpoint.agent_name(), Some("coder"));
140    }
141
142    #[test]
143    fn test_checkpoint_agent_name_serialization() {
144        let run_id = ExecutionId::new();
145        let checkpoint = Checkpoint::new(run_id).with_agent_name("assistant");
146
147        // Serialize and deserialize
148        let json = serde_json::to_string(&checkpoint).unwrap();
149        let deserialized: Checkpoint = serde_json::from_str(&json).unwrap();
150
151        assert_eq!(deserialized.agent_name(), Some("assistant"));
152    }
153
154    #[test]
155    fn test_checkpoint_add_message() {
156        let run_id = ExecutionId::new();
157        let mut checkpoint = Checkpoint::new(run_id);
158
159        checkpoint.add_message("user", "Hello");
160        checkpoint.add_message("assistant", "Hi there!");
161
162        assert_eq!(checkpoint.messages.len(), 2);
163        assert_eq!(checkpoint.messages[0].role, "user");
164        assert_eq!(checkpoint.messages[0].content, "Hello");
165        assert_eq!(checkpoint.messages[1].role, "assistant");
166        assert_eq!(checkpoint.messages[1].content, "Hi there!");
167    }
168
169    #[test]
170    fn test_checkpoint_add_tool_result() {
171        let run_id = ExecutionId::new();
172        let mut checkpoint = Checkpoint::new(run_id);
173
174        checkpoint.add_tool_result("read_file", Value::String("file contents".to_string()));
175
176        assert_eq!(checkpoint.tool_results.len(), 1);
177        assert_eq!(
178            checkpoint.tool_results.get("read_file"),
179            Some(&Value::String("file contents".to_string()))
180        );
181    }
182}