use std::sync::Arc;
use serde::Serialize;
use crate::checkpoint::Checkpointer;
use crate::error::{GraphError, Result};
use crate::graph::CompiledGraph;
use crate::node::ExecutionConfig;
use crate::state::State;
#[derive(Debug, Clone, Serialize)]
pub struct StepInfo {
pub step: usize,
pub checkpoint_id: String,
pub timestamp: Option<String>,
pub pending_nodes: Vec<String>,
pub state_keys: Vec<String>,
}
pub struct TimeTravelHandle<'g> {
#[allow(dead_code)]
pub(crate) graph: &'g CompiledGraph,
pub(crate) thread_id: String,
pub(crate) checkpointer: Arc<dyn Checkpointer>,
}
impl<'g> TimeTravelHandle<'g> {
pub fn new(
graph: &'g CompiledGraph,
thread_id: &str,
checkpointer: Arc<dyn Checkpointer>,
) -> Self {
Self { graph, thread_id: thread_id.to_string(), checkpointer }
}
pub async fn steps(&self) -> Result<Vec<StepInfo>> {
let checkpoints = self.checkpointer.list(&self.thread_id).await?;
let mut steps: Vec<StepInfo> = checkpoints
.into_iter()
.map(|cp| StepInfo {
step: cp.step,
checkpoint_id: cp.checkpoint_id,
timestamp: Some(cp.created_at.to_rfc3339()),
pending_nodes: cp.pending_nodes,
state_keys: cp.state.keys().cloned().collect(),
})
.collect();
steps.sort_by_key(|s| s.step);
Ok(steps)
}
pub async fn resume_from(&self, step: usize, config: ExecutionConfig) -> Result<State> {
let checkpoints = self.checkpointer.list(&self.thread_id).await?;
let checkpoint = checkpoints.into_iter().find(|cp| cp.step == step).ok_or_else(|| {
GraphError::CheckpointError(format!(
"no checkpoint found at step {step} for thread '{}'",
self.thread_id
))
})?;
let resume_config = ExecutionConfig::new(&config.thread_id)
.with_resume_from(&checkpoint.checkpoint_id)
.with_recursion_limit(config.recursion_limit);
self.graph.invoke(State::new(), resume_config).await
}
pub async fn fork_at(&self, step: usize, new_thread_id: &str) -> Result<()> {
let checkpoints = self.checkpointer.list(&self.thread_id).await?;
let checkpoint = checkpoints.into_iter().find(|cp| cp.step == step).ok_or_else(|| {
GraphError::CheckpointError(format!(
"no checkpoint found at step {step} for thread '{}'",
self.thread_id
))
})?;
let forked = crate::state::Checkpoint::new(
new_thread_id,
checkpoint.state,
checkpoint.step,
checkpoint.pending_nodes,
);
self.checkpointer.save(&forked).await?;
Ok(())
}
pub async fn replay(
&self,
from_step: usize,
to_step: Option<usize>,
) -> Result<Vec<(usize, State)>> {
let mut checkpoints = self.checkpointer.list(&self.thread_id).await?;
checkpoints.sort_by_key(|cp| cp.step);
let results: Vec<(usize, State)> = checkpoints
.into_iter()
.filter(|cp| cp.step >= from_step && to_step.is_none_or(|end| cp.step <= end))
.map(|cp| (cp.step, cp.state))
.collect();
if results.is_empty() || results[0].0 != from_step {
return Err(GraphError::CheckpointError(format!(
"no checkpoint found at step {from_step} for thread '{}'",
self.thread_id
)));
}
Ok(results)
}
}
impl CompiledGraph {
pub fn time_travel(&self, thread_id: &str) -> TimeTravelHandle<'_> {
let checkpointer = self
.checkpointer
.clone()
.expect("time_travel requires a checkpointer to be configured");
TimeTravelHandle::new(self, thread_id, checkpointer)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::checkpoint::MemoryCheckpointer;
use crate::graph::StateGraph;
use crate::node::{ExecutionConfig, FunctionNode, NodeContext, NodeOutput};
use crate::state::{Checkpoint, State, StateSchema};
use serde_json::json;
fn build_test_graph() -> (CompiledGraph, Arc<MemoryCheckpointer>) {
let checkpointer = Arc::new(MemoryCheckpointer::new());
let node = FunctionNode::new("increment", |ctx: NodeContext| {
let count = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0);
Box::pin(async move { Ok(NodeOutput::new().with_update("count", json!(count + 1))) })
});
let graph = StateGraph::new(StateSchema::simple(&["count"]))
.add_node(node)
.add_edge("__start__", "increment")
.add_edge("increment", "__end__")
.compile()
.unwrap()
.with_checkpointer_arc(checkpointer.clone());
(graph, checkpointer)
}
async fn seed_checkpoints(checkpointer: &MemoryCheckpointer, thread_id: &str, count: usize) {
for step in 0..count {
let mut state = State::new();
state.insert("count".to_string(), json!(step));
let cp = Checkpoint::new(thread_id, state, step, vec!["increment".to_string()]);
checkpointer.save(&cp).await.unwrap();
}
}
#[tokio::test]
async fn test_fork_at_creates_new_thread() {
let (graph, checkpointer) = build_test_graph();
seed_checkpoints(&checkpointer, "thread_1", 5).await;
let handle = graph.time_travel("thread_1");
handle.fork_at(2, "thread_1_fork").await.unwrap();
let forked_checkpoints = checkpointer.list("thread_1_fork").await.unwrap();
assert_eq!(forked_checkpoints.len(), 1);
assert_eq!(forked_checkpoints[0].step, 2);
assert_eq!(forked_checkpoints[0].state.get("count"), Some(&json!(2)));
let original_checkpoints = checkpointer.list("thread_1").await.unwrap();
assert_eq!(original_checkpoints.len(), 5);
}
#[tokio::test]
async fn test_fork_at_nonexistent_step_errors() {
let (graph, checkpointer) = build_test_graph();
seed_checkpoints(&checkpointer, "thread_1", 3).await;
let handle = graph.time_travel("thread_1");
let result = handle.fork_at(99, "new_thread").await;
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("no checkpoint found at step 99"));
}
#[tokio::test]
async fn test_replay_returns_states_in_range() {
let (graph, checkpointer) = build_test_graph();
seed_checkpoints(&checkpointer, "thread_1", 5).await;
let handle = graph.time_travel("thread_1");
let results = handle.replay(1, Some(3)).await.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(
results[0],
(1, {
let mut s = State::new();
s.insert("count".to_string(), json!(1));
s
})
);
assert_eq!(results[1].0, 2);
assert_eq!(results[2].0, 3);
}
#[tokio::test]
async fn test_replay_to_end() {
let (graph, checkpointer) = build_test_graph();
seed_checkpoints(&checkpointer, "thread_1", 5).await;
let handle = graph.time_travel("thread_1");
let results = handle.replay(2, None).await.unwrap();
assert_eq!(results.len(), 3); assert_eq!(results[0].0, 2);
assert_eq!(results[1].0, 3);
assert_eq!(results[2].0, 4);
}
#[tokio::test]
async fn test_replay_nonexistent_from_step_errors() {
let (graph, checkpointer) = build_test_graph();
seed_checkpoints(&checkpointer, "thread_1", 3).await;
let handle = graph.time_travel("thread_1");
let result = handle.replay(99, None).await;
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("no checkpoint found at step 99"));
}
#[tokio::test]
async fn test_resume_from_executes_graph() {
let (graph, checkpointer) = build_test_graph();
let mut state = State::new();
state.insert("count".to_string(), json!(5));
let cp = Checkpoint::new("thread_1", state, 0, vec!["increment".to_string()]);
checkpointer.save(&cp).await.unwrap();
let handle = graph.time_travel("thread_1");
let config = ExecutionConfig::new("thread_1");
let final_state = handle.resume_from(0, config).await.unwrap();
assert_eq!(final_state.get("count"), Some(&json!(6)));
}
#[tokio::test]
async fn test_resume_from_nonexistent_step_errors() {
let (graph, checkpointer) = build_test_graph();
seed_checkpoints(&checkpointer, "thread_1", 3).await;
let handle = graph.time_travel("thread_1");
let config = ExecutionConfig::new("thread_1");
let result = handle.resume_from(99, config).await;
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("no checkpoint found at step 99"));
}
#[tokio::test]
async fn test_fork_independence() {
let (graph, checkpointer) = build_test_graph();
seed_checkpoints(&checkpointer, "thread_1", 5).await;
let handle = graph.time_travel("thread_1");
handle.fork_at(2, "forked").await.unwrap();
let mut new_state = State::new();
new_state.insert("count".to_string(), json!(100));
let new_cp = Checkpoint::new("forked", new_state, 3, vec![]);
checkpointer.save(&new_cp).await.unwrap();
let original = checkpointer.list("thread_1").await.unwrap();
assert_eq!(original.len(), 5);
for cp in &original {
assert_ne!(cp.state.get("count"), Some(&json!(100)));
}
}
}