Skip to main content

lellm_graph/checkpoint/
checkpoint_codec.rs

1//! CheckpointCodec — 序列化层,对象 ↔ 二进制表示。
2//!
3//! 将 `Checkpoint<S>` 序列化为 `CheckpointBlob`,实现存储层与 State 类型的解耦。
4//!
5//! # 设计
6//!
7//! ```text
8//! Checkpoint<S> ──serialize()──▶ CheckpointBlob ──deserialize()──▶ Checkpoint<S>
9//! ```
10//!
11//! Codec 实现可以选择任意序列化格式(JSON、MessagePack、Bincode 等),
12//! 存储层只需操作 `CheckpointBlob`,无需知道 State 类型或序列化格式。
13
14use serde::{Deserialize, Serialize};
15
16use super::checkpoint::{Checkpoint, CheckpointBlob, CheckpointStoreError};
17use super::store::BlobCheckpointStore;
18use crate::state::State;
19use crate::state::workflow_state::WorkflowState;
20
21// ─── CheckpointCodec Trait ─────────────────────────────────────
22
23/// Checkpoint 序列化/反序列化接口。
24///
25/// # 泛型参数
26///
27/// - `S` — 类型化状态(默认 `State` = HashMap,向后兼容)
28pub trait CheckpointCodec<S: WorkflowState = State>: Send + Sync {
29    /// 将 Checkpoint 序列化为二进制 Blob。
30    ///
31    /// `graph_hash` 由调用方提供(从 `Graph::hash_u64()` 获取),
32    /// 写入 Blob 作为 correctness invariant。
33    fn serialize(
34        &self,
35        cp: &Checkpoint<S>,
36        graph_hash: u64,
37    ) -> Result<CheckpointBlob, CheckpointStoreError>;
38
39    /// 从二进制 Blob 反序列化为 Checkpoint。
40    ///
41    /// 如果 Blob 中的 `graph_hash` 与 `expected_hash` 不匹配,
42    /// 返回 `CheckpointStoreError::GraphMismatch`。
43    fn deserialize(
44        &self,
45        blob: &CheckpointBlob,
46        expected_hash: u64,
47    ) -> Result<Checkpoint<S>, CheckpointStoreError>;
48}
49
50// ─── SerdeCheckpointCodec ──────────────────────────────────────
51
52/// 基于 Serde + JSON 的默认 Codec 实现。
53///
54/// 使用 `serde_json` 进行序列化,适用于大多数场景。
55/// 对于性能敏感场景,可替换为 Bincode 或 MessagePack。
56#[derive(Debug, Default)]
57pub struct SerdeCheckpointCodec<S: WorkflowState = State> {
58    _phantom: std::marker::PhantomData<S>,
59}
60
61impl<S: WorkflowState> SerdeCheckpointCodec<S> {
62    pub fn new() -> Self {
63        Self {
64            _phantom: std::marker::PhantomData,
65        }
66    }
67}
68
69impl<S> CheckpointCodec<S> for SerdeCheckpointCodec<S>
70where
71    S: WorkflowState + Serialize + for<'de> Deserialize<'de>,
72{
73    fn serialize(
74        &self,
75        cp: &Checkpoint<S>,
76        graph_hash: u64,
77    ) -> Result<CheckpointBlob, CheckpointStoreError> {
78        let data = serde_json::to_vec(cp)
79            .map_err(|e| CheckpointStoreError::Serialization(e.to_string()))?;
80        Ok(CheckpointBlob::new(
81            cp.checkpoint_id.clone(),
82            data,
83            graph_hash,
84            cp.created_at,
85        ))
86    }
87
88    fn deserialize(
89        &self,
90        blob: &CheckpointBlob,
91        expected_hash: u64,
92    ) -> Result<Checkpoint<S>, CheckpointStoreError> {
93        if blob.graph_hash != expected_hash {
94            return Err(CheckpointStoreError::GraphMismatch {
95                expected: expected_hash,
96                actual: blob.graph_hash,
97            });
98        }
99        let cp: Checkpoint<S> = serde_json::from_slice(&blob.data)
100            .map_err(|e| CheckpointStoreError::Corrupted(e.to_string()))?;
101        Ok(cp)
102    }
103}
104
105// ─── TypedCheckpointStore ──────────────────────────────────────
106
107/// 类型化 Checkpoint 存储 — Codec + BlobStore 的组合。
108///
109/// 将 `Checkpoint<S>` 的保存/加载委托给 Codec 进行序列化,
110/// 再通过 BlobCheckpointStore 进行持久化。
111///
112/// # 示例
113///
114/// ```rust,ignore
115/// let store = InMemoryBlobStore::new();
116/// let codec = SerdeCheckpointCodec::<State>::new();
117/// let typed = TypedCheckpointStore::new(&store, codec);
118///
119/// typed.save_with_trace(&trace_id, &checkpoint, graph_hash).await?;
120/// let restored = typed.load(&id, graph_hash).await?;
121/// ```
122pub struct TypedCheckpointStore<'a, Codec, S: WorkflowState = State> {
123    store: &'a dyn BlobCheckpointStore,
124    codec: Codec,
125    _phantom: std::marker::PhantomData<S>,
126}
127
128impl<'a, Codec, S> TypedCheckpointStore<'a, Codec, S>
129where
130    S: WorkflowState,
131{
132    pub fn new(store: &'a dyn BlobCheckpointStore, codec: Codec) -> Self {
133        Self {
134            store,
135            codec,
136            _phantom: std::marker::PhantomData,
137        }
138    }
139}
140
141impl<'a, Codec, S> TypedCheckpointStore<'a, Codec, S>
142where
143    S: WorkflowState + Serialize + for<'de> Deserialize<'de>,
144    Codec: CheckpointCodec<S>,
145{
146    /// 保存 Checkpoint 并关联 trace_id。
147    ///
148    /// `graph_hash` 由调用方提供(从 `Graph::hash_u64()` 获取),
149    /// 写入 Blob 作为 correctness invariant。
150    pub async fn save_with_trace(
151        &self,
152        trace_id: &super::checkpoint::TraceId,
153        checkpoint: &Checkpoint<S>,
154        graph_hash: u64,
155    ) -> Result<(), CheckpointStoreError> {
156        let blob = self.codec.serialize(checkpoint, graph_hash)?;
157        self.store.save_with_trace(trace_id, &blob).await
158    }
159
160    /// 加载指定 ID 的 Checkpoint。
161    ///
162    /// 校验 `graph_hash`:不匹配则返回 `GraphMismatch` 错误。
163    pub async fn load(
164        &self,
165        id: &super::checkpoint::CheckpointId,
166        expected_hash: u64,
167    ) -> Result<Option<Checkpoint<S>>, CheckpointStoreError> {
168        match self.store.load(id).await? {
169            Some(blob) => Ok(Some(self.codec.deserialize(&blob, expected_hash)?)),
170            None => Ok(None),
171        }
172    }
173
174    /// 加载 trace 最新的 Checkpoint。
175    ///
176    /// 校验 `graph_hash`:不匹配则返回 `GraphMismatch` 错误。
177    pub async fn load_latest(
178        &self,
179        trace_id: &super::checkpoint::TraceId,
180        expected_hash: u64,
181    ) -> Result<Option<Checkpoint<S>>, CheckpointStoreError> {
182        match self.store.load_latest(trace_id).await? {
183            Some(blob) => Ok(Some(self.codec.deserialize(&blob, expected_hash)?)),
184            None => Ok(None),
185        }
186    }
187}