graph_sp/builder.rs
1//! Graph builder with implicit connections API
2
3use crate::dag::Dag;
4use crate::node::{Node, NodeId};
5use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7
8/// Trait for types that can be converted into variant values
9pub trait IntoVariantValues {
10 fn into_variant_values(self) -> Vec<String>;
11}
12
13/// Implement for Vec<String> - direct list of values
14impl IntoVariantValues for Vec<String> {
15 fn into_variant_values(self) -> Vec<String> {
16 self
17 }
18}
19
20/// Implement for Vec<&str> - direct list of string slices
21impl IntoVariantValues for Vec<&str> {
22 fn into_variant_values(self) -> Vec<String> {
23 self.into_iter().map(|s| s.to_string()).collect()
24 }
25}
26
27/// Implement for Vec<f64> - list of numeric values
28impl IntoVariantValues for Vec<f64> {
29 fn into_variant_values(self) -> Vec<String> {
30 self.into_iter().map(|v| v.to_string()).collect()
31 }
32}
33
34/// Implement for Vec<i32> - list of integer values
35impl IntoVariantValues for Vec<i32> {
36 fn into_variant_values(self) -> Vec<String> {
37 self.into_iter().map(|v| v.to_string()).collect()
38 }
39}
40
41/// Helper struct for linearly spaced values
42pub struct Linspace {
43 start: f64,
44 end: f64,
45 count: usize,
46}
47
48impl Linspace {
49 pub fn new(start: f64, end: f64, count: usize) -> Self {
50 Self { start, end, count }
51 }
52}
53
54impl IntoVariantValues for Linspace {
55 fn into_variant_values(self) -> Vec<String> {
56 if self.count == 0 {
57 return Vec::new();
58 }
59
60 let step = if self.count > 1 {
61 (self.end - self.start) / (self.count - 1) as f64
62 } else {
63 0.0
64 };
65
66 (0..self.count)
67 .map(|i| {
68 let value = self.start + step * i as f64;
69 value.to_string()
70 })
71 .collect()
72 }
73}
74
75/// Helper struct for logarithmically spaced values
76pub struct Logspace {
77 start: f64,
78 end: f64,
79 count: usize,
80}
81
82impl Logspace {
83 pub fn new(start: f64, end: f64, count: usize) -> Self {
84 Self { start, end, count }
85 }
86}
87
88impl IntoVariantValues for Logspace {
89 fn into_variant_values(self) -> Vec<String> {
90 if self.count == 0 || self.start <= 0.0 || self.end <= 0.0 {
91 return Vec::new();
92 }
93
94 let log_start = self.start.ln();
95 let log_end = self.end.ln();
96 let step = if self.count > 1 {
97 (log_end - log_start) / (self.count - 1) as f64
98 } else {
99 0.0
100 };
101
102 (0..self.count)
103 .map(|i| {
104 let value = (log_start + step * i as f64).exp();
105 value.to_string()
106 })
107 .collect()
108 }
109}
110
111/// Helper struct for geometric progression
112pub struct Geomspace {
113 start: f64,
114 ratio: f64,
115 count: usize,
116}
117
118impl Geomspace {
119 pub fn new(start: f64, ratio: f64, count: usize) -> Self {
120 Self { start, ratio, count }
121 }
122}
123
124impl IntoVariantValues for Geomspace {
125 fn into_variant_values(self) -> Vec<String> {
126 (0..self.count)
127 .map(|i| {
128 let value = self.start * self.ratio.powi(i as i32);
129 value.to_string()
130 })
131 .collect()
132 }
133}
134
135/// Helper struct for custom generator functions
136pub struct Generator<F>
137where
138 F: Fn(usize) -> String,
139{
140 count: usize,
141 generator: F,
142}
143
144impl<F> Generator<F>
145where
146 F: Fn(usize) -> String,
147{
148 pub fn new(count: usize, generator: F) -> Self {
149 Self { count, generator }
150 }
151}
152
153impl<F> IntoVariantValues for Generator<F>
154where
155 F: Fn(usize) -> String,
156{
157 fn into_variant_values(self) -> Vec<String> {
158 (0..self.count).map(|i| (self.generator)(i)).collect()
159 }
160}
161
162/// Graph builder for constructing graphs with implicit node connections
163pub struct Graph {
164 /// All nodes in the graph
165 nodes: Vec<Node>,
166 /// Counter for generating unique node IDs
167 next_id: NodeId,
168 /// The last added node ID (for implicit connections)
169 last_node_id: Option<NodeId>,
170 /// Track the last branch point for sequential .branch() calls
171 last_branch_point: Option<NodeId>,
172 /// Subgraph builders for branches with their IDs
173 branches: Vec<(usize, Graph)>,
174 /// Next branch ID counter
175 next_branch_id: usize,
176 /// Track nodes that should be merged together
177 merge_targets: Vec<NodeId>,
178}
179
180impl Graph {
181 /// Create a new graph
182 pub fn new() -> Self {
183 Self {
184 nodes: Vec::new(),
185 next_id: 0,
186 last_node_id: None,
187 last_branch_point: None,
188 branches: Vec::new(),
189 next_branch_id: 1,
190 merge_targets: Vec::new(),
191 }
192 }
193
194 /// Get a unique branch ID for tracking branches
195 fn get_branch_id(&mut self) -> usize {
196 let id = self.next_branch_id;
197 self.next_branch_id += 1;
198 id
199 }
200
201 /// Add a node to the graph with implicit connections
202 ///
203 /// # Arguments
204 ///
205 /// * `function_handle` - The function to execute for this node
206 /// * `label` - Optional label for visualization
207 /// * `inputs` - Optional list of (broadcast_var, impl_var) tuples for inputs
208 /// * `outputs` - Optional list of (impl_var, broadcast_var) tuples for outputs
209 ///
210 /// # Implicit Connection Behavior
211 ///
212 /// - The first node added has no dependencies
213 /// - Subsequent nodes automatically depend on the previous node
214 /// - This creates a natural sequential flow unless `.branch()` is used
215 ///
216 /// # Function Signature
217 ///
218 /// Functions receive two parameters:
219 /// - `inputs: &HashMap<String, String>` - Mapped input variables (impl_var names)
220 /// - `variant_params: &HashMap<String, String>` - Variant parameter values
221 ///
222 /// Functions return outputs using impl_var names, which get mapped to broadcast_var names.
223 ///
224 /// # Example
225 ///
226 /// ```ignore
227 /// // Function sees "input_data", context has "data"
228 /// // Function returns "output_value", gets stored as "result" in context
229 /// graph.add(
230 /// process_fn,
231 /// Some("Process"),
232 /// Some(vec![("data", "input_data")]), // (broadcast, impl)
233 /// Some(vec![("output_value", "result")]) // (impl, broadcast)
234 /// );
235 /// ```
236 pub fn add<F>(
237 &mut self,
238 function_handle: F,
239 label: Option<&str>,
240 inputs: Option<Vec<(&str, &str)>>,
241 outputs: Option<Vec<(&str, &str)>>,
242 ) -> &mut Self
243 where
244 F: Fn(&std::collections::HashMap<String, String>, &std::collections::HashMap<String, String>) -> std::collections::HashMap<String, String>
245 + Send
246 + Sync
247 + 'static,
248 {
249 let id = self.next_id;
250 self.next_id += 1;
251
252 // Build input_mapping: broadcast_var -> impl_var
253 let input_mapping: HashMap<String, String> = inputs
254 .unwrap_or_default()
255 .iter()
256 .map(|(broadcast, impl_var)| (broadcast.to_string(), impl_var.to_string()))
257 .collect();
258
259 // Build output_mapping: impl_var -> broadcast_var
260 let output_mapping: HashMap<String, String> = outputs
261 .unwrap_or_default()
262 .iter()
263 .map(|(impl_var, broadcast)| (impl_var.to_string(), broadcast.to_string()))
264 .collect();
265
266 let mut node = Node::new(
267 id,
268 Arc::new(function_handle),
269 label.map(|s| s.to_string()),
270 input_mapping,
271 output_mapping,
272 );
273
274 // Implicit connection: connect to the last added node or merge targets
275 if !self.merge_targets.is_empty() {
276 // Connect to all merge targets
277 node.dependencies.extend(self.merge_targets.iter().copied());
278 self.merge_targets.clear();
279 } else if let Some(prev_id) = self.last_node_id {
280 node.dependencies.push(prev_id);
281 }
282
283 self.nodes.push(node);
284 self.last_node_id = Some(id);
285
286 // Reset branch point after adding a regular node
287 self.last_branch_point = None;
288
289 self
290 }
291
292 /// Insert a branching subgraph
293 ///
294 /// # Implicit Branching Behavior
295 ///
296 /// - Sequential `.branch()` calls without `.add()` between them implicitly
297 /// branch from the same node
298 /// - This allows creating multiple parallel execution paths easily
299 ///
300 /// # Arguments
301 ///
302 /// * `subgraph` - A configured Graph representing the branch
303 ///
304 /// # Returns
305 ///
306 /// Returns the branch ID for use in merge operations
307 pub fn branch(&mut self, mut subgraph: Graph) -> usize {
308 // Assign a branch ID to this subgraph
309 let branch_id = self.get_branch_id();
310
311 // Determine the branch point
312 let branch_point = if let Some(bp) = self.last_branch_point {
313 // Sequential .branch() calls - use the same branch point
314 bp
315 } else {
316 // First branch after .add() - branch from last node
317 if let Some(last_id) = self.last_node_id {
318 self.last_branch_point = Some(last_id);
319 last_id
320 } else {
321 // No previous node, subgraph starts independently
322 self.branches.push((branch_id, subgraph));
323 return branch_id;
324 }
325 };
326
327 // Connect the first node of the subgraph to the branch point
328 if let Some(first_node) = subgraph.nodes.first_mut() {
329 if !first_node.dependencies.contains(&branch_point) {
330 first_node.dependencies.push(branch_point);
331 }
332 first_node.is_branch = true;
333 first_node.branch_id = Some(branch_id);
334 }
335
336 // Mark all nodes in this branch with the branch ID
337 for node in &mut subgraph.nodes {
338 node.branch_id = Some(branch_id);
339 }
340
341 // Store subgraph with its branch ID
342 self.branches.push((branch_id, subgraph));
343
344 branch_id
345 }
346
347 /// Create configuration sweep variants using a factory function (sigexec-style)
348 ///
349 /// Takes a factory function and an array of parameter values. The factory is called
350 /// with each parameter value to create a node function for that variant.
351 ///
352 /// # Arguments
353 ///
354 /// * `factory` - Function that takes a parameter value and returns a node function
355 /// * `param_values` - Array of parameter values to sweep over
356 /// * `label` - Optional label for visualization (default: None)
357 /// * `inputs` - Optional list of (broadcast_var, impl_var) tuples for inputs
358 /// * `outputs` - Optional list of (impl_var, broadcast_var) tuples for outputs
359 ///
360 /// # Example
361 ///
362 /// ```ignore
363 /// fn make_scaler(factor: f64) -> impl Fn(&HashMap<String, String>, &HashMap<String, String>) -> HashMap<String, String> {
364 /// move |inputs, _variant_params| {
365 /// let mut outputs = HashMap::new();
366 /// if let Some(val) = inputs.get("x").and_then(|s| s.parse::<f64>().ok()) {
367 /// outputs.insert("scaled_x".to_string(), (val * factor).to_string());
368 /// }
369 /// outputs
370 /// }
371 /// }
372 ///
373 /// graph.variant(
374 /// make_scaler,
375 /// vec![2.0, 3.0, 5.0],
376 /// Some("Scale"),
377 /// Some(vec![("data", "x")]), // (broadcast, impl)
378 /// Some(vec![("scaled_x", "result")]) // (impl, broadcast)
379 /// );
380 /// ```
381 ///
382 /// # Behavior
383 ///
384 /// - Creates one node per parameter value
385 /// - Each node is created by calling factory(param_value)
386 /// - Nodes still receive both regular inputs and variant_params
387 /// - All variants branch from the same point and can execute in parallel
388 pub fn variant<F, P, NF>(
389 &mut self,
390 factory: F,
391 param_values: Vec<P>,
392 label: Option<&str>,
393 inputs: Option<Vec<(&str, &str)>>,
394 outputs: Option<Vec<(&str, &str)>>,
395 ) -> &mut Self
396 where
397 F: Fn(P) -> NF,
398 P: ToString + Clone,
399 NF: Fn(&std::collections::HashMap<String, String>, &std::collections::HashMap<String, String>) -> std::collections::HashMap<String, String>
400 + Send
401 + Sync
402 + 'static,
403 {
404 // Remember the branch point before adding variants
405 let branch_point = self.last_node_id;
406
407 // Create a variant node for each parameter value
408 for (idx, param_value) in param_values.iter().enumerate() {
409 // Create the node function using the factory
410 let node_fn = factory(param_value.clone());
411
412 let id = self.next_id;
413 self.next_id += 1;
414
415 // Build input_mapping: broadcast_var -> impl_var
416 let input_mapping: HashMap<String, String> = inputs
417 .as_ref()
418 .unwrap_or(&vec![])
419 .iter()
420 .map(|(broadcast, impl_var)| (broadcast.to_string(), impl_var.to_string()))
421 .collect();
422
423 // Build output_mapping: impl_var -> broadcast_var
424 let output_mapping: HashMap<String, String> = outputs
425 .as_ref()
426 .unwrap_or(&vec![])
427 .iter()
428 .map(|(impl_var, broadcast)| (impl_var.to_string(), broadcast.to_string()))
429 .collect();
430
431 let mut node = Node::new(
432 id,
433 Arc::new(node_fn),
434 label.map(|s| format!("{} (v{})", s, idx)),
435 input_mapping,
436 output_mapping,
437 );
438
439 // Set variant index and param value
440 node.variant_index = Some(idx);
441 node.variant_params.insert("param_value".to_string(), param_value.to_string());
442
443 // Connect to branch point (all variants branch from same node)
444 if let Some(bp_id) = branch_point {
445 node.dependencies.push(bp_id);
446 node.is_branch = true;
447 }
448
449 self.nodes.push(node);
450 }
451
452 // Don't update last_node_id - variants don't create sequential flow
453 // Set last_branch_point for potential merge
454 self.last_branch_point = branch_point;
455
456 self
457 }
458
459 /// Merge multiple branches back together with a merge function
460 ///
461 /// After branching, use `.merge()` to bring parallel paths back to a single point.
462 /// The merge function receives outputs from all specified branches and combines them.
463 ///
464 /// # Arguments
465 ///
466 /// * `merge_fn` - Function that combines outputs from all branches
467 /// * `label` - Optional label for visualization
468 /// * `inputs` - List of (branch_id, broadcast_var, impl_var) tuples specifying which branch outputs to merge
469 /// * `outputs` - Optional list of (impl_var, broadcast_var) tuples for outputs
470 ///
471 /// # Example
472 ///
473 /// ```ignore
474 /// graph.add(source_fn, Some("Source"), None, Some(vec![("src_out", "data")]));
475 ///
476 /// let mut branch_a = Graph::new();
477 /// branch_a.add(process_a, Some("Process A"), Some(vec![("data", "input")]), Some(vec![("output", "result")]));
478 ///
479 /// let mut branch_b = Graph::new();
480 /// branch_b.add(process_b, Some("Process B"), Some(vec![("data", "input")]), Some(vec![("output", "result")]));
481 ///
482 /// let branch_a_id = graph.branch(branch_a);
483 /// let branch_b_id = graph.branch(branch_b);
484 ///
485 /// // Merge function combines results from both branches
486 /// // Branches can use same output name "result", merge maps them distinctly
487 /// graph.merge(
488 /// combine_fn,
489 /// Some("Combine"),
490 /// vec![
491 /// (branch_a_id, "result", "a_result"), // (branch, broadcast, impl)
492 /// (branch_b_id, "result", "b_result")
493 /// ],
494 /// Some(vec![("combined", "final")]) // (impl, broadcast)
495 /// );
496 /// ```
497 pub fn merge<F>(
498 &mut self,
499 merge_fn: F,
500 label: Option<&str>,
501 inputs: Vec<(usize, &str, &str)>,
502 outputs: Option<Vec<(&str, &str)>>,
503 ) -> &mut Self
504 where
505 F: Fn(&std::collections::HashMap<String, String>, &std::collections::HashMap<String, String>) -> std::collections::HashMap<String, String>
506 + Send
507 + Sync
508 + 'static,
509 {
510 // First, integrate all pending branches into the main graph
511 let branches = std::mem::take(&mut self.branches);
512 let mut branch_terminals = Vec::new();
513
514 for (_branch_id, branch) in branches {
515 if let Some(last_id) = branch.last_node_id {
516 branch_terminals.push(last_id);
517 }
518 self.merge_branch(branch);
519 }
520
521 // Create the merge node
522 let id = self.next_id;
523 self.next_id += 1;
524
525 // Build input_mapping with branch-specific resolution
526 // For merge, we need special handling: (branch_id, broadcast_var) -> impl_var
527 // This will be handled in execution by looking at branch_id field of dependency nodes
528 let input_mapping: HashMap<String, String> = inputs
529 .iter()
530 .map(|(branch_id, broadcast_var, impl_var)| {
531 // Store as "branch_id:broadcast_var" -> impl_var for unique identification
532 (format!("{}:{}", branch_id, broadcast_var), impl_var.to_string())
533 })
534 .collect();
535
536 // Build output_mapping: impl_var -> broadcast_var
537 let output_mapping: HashMap<String, String> = outputs
538 .unwrap_or_default()
539 .iter()
540 .map(|(impl_var, broadcast)| (impl_var.to_string(), broadcast.to_string()))
541 .collect();
542
543 let mut node = Node::new(
544 id,
545 Arc::new(merge_fn),
546 label.map(|s| s.to_string()),
547 input_mapping,
548 output_mapping,
549 );
550
551 // Connect to all branch terminals
552 node.dependencies.extend(branch_terminals);
553
554 self.nodes.push(node);
555 self.last_node_id = Some(id);
556
557 // Reset branch point
558 self.last_branch_point = None;
559
560 self
561 }
562
563 /// Build the final DAG from the graph builder
564 ///
565 /// This performs the implicit inspection phase:
566 /// - Full graph traversal
567 /// - Execution path optimization
568 /// - Data flow connection determination
569 /// - Identification of parallelizable operations
570 pub fn build(mut self) -> Dag {
571 // Merge all branch subgraphs into main node list
572 let branches = std::mem::take(&mut self.branches);
573 for (_branch_id, branch) in branches {
574 self.merge_branch(branch);
575 }
576
577 Dag::new(self.nodes)
578 }
579
580 /// Merge a branch builder's nodes into this builder
581 fn merge_branch(&mut self, branch: Graph) {
582 // Create a mapping from old branch IDs to new IDs
583 let mut id_mapping: HashMap<NodeId, NodeId> = HashMap::new();
584
585 // Get the set of existing node IDs in the main graph (before merging)
586 let existing_ids: HashSet<NodeId> = self.nodes.iter().map(|n| n.id).collect();
587
588 // Renumber all nodes from the branch
589 for mut node in branch.nodes {
590 let old_id = node.id;
591 let new_id = self.next_id;
592 self.next_id += 1;
593
594 id_mapping.insert(old_id, new_id);
595 node.id = new_id;
596
597 // Update dependencies with new IDs
598 // Only remap dependencies that were part of the branch (not from main graph)
599 node.dependencies = node.dependencies
600 .iter()
601 .map(|&dep_id| {
602 if existing_ids.contains(&dep_id) {
603 // This dependency is from the main graph, keep it as-is
604 dep_id
605 } else {
606 // This dependency is from the branch, remap it
607 *id_mapping.get(&dep_id).unwrap_or(&dep_id)
608 }
609 })
610 .collect();
611
612 self.nodes.push(node);
613 }
614
615 // Recursively merge nested branches
616 for (_branch_id, nested_branch) in branch.branches {
617 self.merge_branch(nested_branch);
618 }
619 }
620}
621
622impl Default for Graph {
623 fn default() -> Self {
624 Self::new()
625 }
626}