use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::graph::state::State;
use super::config::CheckpointConfig;
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound = "S: Serialize + serde::de::DeserializeOwned")]
pub struct StateSnapshot<S: State> {
pub values: S,
pub next: Vec<String>,
pub config: CheckpointConfig,
pub metadata: HashMap<String, Value>,
pub created_at: DateTime<Utc>,
pub parent_config: Option<CheckpointConfig>,
#[serde(default)]
pub at_seq: Option<u64>,
}
impl<S: State> StateSnapshot<S> {
pub fn new(values: S, next: Vec<String>, config: CheckpointConfig) -> Self {
Self {
values,
next,
config,
metadata: HashMap::new(),
created_at: Utc::now(),
parent_config: None,
at_seq: None,
}
}
pub fn with_metadata(
values: S,
next: Vec<String>,
config: CheckpointConfig,
metadata: HashMap<String, Value>,
) -> Self {
Self {
values,
next,
config,
metadata,
created_at: Utc::now(),
parent_config: None,
at_seq: None,
}
}
pub fn with_parent(
values: S,
next: Vec<String>,
config: CheckpointConfig,
parent_config: CheckpointConfig,
) -> Self {
Self {
values,
next,
config,
metadata: HashMap::new(),
created_at: Utc::now(),
parent_config: Some(parent_config),
at_seq: None,
}
}
pub fn with_at_seq(mut self, at_seq: u64) -> Self {
self.at_seq = Some(at_seq);
self
}
pub fn checkpoint_id(&self) -> Option<&String> {
self.config.checkpoint_id.as_ref()
}
pub fn thread_id(&self) -> &str {
&self.config.thread_id
}
pub fn to_config(&self) -> super::config::RunnableConfig {
let mut config = super::config::RunnableConfig::with_thread_id(self.thread_id());
if let Some(checkpoint_id) = self.checkpoint_id() {
config.configurable.insert(
"checkpoint_id".to_string(),
serde_json::json!(checkpoint_id),
);
}
if let Some(checkpoint_ns) = &self.config.checkpoint_ns {
config.configurable.insert(
"checkpoint_ns".to_string(),
serde_json::json!(checkpoint_ns),
);
}
config
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::state::MessagesState;
#[test]
fn test_state_snapshot() {
let state = MessagesState::new();
let config = CheckpointConfig::new("thread-1");
let snapshot = StateSnapshot::new(state, vec!["node1".to_string()], config);
assert_eq!(snapshot.thread_id(), "thread-1");
assert_eq!(snapshot.next, vec!["node1"]);
}
}