use super::edge::{ConditionalEdge, Edge, EdgeTarget};
use super::node::{DynNode, FunctionNode, Node, NodeState};
use super::CompiledGraph;
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::sync::Arc;
pub struct StateGraph {
pub nodes: HashMap<String, DynNode>,
pub edges: Vec<Edge>,
pub conditional_edges: Vec<ConditionalEdge>,
pub entry_point: Option<String>,
}
impl StateGraph {
pub fn new() -> Self {
Self {
nodes: HashMap::new(),
edges: Vec::new(),
conditional_edges: Vec::new(),
entry_point: None,
}
}
pub fn add_node<F, Fut>(mut self, name: impl Into<String>, func: F) -> Self
where
F: Fn(NodeState) -> Fut + Send + Sync + 'static,
Fut: Future<Output = anyhow::Result<NodeState>> + Send + 'static,
{
let name = name.into();
let node = Arc::new(FunctionNode::new(name.clone(), func));
self.nodes.insert(name, node);
self
}
pub fn add_node_impl(mut self, node: impl Node + 'static) -> Self {
let name = node.name().to_string();
self.nodes.insert(name, Arc::new(node));
self
}
pub fn add_edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
self.edges.push(Edge::new(from, EdgeTarget::node(to)));
self
}
pub fn add_edge_to_end(mut self, from: impl Into<String>) -> Self {
self.edges.push(Edge::new(from, EdgeTarget::End));
self
}
pub fn add_conditional_edge<F>(mut self, from: impl Into<String>, router: F) -> Self
where
F: Fn(&str) -> EdgeTarget + Send + Sync + 'static,
{
self.conditional_edges.push(ConditionalEdge {
from: from.into(),
router: Arc::new(router),
});
self
}
pub fn set_entry_point(mut self, name: impl Into<String>) -> Self {
self.entry_point = Some(name.into());
self
}
pub fn compile(self) -> anyhow::Result<CompiledGraph> {
if self.nodes.is_empty() {
anyhow::bail!("Graph must have at least one node");
}
let entry_point = self.entry_point.clone().or_else(|| {
self.nodes.keys().next().cloned()
});
let entry_point = entry_point.ok_or_else(|| anyhow::anyhow!("No entry point defined"))?;
if !self.nodes.contains_key(&entry_point) {
anyhow::bail!("Entry point '{}' does not exist", entry_point);
}
for edge in &self.edges {
if !self.nodes.contains_key(&edge.from) {
anyhow::bail!("Edge source '{}' does not exist", edge.from);
}
if let EdgeTarget::Node(ref target) = edge.to {
if !self.nodes.contains_key(target) {
anyhow::bail!("Edge target '{}' does not exist", target);
}
}
}
let adjacency = self.build_adjacency_list();
if let Some(cycle) = self.detect_cycle(&adjacency, &entry_point) {
anyhow::bail!(
"Graph contains a cycle: {} -> ... -> {}. Cycles are not allowed in DAGs.",
cycle.first().unwrap_or(&"?".to_string()),
cycle.last().unwrap_or(&"?".to_string())
);
}
Ok(CompiledGraph {
nodes: self.nodes,
edges: self.edges,
conditional_edges: self.conditional_edges,
entry_point,
})
}
fn build_adjacency_list(&self) -> HashMap<String, Vec<String>> {
let mut adjacency: HashMap<String, Vec<String>> = HashMap::new();
for node_name in self.nodes.keys() {
adjacency.entry(node_name.clone()).or_default();
}
for edge in &self.edges {
if let EdgeTarget::Node(ref target) = edge.to {
adjacency
.entry(edge.from.clone())
.or_default()
.push(target.clone());
}
}
adjacency
}
fn detect_cycle(
&self,
adjacency: &HashMap<String, Vec<String>>,
entry_point: &str,
) -> Option<Vec<String>> {
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
let mut path = Vec::new();
if self.dfs_cycle_detect(
entry_point,
adjacency,
&mut visited,
&mut rec_stack,
&mut path,
) {
return Some(path);
}
for node in self.nodes.keys() {
if !visited.contains(node) {
path.clear();
if self.dfs_cycle_detect(node, adjacency, &mut visited, &mut rec_stack, &mut path) {
return Some(path);
}
}
}
None
}
#[allow(clippy::only_used_in_recursion)]
fn dfs_cycle_detect(
&self,
node: &str,
adjacency: &HashMap<String, Vec<String>>,
visited: &mut HashSet<String>,
rec_stack: &mut HashSet<String>,
path: &mut Vec<String>,
) -> bool {
visited.insert(node.to_string());
rec_stack.insert(node.to_string());
path.push(node.to_string());
if let Some(neighbors) = adjacency.get(node) {
for neighbor in neighbors {
if !visited.contains(neighbor) {
if self.dfs_cycle_detect(neighbor, adjacency, visited, rec_stack, path) {
return true;
}
} else if rec_stack.contains(neighbor) {
path.push(neighbor.clone());
return true;
}
}
}
rec_stack.remove(node);
path.pop();
false
}
}
impl Default for StateGraph {
fn default() -> Self {
Self::new()
}
}