use crate::channels::{BaseChannel, LastValue};
use crate::checkpoint::{BaseCheckpointSaver, CheckpointMetadata, StateSnapshot};
use crate::config::Config;
use crate::errors::{Error, Result};
use crate::graph::{START, END};
use crate::nodes::{Node, PregelNode, ChannelWrite, NodeArc};
use crate::pregel::{Branch, Pregel};
use crate::state::State;
use crate::types::{StreamEvent, StreamMode};
use futures::stream::Stream;
use std::collections::{HashMap, HashSet};
use std::pin::Pin;
use std::sync::Arc;
pub struct StateGraph<S: State> {
nodes: HashMap<String, Box<dyn Node<S>>>,
edges: HashMap<String, Vec<String>>,
conditional_edges: HashMap<String, Box<dyn Branch<S>>>,
entry_point: Option<String>,
finish_points: HashSet<String>,
}
impl<S: State> StateGraph<S> {
pub fn new() -> Self {
Self {
nodes: HashMap::new(),
edges: HashMap::new(),
conditional_edges: HashMap::new(),
entry_point: None,
finish_points: HashSet::new(),
}
}
pub fn add_node(&mut self, name: impl Into<String>, node: impl Node<S> + 'static) -> &mut Self {
self.nodes.insert(name.into(), Box::new(node));
self
}
pub fn add_edge(&mut self, from: impl Into<String>, to: impl Into<String>) -> &mut Self {
let from = from.into();
let to = to.into();
self.edges.entry(from).or_default().push(to);
self
}
pub fn add_conditional_edges(
&mut self,
source: impl Into<String>,
branch: impl Branch<S> + 'static,
) -> &mut Self {
self.conditional_edges.insert(source.into(), Box::new(branch));
self
}
pub fn set_entry_point(&mut self, node: impl Into<String>) -> &mut Self {
self.entry_point = Some(node.into());
self
}
pub fn set_finish_point(&mut self, node: impl Into<String>) -> &mut Self {
self.finish_points.insert(node.into());
self
}
pub fn add_finish_points(&mut self, nodes: Vec<impl Into<String>>) -> &mut Self {
for node in nodes {
self.finish_points.insert(node.into());
}
self
}
pub fn compile(
self,
checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
) -> Result<CompiledGraph<S>> {
if self.entry_point.is_none() {
return Err(Error::invalid_graph("No entry point set"));
}
let entry_point = self.entry_point.unwrap();
if !self.nodes.contains_key(&entry_point) {
return Err(Error::invalid_graph(format!(
"Entry point '{}' is not a valid node",
entry_point
)));
}
let mut pregel_nodes = HashMap::new();
for (name, node) in self.nodes {
let mut triggers = vec![];
if name == entry_point {
triggers.push(START.to_string());
}
for (source, targets) in &self.edges {
if targets.contains(&name) {
triggers.push(format!("{}_output", source));
}
}
if triggers.is_empty() {
triggers.push(format!("{}_input", name));
}
let pregel_node = PregelNode::new(
name.clone(),
vec![format!("{}_input", name)],
triggers,
Arc::from(node) as NodeArc<S>,
vec![ChannelWrite::new(format!("{}_output", name))],
);
pregel_nodes.insert(name, pregel_node);
}
let mut channels: HashMap<String, Box<dyn BaseChannel>> = HashMap::new();
channels.insert(START.to_string(), Box::new(LastValue::<S>::new()));
channels.insert(END.to_string(), Box::new(LastValue::<S>::new()));
for node_name in pregel_nodes.keys() {
channels.insert(
format!("{}_input", node_name),
Box::new(LastValue::<S>::new()),
);
channels.insert(
format!("{}_output", node_name),
Box::new(LastValue::<S>::new()),
);
}
let pregel = Pregel::new(
pregel_nodes,
channels,
checkpointer.clone(),
entry_point.clone(),
self.finish_points.clone(),
self.edges.clone(),
);
Ok(CompiledGraph {
pregel,
entry_point,
finish_points: self.finish_points,
checkpointer,
})
}
}
impl<S: State> Default for StateGraph<S> {
fn default() -> Self {
Self::new()
}
}
pub struct CompiledGraph<S: State> {
pregel: Pregel<S>,
entry_point: String,
finish_points: HashSet<String>,
checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
}
impl<S: State> CompiledGraph<S> {
pub async fn invoke(&mut self, input: S, config: Config) -> Result<S> {
self.pregel.invoke(input, config).await
}
pub async fn stream(
&mut self,
input: S,
config: Config,
mode: StreamMode,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
self.pregel.stream(input, config, mode).await
}
pub async fn get_state(&self, config: &Config) -> Result<Option<StateSnapshot<S>>> {
self.pregel.get_state(config).await
}
pub async fn get_state_history(
&self,
config: &Config,
limit: Option<usize>,
) -> Result<Vec<StateSnapshot<S>>> {
self.pregel.get_state_history(config, limit).await
}
pub async fn update_state(&mut self, config: Config, values: S) -> Result<Config> {
if let Some(checkpointer) = &self.checkpointer {
let mut tuple = checkpointer
.get_tuple(&config)
.await?
.ok_or_else(|| Error::checkpoint("No checkpoint found for config"))?;
let mut current_state = S::from_value(
tuple
.checkpoint
.get_channel("__start__")
.ok_or_else(|| Error::checkpoint("No state in checkpoint"))?
.clone(),
)?;
current_state.merge(values)?;
tuple.checkpoint.set_channel("__start__", current_state.to_value()?);
let metadata = CheckpointMetadata {
step: tuple.metadata.step + 1,
source: "update_state".to_string(),
created_at: chrono::Utc::now(),
extra: HashMap::new(),
};
checkpointer.put(&tuple.checkpoint, &metadata, &config).await
} else {
Err(Error::checkpoint("No checkpointer configured"))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::checkpoint_backends::memory::MemorySaver;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
struct TestState {
count: i32,
}
impl crate::state::State for TestState {
fn merge(&mut self, other: Self) -> Result<()> {
self.count += other.count;
Ok(())
}
}
#[tokio::test]
async fn test_state_graph_basic() {
let mut graph = StateGraph::new();
graph.add_node("increment", |mut state: TestState, _config: &Config| async move {
state.count += 1;
Ok(state)
});
graph.set_entry_point("increment");
graph.set_finish_point("increment");
let mut app = graph.compile(None).unwrap();
let result = app.invoke(TestState { count: 0 }, Config::default()).await.unwrap();
assert_eq!(result.count, 1);
}
#[tokio::test]
async fn test_state_graph_chain() {
let mut graph = StateGraph::new();
graph.add_node("add_one", |mut state: TestState, _config: &Config| async move {
state.count += 1;
Ok(state)
});
graph.add_node("multiply_two", |mut state: TestState, _config: &Config| async move {
state.count *= 2;
Ok(state)
});
graph.set_entry_point("add_one");
graph.add_edge("add_one", "multiply_two");
graph.set_finish_point("multiply_two");
let mut app = graph.compile(None).unwrap();
let result = app.invoke(TestState { count: 5 }, Config::default()).await.unwrap();
assert_eq!(result.count, 12); }
#[tokio::test]
async fn test_state_graph_with_checkpointer() {
let mut graph = StateGraph::new();
graph.add_node("increment", |mut state: TestState, _config: &Config| async move {
state.count += 1;
Ok(state)
});
graph.set_entry_point("increment");
graph.set_finish_point("increment");
let checkpointer = Arc::new(MemorySaver::new());
let mut app = graph.compile(Some(checkpointer)).unwrap();
let config = Config::new().with_thread_id("test-123");
let result = app.invoke(TestState { count: 0 }, config.clone()).await.unwrap();
assert_eq!(result.count, 1);
let snapshot = app.get_state(&config).await.unwrap();
assert!(snapshot.is_some());
}
#[test]
fn test_state_graph_no_entry_point() {
let mut graph = StateGraph::<TestState>::new();
graph.add_node("test", |s: TestState, _config: &Config| async move { Ok(s) });
let result = graph.compile(None);
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("entry point"));
}
}
}