lellm_graph/
checkpoint_codec.rs1use serde::{Deserialize, Serialize};
15
16use crate::checkpoint::{Checkpoint, CheckpointBlob, CheckpointStoreError};
17use crate::state::State;
18use crate::store::BlobCheckpointStore;
19use crate::workflow_state::WorkflowState;
20
21pub trait CheckpointCodec<S: WorkflowState = State>: Send + Sync {
29 fn serialize(
34 &self,
35 cp: &Checkpoint<S>,
36 graph_hash: u64,
37 ) -> Result<CheckpointBlob, CheckpointStoreError>;
38
39 fn deserialize(
44 &self,
45 blob: &CheckpointBlob,
46 expected_hash: u64,
47 ) -> Result<Checkpoint<S>, CheckpointStoreError>;
48}
49
50#[derive(Debug, Default)]
57pub struct SerdeCheckpointCodec<S: WorkflowState = State> {
58 _phantom: std::marker::PhantomData<S>,
59}
60
61impl<S: WorkflowState> SerdeCheckpointCodec<S> {
62 pub fn new() -> Self {
63 Self {
64 _phantom: std::marker::PhantomData,
65 }
66 }
67}
68
69impl<S> CheckpointCodec<S> for SerdeCheckpointCodec<S>
70where
71 S: WorkflowState + Serialize + for<'de> Deserialize<'de>,
72{
73 fn serialize(
74 &self,
75 cp: &Checkpoint<S>,
76 graph_hash: u64,
77 ) -> Result<CheckpointBlob, CheckpointStoreError> {
78 let data = serde_json::to_vec(cp)
79 .map_err(|e| CheckpointStoreError::Serialization(e.to_string()))?;
80 Ok(CheckpointBlob::new(
81 cp.checkpoint_id.clone(),
82 data,
83 graph_hash,
84 cp.created_at,
85 ))
86 }
87
88 fn deserialize(
89 &self,
90 blob: &CheckpointBlob,
91 expected_hash: u64,
92 ) -> Result<Checkpoint<S>, CheckpointStoreError> {
93 if blob.graph_hash != expected_hash {
94 return Err(CheckpointStoreError::GraphMismatch {
95 expected: expected_hash,
96 actual: blob.graph_hash,
97 });
98 }
99 let cp: Checkpoint<S> = serde_json::from_slice(&blob.data)
100 .map_err(|e| CheckpointStoreError::Corrupted(e.to_string()))?;
101 Ok(cp)
102 }
103}
104
105pub struct TypedCheckpointStore<'a, Codec, S: WorkflowState = State> {
123 store: &'a dyn BlobCheckpointStore,
124 codec: Codec,
125 _phantom: std::marker::PhantomData<S>,
126}
127
128impl<'a, Codec, S> TypedCheckpointStore<'a, Codec, S>
129where
130 S: WorkflowState,
131{
132 pub fn new(store: &'a dyn BlobCheckpointStore, codec: Codec) -> Self {
133 Self {
134 store,
135 codec,
136 _phantom: std::marker::PhantomData,
137 }
138 }
139}
140
141impl<'a, Codec, S> TypedCheckpointStore<'a, Codec, S>
142where
143 S: WorkflowState + Serialize + for<'de> Deserialize<'de>,
144 Codec: CheckpointCodec<S>,
145{
146 pub async fn save_with_trace(
151 &self,
152 trace_id: &crate::checkpoint::TraceId,
153 checkpoint: &Checkpoint<S>,
154 graph_hash: u64,
155 ) -> Result<(), CheckpointStoreError> {
156 let blob = self.codec.serialize(checkpoint, graph_hash)?;
157 self.store.save_with_trace(trace_id, &blob).await
158 }
159
160 pub async fn load(
164 &self,
165 id: &crate::checkpoint::CheckpointId,
166 expected_hash: u64,
167 ) -> Result<Option<Checkpoint<S>>, CheckpointStoreError> {
168 match self.store.load(id).await? {
169 Some(blob) => Ok(Some(self.codec.deserialize(&blob, expected_hash)?)),
170 None => Ok(None),
171 }
172 }
173
174 pub async fn load_latest(
178 &self,
179 trace_id: &crate::checkpoint::TraceId,
180 expected_hash: u64,
181 ) -> Result<Option<Checkpoint<S>>, CheckpointStoreError> {
182 match self.store.load_latest(trace_id).await? {
183 Some(blob) => Ok(Some(self.codec.deserialize(&blob, expected_hash)?)),
184 None => Ok(None),
185 }
186 }
187}