enact_core/graph/
compiled.rs1use super::edge::{ConditionalEdge, Edge, EdgeTarget};
6use super::node::{DynNode, NodeState};
7use futures::future::join_all;
8use std::collections::HashMap;
9
10pub struct CompiledGraph {
12 pub(crate) nodes: HashMap<String, DynNode>,
13 pub(crate) edges: Vec<Edge>,
14 pub(crate) conditional_edges: Vec<ConditionalEdge>,
15 pub(crate) entry_point: String,
16}
17
18impl CompiledGraph {
19 pub fn get_node(&self, name: &str) -> Option<&DynNode> {
21 self.nodes.get(name)
22 }
23
24 pub fn entry_point(&self) -> &str {
26 &self.entry_point
27 }
28
29 pub fn get_next(&self, from: &str, output: &str) -> Vec<EdgeTarget> {
31 let mut targets = Vec::new();
32
33 for ce in &self.conditional_edges {
35 if ce.from == from {
36 targets.push((ce.router)(output));
37 }
38 }
39
40 for edge in &self.edges {
42 if edge.from == from {
43 targets.push(edge.to.clone());
44 }
45 }
46
47 targets
48 }
49
50 pub async fn run(&self, input: impl Into<String>) -> anyhow::Result<NodeState> {
52 let initial_state = NodeState::from_string(&input.into());
53 self.run_with_state(initial_state).await
54 }
55
56 pub async fn run_with_state(&self, initial_state: NodeState) -> anyhow::Result<NodeState> {
62 let mut current_node = self.entry_point.clone();
63 let mut state = initial_state;
64
65 loop {
66 let node = self
68 .nodes
69 .get(¤t_node)
70 .ok_or_else(|| anyhow::anyhow!("Node '{}' not found", current_node))?;
71
72 tracing::debug!(node = %current_node, "Executing node");
74 state = node.execute(state).await?;
75
76 let output = state.as_str().unwrap_or_default().to_string();
78
79 let next_targets = self.get_next(¤t_node, &output);
81
82 if next_targets.is_empty() {
83 tracing::debug!(node = %current_node, "No outgoing edges, ending");
85 break;
86 }
87
88 let has_end = next_targets.iter().any(|t| matches!(t, EdgeTarget::End));
90 if has_end {
91 tracing::debug!("Reached END");
92 break;
93 }
94
95 let node_targets: Vec<String> = next_targets
97 .iter()
98 .filter_map(|t| match t {
99 EdgeTarget::Node(n) => Some(n.clone()),
100 EdgeTarget::End => None,
101 })
102 .collect();
103
104 if node_targets.is_empty() {
105 break;
106 }
107
108 if node_targets.len() == 1 {
110 current_node = node_targets[0].clone();
111 continue;
112 }
113
114 tracing::debug!(
116 targets = ?node_targets,
117 "Executing {} nodes in parallel",
118 node_targets.len()
119 );
120
121 let parallel_results = self
123 .execute_nodes_parallel(&node_targets, state.clone())
124 .await?;
125
126 if let Some(last_state) = parallel_results.into_iter().last() {
130 state = last_state;
131 }
132
133 tracing::debug!("Parallel execution complete");
137 break;
138 }
139
140 Ok(state)
141 }
142
143 async fn execute_nodes_parallel(
147 &self,
148 node_names: &[String],
149 input_state: NodeState,
150 ) -> anyhow::Result<Vec<NodeState>> {
151 let futures: Vec<_> = node_names
152 .iter()
153 .filter_map(|name| {
154 self.nodes.get(name).map(|node| {
155 let state = input_state.clone();
156 let node_name = name.clone();
157 async move {
158 tracing::debug!(node = %node_name, "Executing parallel node");
159 node.execute(state).await
160 }
161 })
162 })
163 .collect();
164
165 let results = join_all(futures).await;
166
167 let successful: Vec<NodeState> = results.into_iter().filter_map(|r| r.ok()).collect();
169
170 if successful.is_empty() {
171 anyhow::bail!("All parallel nodes failed");
172 }
173
174 Ok(successful)
175 }
176
177 pub fn node_count(&self) -> usize {
179 self.nodes.len()
180 }
181
182 pub fn edge_count(&self) -> usize {
184 self.edges.len() + self.conditional_edges.len()
185 }
186}