graph_sp/executor/
mod.rs

1//! Parallel execution engine for DAG graphs.
2
3use crate::core::{Graph, PortData, Result};
4use dashmap::DashMap;
5use futures::stream::{FuturesUnordered, StreamExt};
6use std::collections::HashMap;
7use std::sync::Arc;
8
9/// Executor for running graphs with parallel execution
10#[derive(Clone)]
11pub struct Executor {
12    /// Maximum number of concurrent tasks (reserved for future use)
13    #[allow(dead_code)]
14    max_concurrency: usize,
15}
16
17impl Executor {
18    /// Create a new executor with default concurrency
19    pub fn new() -> Self {
20        Self {
21            max_concurrency: num_cpus::get(),
22        }
23    }
24
25    /// Create a new executor with specified concurrency limit
26    pub fn with_concurrency(max_concurrency: usize) -> Self {
27        Self { max_concurrency }
28    }
29
30    /// Execute a graph and return the results
31    pub async fn execute(&self, graph: &mut Graph) -> Result<ExecutionResult> {
32        // Validate the graph first
33        graph.validate()?;
34
35        // Get topological order to determine dependencies
36        let topo_order = graph.topological_order()?;
37
38        // Track execution state - map from node_id to outputs
39        let execution_state: Arc<DashMap<String, HashMap<String, PortData>>> =
40            Arc::new(DashMap::new());
41
42        // Build dependency levels for parallel execution
43        let levels = self.build_dependency_levels(graph, &topo_order)?;
44
45        // Execute each level in parallel
46        for level in levels {
47            let mut tasks = FuturesUnordered::new();
48
49            for node_id in level {
50                let node = graph.get_node(&node_id)?.clone();
51                let edges = graph
52                    .incoming_edges(&node_id)?
53                    .iter()
54                    .map(|e| (*e).clone())
55                    .collect::<Vec<_>>();
56                let state = Arc::clone(&execution_state);
57
58                // Spawn a blocking task for each node (nodes execute synchronously)
59                let task = tokio::task::spawn_blocking(move || {
60                    let mut node = node;
61
62                    // Collect inputs from incoming edges
63                    for edge in edges {
64                        if let Some(source_outputs) = state.get(&edge.from_node) {
65                            if let Some(data) = source_outputs.get(&edge.from_port) {
66                                node.set_input(edge.to_port.clone(), data.clone());
67                            }
68                        }
69                    }
70
71                    // Execute the node
72                    let result = node.execute();
73
74                    (node.config.id.clone(), node.outputs.clone(), result)
75                });
76
77                tasks.push(task);
78            }
79
80            // Wait for all nodes in this level to complete
81            while let Some(result) = tasks.next().await {
82                let (node_id, outputs, exec_result) = result.map_err(|e| {
83                    crate::core::GraphError::ExecutionError(format!("Task join error: {}", e))
84                })?;
85                exec_result?;
86                execution_state.insert(node_id, outputs);
87            }
88        }
89
90        // Collect results
91        let mut node_outputs = HashMap::new();
92        for entry in execution_state.iter() {
93            node_outputs.insert(entry.key().clone(), entry.value().clone());
94        }
95
96        Ok(ExecutionResult {
97            success: true,
98            node_outputs,
99            errors: Vec::new(),
100        })
101    }
102
103    /// Build dependency levels for parallel execution
104    /// All nodes in the same level can execute in parallel
105    fn build_dependency_levels(
106        &self,
107        graph: &Graph,
108        topo_order: &[String],
109    ) -> Result<Vec<Vec<String>>> {
110        let mut levels: Vec<Vec<String>> = Vec::new();
111        let mut node_level: HashMap<String, usize> = HashMap::new();
112
113        // Assign each node to a level based on its dependencies
114        for node_id in topo_order {
115            let incoming = graph.incoming_edges(node_id)?;
116
117            // Find the maximum level of all dependencies
118            let max_dep_level = incoming
119                .iter()
120                .filter_map(|edge| node_level.get(&edge.from_node))
121                .max()
122                .copied();
123
124            // This node goes one level after its dependencies
125            let level = max_dep_level.map(|l| l + 1).unwrap_or(0);
126            node_level.insert(node_id.clone(), level);
127
128            // Ensure we have enough levels
129            while levels.len() <= level {
130                levels.push(Vec::new());
131            }
132
133            levels[level].push(node_id.clone());
134        }
135
136        Ok(levels)
137    }
138}
139
140impl Default for Executor {
141    fn default() -> Self {
142        Self::new()
143    }
144}
145
146/// Result of graph execution
147#[derive(Debug, Clone)]
148pub struct ExecutionResult {
149    /// Whether execution was successful
150    pub success: bool,
151    /// Outputs from each node
152    pub node_outputs:
153        std::collections::HashMap<String, std::collections::HashMap<String, PortData>>,
154    /// Any errors that occurred
155    pub errors: Vec<String>,
156}
157
158impl ExecutionResult {
159    /// Get output from a specific node and port
160    pub fn get_output(&self, node_id: &str, port_id: &str) -> Option<&PortData> {
161        self.node_outputs.get(node_id)?.get(port_id)
162    }
163
164    /// Check if execution was successful
165    pub fn is_success(&self) -> bool {
166        self.success
167    }
168}
169
170// Helper function to get number of CPUs
171mod num_cpus {
172    pub fn get() -> usize {
173        std::thread::available_parallelism()
174            .map(|n| n.get())
175            .unwrap_or(1)
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::core::{Edge, Node, NodeConfig, Port};
183    use std::collections::HashMap;
184    use std::sync::Arc;
185
186    #[tokio::test]
187    async fn test_executor_simple_graph() {
188        let mut graph = Graph::new();
189
190        // Create a simple node that doubles input
191        let config = NodeConfig::new(
192            "double",
193            "Double Node",
194            vec![Port::simple("input")],
195            vec![Port::simple("output")],
196            Arc::new(|inputs: &HashMap<String, PortData>| {
197                let mut outputs = HashMap::new();
198                if let Some(PortData::Int(val)) = inputs.get("input") {
199                    outputs.insert("output".to_string(), PortData::Int(val * 2));
200                }
201                Ok(outputs)
202            }),
203        );
204
205        let mut node = Node::new(config);
206        node.set_input("input", PortData::Int(21));
207
208        graph.add(node).unwrap();
209
210        let executor = Executor::new();
211        let result = executor.execute(&mut graph).await.unwrap();
212
213        assert!(result.is_success());
214        if let Some(PortData::Int(val)) = result.get_output("double", "output") {
215            assert_eq!(*val, 42);
216        } else {
217            panic!("Expected output");
218        }
219    }
220
221    #[tokio::test]
222    async fn test_executor_linear_pipeline() {
223        let mut graph = Graph::new();
224
225        // Node 1: Output 10
226        let config1 = NodeConfig::new(
227            "source",
228            "Source Node",
229            vec![],
230            vec![Port::simple("output")],
231            Arc::new(|_: &HashMap<String, PortData>| {
232                let mut outputs = HashMap::new();
233                outputs.insert("output".to_string(), PortData::Int(10));
234                Ok(outputs)
235            }),
236        );
237
238        // Node 2: Double the input
239        let config2 = NodeConfig::new(
240            "double",
241            "Double Node",
242            vec![Port::simple("input")],
243            vec![Port::simple("output")],
244            Arc::new(|inputs: &HashMap<String, PortData>| {
245                let mut outputs = HashMap::new();
246                if let Some(PortData::Int(val)) = inputs.get("input") {
247                    outputs.insert("output".to_string(), PortData::Int(val * 2));
248                }
249                Ok(outputs)
250            }),
251        );
252
253        // Node 3: Add 5
254        let config3 = NodeConfig::new(
255            "add5",
256            "Add 5 Node",
257            vec![Port::simple("input")],
258            vec![Port::simple("output")],
259            Arc::new(|inputs: &HashMap<String, PortData>| {
260                let mut outputs = HashMap::new();
261                if let Some(PortData::Int(val)) = inputs.get("input") {
262                    outputs.insert("output".to_string(), PortData::Int(val + 5));
263                }
264                Ok(outputs)
265            }),
266        );
267
268        graph.add(Node::new(config1)).unwrap();
269        graph.add(Node::new(config2)).unwrap();
270        graph.add(Node::new(config3)).unwrap();
271
272        graph
273            .add_edge(Edge::new("source", "output", "double", "input"))
274            .unwrap();
275        graph
276            .add_edge(Edge::new("double", "output", "add5", "input"))
277            .unwrap();
278
279        let executor = Executor::new();
280        let result = executor.execute(&mut graph).await.unwrap();
281
282        assert!(result.is_success());
283
284        // Source outputs 10
285        if let Some(PortData::Int(val)) = result.get_output("source", "output") {
286            assert_eq!(*val, 10);
287        }
288
289        // Double outputs 20
290        if let Some(PortData::Int(val)) = result.get_output("double", "output") {
291            assert_eq!(*val, 20);
292        }
293
294        // Add5 outputs 25
295        if let Some(PortData::Int(val)) = result.get_output("add5", "output") {
296            assert_eq!(*val, 25);
297        }
298    }
299}