use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use synaptic_core::SynapticError;
use synaptic_graph::{
CheckpointConfig, Checkpointer, Node, NodeOutput, State, StateGraph, StoreCheckpointer, END,
};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
struct CounterState {
counter: usize,
visited: Vec<String>,
}
impl State for CounterState {
fn merge(&mut self, other: Self) {
self.counter += other.counter;
self.visited.extend(other.visited);
}
}
struct IncrementNode {
name: String,
}
#[async_trait]
impl Node<CounterState> for IncrementNode {
async fn process(
&self,
mut state: CounterState,
) -> Result<NodeOutput<CounterState>, SynapticError> {
state.counter += 1;
state.visited.push(self.name.clone());
Ok(state.into())
}
}
#[tokio::test]
async fn simple_linear_graph() {
let graph = StateGraph::new()
.add_node("a", IncrementNode { name: "a".into() })
.add_node("b", IncrementNode { name: "b".into() })
.add_edge("a", "b")
.add_edge("b", END)
.set_entry_point("a")
.compile()
.unwrap();
let result = graph
.invoke(CounterState::default())
.await
.unwrap()
.into_state();
assert_eq!(result.counter, 2);
assert_eq!(result.visited, vec!["a", "b"]);
}
#[tokio::test]
async fn conditional_routing() {
let graph = StateGraph::new()
.add_node(
"start",
IncrementNode {
name: "start".into(),
},
)
.add_node(
"left",
IncrementNode {
name: "left".into(),
},
)
.add_node(
"right",
IncrementNode {
name: "right".into(),
},
)
.set_entry_point("start")
.add_conditional_edges("start", |state: &CounterState| {
if state.counter < 2 {
"left".to_string()
} else {
"right".to_string()
}
})
.add_edge("left", END)
.add_edge("right", END)
.compile()
.unwrap();
let result = graph
.invoke(CounterState::default())
.await
.unwrap()
.into_state();
assert_eq!(result.visited, vec!["start", "left"]);
let state = CounterState {
counter: 5,
visited: vec![],
};
let result = graph.invoke(state).await.unwrap().into_state();
assert_eq!(result.visited, vec!["start", "right"]);
}
#[tokio::test]
async fn interrupt_before_stops_execution() {
let saver = Arc::new(StoreCheckpointer::new(Arc::new(
synaptic_store::InMemoryStore::new(),
)));
let graph = StateGraph::new()
.add_node("a", IncrementNode { name: "a".into() })
.add_node("b", IncrementNode { name: "b".into() })
.add_edge("a", "b")
.add_edge("b", END)
.set_entry_point("a")
.interrupt_before(vec!["b".to_string()])
.compile()
.unwrap()
.with_checkpointer(saver.clone());
let config = CheckpointConfig::new("thread-1");
let result = graph
.invoke_with_config(CounterState::default(), Some(config.clone()))
.await
.unwrap();
assert!(result.is_interrupted());
let interrupt_val = result.interrupt_value().unwrap();
let reason = interrupt_val["reason"].as_str().unwrap();
assert!(
reason.contains("interrupted before node 'b'"),
"got: {reason}"
);
let cp = saver.get(&config).await.unwrap().unwrap();
assert!(cp.next_node.as_deref() == Some("b"));
}
#[tokio::test]
async fn interrupt_after_stops_execution() {
let saver = Arc::new(StoreCheckpointer::new(Arc::new(
synaptic_store::InMemoryStore::new(),
)));
let graph = StateGraph::new()
.add_node("a", IncrementNode { name: "a".into() })
.add_node("b", IncrementNode { name: "b".into() })
.add_edge("a", "b")
.add_edge("b", END)
.set_entry_point("a")
.interrupt_after(vec!["a".to_string()])
.compile()
.unwrap()
.with_checkpointer(saver.clone());
let config = CheckpointConfig::new("thread-2");
let result = graph
.invoke_with_config(CounterState::default(), Some(config.clone()))
.await
.unwrap();
assert!(result.is_interrupted());
let interrupt_val = result.interrupt_value().unwrap();
let reason = interrupt_val["reason"].as_str().unwrap();
assert!(
reason.contains("interrupted after node 'a'"),
"got: {reason}"
);
}
#[tokio::test]
async fn resume_from_checkpoint() {
let saver = Arc::new(StoreCheckpointer::new(Arc::new(
synaptic_store::InMemoryStore::new(),
)));
let graph = StateGraph::new()
.add_node("a", IncrementNode { name: "a".into() })
.add_node("b", IncrementNode { name: "b".into() })
.add_edge("a", "b")
.add_edge("b", END)
.set_entry_point("a")
.interrupt_before(vec!["b".to_string()])
.compile()
.unwrap()
.with_checkpointer(saver.clone());
let config = CheckpointConfig::new("thread-3");
let _ = graph
.invoke_with_config(CounterState::default(), Some(config.clone()))
.await;
let graph2 = StateGraph::new()
.add_node("a", IncrementNode { name: "a".into() })
.add_node("b", IncrementNode { name: "b".into() })
.add_edge("a", "b")
.add_edge("b", END)
.set_entry_point("a")
.compile()
.unwrap()
.with_checkpointer(saver.clone());
let result = graph2
.invoke_with_config(CounterState::default(), Some(config))
.await
.unwrap()
.into_state();
assert_eq!(result.counter, 2);
assert_eq!(result.visited, vec!["a", "b"]);
}
#[tokio::test]
async fn update_state_modifies_checkpoint() {
let saver = Arc::new(StoreCheckpointer::new(Arc::new(
synaptic_store::InMemoryStore::new(),
)));
let graph = StateGraph::new()
.add_node("a", IncrementNode { name: "a".into() })
.add_node("b", IncrementNode { name: "b".into() })
.add_edge("a", "b")
.add_edge("b", END)
.set_entry_point("a")
.interrupt_before(vec!["b".to_string()])
.compile()
.unwrap()
.with_checkpointer(saver.clone());
let config = CheckpointConfig::new("thread-4");
let _ = graph
.invoke_with_config(CounterState::default(), Some(config.clone()))
.await;
let update = CounterState {
counter: 10,
visited: vec!["injected".to_string()],
};
graph.update_state(&config, update).await.unwrap();
let cp = saver.get(&config).await.unwrap().unwrap();
let state: CounterState = serde_json::from_value(cp.state).unwrap();
assert_eq!(state.counter, 11);
assert!(state.visited.contains(&"a".to_string()));
assert!(state.visited.contains(&"injected".to_string()));
}
#[tokio::test]
async fn max_iterations_guard() {
let graph = StateGraph::new()
.add_node("a", IncrementNode { name: "a".into() })
.add_node("b", IncrementNode { name: "b".into() })
.add_edge("a", "b")
.add_edge("b", "a")
.set_entry_point("a")
.compile()
.unwrap();
let err = graph.invoke(CounterState::default()).await.unwrap_err();
assert!(err.to_string().contains("max iterations"), "got: {err}");
}