Skip to main content

enact_core/graph/
compiled.rs

1//! CompiledGraph - validated and ready-to-execute graph
2//!
3//! Supports parallel execution when multiple targets are available.
4
5use super::edge::{ConditionalEdge, Edge, EdgeTarget};
6use super::node::{DynNode, NodeState};
7use futures::future::join_all;
8use std::collections::HashMap;
9
10/// Compiled graph - validated and ready to execute
11pub struct CompiledGraph {
12    pub(crate) nodes: HashMap<String, DynNode>,
13    pub(crate) edges: Vec<Edge>,
14    pub(crate) conditional_edges: Vec<ConditionalEdge>,
15    pub(crate) entry_point: String,
16}
17
18impl CompiledGraph {
19    /// Get a node by name
20    pub fn get_node(&self, name: &str) -> Option<&DynNode> {
21        self.nodes.get(name)
22    }
23
24    /// Get the entry point
25    pub fn entry_point(&self) -> &str {
26        &self.entry_point
27    }
28
29    /// Get the next node(s) after the given node
30    pub fn get_next(&self, from: &str, output: &str) -> Vec<EdgeTarget> {
31        let mut targets = Vec::new();
32
33        // Check conditional edges first
34        for ce in &self.conditional_edges {
35            if ce.from == from {
36                targets.push((ce.router)(output));
37            }
38        }
39
40        // Then check regular edges
41        for edge in &self.edges {
42            if edge.from == from {
43                targets.push(edge.to.clone());
44            }
45        }
46
47        targets
48    }
49
50    /// Run the graph with an initial input
51    pub async fn run(&self, input: impl Into<String>) -> anyhow::Result<NodeState> {
52        let initial_state = NodeState::from_string(&input.into());
53        self.run_with_state(initial_state).await
54    }
55
56    /// Run the graph with an initial state
57    ///
58    /// When multiple targets are available, executes them in parallel.
59    /// This implements the Agentic DAG execution model where independent
60    /// nodes can run concurrently.
61    pub async fn run_with_state(&self, initial_state: NodeState) -> anyhow::Result<NodeState> {
62        let mut current_node = self.entry_point.clone();
63        let mut state = initial_state;
64
65        loop {
66            // Get the current node
67            let node = self
68                .nodes
69                .get(&current_node)
70                .ok_or_else(|| anyhow::anyhow!("Node '{}' not found", current_node))?;
71
72            // Execute the node
73            tracing::debug!(node = %current_node, "Executing node");
74            state = node.execute(state).await?;
75
76            // Get the output for routing
77            let output = state.as_str().unwrap_or_default().to_string();
78
79            // Find next node(s)
80            let next_targets = self.get_next(&current_node, &output);
81
82            if next_targets.is_empty() {
83                // No outgoing edges - end execution
84                tracing::debug!(node = %current_node, "No outgoing edges, ending");
85                break;
86            }
87
88            // Check for END target
89            let has_end = next_targets.iter().any(|t| matches!(t, EdgeTarget::End));
90            if has_end {
91                tracing::debug!("Reached END");
92                break;
93            }
94
95            // Collect node targets (filter out End)
96            let node_targets: Vec<String> = next_targets
97                .iter()
98                .filter_map(|t| match t {
99                    EdgeTarget::Node(n) => Some(n.clone()),
100                    EdgeTarget::End => None,
101                })
102                .collect();
103
104            if node_targets.is_empty() {
105                break;
106            }
107
108            // Single target - sequential execution
109            if node_targets.len() == 1 {
110                current_node = node_targets[0].clone();
111                continue;
112            }
113
114            // Multiple targets - PARALLEL EXECUTION
115            tracing::debug!(
116                targets = ?node_targets,
117                "Executing {} nodes in parallel",
118                node_targets.len()
119            );
120
121            // Execute all target nodes in parallel
122            let parallel_results = self
123                .execute_nodes_parallel(&node_targets, state.clone())
124                .await?;
125
126            // Aggregate results: combine all outputs
127            // For now, we use the last successful result as the state
128            // In a full implementation, this would support custom aggregation strategies
129            if let Some(last_state) = parallel_results.into_iter().last() {
130                state = last_state;
131            }
132
133            // After parallel execution, check if any nodes have outgoing edges
134            // For simplicity, we end after parallel execution
135            // A full implementation would continue with fan-in logic
136            tracing::debug!("Parallel execution complete");
137            break;
138        }
139
140        Ok(state)
141    }
142
143    /// Execute multiple nodes in parallel
144    ///
145    /// Returns results from all nodes that completed successfully.
146    async fn execute_nodes_parallel(
147        &self,
148        node_names: &[String],
149        input_state: NodeState,
150    ) -> anyhow::Result<Vec<NodeState>> {
151        let futures: Vec<_> = node_names
152            .iter()
153            .filter_map(|name| {
154                self.nodes.get(name).map(|node| {
155                    let state = input_state.clone();
156                    let node_name = name.clone();
157                    async move {
158                        tracing::debug!(node = %node_name, "Executing parallel node");
159                        node.execute(state).await
160                    }
161                })
162            })
163            .collect();
164
165        let results = join_all(futures).await;
166
167        // Collect successful results
168        let successful: Vec<NodeState> = results.into_iter().filter_map(|r| r.ok()).collect();
169
170        if successful.is_empty() {
171            anyhow::bail!("All parallel nodes failed");
172        }
173
174        Ok(successful)
175    }
176
177    /// Get node count
178    pub fn node_count(&self) -> usize {
179        self.nodes.len()
180    }
181
182    /// Get edge count
183    pub fn edge_count(&self) -> usize {
184        self.edges.len() + self.conditional_edges.len()
185    }
186}