graph_sp/
builder.rs

1//! Graph builder with implicit connections API
2
3use crate::dag::Dag;
4use crate::node::{Node, NodeId};
5use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7
8/// Trait for types that can be converted into variant values
9pub trait IntoVariantValues {
10    fn into_variant_values(self) -> Vec<String>;
11}
12
13/// Implement for Vec<String> - direct list of values
14impl IntoVariantValues for Vec<String> {
15    fn into_variant_values(self) -> Vec<String> {
16        self
17    }
18}
19
20/// Implement for Vec<&str> - direct list of string slices
21impl IntoVariantValues for Vec<&str> {
22    fn into_variant_values(self) -> Vec<String> {
23        self.into_iter().map(|s| s.to_string()).collect()
24    }
25}
26
27/// Implement for Vec<f64> - list of numeric values
28impl IntoVariantValues for Vec<f64> {
29    fn into_variant_values(self) -> Vec<String> {
30        self.into_iter().map(|v| v.to_string()).collect()
31    }
32}
33
34/// Implement for Vec<i32> - list of integer values
35impl IntoVariantValues for Vec<i32> {
36    fn into_variant_values(self) -> Vec<String> {
37        self.into_iter().map(|v| v.to_string()).collect()
38    }
39}
40
41/// Helper struct for linearly spaced values
42pub struct Linspace {
43    start: f64,
44    end: f64,
45    count: usize,
46}
47
48impl Linspace {
49    pub fn new(start: f64, end: f64, count: usize) -> Self {
50        Self { start, end, count }
51    }
52}
53
54impl IntoVariantValues for Linspace {
55    fn into_variant_values(self) -> Vec<String> {
56        if self.count == 0 {
57            return Vec::new();
58        }
59        
60        let step = if self.count > 1 {
61            (self.end - self.start) / (self.count - 1) as f64
62        } else {
63            0.0
64        };
65        
66        (0..self.count)
67            .map(|i| {
68                let value = self.start + step * i as f64;
69                value.to_string()
70            })
71            .collect()
72    }
73}
74
75/// Helper struct for logarithmically spaced values
76pub struct Logspace {
77    start: f64,
78    end: f64,
79    count: usize,
80}
81
82impl Logspace {
83    pub fn new(start: f64, end: f64, count: usize) -> Self {
84        Self { start, end, count }
85    }
86}
87
88impl IntoVariantValues for Logspace {
89    fn into_variant_values(self) -> Vec<String> {
90        if self.count == 0 || self.start <= 0.0 || self.end <= 0.0 {
91            return Vec::new();
92        }
93        
94        let log_start = self.start.ln();
95        let log_end = self.end.ln();
96        let step = if self.count > 1 {
97            (log_end - log_start) / (self.count - 1) as f64
98        } else {
99            0.0
100        };
101        
102        (0..self.count)
103            .map(|i| {
104                let value = (log_start + step * i as f64).exp();
105                value.to_string()
106            })
107            .collect()
108    }
109}
110
111/// Helper struct for geometric progression
112pub struct Geomspace {
113    start: f64,
114    ratio: f64,
115    count: usize,
116}
117
118impl Geomspace {
119    pub fn new(start: f64, ratio: f64, count: usize) -> Self {
120        Self { start, ratio, count }
121    }
122}
123
124impl IntoVariantValues for Geomspace {
125    fn into_variant_values(self) -> Vec<String> {
126        (0..self.count)
127            .map(|i| {
128                let value = self.start * self.ratio.powi(i as i32);
129                value.to_string()
130            })
131            .collect()
132    }
133}
134
135/// Helper struct for custom generator functions
136pub struct Generator<F>
137where
138    F: Fn(usize) -> String,
139{
140    count: usize,
141    generator: F,
142}
143
144impl<F> Generator<F>
145where
146    F: Fn(usize) -> String,
147{
148    pub fn new(count: usize, generator: F) -> Self {
149        Self { count, generator }
150    }
151}
152
153impl<F> IntoVariantValues for Generator<F>
154where
155    F: Fn(usize) -> String,
156{
157    fn into_variant_values(self) -> Vec<String> {
158        (0..self.count).map(|i| (self.generator)(i)).collect()
159    }
160}
161
162/// Graph builder for constructing graphs with implicit node connections
163pub struct Graph {
164    /// All nodes in the graph
165    nodes: Vec<Node>,
166    /// Counter for generating unique node IDs
167    next_id: NodeId,
168    /// The last added node ID (for implicit connections)
169    last_node_id: Option<NodeId>,
170    /// Track the last branch point for sequential .branch() calls
171    last_branch_point: Option<NodeId>,
172    /// Subgraph builders for branches with their IDs
173    branches: Vec<(usize, Graph)>,
174    /// Next branch ID counter
175    next_branch_id: usize,
176    /// Track nodes that should be merged together
177    merge_targets: Vec<NodeId>,
178}
179
180impl Graph {
181    /// Create a new graph
182    pub fn new() -> Self {
183        Self {
184            nodes: Vec::new(),
185            next_id: 0,
186            last_node_id: None,
187            last_branch_point: None,
188            branches: Vec::new(),
189            next_branch_id: 1,
190            merge_targets: Vec::new(),
191        }
192    }
193
194    /// Get a unique branch ID for tracking branches
195    fn get_branch_id(&mut self) -> usize {
196        let id = self.next_branch_id;
197        self.next_branch_id += 1;
198        id
199    }
200
201    /// Add a node to the graph with implicit connections
202    ///
203    /// # Arguments
204    ///
205    /// * `function_handle` - The function to execute for this node
206    /// * `label` - Optional label for visualization
207    /// * `inputs` - Optional list of (broadcast_var, impl_var) tuples for inputs
208    /// * `outputs` - Optional list of (impl_var, broadcast_var) tuples for outputs
209    ///
210    /// # Implicit Connection Behavior
211    ///
212    /// - The first node added has no dependencies
213    /// - Subsequent nodes automatically depend on the previous node
214    /// - This creates a natural sequential flow unless `.branch()` is used
215    ///
216    /// # Function Signature
217    ///
218    /// Functions receive two parameters:
219    /// - `inputs: &HashMap<String, String>` - Mapped input variables (impl_var names)
220    /// - `variant_params: &HashMap<String, String>` - Variant parameter values
221    ///
222    /// Functions return outputs using impl_var names, which get mapped to broadcast_var names.
223    ///
224    /// # Example
225    ///
226    /// ```ignore
227    /// // Function sees "input_data", context has "data"
228    /// // Function returns "output_value", gets stored as "result" in context
229    /// graph.add(
230    ///     process_fn,
231    ///     Some("Process"),
232    ///     Some(vec![("data", "input_data")]),     // (broadcast, impl)
233    ///     Some(vec![("output_value", "result")])  // (impl, broadcast)
234    /// );
235    /// ```
236    pub fn add<F>(
237        &mut self,
238        function_handle: F,
239        label: Option<&str>,
240        inputs: Option<Vec<(&str, &str)>>,
241        outputs: Option<Vec<(&str, &str)>>,
242    ) -> &mut Self
243    where
244        F: Fn(&std::collections::HashMap<String, String>, &std::collections::HashMap<String, String>) -> std::collections::HashMap<String, String>
245            + Send
246            + Sync
247            + 'static,
248    {
249        let id = self.next_id;
250        self.next_id += 1;
251
252        // Build input_mapping: broadcast_var -> impl_var
253        let input_mapping: HashMap<String, String> = inputs
254            .unwrap_or_default()
255            .iter()
256            .map(|(broadcast, impl_var)| (broadcast.to_string(), impl_var.to_string()))
257            .collect();
258
259        // Build output_mapping: impl_var -> broadcast_var
260        let output_mapping: HashMap<String, String> = outputs
261            .unwrap_or_default()
262            .iter()
263            .map(|(impl_var, broadcast)| (impl_var.to_string(), broadcast.to_string()))
264            .collect();
265
266        let mut node = Node::new(
267            id,
268            Arc::new(function_handle),
269            label.map(|s| s.to_string()),
270            input_mapping,
271            output_mapping,
272        );
273
274        // Implicit connection: connect to the last added node or merge targets
275        if !self.merge_targets.is_empty() {
276            // Connect to all merge targets
277            node.dependencies.extend(self.merge_targets.iter().copied());
278            self.merge_targets.clear();
279        } else if let Some(prev_id) = self.last_node_id {
280            node.dependencies.push(prev_id);
281        }
282
283        self.nodes.push(node);
284        self.last_node_id = Some(id);
285        
286        // Reset branch point after adding a regular node
287        self.last_branch_point = None;
288
289        self
290    }
291
292    /// Insert a branching subgraph
293    ///
294    /// # Implicit Branching Behavior
295    ///
296    /// - Sequential `.branch()` calls without `.add()` between them implicitly
297    ///   branch from the same node
298    /// - This allows creating multiple parallel execution paths easily
299    ///
300    /// # Arguments
301    ///
302    /// * `subgraph` - A configured Graph representing the branch
303    ///
304    /// # Returns
305    ///
306    /// Returns the branch ID for use in merge operations
307    pub fn branch(&mut self, mut subgraph: Graph) -> usize {
308        // Assign a branch ID to this subgraph
309        let branch_id = self.get_branch_id();
310        
311        // Determine the branch point
312        let branch_point = if let Some(bp) = self.last_branch_point {
313            // Sequential .branch() calls - use the same branch point
314            bp
315        } else {
316            // First branch after .add() - branch from last node
317            if let Some(last_id) = self.last_node_id {
318                self.last_branch_point = Some(last_id);
319                last_id
320            } else {
321                // No previous node, subgraph starts independently
322                self.branches.push((branch_id, subgraph));
323                return branch_id;
324            }
325        };
326
327        // Connect the first node of the subgraph to the branch point
328        if let Some(first_node) = subgraph.nodes.first_mut() {
329            if !first_node.dependencies.contains(&branch_point) {
330                first_node.dependencies.push(branch_point);
331            }
332            first_node.is_branch = true;
333            first_node.branch_id = Some(branch_id);
334        }
335        
336        // Mark all nodes in this branch with the branch ID
337        for node in &mut subgraph.nodes {
338            node.branch_id = Some(branch_id);
339        }
340
341        // Store subgraph with its branch ID
342        self.branches.push((branch_id, subgraph));
343
344        branch_id
345    }
346
347    /// Create configuration sweep variants using a factory function (sigexec-style)
348    ///
349    /// Takes a factory function and an array of parameter values. The factory is called
350    /// with each parameter value to create a node function for that variant.
351    ///
352    /// # Arguments
353    ///
354    /// * `factory` - Function that takes a parameter value and returns a node function
355    /// * `param_values` - Array of parameter values to sweep over
356    /// * `label` - Optional label for visualization (default: None)
357    /// * `inputs` - Optional list of (broadcast_var, impl_var) tuples for inputs
358    /// * `outputs` - Optional list of (impl_var, broadcast_var) tuples for outputs
359    ///
360    /// # Example
361    ///
362    /// ```ignore
363    /// fn make_scaler(factor: f64) -> impl Fn(&HashMap<String, String>, &HashMap<String, String>) -> HashMap<String, String> {
364    ///     move |inputs, _variant_params| {
365    ///         let mut outputs = HashMap::new();
366    ///         if let Some(val) = inputs.get("x").and_then(|s| s.parse::<f64>().ok()) {
367    ///             outputs.insert("scaled_x".to_string(), (val * factor).to_string());
368    ///         }
369    ///         outputs
370    ///     }
371    /// }
372    ///
373    /// graph.variant(
374    ///     make_scaler,
375    ///     vec![2.0, 3.0, 5.0],
376    ///     Some("Scale"),
377    ///     Some(vec![("data", "x")]),          // (broadcast, impl)
378    ///     Some(vec![("scaled_x", "result")])  // (impl, broadcast)
379    /// );
380    /// ```
381    ///
382    /// # Behavior
383    ///
384    /// - Creates one node per parameter value
385    /// - Each node is created by calling factory(param_value)
386    /// - Nodes still receive both regular inputs and variant_params
387    /// - All variants branch from the same point and can execute in parallel
388    pub fn variant<F, P, NF>(
389        &mut self,
390        factory: F,
391        param_values: Vec<P>,
392        label: Option<&str>,
393        inputs: Option<Vec<(&str, &str)>>,
394        outputs: Option<Vec<(&str, &str)>>,
395    ) -> &mut Self
396    where
397        F: Fn(P) -> NF,
398        P: ToString + Clone,
399        NF: Fn(&std::collections::HashMap<String, String>, &std::collections::HashMap<String, String>) -> std::collections::HashMap<String, String>
400            + Send
401            + Sync
402            + 'static,
403    {
404        // Remember the branch point before adding variants
405        let branch_point = self.last_node_id;
406        
407        // Create a variant node for each parameter value
408        for (idx, param_value) in param_values.iter().enumerate() {
409            // Create the node function using the factory
410            let node_fn = factory(param_value.clone());
411            
412            let id = self.next_id;
413            self.next_id += 1;
414
415            // Build input_mapping: broadcast_var -> impl_var
416            let input_mapping: HashMap<String, String> = inputs
417                .as_ref()
418                .unwrap_or(&vec![])
419                .iter()
420                .map(|(broadcast, impl_var)| (broadcast.to_string(), impl_var.to_string()))
421                .collect();
422
423            // Build output_mapping: impl_var -> broadcast_var
424            let output_mapping: HashMap<String, String> = outputs
425                .as_ref()
426                .unwrap_or(&vec![])
427                .iter()
428                .map(|(impl_var, broadcast)| (impl_var.to_string(), broadcast.to_string()))
429                .collect();
430
431            let mut node = Node::new(
432                id,
433                Arc::new(node_fn),
434                label.map(|s| format!("{} (v{})", s, idx)),
435                input_mapping,
436                output_mapping,
437            );
438
439            // Set variant index and param value
440            node.variant_index = Some(idx);
441            node.variant_params.insert("param_value".to_string(), param_value.to_string());
442
443            // Connect to branch point (all variants branch from same node)
444            if let Some(bp_id) = branch_point {
445                node.dependencies.push(bp_id);
446                node.is_branch = true;
447            }
448
449            self.nodes.push(node);
450        }
451
452        // Don't update last_node_id - variants don't create sequential flow
453        // Set last_branch_point for potential merge
454        self.last_branch_point = branch_point;
455
456        self
457    }
458
459    /// Merge multiple branches back together with a merge function
460    ///
461    /// After branching, use `.merge()` to bring parallel paths back to a single point.
462    /// The merge function receives outputs from all specified branches and combines them.
463    ///
464    /// # Arguments
465    ///
466    /// * `merge_fn` - Function that combines outputs from all branches
467    /// * `label` - Optional label for visualization
468    /// * `inputs` - List of (branch_id, broadcast_var, impl_var) tuples specifying which branch outputs to merge
469    /// * `outputs` - Optional list of (impl_var, broadcast_var) tuples for outputs
470    ///
471    /// # Example
472    ///
473    /// ```ignore
474    /// graph.add(source_fn, Some("Source"), None, Some(vec![("src_out", "data")]));
475    /// 
476    /// let mut branch_a = Graph::new();
477    /// branch_a.add(process_a, Some("Process A"), Some(vec![("data", "input")]), Some(vec![("output", "result")]));
478    /// 
479    /// let mut branch_b = Graph::new();
480    /// branch_b.add(process_b, Some("Process B"), Some(vec![("data", "input")]), Some(vec![("output", "result")]));
481    /// 
482    /// let branch_a_id = graph.branch(branch_a);
483    /// let branch_b_id = graph.branch(branch_b);
484    /// 
485    /// // Merge function combines results from both branches
486    /// // Branches can use same output name "result", merge maps them distinctly
487    /// graph.merge(
488    ///     combine_fn,
489    ///     Some("Combine"),
490    ///     vec![
491    ///         (branch_a_id, "result", "a_result"),    // (branch, broadcast, impl)
492    ///         (branch_b_id, "result", "b_result")
493    ///     ],
494    ///     Some(vec![("combined", "final")])            // (impl, broadcast)
495    /// );
496    /// ```
497    pub fn merge<F>(
498        &mut self,
499        merge_fn: F,
500        label: Option<&str>,
501        inputs: Vec<(usize, &str, &str)>,
502        outputs: Option<Vec<(&str, &str)>>,
503    ) -> &mut Self
504    where
505        F: Fn(&std::collections::HashMap<String, String>, &std::collections::HashMap<String, String>) -> std::collections::HashMap<String, String>
506            + Send
507            + Sync
508            + 'static,
509    {
510        // First, integrate all pending branches into the main graph
511        let branches = std::mem::take(&mut self.branches);
512        let mut branch_terminals = Vec::new();
513        
514        for (_branch_id, branch) in branches {
515            if let Some(last_id) = branch.last_node_id {
516                branch_terminals.push(last_id);
517            }
518            self.merge_branch(branch);
519        }
520        
521        // Create the merge node
522        let id = self.next_id;
523        self.next_id += 1;
524
525        // Build input_mapping with branch-specific resolution
526        // For merge, we need special handling: (branch_id, broadcast_var) -> impl_var
527        // This will be handled in execution by looking at branch_id field of dependency nodes
528        let input_mapping: HashMap<String, String> = inputs
529            .iter()
530            .map(|(branch_id, broadcast_var, impl_var)| {
531                // Store as "branch_id:broadcast_var" -> impl_var for unique identification
532                (format!("{}:{}", branch_id, broadcast_var), impl_var.to_string())
533            })
534            .collect();
535
536        // Build output_mapping: impl_var -> broadcast_var
537        let output_mapping: HashMap<String, String> = outputs
538            .unwrap_or_default()
539            .iter()
540            .map(|(impl_var, broadcast)| (impl_var.to_string(), broadcast.to_string()))
541            .collect();
542
543        let mut node = Node::new(
544            id,
545            Arc::new(merge_fn),
546            label.map(|s| s.to_string()),
547            input_mapping,
548            output_mapping,
549        );
550
551        // Connect to all branch terminals
552        node.dependencies.extend(branch_terminals);
553
554        self.nodes.push(node);
555        self.last_node_id = Some(id);
556        
557        // Reset branch point
558        self.last_branch_point = None;
559        
560        self
561    }
562
563    /// Build the final DAG from the graph builder
564    ///
565    /// This performs the implicit inspection phase:
566    /// - Full graph traversal
567    /// - Execution path optimization
568    /// - Data flow connection determination
569    /// - Identification of parallelizable operations
570    pub fn build(mut self) -> Dag {
571        // Merge all branch subgraphs into main node list
572        let branches = std::mem::take(&mut self.branches);
573        for (_branch_id, branch) in branches {
574            self.merge_branch(branch);
575        }
576
577        Dag::new(self.nodes)
578    }
579
580    /// Merge a branch builder's nodes into this builder
581    fn merge_branch(&mut self, branch: Graph) {
582        // Create a mapping from old branch IDs to new IDs
583        let mut id_mapping: HashMap<NodeId, NodeId> = HashMap::new();
584        
585        // Get the set of existing node IDs in the main graph (before merging)
586        let existing_ids: HashSet<NodeId> = self.nodes.iter().map(|n| n.id).collect();
587        
588        // Renumber all nodes from the branch
589        for mut node in branch.nodes {
590            let old_id = node.id;
591            let new_id = self.next_id;
592            self.next_id += 1;
593            
594            id_mapping.insert(old_id, new_id);
595            node.id = new_id;
596            
597            // Update dependencies with new IDs
598            // Only remap dependencies that were part of the branch (not from main graph)
599            node.dependencies = node.dependencies
600                .iter()
601                .map(|&dep_id| {
602                    if existing_ids.contains(&dep_id) {
603                        // This dependency is from the main graph, keep it as-is
604                        dep_id
605                    } else {
606                        // This dependency is from the branch, remap it
607                        *id_mapping.get(&dep_id).unwrap_or(&dep_id)
608                    }
609                })
610                .collect();
611            
612            self.nodes.push(node);
613        }
614
615        // Recursively merge nested branches
616        for (_branch_id, nested_branch) in branch.branches {
617            self.merge_branch(nested_branch);
618        }
619    }
620}
621
622impl Default for Graph {
623    fn default() -> Self {
624        Self::new()
625    }
626}