1use crate::core::{Graph, PortData, Result};
4use dashmap::DashMap;
5use futures::stream::{FuturesUnordered, StreamExt};
6use std::collections::HashMap;
7use std::sync::Arc;
8
9#[derive(Clone)]
11pub struct Executor {
12 #[allow(dead_code)]
14 max_concurrency: usize,
15}
16
17impl Executor {
18 pub fn new() -> Self {
20 Self {
21 max_concurrency: num_cpus::get(),
22 }
23 }
24
25 pub fn with_concurrency(max_concurrency: usize) -> Self {
27 Self { max_concurrency }
28 }
29
30 pub async fn execute(&self, graph: &mut Graph) -> Result<ExecutionResult> {
32 graph.validate()?;
34
35 let topo_order = graph.topological_order()?;
37
38 let execution_state: Arc<DashMap<String, HashMap<String, PortData>>> =
40 Arc::new(DashMap::new());
41
42 let levels = self.build_dependency_levels(graph, &topo_order)?;
44
45 for level in levels {
47 let mut tasks = FuturesUnordered::new();
48
49 for node_id in level {
50 let node = graph.get_node(&node_id)?.clone();
51 let edges = graph
52 .incoming_edges(&node_id)?
53 .iter()
54 .map(|e| (*e).clone())
55 .collect::<Vec<_>>();
56 let state = Arc::clone(&execution_state);
57
58 let task = tokio::task::spawn_blocking(move || {
60 let mut node = node;
61
62 for edge in edges {
64 if let Some(source_outputs) = state.get(&edge.from_node) {
65 if let Some(data) = source_outputs.get(&edge.from_port) {
66 node.set_input(edge.to_port.clone(), data.clone());
67 }
68 }
69 }
70
71 let result = node.execute();
73
74 (node.config.id.clone(), node.outputs.clone(), result)
75 });
76
77 tasks.push(task);
78 }
79
80 while let Some(result) = tasks.next().await {
82 let (node_id, outputs, exec_result) = result.map_err(|e| {
83 crate::core::GraphError::ExecutionError(format!("Task join error: {}", e))
84 })?;
85 exec_result?;
86 execution_state.insert(node_id, outputs);
87 }
88 }
89
90 let mut node_outputs = HashMap::new();
92 for entry in execution_state.iter() {
93 node_outputs.insert(entry.key().clone(), entry.value().clone());
94 }
95
96 Ok(ExecutionResult {
97 success: true,
98 node_outputs,
99 errors: Vec::new(),
100 })
101 }
102
103 fn build_dependency_levels(
106 &self,
107 graph: &Graph,
108 topo_order: &[String],
109 ) -> Result<Vec<Vec<String>>> {
110 let mut levels: Vec<Vec<String>> = Vec::new();
111 let mut node_level: HashMap<String, usize> = HashMap::new();
112
113 for node_id in topo_order {
115 let incoming = graph.incoming_edges(node_id)?;
116
117 let max_dep_level = incoming
119 .iter()
120 .filter_map(|edge| node_level.get(&edge.from_node))
121 .max()
122 .copied();
123
124 let level = max_dep_level.map(|l| l + 1).unwrap_or(0);
126 node_level.insert(node_id.clone(), level);
127
128 while levels.len() <= level {
130 levels.push(Vec::new());
131 }
132
133 levels[level].push(node_id.clone());
134 }
135
136 Ok(levels)
137 }
138}
139
140impl Default for Executor {
141 fn default() -> Self {
142 Self::new()
143 }
144}
145
146#[derive(Debug, Clone)]
148pub struct ExecutionResult {
149 pub success: bool,
151 pub node_outputs:
153 std::collections::HashMap<String, std::collections::HashMap<String, PortData>>,
154 pub errors: Vec<String>,
156}
157
158impl ExecutionResult {
159 pub fn get_output(&self, node_id: &str, port_id: &str) -> Option<&PortData> {
161 self.node_outputs.get(node_id)?.get(port_id)
162 }
163
164 pub fn is_success(&self) -> bool {
166 self.success
167 }
168}
169
170mod num_cpus {
172 pub fn get() -> usize {
173 std::thread::available_parallelism()
174 .map(|n| n.get())
175 .unwrap_or(1)
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::core::{Edge, Node, NodeConfig, Port};
183 use std::collections::HashMap;
184 use std::sync::Arc;
185
186 #[tokio::test]
187 async fn test_executor_simple_graph() {
188 let mut graph = Graph::new();
189
190 let config = NodeConfig::new(
192 "double",
193 "Double Node",
194 vec![Port::simple("input")],
195 vec![Port::simple("output")],
196 Arc::new(|inputs: &HashMap<String, PortData>| {
197 let mut outputs = HashMap::new();
198 if let Some(PortData::Int(val)) = inputs.get("input") {
199 outputs.insert("output".to_string(), PortData::Int(val * 2));
200 }
201 Ok(outputs)
202 }),
203 );
204
205 let mut node = Node::new(config);
206 node.set_input("input", PortData::Int(21));
207
208 graph.add(node).unwrap();
209
210 let executor = Executor::new();
211 let result = executor.execute(&mut graph).await.unwrap();
212
213 assert!(result.is_success());
214 if let Some(PortData::Int(val)) = result.get_output("double", "output") {
215 assert_eq!(*val, 42);
216 } else {
217 panic!("Expected output");
218 }
219 }
220
221 #[tokio::test]
222 async fn test_executor_linear_pipeline() {
223 let mut graph = Graph::new();
224
225 let config1 = NodeConfig::new(
227 "source",
228 "Source Node",
229 vec![],
230 vec![Port::simple("output")],
231 Arc::new(|_: &HashMap<String, PortData>| {
232 let mut outputs = HashMap::new();
233 outputs.insert("output".to_string(), PortData::Int(10));
234 Ok(outputs)
235 }),
236 );
237
238 let config2 = NodeConfig::new(
240 "double",
241 "Double Node",
242 vec![Port::simple("input")],
243 vec![Port::simple("output")],
244 Arc::new(|inputs: &HashMap<String, PortData>| {
245 let mut outputs = HashMap::new();
246 if let Some(PortData::Int(val)) = inputs.get("input") {
247 outputs.insert("output".to_string(), PortData::Int(val * 2));
248 }
249 Ok(outputs)
250 }),
251 );
252
253 let config3 = NodeConfig::new(
255 "add5",
256 "Add 5 Node",
257 vec![Port::simple("input")],
258 vec![Port::simple("output")],
259 Arc::new(|inputs: &HashMap<String, PortData>| {
260 let mut outputs = HashMap::new();
261 if let Some(PortData::Int(val)) = inputs.get("input") {
262 outputs.insert("output".to_string(), PortData::Int(val + 5));
263 }
264 Ok(outputs)
265 }),
266 );
267
268 graph.add(Node::new(config1)).unwrap();
269 graph.add(Node::new(config2)).unwrap();
270 graph.add(Node::new(config3)).unwrap();
271
272 graph
273 .add_edge(Edge::new("source", "output", "double", "input"))
274 .unwrap();
275 graph
276 .add_edge(Edge::new("double", "output", "add5", "input"))
277 .unwrap();
278
279 let executor = Executor::new();
280 let result = executor.execute(&mut graph).await.unwrap();
281
282 assert!(result.is_success());
283
284 if let Some(PortData::Int(val)) = result.get_output("source", "output") {
286 assert_eq!(*val, 10);
287 }
288
289 if let Some(PortData::Int(val)) = result.get_output("double", "output") {
291 assert_eq!(*val, 20);
292 }
293
294 if let Some(PortData::Int(val)) = result.get_output("add5", "output") {
296 assert_eq!(*val, 25);
297 }
298 }
299}