Skip to main content

lellm_graph/checkpoint/
checkpoint.rs

1//! Checkpoint — 执行恢复的唯一数据源。
2//!
3//! 分层架构:
4//! ```text
5//! ExecutionEngine (Trigger)
6//!   ↓ on_checkpoint(&state, &frame_info)
7//! CheckpointSink (SPI — 策略层)
8//!   ↓ 自行决定
9//! MemorySink → FrameStack (内存)
10//! DiskSink   → 每 N 步 snapshot → 磁盘
11//! NetworkSink → protobuf → remote
12//! ```
13//!
14//! # Trigger / Storage 分离
15//!
16//! **ExecutionEngine** 负责定义一致的 checkpoint 语义——什么时候产生一个恢复点。
17//! 唯一的位置:`execute() → commit() → checkpoint() → route()`。
18//!
19//! **CheckpointSink** 负责决定是否真的保存、保存到哪里、保存多少。
20//! Engine 只管借用 `&dyn CheckpointSink<S>`,不知道 FrameStack、磁盘、网络。
21//!
22//! # Phase 6: Execution Frame Snapshot
23//!
24//! 核心洞察:checkpoint 不是保存 state,而是保存 execution position + state projection。
25//!
26//! ```text
27//! checkpoint 的边界单位是 Graph Execution Frame,不是 WorkflowState 或 Node。
28//!
29//! 正确模型:
30//!   Graph Execution = Frame Stack
31//!
32//! Frame = {
33//!     graph_id,
34//!     node_id,
35//!     state_snapshot,
36//!     cursor,
37//! }
38//!
39//! checkpoint = FrameStack snapshot
40//! ```
41
42use std::fmt::Debug;
43
44use serde::{Deserialize, Serialize};
45
46use crate::state::State;
47use crate::state::workflow_state::WorkflowState;
48
49// ─── CheckpointId ──────────────────────────────────────────────
50
51/// Checkpoint 唯一标识。
52#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
53pub struct CheckpointId(pub uuid::Uuid);
54
55impl std::fmt::Display for CheckpointId {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        write!(f, "{}", self.0)
58    }
59}
60
61// ─── NodeId ────────────────────────────────────────────────────
62
63/// 节点标识。
64#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
65pub struct NodeId(pub String);
66
67impl std::fmt::Display for NodeId {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        write!(f, "{}", self.0)
70    }
71}
72
73// ─── Checkpoint ────────────────────────────────────────────────
74
75/// 执行检查点 — 物化快照 + 执行游标。
76///
77/// Checkpoint 的唯一职责:恢复(Restore)。
78/// 给我一个 Checkpoint,我就能从 `current_node` 开始,用 `state` 继续执行。
79///
80/// # P0-1: Checkpoint Projection
81///
82/// `state` 字段使用 `S::Checkpoint`(关联类型),不是 `S`(Runtime State)。
83/// 这保证:
84/// - Runtime State 可以包含不可序列化字段(`Arc<dyn ...>`, `Sender`, `Cache`)
85/// - Checkpoint 只序列化必要字段
86/// - 编译期保证可序列化
87///
88/// # Graph Compatibility
89///
90/// `graph_hash` 记录创建 Checkpoint 时的图结构指纹。
91/// 恢复时必须校验:`graph_hash` 不匹配 → 拒绝恢复(不允许 silent mismatch)。
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct Checkpoint<S: WorkflowState = State> {
94    /// 唯一标识
95    pub checkpoint_id: CheckpointId,
96    /// 下一个要执行的节点
97    pub current_node: NodeId,
98    /// 物化状态快照(P0-1: 使用 Checkpoint 关联类型,不是 raw State)
99    pub state: S::Checkpoint,
100    /// 图结构指纹 — 恢复时校验兼容性
101    pub graph_hash: u64,
102    /// 创建时间
103    pub created_at: std::time::SystemTime,
104}
105
106impl<S: WorkflowState> Checkpoint<S> {
107    /// 从 Runtime State 创建 Checkpoint(使用 snapshot() 投影)。
108    pub fn new(current_node: impl Into<String>, state: &S, graph_hash: u64) -> Self {
109        Self {
110            checkpoint_id: CheckpointId(uuid::Uuid::new_v4()),
111            current_node: NodeId(current_node.into()),
112            state: state.snapshot(),
113            graph_hash,
114            created_at: std::time::SystemTime::now(),
115        }
116    }
117
118    /// 从 Checkpoint 恢复 Runtime State(使用 restore())。
119    pub fn restore_state(self) -> S {
120        S::restore(self.state)
121    }
122}
123
124// ─── CheckpointBlob ────────────────────────────────────────────
125
126/// 跨 Codec 的统一载体 — 存储层操作的对象。
127///
128/// 将序列化后的二进制数据与元数据打包,供 CheckpointStore 使用。
129/// 存储层无需知道 State 类型或序列化格式。
130///
131/// `graph_hash` 作为 correctness invariant 存储:
132/// 恢复时校验 `graph_hash` 不匹配 → reject,不允许 silent mismatch。
133#[derive(Debug, Clone)]
134pub struct CheckpointBlob {
135    /// Checkpoint 唯一标识
136    pub id: CheckpointId,
137    /// 序列化后的二进制数据(格式由 Codec 决定)
138    pub data: Vec<u8>,
139    /// 图结构指纹 — 恢复时校验兼容性
140    pub graph_hash: u64,
141    /// 创建时间
142    pub created_at: std::time::SystemTime,
143}
144
145impl CheckpointBlob {
146    pub fn new(
147        id: CheckpointId,
148        data: Vec<u8>,
149        graph_hash: u64,
150        created_at: std::time::SystemTime,
151    ) -> Self {
152        Self {
153            id,
154            data,
155            graph_hash,
156            created_at,
157        }
158    }
159}
160
161// ─── CheckpointStoreError ──────────────────────────────────────
162
163/// Checkpoint 存储操作错误。
164#[derive(Debug, thiserror::Error)]
165pub enum CheckpointStoreError {
166    #[error("storage error: {0}")]
167    Storage(String),
168    #[error("checkpoint not found: {0}")]
169    NotFound(CheckpointId),
170    #[error("corrupted checkpoint: {0}")]
171    Corrupted(String),
172    #[error("serialization error: {0}")]
173    Serialization(String),
174    #[error("graph mismatch: expected hash {expected:#018x}, got {actual:#018x}")]
175    GraphMismatch { expected: u64, actual: u64 },
176}
177
178// ─── TraceId Re-export ─────────────────────────────────────────
179
180/// 从 ids 模块重导出 TraceId。
181///
182/// 注意:Checkpoint 结构体**不包含** trace_id。
183/// 关联关系由存储层组织(如同一目录下的文件)。
184pub use crate::ids::TraceId;
185
186// ─── CheckpointPolicy 已迁移 ──────────────────────────────────
187
188/// 向后兼容 — CheckpointPolicy 已迁移至 checkpoint_policy 模块。
189/// v0.5 使用 TriggerPolicy + RetentionPolicy 替代。
190#[allow(deprecated)]
191// ─── Phase 6: Execution Frame Snapshot ────────────────────────
192
193/// 执行帧 — 保存单个 Graph 的执行位置。
194///
195/// 可序列化 — 用于 SessionCheckpoint 持久化。
196#[derive(Clone, Serialize, Deserialize)]
197pub struct Frame<S: WorkflowState = State> {
198    /// 图 ID
199    pub graph_id: String,
200
201    /// 当前节点 ID
202    pub node_id: String,
203
204    /// 状态快照(P0-1: 使用 Checkpoint 关联类型,可序列化)
205    pub state: S::Checkpoint,
206
207    /// 执行游标(节点索引或步骤数)
208    pub cursor: usize,
209}
210
211impl<S: WorkflowState> Frame<S> {
212    /// 从 Runtime State 创建 Frame(使用 snapshot() 投影)。
213    pub fn new(graph_id: String, node_id: String, state: &S, cursor: usize) -> Self {
214        Self {
215            graph_id,
216            node_id,
217            state: state.snapshot(),
218            cursor,
219        }
220    }
221}
222
223impl<S: WorkflowState> Debug for Frame<S>
224where
225    S::Checkpoint: Debug,
226{
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        f.debug_struct("Frame")
229            .field("graph_id", &self.graph_id)
230            .field("node_id", &self.node_id)
231            .field("state", &self.state)
232            .field("cursor", &self.cursor)
233            .finish()
234    }
235}
236
237/// 帧栈 — 保存完整的执行位置历史。
238///
239/// 可序列化 — 用于 SessionCheckpoint 持久化。
240#[derive(Clone, Serialize, Deserialize)]
241pub struct FrameStack<S: WorkflowState = State>
242where
243    S::Checkpoint: Debug,
244{
245    /// 帧列表(从外到内)
246    frames: Vec<Frame<S>>,
247}
248
249impl<S: WorkflowState> FrameStack<S>
250where
251    S::Checkpoint: Debug,
252{
253    /// 创建空的帧栈。
254    pub fn new() -> Self {
255        Self { frames: Vec::new() }
256    }
257
258    /// Push 一个新帧。
259    pub fn push(&mut self, frame: Frame<S>) {
260        self.frames.push(frame);
261    }
262
263    /// Pop 最后一个帧。
264    pub fn pop(&mut self) -> Option<Frame<S>> {
265        self.frames.pop()
266    }
267
268    /// 获取当前帧(最顶层)。
269    pub fn current(&self) -> Option<&Frame<S>> {
270        self.frames.last()
271    }
272
273    /// 获取帧数量。
274    pub fn depth(&self) -> usize {
275        self.frames.len()
276    }
277
278    /// 检查是否为空。
279    pub fn is_empty(&self) -> bool {
280        self.frames.is_empty()
281    }
282
283    /// 获取所有帧的引用。
284    pub fn frames(&self) -> &[Frame<S>] {
285        &self.frames
286    }
287}
288
289impl<S: WorkflowState> Default for FrameStack<S>
290where
291    S::Checkpoint: Debug,
292{
293    fn default() -> Self {
294        Self::new()
295    }
296}
297
298impl<S: WorkflowState> Debug for FrameStack<S>
299where
300    S::Checkpoint: Debug,
301{
302    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303        f.debug_struct("FrameStack")
304            .field("frames", &self.frames)
305            .finish()
306    }
307}
308
309// ─── FrameInfo ─────────────────────────────────────────────────
310
311/// Checkpoint 边界描述 — Engine 传递给 Sink 的最小上下文。
312///
313/// 设计原则:极简。Engine 只传递"我到了哪里",Sink 自行决定
314/// 是否记录、如何记录、记录多少。
315///
316/// - 想做节流?Sink 自己维护计数器。
317/// - 想做脏检查?Sink 自己缓存上次 snapshot 的 hash。
318/// - 想过滤特定节点?Sink 匹配 `node_id`。
319#[derive(Debug, Clone)]
320pub struct FrameInfo {
321    /// 当前节点 ID(commit 刚完成的节点)
322    pub node_id: String,
323    /// 执行步数(从 run_inline 入口开始计数)
324    pub step: usize,
325}
326
327impl FrameInfo {
328    /// 创建 FrameInfo。
329    pub fn new(node_id: impl Into<String>, step: usize) -> Self {
330        Self {
331            node_id: node_id.into(),
332            step,
333        }
334    }
335}
336
337// ─── CheckpointSink ────────────────────────────────────────────
338
339/// Checkpoint Sink SPI — 执行引擎通知 Sink 到达了合法的恢复边界。
340///
341/// Engine 保证:
342/// - 每次调用时,State 已 commit(mutation 已 apply),状态是一致的。
343/// - 调用顺序:`execute() → commit() → on_checkpoint() → route()`。
344///
345/// Sink 自行决定:
346/// - 是否记录(节流、过滤)
347/// - 是否 snapshot(借用 `&S`,Sink 决定是否 clone)
348/// - 序列化格式(serde、protobuf、binary)
349/// - 存储后端(内存、磁盘、网络)
350///
351/// # 设计原则
352///
353/// Engine 不拥有 Checkpoint 生命周期,只借用 Sink。
354/// 这与 D6 原则一致——Engine 不知道 FrameStack。
355pub trait CheckpointSink<S: WorkflowState>: Send + Sync {
356    /// 节点完成,State 已 commit。
357    ///
358    /// `state` 是借用——Sink 决定是否 snapshot/clone。
359    /// `frame` 描述当前执行位置。
360    fn on_checkpoint(&mut self, state: &S, frame: &FrameInfo);
361}
362
363/// 空 Sink — 不记录任何内容。
364///
365/// 用于不需要 Checkpoint 的场景(如 ToolUseLoop 的简单调用)。
366#[derive(Debug, Default)]
367pub struct NoopCheckpointSink;
368
369impl<S: WorkflowState> CheckpointSink<S> for NoopCheckpointSink {
370    fn on_checkpoint(&mut self, _state: &S, _frame: &FrameInfo) {
371        // 什么都不做
372    }
373}
374
375/// 内存 Sink — 将所有 checkpoint 记录到内存。
376///
377/// 用于调试、测试、time travel。
378///
379/// 要求 `S::Checkpoint: Debug`(Frame 需要 Debug)。
380pub struct MemorySink<S: WorkflowState = State> {
381    pub frames: Vec<Frame<S>>,
382}
383
384impl<S: WorkflowState> Debug for MemorySink<S>
385where
386    S::Checkpoint: Debug,
387{
388    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
389        f.debug_struct("MemorySink")
390            .field("frames", &self.frames)
391            .finish()
392    }
393}
394
395impl<S: WorkflowState> Default for MemorySink<S> {
396    fn default() -> Self {
397        Self { frames: Vec::new() }
398    }
399}
400
401impl<S: WorkflowState> MemorySink<S> {
402    pub fn new() -> Self {
403        Self::default()
404    }
405
406    pub fn into_frames(self) -> Vec<Frame<S>> {
407        self.frames
408    }
409}
410
411impl<S: WorkflowState> CheckpointSink<S> for MemorySink<S>
412where
413    S::Checkpoint: Sync,
414{
415    fn on_checkpoint(&mut self, state: &S, frame: &FrameInfo) {
416        self.frames.push(Frame {
417            graph_id: String::new(), // Engine 不传递 graph_id,由 Sink 填充
418            node_id: frame.node_id.clone(),
419            state: state.snapshot(),
420            cursor: frame.step,
421        });
422    }
423}
424
425// ─── Tests ─────────────────────────────────────────────────────
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430    use crate::state::StateMerge;
431    use crate::{GraphBuilder, NodeKind, TaskNode};
432    use std::sync::Arc;
433    use tokio_util::sync::CancellationToken;
434
435    #[tokio::test]
436    async fn test_auto_checkpoint_via_memory_sink() {
437        // 创建一个简单的 Graph: start → a → b → end
438        let mut builder = GraphBuilder::<State, StateMerge>::new("test");
439        builder.start("a");
440        builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
441        builder.node("b", NodeKind::Task(TaskNode::new("b", |_| Ok(()))));
442        builder.end("b");
443        builder.edge("a", "b");
444        let graph = Arc::new(builder.build().unwrap());
445
446        // 创建 MemorySink
447        let mut sink = MemorySink::<State>::new();
448
449        // 创建 Engine 并绑定 sink
450        let mut state = State::new();
451        let mut engine: crate::ExecutionEngine<'_, State> = crate::ExecutionEngine::new(
452            &mut state,
453            None,
454            CancellationToken::new(),
455            Some(&mut sink),
456            None,
457        );
458
459        // 执行
460        let mut cb = crate::graph::NoopStepCallback;
461        graph.run_inline(&mut engine, 100, &mut cb).await.unwrap();
462
463        // 验证:应该有 2 个 checkpoint(a 和 b)
464        assert_eq!(sink.frames.len(), 2);
465        assert_eq!(sink.frames[0].node_id, "a");
466        assert_eq!(sink.frames[1].node_id, "b");
467        assert_eq!(sink.frames[0].cursor, 1);
468        assert_eq!(sink.frames[1].cursor, 2);
469    }
470
471    #[tokio::test]
472    async fn test_noop_checkpoint_sink() {
473        // 验证 NoopCheckpointSink 不记录任何内容
474        let mut builder = GraphBuilder::<State, StateMerge>::new("test");
475        builder.start("a");
476        builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
477        builder.end("a");
478        let graph = Arc::new(builder.build().unwrap());
479
480        let mut sink = NoopCheckpointSink;
481        let mut state = State::new();
482        let mut engine: crate::ExecutionEngine<'_, State> = crate::ExecutionEngine::new(
483            &mut state,
484            None,
485            CancellationToken::new(),
486            Some(&mut sink),
487            None,
488        );
489
490        let mut cb = crate::graph::NoopStepCallback;
491        graph.run_inline(&mut engine, 100, &mut cb).await.unwrap();
492        // NoopSink 不记录,无需断言
493    }
494
495    #[test]
496    fn test_frame_info_minimal() {
497        let info = FrameInfo::new("test_node", 42);
498        assert_eq!(info.node_id, "test_node");
499        assert_eq!(info.step, 42);
500    }
501}