graph_sp/
dag.rs

1//! DAG representation with execution and visualization support
2
3use crate::node::{Node, NodeId};
4use std::collections::{HashMap, HashSet, VecDeque};
5use std::sync::{Arc, Mutex};
6
7/// Execution context for storing variable values during graph execution
8pub type ExecutionContext = HashMap<String, String>;
9
10/// Execution result that tracks outputs per node and per branch
11#[derive(Debug, Clone)]
12pub struct ExecutionResult {
13    /// Global execution context (all variables accessible by broadcast name)
14    pub context: ExecutionContext,
15    /// Outputs per node (node_id -> HashMap of output variables)
16    pub node_outputs: HashMap<NodeId, HashMap<String, String>>,
17    /// Outputs per branch (branch_id -> HashMap of output variables)
18    pub branch_outputs: HashMap<usize, HashMap<String, String>>,
19}
20
21impl ExecutionResult {
22    /// Create a new empty execution result
23    pub fn new() -> Self {
24        Self {
25            context: HashMap::new(),
26            node_outputs: HashMap::new(),
27            branch_outputs: HashMap::new(),
28        }
29    }
30    
31    /// Get a value from the global context
32    pub fn get(&self, key: &str) -> Option<&String> {
33        self.context.get(key)
34    }
35    
36    /// Get all outputs from a specific node
37    pub fn get_node_outputs(&self, node_id: NodeId) -> Option<&HashMap<String, String>> {
38        self.node_outputs.get(&node_id)
39    }
40    
41    /// Get all outputs from a specific branch
42    pub fn get_branch_outputs(&self, branch_id: usize) -> Option<&HashMap<String, String>> {
43        self.branch_outputs.get(&branch_id)
44    }
45    
46    /// Get a specific variable from a node
47    pub fn get_from_node(&self, node_id: NodeId, key: &str) -> Option<&String> {
48        self.node_outputs.get(&node_id).and_then(|outputs| outputs.get(key))
49    }
50    
51    /// Get a specific variable from a branch
52    pub fn get_from_branch(&self, branch_id: usize, key: &str) -> Option<&String> {
53        self.branch_outputs.get(&branch_id).and_then(|outputs| outputs.get(key))
54    }
55    
56    /// Check if a variable exists in global context
57    pub fn contains_key(&self, key: &str) -> bool {
58        self.context.contains_key(key)
59    }
60}
61
62/// Directed Acyclic Graph representing the optimized execution plan
63pub struct Dag {
64    /// All nodes in the DAG
65    nodes: Vec<Node>,
66    /// Execution order (topologically sorted)
67    execution_order: Vec<NodeId>,
68    /// Levels for parallel execution (nodes at same level can run in parallel)
69    execution_levels: Vec<Vec<NodeId>>,
70}
71
72impl Dag {
73    /// Create a new DAG from a list of nodes
74    ///
75    /// Performs implicit inspection:
76    /// - Validates the graph is acyclic
77    /// - Determines optimal execution order
78    /// - Identifies parallelizable operations
79    pub fn new(nodes: Vec<Node>) -> Self {
80        let execution_order = Self::topological_sort(&nodes);
81        let execution_levels = Self::compute_execution_levels(&nodes, &execution_order);
82
83        Self {
84            nodes,
85            execution_order,
86            execution_levels,
87        }
88    }
89
90    /// Perform topological sort to determine execution order
91    fn topological_sort(nodes: &[Node]) -> Vec<NodeId> {
92        let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
93        let mut adj_list: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
94
95        // Initialize in-degree and adjacency list
96        for node in nodes {
97            in_degree.entry(node.id).or_insert(0);
98            adj_list.entry(node.id).or_insert_with(Vec::new);
99
100            for &dep in &node.dependencies {
101                *in_degree.entry(node.id).or_insert(0) += 1;
102                adj_list.entry(dep).or_insert_with(Vec::new).push(node.id);
103            }
104        }
105
106        // Kahn's algorithm for topological sort
107        let mut queue: VecDeque<NodeId> = in_degree
108            .iter()
109            .filter(|(_, &degree)| degree == 0)
110            .map(|(&id, _)| id)
111            .collect();
112
113        let mut result = Vec::new();
114
115        while let Some(node_id) = queue.pop_front() {
116            result.push(node_id);
117
118            if let Some(neighbors) = adj_list.get(&node_id) {
119                for &neighbor in neighbors {
120                    if let Some(degree) = in_degree.get_mut(&neighbor) {
121                        *degree -= 1;
122                        if *degree == 0 {
123                            queue.push_back(neighbor);
124                        }
125                    }
126                }
127            }
128        }
129
130        result
131    }
132
133    /// Compute execution levels for parallel execution
134    ///
135    /// Nodes at the same level have no dependencies on each other and can
136    /// execute in parallel.
137    fn compute_execution_levels(nodes: &[Node], execution_order: &[NodeId]) -> Vec<Vec<NodeId>> {
138        let mut levels: Vec<Vec<NodeId>> = Vec::new();
139        let mut node_level: HashMap<NodeId, usize> = HashMap::new();
140
141        for &node_id in execution_order {
142            let node = nodes.iter().find(|n| n.id == node_id).unwrap();
143
144            // Find the maximum level of all dependencies
145            let level = if node.dependencies.is_empty() {
146                0
147            } else {
148                node.dependencies
149                    .iter()
150                    .filter_map(|dep_id| node_level.get(dep_id))
151                    .max()
152                    .map(|&max_level| max_level + 1)
153                    .unwrap_or(0)
154            };
155
156            node_level.insert(node_id, level);
157
158            // Add node to its level
159            while levels.len() <= level {
160                levels.push(Vec::new());
161            }
162            levels[level].push(node_id);
163        }
164
165        levels
166    }
167
168    /// Execute the DAG (legacy method returning just context)
169    ///
170    /// Runs all nodes in topological order, accumulating outputs in the execution context.
171    /// 
172    /// # Arguments
173    /// * `parallel` - If true, execute nodes at the same level concurrently
174    /// * `max_threads` - Optional maximum number of threads to use per level (None = unlimited)
175    pub fn execute(&self, parallel: bool, max_threads: Option<usize>) -> ExecutionContext {
176        self.execute_detailed(parallel, max_threads).context
177    }
178    
179    /// Execute the DAG with detailed per-node and per-branch tracking
180    ///
181    /// Runs all nodes in topological order and tracks outputs per node and per branch.
182    /// 
183    /// # Arguments
184    /// * `parallel` - If true, execute nodes at the same level concurrently
185    /// * `max_threads` - Optional maximum number of threads to use per level (None = unlimited)
186    pub fn execute_detailed(&self, parallel: bool, max_threads: Option<usize>) -> ExecutionResult {
187        let mut result = ExecutionResult::new();
188
189        if !parallel {
190            // Sequential execution
191            for &node_id in &self.execution_order {
192                if let Some(node) = self.nodes.iter().find(|n| n.id == node_id) {
193                    let outputs = node.execute(&result.context);
194                    
195                    // Store outputs in global context
196                    result.context.extend(outputs.clone());
197                    
198                    // Store outputs per node (using broadcast variable names from output_mapping)
199                    let node_outputs: HashMap<String, String> = outputs.clone();
200                    result.node_outputs.insert(node_id, node_outputs);
201                    
202                    // Store outputs per branch if this node belongs to a branch
203                    if let Some(branch_id) = node.branch_id {
204                        result.branch_outputs
205                            .entry(branch_id)
206                            .or_insert_with(HashMap::new)
207                            .extend(outputs);
208                    }
209                }
210            }
211        } else {
212            // Parallel execution
213            for level in &self.execution_levels {
214                // Execute nodes at the same level in parallel
215                if level.len() == 1 {
216                    // Single node - no need for threading overhead
217                    let node_id = level[0];
218                    if let Some(node) = self.nodes.iter().find(|n| n.id == node_id) {
219                        let outputs = node.execute(&result.context);
220                        
221                        result.context.extend(outputs.clone());
222                        result.node_outputs.insert(node_id, outputs.clone());
223                        
224                        if let Some(branch_id) = node.branch_id {
225                            result.branch_outputs
226                                .entry(branch_id)
227                                .or_insert_with(HashMap::new)
228                                .extend(outputs);
229                        }
230                    }
231                } else {
232                    // Multiple nodes - execute in parallel using scoped threads
233                    let context = Arc::new(result.context.clone());
234                    let nodes_to_execute: Vec<_> = level.iter()
235                        .filter_map(|&node_id| {
236                            self.nodes.iter().find(|n| n.id == node_id)
237                        })
238                        .collect();
239                    
240                    // Limit threads if max_threads is specified
241                    let chunk_size = if let Some(max) = max_threads {
242                        max.max(1) // At least 1 thread
243                    } else {
244                        nodes_to_execute.len() // Unlimited - one thread per node
245                    };
246                    
247                    let outputs = Arc::new(Mutex::new(Vec::new()));
248                    
249                    // Process nodes in chunks to respect max_threads limit
250                    for chunk in nodes_to_execute.chunks(chunk_size) {
251                        std::thread::scope(|s| {
252                            for node in chunk {
253                                let context = Arc::clone(&context);
254                                let outputs = Arc::clone(&outputs);
255                                
256                                s.spawn(move || {
257                                    let node_outputs = node.execute(&context);
258                                    outputs.lock().unwrap().push((node.id, node.branch_id, node_outputs));
259                                });
260                            }
261                        });
262                    }
263                    
264                    // Collect outputs from all parallel executions
265                    let collected_outputs = outputs.lock().unwrap();
266                    for (node_id, branch_id, node_outputs) in collected_outputs.iter() {
267                        result.context.extend(node_outputs.clone());
268                        result.node_outputs.insert(*node_id, node_outputs.clone());
269                        
270                        if let Some(bid) = branch_id {
271                            result.branch_outputs
272                                .entry(*bid)
273                                .or_insert_with(HashMap::new)
274                                .extend(node_outputs.clone());
275                        }
276                    }
277                }
278            }
279        }
280
281        result
282    }
283
284    /// Generate a Mermaid diagram for visualization with port mappings
285    ///
286    /// Returns a string containing a Mermaid flowchart representing the DAG.
287    /// Edge labels show port mappings (broadcast_var → impl_var).
288    pub fn to_mermaid(&self) -> String {
289        let mut mermaid = String::from("graph TD\n");
290
291        // Add all nodes
292        for node in &self.nodes {
293            let node_label = node.display_name();
294            mermaid.push_str(&format!("    {}[\"{}\"]\n", node.id, node_label));
295        }
296
297        // Add edges with port mapping labels
298        let mut edges_added: HashSet<(NodeId, NodeId)> = HashSet::new();
299        for node in &self.nodes {
300            for &dep_id in &node.dependencies {
301                let edge = (dep_id, node.id);
302                if !edges_added.contains(&edge) {
303                    // Find the dependency node to get its output mappings
304                    let dep_node = self.nodes.iter().find(|n| n.id == dep_id);
305                    
306                    // Build port mapping label
307                    let mut port_labels = Vec::new();
308                    
309                    // Show input mappings for the current node that come from this dependency
310                    for (broadcast_var, impl_var) in &node.input_mapping {
311                        // Check if this broadcast var comes from the dependency
312                        if let Some(dep) = dep_node {
313                            // Check if dependency produces this broadcast var
314                            if dep.output_mapping.values().any(|v| v == broadcast_var) {
315                                port_labels.push(format!("{} → {}", broadcast_var, impl_var));
316                            }
317                        }
318                    }
319                    
320                    // Format edge with port labels
321                    if port_labels.is_empty() {
322                        mermaid.push_str(&format!("    {} --> {}\n", dep_id, node.id));
323                    } else {
324                        let label = port_labels.join("<br/>");
325                        mermaid.push_str(&format!("    {} -->|{}| {}\n", dep_id, label, node.id));
326                    }
327                    
328                    edges_added.insert(edge);
329                }
330            }
331        }
332
333        // Add styling for branches
334        for node in &self.nodes {
335            if node.is_branch {
336                mermaid.push_str(&format!("    style {} fill:#e1f5ff\n", node.id));
337            }
338        }
339
340        // Add styling for variants
341        for node in &self.nodes {
342            if let Some(variant_idx) = node.variant_index {
343                let colors = ["#ffe1e1", "#e1ffe1", "#ffe1ff", "#ffffe1"];
344                let color = colors[variant_idx % colors.len()];
345                mermaid.push_str(&format!("    style {} fill:{}\n", node.id, color));
346            }
347        }
348
349        mermaid
350    }
351
352    /// Get the execution order
353    pub fn execution_order(&self) -> &[NodeId] {
354        &self.execution_order
355    }
356
357    /// Get the execution levels
358    pub fn execution_levels(&self) -> &[Vec<NodeId>] {
359        &self.execution_levels
360    }
361
362    /// Get all nodes
363    pub fn nodes(&self) -> &[Node] {
364        &self.nodes
365    }
366
367    /// Get statistics about the DAG
368    pub fn stats(&self) -> DagStats {
369        DagStats {
370            node_count: self.nodes.len(),
371            depth: self.execution_levels.len(),
372            max_parallelism: self
373                .execution_levels
374                .iter()
375                .map(|level| level.len())
376                .max()
377                .unwrap_or(0),
378            branch_count: self.nodes.iter().filter(|n| n.is_branch).count(),
379            variant_count: self
380                .nodes
381                .iter()
382                .filter_map(|n| n.variant_index)
383                .max()
384                .map(|max| max + 1)
385                .unwrap_or(0),
386        }
387    }
388}
389
390/// Statistics about a DAG
391#[derive(Debug, Clone)]
392pub struct DagStats {
393    /// Total number of nodes
394    pub node_count: usize,
395    /// Maximum depth (longest path from source to sink)
396    pub depth: usize,
397    /// Maximum number of nodes that can execute in parallel
398    pub max_parallelism: usize,
399    /// Number of branch nodes
400    pub branch_count: usize,
401    /// Number of variants
402    pub variant_count: usize,
403}
404
405impl DagStats {
406    /// Format stats as a human-readable string
407    pub fn summary(&self) -> String {
408        format!(
409            "DAG Statistics:\n\
410             - Nodes: {}\n\
411             - Depth: {} levels\n\
412             - Max Parallelism: {} nodes\n\
413             - Branches: {}\n\
414             - Variants: {}",
415            self.node_count, self.depth, self.max_parallelism, self.branch_count, self.variant_count
416        )
417    }
418}