enact_core/graph/
graph_schema.rs1use super::edge::{ConditionalEdge, Edge, EdgeTarget};
6use super::node::{DynNode, FunctionNode, Node, NodeState};
7use super::CompiledGraph;
8use std::collections::{HashMap, HashSet};
9use std::future::Future;
10use std::sync::Arc;
11
12pub struct StateGraph {
14 pub nodes: HashMap<String, DynNode>,
15 pub edges: Vec<Edge>,
16 pub conditional_edges: Vec<ConditionalEdge>,
17 pub entry_point: Option<String>,
18}
19
20impl StateGraph {
21 pub fn new() -> Self {
23 Self {
24 nodes: HashMap::new(),
25 edges: Vec::new(),
26 conditional_edges: Vec::new(),
27 entry_point: None,
28 }
29 }
30
31 pub fn add_node<F, Fut>(mut self, name: impl Into<String>, func: F) -> Self
33 where
34 F: Fn(NodeState) -> Fut + Send + Sync + 'static,
35 Fut: Future<Output = anyhow::Result<NodeState>> + Send + 'static,
36 {
37 let name = name.into();
38 let node = Arc::new(FunctionNode::new(name.clone(), func));
39 self.nodes.insert(name, node);
40 self
41 }
42
43 pub fn add_node_impl(mut self, node: impl Node + 'static) -> Self {
45 let name = node.name().to_string();
46 self.nodes.insert(name, Arc::new(node));
47 self
48 }
49
50 pub fn add_edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
52 self.edges.push(Edge::new(from, EdgeTarget::node(to)));
53 self
54 }
55
56 pub fn add_edge_to_end(mut self, from: impl Into<String>) -> Self {
58 self.edges.push(Edge::new(from, EdgeTarget::End));
59 self
60 }
61
62 pub fn add_conditional_edge<F>(mut self, from: impl Into<String>, router: F) -> Self
64 where
65 F: Fn(&str) -> EdgeTarget + Send + Sync + 'static,
66 {
67 self.conditional_edges.push(ConditionalEdge {
68 from: from.into(),
69 router: Arc::new(router),
70 });
71 self
72 }
73
74 pub fn set_entry_point(mut self, name: impl Into<String>) -> Self {
76 self.entry_point = Some(name.into());
77 self
78 }
79
80 pub fn compile(self) -> anyhow::Result<CompiledGraph> {
82 if self.nodes.is_empty() {
84 anyhow::bail!("Graph must have at least one node");
85 }
86
87 let entry_point = self.entry_point.clone().or_else(|| {
89 self.nodes.keys().next().cloned()
91 });
92
93 let entry_point = entry_point.ok_or_else(|| anyhow::anyhow!("No entry point defined"))?;
94
95 if !self.nodes.contains_key(&entry_point) {
97 anyhow::bail!("Entry point '{}' does not exist", entry_point);
98 }
99
100 for edge in &self.edges {
102 if !self.nodes.contains_key(&edge.from) {
103 anyhow::bail!("Edge source '{}' does not exist", edge.from);
104 }
105 if let EdgeTarget::Node(ref target) = edge.to {
106 if !self.nodes.contains_key(target) {
107 anyhow::bail!("Edge target '{}' does not exist", target);
108 }
109 }
110 }
111
112 let adjacency = self.build_adjacency_list();
115 if let Some(cycle) = self.detect_cycle(&adjacency, &entry_point) {
116 anyhow::bail!(
117 "Graph contains a cycle: {} -> ... -> {}. Cycles are not allowed in DAGs.",
118 cycle.first().unwrap_or(&"?".to_string()),
119 cycle.last().unwrap_or(&"?".to_string())
120 );
121 }
122
123 Ok(CompiledGraph {
124 nodes: self.nodes,
125 edges: self.edges,
126 conditional_edges: self.conditional_edges,
127 entry_point,
128 })
129 }
130
131 fn build_adjacency_list(&self) -> HashMap<String, Vec<String>> {
133 let mut adjacency: HashMap<String, Vec<String>> = HashMap::new();
134
135 for node_name in self.nodes.keys() {
137 adjacency.entry(node_name.clone()).or_default();
138 }
139
140 for edge in &self.edges {
142 if let EdgeTarget::Node(ref target) = edge.to {
143 adjacency
144 .entry(edge.from.clone())
145 .or_default()
146 .push(target.clone());
147 }
148 }
149
150 adjacency
151 }
152
153 fn detect_cycle(
156 &self,
157 adjacency: &HashMap<String, Vec<String>>,
158 entry_point: &str,
159 ) -> Option<Vec<String>> {
160 let mut visited = HashSet::new();
161 let mut rec_stack = HashSet::new();
162 let mut path = Vec::new();
163
164 if self.dfs_cycle_detect(
165 entry_point,
166 adjacency,
167 &mut visited,
168 &mut rec_stack,
169 &mut path,
170 ) {
171 return Some(path);
172 }
173
174 for node in self.nodes.keys() {
176 if !visited.contains(node) {
177 path.clear();
178 if self.dfs_cycle_detect(node, adjacency, &mut visited, &mut rec_stack, &mut path) {
179 return Some(path);
180 }
181 }
182 }
183
184 None
185 }
186
187 #[allow(clippy::only_used_in_recursion)]
189 fn dfs_cycle_detect(
190 &self,
191 node: &str,
192 adjacency: &HashMap<String, Vec<String>>,
193 visited: &mut HashSet<String>,
194 rec_stack: &mut HashSet<String>,
195 path: &mut Vec<String>,
196 ) -> bool {
197 visited.insert(node.to_string());
198 rec_stack.insert(node.to_string());
199 path.push(node.to_string());
200
201 if let Some(neighbors) = adjacency.get(node) {
202 for neighbor in neighbors {
203 if !visited.contains(neighbor) {
204 if self.dfs_cycle_detect(neighbor, adjacency, visited, rec_stack, path) {
205 return true;
206 }
207 } else if rec_stack.contains(neighbor) {
208 path.push(neighbor.clone());
210 return true;
211 }
212 }
213 }
214
215 rec_stack.remove(node);
216 path.pop();
217 false
218 }
219}
220
221impl Default for StateGraph {
222 fn default() -> Self {
223 Self::new()
224 }
225}