use serde::{Deserialize, Serialize};
use crate::state_store::StateStoreProvider;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ReactionCheckpoint {
pub sequence: u64,
pub config_hash: u64,
}
pub(crate) async fn read_checkpoint(
store: &dyn StateStoreProvider,
reaction_id: &str,
query_id: &str,
) -> anyhow::Result<Option<ReactionCheckpoint>> {
let key = format!("checkpoint:{query_id}");
match store.get(reaction_id, &key).await? {
Some(bytes) => {
let cp: ReactionCheckpoint = bincode::deserialize(&bytes)
.map_err(|e| anyhow::anyhow!("Failed to deserialize checkpoint: {e}"))?;
Ok(Some(cp))
}
None => Ok(None),
}
}
pub(crate) async fn read_checkpoints_batch(
store: &dyn StateStoreProvider,
reaction_id: &str,
query_ids: &[String],
) -> anyhow::Result<std::collections::HashMap<String, ReactionCheckpoint>> {
let keys: Vec<String> = query_ids
.iter()
.map(|q| format!("checkpoint:{q}"))
.collect();
let key_refs: Vec<&str> = keys.iter().map(|k| k.as_str()).collect();
let raw = store.get_many(reaction_id, &key_refs).await?;
let mut result = std::collections::HashMap::new();
for (key, bytes) in raw {
let qid = key.strip_prefix("checkpoint:").unwrap_or(&key).to_string();
let cp: ReactionCheckpoint = bincode::deserialize(&bytes).map_err(|e| {
anyhow::anyhow!("Failed to deserialize checkpoint for query '{qid}': {e}")
})?;
result.insert(qid, cp);
}
Ok(result)
}
pub(crate) async fn write_checkpoint(
store: &dyn StateStoreProvider,
reaction_id: &str,
query_id: &str,
checkpoint: &ReactionCheckpoint,
) -> anyhow::Result<()> {
let key = format!("checkpoint:{query_id}");
let bytes = bincode::serialize(checkpoint)
.map_err(|e| anyhow::anyhow!("Failed to serialize checkpoint: {e}"))?;
store.set(reaction_id, &key, bytes).await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bincode_round_trip() {
let checkpoint = ReactionCheckpoint {
sequence: 42,
config_hash: 0xDEAD_BEEF,
};
let bytes = bincode::serialize(&checkpoint).unwrap();
let decoded: ReactionCheckpoint = bincode::deserialize(&bytes).unwrap();
assert_eq!(checkpoint, decoded);
}
#[test]
fn serde_json_round_trip() {
let checkpoint = ReactionCheckpoint {
sequence: 100,
config_hash: 12345,
};
let json = serde_json::to_string(&checkpoint).unwrap();
let decoded: ReactionCheckpoint = serde_json::from_str(&json).unwrap();
assert_eq!(checkpoint, decoded);
}
}