use crate::config::Config;
use crate::errors::{Error, Result};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint {
pub v: i32,
pub id: String,
pub ts: String,
pub channel_values: HashMap<String, serde_json::Value>,
pub channel_versions: HashMap<String, i32>,
pub versions_seen: HashMap<String, HashMap<String, i32>>,
pub thread_id: Option<String>,
pub parent_id: Option<String>,
}
impl Checkpoint {
pub fn new() -> Self {
Self {
v: 1,
id: Uuid::new_v4().to_string(),
ts: Utc::now().to_rfc3339(),
channel_values: HashMap::new(),
channel_versions: HashMap::new(),
versions_seen: HashMap::new(),
thread_id: None,
parent_id: None,
}
}
pub fn with_thread_id(mut self, thread_id: impl Into<String>) -> Self {
self.thread_id = Some(thread_id.into());
self
}
pub fn set_channel(&mut self, name: impl Into<String>, value: serde_json::Value) {
let name = name.into();
let version = self.channel_versions.get(&name).copied().unwrap_or(0) + 1;
self.channel_values.insert(name.clone(), value);
self.channel_versions.insert(name, version);
}
pub fn get_channel(&self, name: &str) -> Option<&serde_json::Value> {
self.channel_values.get(name)
}
}
impl Default for Checkpoint {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointTuple {
pub checkpoint: Checkpoint,
pub metadata: CheckpointMetadata,
pub config: Config,
pub parent: Option<Box<CheckpointTuple>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointMetadata {
pub created_at: DateTime<Utc>,
pub step: usize,
pub source: String,
pub extra: HashMap<String, serde_json::Value>,
}
impl Default for CheckpointMetadata {
fn default() -> Self {
Self {
created_at: Utc::now(),
step: 0,
source: "unknown".to_string(),
extra: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StateSnapshot<S> {
pub state: S,
pub checkpoint: Checkpoint,
pub metadata: CheckpointMetadata,
pub config: Config,
}
#[async_trait]
pub trait BaseCheckpointSaver: Send + Sync {
async fn get_tuple(&self, config: &Config) -> Result<Option<CheckpointTuple>>;
async fn put(
&self,
checkpoint: &Checkpoint,
metadata: &CheckpointMetadata,
config: &Config,
) -> Result<Config>;
async fn list(&self, config: &Config, limit: Option<usize>) -> Result<Vec<CheckpointTuple>>;
async fn get(&self, checkpoint_id: &str) -> Result<Option<Checkpoint>> {
let config = Config::new().with_checkpoint_id(checkpoint_id);
Ok(self.get_tuple(&config).await?.map(|t| t.checkpoint))
}
async fn delete_thread(&self, thread_id: &str) -> Result<()> {
Err(Error::checkpoint(format!(
"delete_thread not implemented for thread {}",
thread_id
)))
}
async fn prune(&self, thread_id: &str, keep: usize) -> Result<usize> {
let _ = (thread_id, keep);
Err(Error::checkpoint("prune not implemented"))
}
}
pub type CheckpointSaverBox = Box<dyn BaseCheckpointSaver>;
pub type CheckpointSaverArc = std::sync::Arc<dyn BaseCheckpointSaver>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_checkpoint_creation() {
let checkpoint = Checkpoint::new();
assert_eq!(checkpoint.v, 1);
assert!(!checkpoint.id.is_empty());
assert!(checkpoint.channel_values.is_empty());
}
#[test]
fn test_checkpoint_with_thread_id() {
let checkpoint = Checkpoint::new().with_thread_id("thread-123");
assert_eq!(checkpoint.thread_id.as_deref(), Some("thread-123"));
}
#[test]
fn test_checkpoint_set_get_channel() {
let mut checkpoint = Checkpoint::new();
checkpoint.set_channel("my_channel", serde_json::json!({"value": 42}));
let value = checkpoint.get_channel("my_channel").unwrap();
assert_eq!(value, &serde_json::json!({"value": 42}));
let version = checkpoint.channel_versions.get("my_channel").unwrap();
assert_eq!(*version, 1);
checkpoint.set_channel("my_channel", serde_json::json!({"value": 43}));
let version = checkpoint.channel_versions.get("my_channel").unwrap();
assert_eq!(*version, 2);
}
#[test]
fn test_checkpoint_serialization() {
let mut checkpoint = Checkpoint::new().with_thread_id("test");
checkpoint.set_channel("count", serde_json::json!(5));
let json = serde_json::to_string(&checkpoint).unwrap();
let deserialized: Checkpoint = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.thread_id, checkpoint.thread_id);
assert_eq!(
deserialized.get_channel("count"),
checkpoint.get_channel("count")
);
}
#[test]
fn test_checkpoint_metadata() {
let metadata = CheckpointMetadata {
step: 5,
source: "test".to_string(),
..Default::default()
};
assert_eq!(metadata.step, 5);
assert_eq!(metadata.source, "test");
}
}