Skip to main content

enact_core/graph/
graph_schema.rs

1//! StateGraph - builder for creating graphs
2//!
3//! Includes cycle detection to prevent infinite loops (DAG invariant).
4
5use super::edge::{ConditionalEdge, Edge, EdgeTarget};
6use super::node::{DynNode, FunctionNode, Node, NodeState};
7use super::CompiledGraph;
8use std::collections::{HashMap, HashSet};
9use std::future::Future;
10use std::sync::Arc;
11
12/// StateGraph - fluent builder for creating DAGs
13pub struct StateGraph {
14    pub nodes: HashMap<String, DynNode>,
15    pub edges: Vec<Edge>,
16    pub conditional_edges: Vec<ConditionalEdge>,
17    pub entry_point: Option<String>,
18}
19
20impl StateGraph {
21    /// Create a new empty graph
22    pub fn new() -> Self {
23        Self {
24            nodes: HashMap::new(),
25            edges: Vec::new(),
26            conditional_edges: Vec::new(),
27            entry_point: None,
28        }
29    }
30
31    /// Add a node with a function
32    pub fn add_node<F, Fut>(mut self, name: impl Into<String>, func: F) -> Self
33    where
34        F: Fn(NodeState) -> Fut + Send + Sync + 'static,
35        Fut: Future<Output = anyhow::Result<NodeState>> + Send + 'static,
36    {
37        let name = name.into();
38        let node = Arc::new(FunctionNode::new(name.clone(), func));
39        self.nodes.insert(name, node);
40        self
41    }
42
43    /// Add a pre-built node
44    pub fn add_node_impl(mut self, node: impl Node + 'static) -> Self {
45        let name = node.name().to_string();
46        self.nodes.insert(name, Arc::new(node));
47        self
48    }
49
50    /// Add a direct edge between two nodes
51    pub fn add_edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
52        self.edges.push(Edge::new(from, EdgeTarget::node(to)));
53        self
54    }
55
56    /// Add an edge to END
57    pub fn add_edge_to_end(mut self, from: impl Into<String>) -> Self {
58        self.edges.push(Edge::new(from, EdgeTarget::End));
59        self
60    }
61
62    /// Add a conditional edge with a router function
63    pub fn add_conditional_edge<F>(mut self, from: impl Into<String>, router: F) -> Self
64    where
65        F: Fn(&str) -> EdgeTarget + Send + Sync + 'static,
66    {
67        self.conditional_edges.push(ConditionalEdge {
68            from: from.into(),
69            router: Arc::new(router),
70        });
71        self
72    }
73
74    /// Set the entry point node
75    pub fn set_entry_point(mut self, name: impl Into<String>) -> Self {
76        self.entry_point = Some(name.into());
77        self
78    }
79
80    /// Compile the graph for execution
81    pub fn compile(self) -> anyhow::Result<CompiledGraph> {
82        // Validate: must have at least one node
83        if self.nodes.is_empty() {
84            anyhow::bail!("Graph must have at least one node");
85        }
86
87        // Determine entry point
88        let entry_point = self.entry_point.clone().or_else(|| {
89            // If no entry point set, use first node added
90            self.nodes.keys().next().cloned()
91        });
92
93        let entry_point = entry_point.ok_or_else(|| anyhow::anyhow!("No entry point defined"))?;
94
95        // Validate: entry point must exist
96        if !self.nodes.contains_key(&entry_point) {
97            anyhow::bail!("Entry point '{}' does not exist", entry_point);
98        }
99
100        // Validate: all edge sources/targets exist
101        for edge in &self.edges {
102            if !self.nodes.contains_key(&edge.from) {
103                anyhow::bail!("Edge source '{}' does not exist", edge.from);
104            }
105            if let EdgeTarget::Node(ref target) = edge.to {
106                if !self.nodes.contains_key(target) {
107                    anyhow::bail!("Edge target '{}' does not exist", target);
108                }
109            }
110        }
111
112        // Validate: no cycles (DAG invariant)
113        // Build adjacency list from edges
114        let adjacency = self.build_adjacency_list();
115        if let Some(cycle) = self.detect_cycle(&adjacency, &entry_point) {
116            anyhow::bail!(
117                "Graph contains a cycle: {} -> ... -> {}. Cycles are not allowed in DAGs.",
118                cycle.first().unwrap_or(&"?".to_string()),
119                cycle.last().unwrap_or(&"?".to_string())
120            );
121        }
122
123        Ok(CompiledGraph {
124            nodes: self.nodes,
125            edges: self.edges,
126            conditional_edges: self.conditional_edges,
127            entry_point,
128        })
129    }
130
131    /// Build adjacency list from edges
132    fn build_adjacency_list(&self) -> HashMap<String, Vec<String>> {
133        let mut adjacency: HashMap<String, Vec<String>> = HashMap::new();
134
135        // Initialize all nodes with empty neighbor lists
136        for node_name in self.nodes.keys() {
137            adjacency.entry(node_name.clone()).or_default();
138        }
139
140        // Add edges
141        for edge in &self.edges {
142            if let EdgeTarget::Node(ref target) = edge.to {
143                adjacency
144                    .entry(edge.from.clone())
145                    .or_default()
146                    .push(target.clone());
147            }
148        }
149
150        adjacency
151    }
152
153    /// Detect cycle using DFS
154    /// Returns Some(cycle_path) if a cycle is found, None otherwise
155    fn detect_cycle(
156        &self,
157        adjacency: &HashMap<String, Vec<String>>,
158        entry_point: &str,
159    ) -> Option<Vec<String>> {
160        let mut visited = HashSet::new();
161        let mut rec_stack = HashSet::new();
162        let mut path = Vec::new();
163
164        if self.dfs_cycle_detect(
165            entry_point,
166            adjacency,
167            &mut visited,
168            &mut rec_stack,
169            &mut path,
170        ) {
171            return Some(path);
172        }
173
174        // Also check from all nodes in case graph has disconnected components
175        for node in self.nodes.keys() {
176            if !visited.contains(node) {
177                path.clear();
178                if self.dfs_cycle_detect(node, adjacency, &mut visited, &mut rec_stack, &mut path) {
179                    return Some(path);
180                }
181            }
182        }
183
184        None
185    }
186
187    /// DFS helper for cycle detection
188    #[allow(clippy::only_used_in_recursion)]
189    fn dfs_cycle_detect(
190        &self,
191        node: &str,
192        adjacency: &HashMap<String, Vec<String>>,
193        visited: &mut HashSet<String>,
194        rec_stack: &mut HashSet<String>,
195        path: &mut Vec<String>,
196    ) -> bool {
197        visited.insert(node.to_string());
198        rec_stack.insert(node.to_string());
199        path.push(node.to_string());
200
201        if let Some(neighbors) = adjacency.get(node) {
202            for neighbor in neighbors {
203                if !visited.contains(neighbor) {
204                    if self.dfs_cycle_detect(neighbor, adjacency, visited, rec_stack, path) {
205                        return true;
206                    }
207                } else if rec_stack.contains(neighbor) {
208                    // Found a cycle
209                    path.push(neighbor.clone());
210                    return true;
211                }
212            }
213        }
214
215        rec_stack.remove(node);
216        path.pop();
217        false
218    }
219}
220
221impl Default for StateGraph {
222    fn default() -> Self {
223        Self::new()
224    }
225}