1use crate::node::{Node, NodeId};
4use std::collections::{HashMap, HashSet, VecDeque};
5use std::sync::{Arc, Mutex};
6
7pub type ExecutionContext = HashMap<String, String>;
9
10#[derive(Debug, Clone)]
12pub struct ExecutionResult {
13 pub context: ExecutionContext,
15 pub node_outputs: HashMap<NodeId, HashMap<String, String>>,
17 pub branch_outputs: HashMap<usize, HashMap<String, String>>,
19}
20
21impl ExecutionResult {
22 pub fn new() -> Self {
24 Self {
25 context: HashMap::new(),
26 node_outputs: HashMap::new(),
27 branch_outputs: HashMap::new(),
28 }
29 }
30
31 pub fn get(&self, key: &str) -> Option<&String> {
33 self.context.get(key)
34 }
35
36 pub fn get_node_outputs(&self, node_id: NodeId) -> Option<&HashMap<String, String>> {
38 self.node_outputs.get(&node_id)
39 }
40
41 pub fn get_branch_outputs(&self, branch_id: usize) -> Option<&HashMap<String, String>> {
43 self.branch_outputs.get(&branch_id)
44 }
45
46 pub fn get_from_node(&self, node_id: NodeId, key: &str) -> Option<&String> {
48 self.node_outputs.get(&node_id).and_then(|outputs| outputs.get(key))
49 }
50
51 pub fn get_from_branch(&self, branch_id: usize, key: &str) -> Option<&String> {
53 self.branch_outputs.get(&branch_id).and_then(|outputs| outputs.get(key))
54 }
55
56 pub fn contains_key(&self, key: &str) -> bool {
58 self.context.contains_key(key)
59 }
60}
61
62pub struct Dag {
64 nodes: Vec<Node>,
66 execution_order: Vec<NodeId>,
68 execution_levels: Vec<Vec<NodeId>>,
70}
71
72impl Dag {
73 pub fn new(nodes: Vec<Node>) -> Self {
80 let execution_order = Self::topological_sort(&nodes);
81 let execution_levels = Self::compute_execution_levels(&nodes, &execution_order);
82
83 Self {
84 nodes,
85 execution_order,
86 execution_levels,
87 }
88 }
89
90 fn topological_sort(nodes: &[Node]) -> Vec<NodeId> {
92 let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
93 let mut adj_list: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
94
95 for node in nodes {
97 in_degree.entry(node.id).or_insert(0);
98 adj_list.entry(node.id).or_insert_with(Vec::new);
99
100 for &dep in &node.dependencies {
101 *in_degree.entry(node.id).or_insert(0) += 1;
102 adj_list.entry(dep).or_insert_with(Vec::new).push(node.id);
103 }
104 }
105
106 let mut queue: VecDeque<NodeId> = in_degree
108 .iter()
109 .filter(|(_, °ree)| degree == 0)
110 .map(|(&id, _)| id)
111 .collect();
112
113 let mut result = Vec::new();
114
115 while let Some(node_id) = queue.pop_front() {
116 result.push(node_id);
117
118 if let Some(neighbors) = adj_list.get(&node_id) {
119 for &neighbor in neighbors {
120 if let Some(degree) = in_degree.get_mut(&neighbor) {
121 *degree -= 1;
122 if *degree == 0 {
123 queue.push_back(neighbor);
124 }
125 }
126 }
127 }
128 }
129
130 result
131 }
132
133 fn compute_execution_levels(nodes: &[Node], execution_order: &[NodeId]) -> Vec<Vec<NodeId>> {
138 let mut levels: Vec<Vec<NodeId>> = Vec::new();
139 let mut node_level: HashMap<NodeId, usize> = HashMap::new();
140
141 for &node_id in execution_order {
142 let node = nodes.iter().find(|n| n.id == node_id).unwrap();
143
144 let level = if node.dependencies.is_empty() {
146 0
147 } else {
148 node.dependencies
149 .iter()
150 .filter_map(|dep_id| node_level.get(dep_id))
151 .max()
152 .map(|&max_level| max_level + 1)
153 .unwrap_or(0)
154 };
155
156 node_level.insert(node_id, level);
157
158 while levels.len() <= level {
160 levels.push(Vec::new());
161 }
162 levels[level].push(node_id);
163 }
164
165 levels
166 }
167
168 pub fn execute(&self, parallel: bool, max_threads: Option<usize>) -> ExecutionContext {
176 self.execute_detailed(parallel, max_threads).context
177 }
178
179 pub fn execute_detailed(&self, parallel: bool, max_threads: Option<usize>) -> ExecutionResult {
187 let mut result = ExecutionResult::new();
188
189 if !parallel {
190 for &node_id in &self.execution_order {
192 if let Some(node) = self.nodes.iter().find(|n| n.id == node_id) {
193 let outputs = node.execute(&result.context);
194
195 result.context.extend(outputs.clone());
197
198 let node_outputs: HashMap<String, String> = outputs.clone();
200 result.node_outputs.insert(node_id, node_outputs);
201
202 if let Some(branch_id) = node.branch_id {
204 result.branch_outputs
205 .entry(branch_id)
206 .or_insert_with(HashMap::new)
207 .extend(outputs);
208 }
209 }
210 }
211 } else {
212 for level in &self.execution_levels {
214 if level.len() == 1 {
216 let node_id = level[0];
218 if let Some(node) = self.nodes.iter().find(|n| n.id == node_id) {
219 let outputs = node.execute(&result.context);
220
221 result.context.extend(outputs.clone());
222 result.node_outputs.insert(node_id, outputs.clone());
223
224 if let Some(branch_id) = node.branch_id {
225 result.branch_outputs
226 .entry(branch_id)
227 .or_insert_with(HashMap::new)
228 .extend(outputs);
229 }
230 }
231 } else {
232 let context = Arc::new(result.context.clone());
234 let nodes_to_execute: Vec<_> = level.iter()
235 .filter_map(|&node_id| {
236 self.nodes.iter().find(|n| n.id == node_id)
237 })
238 .collect();
239
240 let chunk_size = if let Some(max) = max_threads {
242 max.max(1) } else {
244 nodes_to_execute.len() };
246
247 let outputs = Arc::new(Mutex::new(Vec::new()));
248
249 for chunk in nodes_to_execute.chunks(chunk_size) {
251 std::thread::scope(|s| {
252 for node in chunk {
253 let context = Arc::clone(&context);
254 let outputs = Arc::clone(&outputs);
255
256 s.spawn(move || {
257 let node_outputs = node.execute(&context);
258 outputs.lock().unwrap().push((node.id, node.branch_id, node_outputs));
259 });
260 }
261 });
262 }
263
264 let collected_outputs = outputs.lock().unwrap();
266 for (node_id, branch_id, node_outputs) in collected_outputs.iter() {
267 result.context.extend(node_outputs.clone());
268 result.node_outputs.insert(*node_id, node_outputs.clone());
269
270 if let Some(bid) = branch_id {
271 result.branch_outputs
272 .entry(*bid)
273 .or_insert_with(HashMap::new)
274 .extend(node_outputs.clone());
275 }
276 }
277 }
278 }
279 }
280
281 result
282 }
283
284 pub fn to_mermaid(&self) -> String {
289 let mut mermaid = String::from("graph TD\n");
290
291 for node in &self.nodes {
293 let node_label = node.display_name();
294 mermaid.push_str(&format!(" {}[\"{}\"]\n", node.id, node_label));
295 }
296
297 let mut edges_added: HashSet<(NodeId, NodeId)> = HashSet::new();
299 for node in &self.nodes {
300 for &dep_id in &node.dependencies {
301 let edge = (dep_id, node.id);
302 if !edges_added.contains(&edge) {
303 let dep_node = self.nodes.iter().find(|n| n.id == dep_id);
305
306 let mut port_labels = Vec::new();
308
309 for (broadcast_var, impl_var) in &node.input_mapping {
311 if let Some(dep) = dep_node {
313 if dep.output_mapping.values().any(|v| v == broadcast_var) {
315 port_labels.push(format!("{} → {}", broadcast_var, impl_var));
316 }
317 }
318 }
319
320 if port_labels.is_empty() {
322 mermaid.push_str(&format!(" {} --> {}\n", dep_id, node.id));
323 } else {
324 let label = port_labels.join("<br/>");
325 mermaid.push_str(&format!(" {} -->|{}| {}\n", dep_id, label, node.id));
326 }
327
328 edges_added.insert(edge);
329 }
330 }
331 }
332
333 for node in &self.nodes {
335 if node.is_branch {
336 mermaid.push_str(&format!(" style {} fill:#e1f5ff\n", node.id));
337 }
338 }
339
340 for node in &self.nodes {
342 if let Some(variant_idx) = node.variant_index {
343 let colors = ["#ffe1e1", "#e1ffe1", "#ffe1ff", "#ffffe1"];
344 let color = colors[variant_idx % colors.len()];
345 mermaid.push_str(&format!(" style {} fill:{}\n", node.id, color));
346 }
347 }
348
349 mermaid
350 }
351
352 pub fn execution_order(&self) -> &[NodeId] {
354 &self.execution_order
355 }
356
357 pub fn execution_levels(&self) -> &[Vec<NodeId>] {
359 &self.execution_levels
360 }
361
362 pub fn nodes(&self) -> &[Node] {
364 &self.nodes
365 }
366
367 pub fn stats(&self) -> DagStats {
369 DagStats {
370 node_count: self.nodes.len(),
371 depth: self.execution_levels.len(),
372 max_parallelism: self
373 .execution_levels
374 .iter()
375 .map(|level| level.len())
376 .max()
377 .unwrap_or(0),
378 branch_count: self.nodes.iter().filter(|n| n.is_branch).count(),
379 variant_count: self
380 .nodes
381 .iter()
382 .filter_map(|n| n.variant_index)
383 .max()
384 .map(|max| max + 1)
385 .unwrap_or(0),
386 }
387 }
388}
389
390#[derive(Debug, Clone)]
392pub struct DagStats {
393 pub node_count: usize,
395 pub depth: usize,
397 pub max_parallelism: usize,
399 pub branch_count: usize,
401 pub variant_count: usize,
403}
404
405impl DagStats {
406 pub fn summary(&self) -> String {
408 format!(
409 "DAG Statistics:\n\
410 - Nodes: {}\n\
411 - Depth: {} levels\n\
412 - Max Parallelism: {} nodes\n\
413 - Branches: {}\n\
414 - Variants: {}",
415 self.node_count, self.depth, self.max_parallelism, self.branch_count, self.variant_count
416 )
417 }
418}