use serde::{Deserialize, Serialize};
use super::checkpoint::{Checkpoint, CheckpointBlob, CheckpointStoreError};
use super::store::BlobCheckpointStore;
use crate::state::State;
use crate::state::workflow_state::WorkflowState;
pub trait CheckpointCodec<S: WorkflowState = State>: Send + Sync {
fn serialize(
&self,
cp: &Checkpoint<S>,
graph_hash: u64,
) -> Result<CheckpointBlob, CheckpointStoreError>;
fn deserialize(
&self,
blob: &CheckpointBlob,
expected_hash: u64,
) -> Result<Checkpoint<S>, CheckpointStoreError>;
}
#[derive(Debug, Default)]
pub struct SerdeCheckpointCodec<S: WorkflowState = State> {
_phantom: std::marker::PhantomData<S>,
}
impl<S: WorkflowState> SerdeCheckpointCodec<S> {
pub fn new() -> Self {
Self {
_phantom: std::marker::PhantomData,
}
}
}
impl<S> CheckpointCodec<S> for SerdeCheckpointCodec<S>
where
S: WorkflowState + Serialize + for<'de> Deserialize<'de>,
{
fn serialize(
&self,
cp: &Checkpoint<S>,
graph_hash: u64,
) -> Result<CheckpointBlob, CheckpointStoreError> {
let data = serde_json::to_vec(cp)
.map_err(|e| CheckpointStoreError::Serialization(e.to_string()))?;
Ok(CheckpointBlob::new(
cp.checkpoint_id.clone(),
data,
graph_hash,
cp.created_at,
))
}
fn deserialize(
&self,
blob: &CheckpointBlob,
expected_hash: u64,
) -> Result<Checkpoint<S>, CheckpointStoreError> {
if blob.graph_hash != expected_hash {
return Err(CheckpointStoreError::GraphMismatch {
expected: expected_hash,
actual: blob.graph_hash,
});
}
let cp: Checkpoint<S> = serde_json::from_slice(&blob.data)
.map_err(|e| CheckpointStoreError::Corrupted(e.to_string()))?;
Ok(cp)
}
}
pub struct TypedCheckpointStore<'a, Codec, S: WorkflowState = State> {
store: &'a dyn BlobCheckpointStore,
codec: Codec,
_phantom: std::marker::PhantomData<S>,
}
impl<'a, Codec, S> TypedCheckpointStore<'a, Codec, S>
where
S: WorkflowState,
{
pub fn new(store: &'a dyn BlobCheckpointStore, codec: Codec) -> Self {
Self {
store,
codec,
_phantom: std::marker::PhantomData,
}
}
}
impl<'a, Codec, S> TypedCheckpointStore<'a, Codec, S>
where
S: WorkflowState + Serialize + for<'de> Deserialize<'de>,
Codec: CheckpointCodec<S>,
{
pub async fn save_with_trace(
&self,
trace_id: &super::checkpoint::TraceId,
checkpoint: &Checkpoint<S>,
graph_hash: u64,
) -> Result<(), CheckpointStoreError> {
let blob = self.codec.serialize(checkpoint, graph_hash)?;
self.store.save_with_trace(trace_id, &blob).await
}
pub async fn load(
&self,
id: &super::checkpoint::CheckpointId,
expected_hash: u64,
) -> Result<Option<Checkpoint<S>>, CheckpointStoreError> {
match self.store.load(id).await? {
Some(blob) => Ok(Some(self.codec.deserialize(&blob, expected_hash)?)),
None => Ok(None),
}
}
pub async fn load_latest(
&self,
trace_id: &super::checkpoint::TraceId,
expected_hash: u64,
) -> Result<Option<Checkpoint<S>>, CheckpointStoreError> {
match self.store.load_latest(trace_id).await? {
Some(blob) => Ok(Some(self.codec.deserialize(&blob, expected_hash)?)),
None => Ok(None),
}
}
}