1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
//! Graph builder with implicit connections API
use crate::dag::Dag;
use crate::distribution::DistTransferFn;
use crate::graph_data::GraphData;
use crate::node::{Node, NodeId};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
/// Graph builder for constructing graphs with implicit node connections
pub struct Graph {
/// All nodes in the graph
nodes: Vec<Node>,
/// Counter for generating unique node IDs
next_id: NodeId,
/// Current frontier node IDs (active attach points)
frontier: Vec<NodeId>,
/// Track the last branch points for sequential `.branch()` calls (copies of `frontier`)
last_branch_point: Option<Vec<NodeId>>,
/// Subgraph builders for branches with their IDs
branches: Vec<(usize, Graph)>,
/// Next branch ID counter
next_branch_id: usize,
/// Track nodes that should be merged together
merge_targets: Vec<NodeId>,
/// Pending dist_transfers to be applied to nodes by label at `build()` time.
/// label -> DistTransferFn
dist_transfers: HashMap<String, DistTransferFn>,
}
impl Graph {
/// Create a new graph
pub fn new() -> Self {
Self {
nodes: Vec::new(),
next_id: 0,
frontier: Vec::new(),
last_branch_point: None,
branches: Vec::new(),
next_branch_id: 1,
merge_targets: Vec::new(),
dist_transfers: HashMap::new(),
}
}
/// Get a unique branch ID for tracking branches
fn get_branch_id(&mut self) -> usize {
let id = self.next_branch_id;
self.next_branch_id += 1;
id
}
/// Add a node to the graph with implicit connections
///
/// # Arguments
///
/// * `function` - The function to execute for this node (automatically wrapped in Arc)
/// * `label` - Optional label for visualization
/// * `inputs` - Optional list of (broadcast_var, impl_var) tuples for inputs
/// * `outputs` - Optional list of (impl_var, broadcast_var) tuples for outputs
///
/// # Implicit Connection Behavior
///
/// - The first node added has no dependencies
/// - Subsequent nodes automatically depend on the previous node
/// - This creates a natural sequential flow unless `.branch()` is used
///
/// # Function Signature
///
/// Functions receive a single parameter:
/// - `inputs: &HashMap<String, GraphData>` - Mapped input variables (impl_var names)
///
/// Functions return outputs using impl_var names, which get mapped to broadcast_var names.
///
/// # Example
///
/// ```ignore
/// // Function sees "input_data", context has "data"
/// // Function returns "output_value", gets stored as "result" in context
/// graph.add(
/// process_fn,
/// Some("Process"),
/// Some(vec![("data", "input_data")]), // (broadcast, impl)
/// Some(vec![("output_value", "result")]) // (impl, broadcast)
/// );
/// ```
pub fn add<F>(
&mut self,
function: F,
label: Option<&str>,
inputs: Option<Vec<(&str, &str)>>,
outputs: Option<Vec<(&str, &str)>>,
) -> &mut Self
where
F: Fn(&HashMap<String, GraphData>) -> HashMap<String, GraphData>
+ Send
+ Sync
+ 'static,
{
// Build input_mapping: broadcast_var -> impl_var
let input_mapping: HashMap<String, String> = inputs
.unwrap_or_default()
.iter()
.map(|(broadcast, impl_var)| (broadcast.to_string(), impl_var.to_string()))
.collect();
// Build output_mapping: impl_var -> broadcast_var
let output_mapping: HashMap<String, String> = outputs
.unwrap_or_default()
.iter()
.map(|(impl_var, broadcast)| (impl_var.to_string(), broadcast.to_string()))
.collect();
// Determine parents for replication: if frontier is empty, we create a single node
let parents: Vec<Option<NodeId>> = if self.frontier.is_empty() {
vec![None]
} else {
self.frontier.iter().map(|&id| Some(id)).collect()
};
let mut created_ids: Vec<NodeId> = Vec::new();
// Automatically wrap the function in Arc for thread-safe sharing
let func_arc: crate::node::NodeFunction = Arc::new(function);
for _parent in parents {
let id = self.next_id;
self.next_id += 1;
let mut node = Node::new(
id,
Arc::clone(&func_arc),
label.map(|s| s.to_string()),
input_mapping.clone(),
output_mapping.clone(),
);
// Connect to merge targets if present
// For branch operations, we still use explicit parent connections
if !self.merge_targets.is_empty() {
node.dependencies.extend(self.merge_targets.iter().copied());
self.merge_targets.clear();
}
// Note: We no longer automatically add frontier dependencies here
// Dependencies will be resolved based on data flow in build()
self.nodes.push(node);
created_ids.push(id);
}
// Update frontier to the newly created node(s)
self.frontier = created_ids;
// Reset branch point after adding a regular node
self.last_branch_point = None;
self
}
/// Insert a branching subgraph
///
/// # Implicit Branching Behavior
///
/// - Sequential `.branch()` calls without `.add()` between them implicitly
/// branch from the same node
/// - This allows creating multiple parallel execution paths easily
///
/// # Arguments
///
/// * `subgraph` - A configured Graph representing the branch
///
/// # Returns
///
/// Returns the branch ID for use in merge operations
pub fn branch(&mut self, mut subgraph: Graph) -> usize {
// Assign a branch ID to this subgraph (shared for all replicates)
let branch_id = self.get_branch_id();
// Apply any pending dist_transfers from the subgraph to its own nodes before copying
for node in &mut subgraph.nodes {
if let Some(label) = &node.label {
if let Some(transfer) = subgraph.dist_transfers.get(label.as_str()) {
node.dist_transfer = Some(Arc::clone(transfer));
}
}
}
// Determine the branch points (could be multiple - frontier / last_branch_point)
let branch_points: Vec<NodeId> = if let Some(bp_vec) = self.last_branch_point.clone() {
// Sequential .branch() calls - use the same branch point(s)
bp_vec
} else if !self.frontier.is_empty() {
// Branch from current frontier nodes
let v = self.frontier.clone();
self.last_branch_point = Some(v.clone());
v
} else {
// No previous node, subgraph starts independently
self.branches.push((branch_id, subgraph));
return branch_id;
};
// For each branch point, append a cloned copy of the subgraph and attach to the branch point
for bp in branch_points.iter() {
// Map old node ids to new ids
let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
for node in &subgraph.nodes {
let new_id = self.next_id;
self.next_id += 1;
id_map.insert(node.id, new_id);
}
// Clone nodes with remapped ids and dependencies
for node in &subgraph.nodes {
let new_id = *id_map.get(&node.id).unwrap();
let mut new_node = Node::new(
new_id,
node.function.clone(),
node.label.clone(),
node.input_mapping.clone(),
node.output_mapping.clone(),
);
// Remap dependencies: if dependency is internal to subgraph, map it; otherwise, attach to branch point
for &dep in &node.dependencies {
if let Some(&mapped) = id_map.get(&dep) {
new_node.dependencies.push(mapped);
}
}
// Ensure first node attaches to the branch point
if node.dependencies.is_empty() {
new_node.dependencies.push(*bp);
}
new_node.is_branch = true;
new_node.branch_id = Some(branch_id);
// Preserve the dist_transfer from the source node
new_node.dist_transfer = node.dist_transfer.clone();
self.nodes.push(new_node);
}
}
// Store original subgraph as metadata under the branch ID for reference
self.branches.push((branch_id, subgraph));
branch_id
}
/// Create variant nodes from a vector of closures
///
/// Takes a vector of closures, each representing a variant of the computation.
/// Functions are automatically wrapped in Arc for thread-safe sharing.
///
/// # Arguments
///
/// * `functions` - Vector of node functions (closures) - automatically wrapped in Arc
/// * `label` - Optional label for visualization (default: None)
/// * `inputs` - Optional list of (broadcast_var, impl_var) tuples for inputs
/// * `outputs` - Optional list of (impl_var, broadcast_var) tuples for outputs
///
/// # Example
///
/// ```ignore
/// let factors = vec![2.0, 3.0, 5.0];
/// graph.variants(
/// factors.iter().map(|&factor| {
/// move |inputs: &HashMap<String, GraphData>| {
/// let mut outputs = HashMap::new();
/// if let Some(val) = inputs.get("x").and_then(|d| d.as_float()) {
/// outputs.insert("scaled".to_string(), GraphData::float(val * factor));
/// }
/// outputs
/// }
/// }).collect(),
/// Some("Scale"),
/// Some(vec![("data", "x")]),
/// Some(vec![("scaled", "result")])
/// );
/// ```
pub fn variants<F>(
&mut self,
functions: Vec<F>,
label: Option<&str>,
inputs: Option<Vec<(&str, &str)>>,
outputs: Option<Vec<(&str, &str)>>,
) -> &mut Self
where
F: Fn(&HashMap<String, GraphData>) -> HashMap<String, GraphData>
+ Send
+ Sync
+ 'static,
{
// Determine parent attach points (frontier). If frontier is empty, treat as a single None parent
let parents: Vec<Option<NodeId>> = if self.frontier.is_empty() {
vec![None]
} else {
self.frontier.iter().map(|&id| Some(id)).collect()
};
// Remember previous frontier as branch point for sequential .branch() calls
let previous_frontier = if self.frontier.is_empty() {
None
} else {
Some(self.frontier.clone())
};
// Prepare mappings
let input_mapping: HashMap<String, String> = inputs
.unwrap_or_default()
.iter()
.map(|(broadcast, impl_var)| (broadcast.to_string(), impl_var.to_string()))
.collect();
let output_mapping: HashMap<String, String> = outputs
.unwrap_or_default()
.iter()
.map(|(impl_var, broadcast)| (impl_var.to_string(), broadcast.to_string()))
.collect();
let mut created_ids: Vec<NodeId> = Vec::new();
for (idx, node_fn) in functions.into_iter().enumerate() {
// Automatically wrap each function in Arc and cast to trait object
let node_fn_arc: crate::node::NodeFunction = Arc::new(node_fn);
for parent in &parents {
let id = self.next_id;
self.next_id += 1;
let mut node = Node::new(
id,
Arc::clone(&node_fn_arc),
label.map(|s| format!("{} (v{})", s, idx)),
input_mapping.clone(),
output_mapping.clone(),
);
node.variant_index = Some(idx);
if !self.merge_targets.is_empty() {
node.dependencies.extend(self.merge_targets.iter().copied());
self.merge_targets.clear();
} else if let Some(pid) = *parent {
node.dependencies.push(pid);
node.is_branch = true;
}
self.nodes.push(node);
created_ids.push(id);
}
}
// New frontier is the set of created nodes
self.frontier = created_ids;
// Set last_branch_point to previous frontier (if any) for sequential .branch() calls
self.last_branch_point = previous_frontier;
self
}
/// Merge multiple branches back together with a merge function
///
/// After branching, use `.merge()` to bring parallel paths back to a single point.
/// The merge function receives outputs from all specified branches and combines them.
///
/// Branch nodes write their outputs with branch-prefixed keys in the global context
/// to avoid conflicts, allowing merge nodes to correctly retrieve branch-specific values.
///
/// # Arguments
///
/// * `merge_fn` - Function that combines outputs from all branches
/// * `label` - Optional label for visualization
/// * `inputs` - List of (branch_id, broadcast_var, impl_var) tuples specifying which branch outputs to merge
/// * `outputs` - Optional list of (impl_var, broadcast_var) tuples for outputs
///
/// # Example
///
/// ```ignore
/// graph.add(source_fn, Some("Source"), None, Some(vec![("src_out", "data")]));
///
/// let mut branch_a = Graph::new();
/// branch_a.add(process_a, Some("Process A"), Some(vec![("data", "input")]), Some(vec![("output", "result")]));
///
/// let mut branch_b = Graph::new();
/// branch_b.add(process_b, Some("Process B"), Some(vec![("data", "input")]), Some(vec![("output", "result")]));
///
/// let branch_a_id = graph.branch(branch_a);
/// let branch_b_id = graph.branch(branch_b);
///
/// // Merge function combines results from both branches
/// // Branches can use same output name "result", merge maps them distinctly
/// graph.merge(
/// combine_fn,
/// Some("Combine"),
/// vec![
/// (branch_a_id, "result", "a_result"), // (branch, broadcast, impl)
/// (branch_b_id, "result", "b_result")
/// ],
/// Some(vec![("combined", "final")]) // (impl, broadcast)
/// );
/// ```
pub fn merge<F>(
&mut self,
merge_fn: F,
label: Option<&str>,
inputs: Vec<(usize, &str, &str)>,
outputs: Option<Vec<(&str, &str)>>,
) -> &mut Self
where
F: Fn(&HashMap<String, GraphData>) -> HashMap<String, GraphData>
+ Send
+ Sync
+ 'static,
{
// First, integrate all pending branches into the main graph
let branches = std::mem::take(&mut self.branches);
let mut branch_terminals = Vec::new();
for (_branch_id, branch) in branches {
let terminals = self.merge_branch(branch);
branch_terminals.extend(terminals);
}
// Create the merge node
let id = self.next_id;
self.next_id += 1;
// Build input_mapping with branch-specific resolution
// For merge, we need special handling: (branch_id, broadcast_var) -> impl_var
// This will be handled in execution by looking at branch_id field of dependency nodes
let input_mapping: HashMap<String, String> = inputs
.iter()
.map(|(branch_id, broadcast_var, impl_var)| {
// Store as "branch_id:broadcast_var" -> impl_var for unique identification
(
format!("{}:{}", branch_id, broadcast_var),
impl_var.to_string(),
)
})
.collect();
// Build output_mapping: impl_var -> broadcast_var
let output_mapping: HashMap<String, String> = outputs
.unwrap_or_default()
.iter()
.map(|(impl_var, broadcast)| (impl_var.to_string(), broadcast.to_string()))
.collect();
let mut node = Node::new(
id,
Arc::new(merge_fn),
label.map(|s| s.to_string()),
input_mapping,
output_mapping,
);
// Connect to all branch terminals
node.dependencies.extend(branch_terminals);
self.nodes.push(node);
// Update frontier to the merge node
self.frontier = vec![id];
// Reset branch point
self.last_branch_point = None;
self
}
/// Attach an analytical distribution transfer to all nodes with the given label.
///
/// The transfer function receives distributions keyed by **impl_var** names (the same
/// names the node function sees) and returns distributions keyed by **impl_var** output
/// names, or `None` to signal that Monte Carlo fallback should be used for this node.
///
/// This is optional — nodes without a dist_transfer automatically fall back to Monte
/// Carlo sampling through their deterministic function when `Dag::predict()` is called.
pub fn set_dist_transfer_for(&mut self, label: &str, transfer: DistTransferFn) -> &mut Self {
self.dist_transfers.insert(label.to_string(), transfer);
self
}
/// Build the final DAG from the graph builder
///
/// This performs the implicit inspection phase:
/// - Full graph traversal
/// - Execution path optimization
/// - Data flow connection determination
/// - Identification of parallelizable operations
pub fn build(mut self) -> Dag {
// Merge all branch subgraphs into main node list
let branches = std::mem::take(&mut self.branches);
for (_branch_id, branch) in branches {
self.merge_branch(branch);
}
// Resolve data dependencies based on input/output mappings
self.resolve_data_dependencies();
// Apply pending dist_transfers to all matching nodes (by label)
let dist_transfers = std::mem::take(&mut self.dist_transfers);
for node in &mut self.nodes {
if node.dist_transfer.is_some() {
// Already set (e.g. copied from a subgraph) — don't overwrite
continue;
}
if let Some(label) = &node.label {
if let Some(transfer) = dist_transfers.get(label.as_str()) {
node.dist_transfer = Some(Arc::clone(transfer));
}
}
}
Dag::new(self.nodes)
}
/// Resolve dependencies based on data flow (input/output mappings)
///
/// For each node, determine which other nodes it depends on by finding
/// nodes that produce the broadcast variables it consumes.
fn resolve_data_dependencies(&mut self) {
// Build a map of which nodes produce which broadcast variables
let mut producers: HashMap<String, Vec<NodeId>> = HashMap::new();
for node in &self.nodes {
for broadcast_var in node.output_mapping.values() {
producers.entry(broadcast_var.clone())
.or_default()
.push(node.id);
}
}
// For each node, find its dependencies based on required inputs
for i in 0..self.nodes.len() {
let node = &self.nodes[i];
let required_inputs: Vec<String> = node.input_mapping.keys().cloned().collect();
let node_id = node.id;
let mut dependencies: HashSet<NodeId> = HashSet::new();
// Keep any existing dependencies (from merge_targets or branches)
dependencies.extend(node.dependencies.iter().copied());
// Add dependencies based on data flow
for input_key in &required_inputs {
// Handle merge node special format: "branch_id:broadcast_var"
let broadcast_var = if input_key.contains(':') {
// For merge nodes, extract the broadcast_var part after ':'
input_key.split(':').nth(1).unwrap_or(input_key.as_str())
} else {
input_key.as_str()
};
if let Some(producer_ids) = producers.get(broadcast_var) {
for &producer_id in producer_ids {
// Don't depend on ourselves
if producer_id != node_id {
dependencies.insert(producer_id);
}
}
}
}
// Update the node's dependencies
self.nodes[i].dependencies = dependencies.into_iter().collect();
}
}
/// Merge a branch builder's nodes into this builder
fn merge_branch(&mut self, branch: Graph) -> Vec<NodeId> {
// Determine terminal nodes in the branch (nodes that are not dependencies of any other node within the branch)
let branch_deps: HashSet<NodeId> = branch
.nodes
.iter()
.flat_map(|n| n.dependencies.iter().copied())
.collect();
let terminal_old_ids: Vec<NodeId> = branch
.nodes
.iter()
.filter(|n| !branch_deps.contains(&n.id))
.map(|n| n.id)
.collect();
// Create a mapping from old branch IDs to new IDs
let mut id_mapping: HashMap<NodeId, NodeId> = HashMap::new();
// Get the set of existing node IDs in the main graph (before merging)
let existing_ids: HashSet<NodeId> = self.nodes.iter().map(|n| n.id).collect();
// Renumber all nodes from the branch
for mut node in branch.nodes {
let old_id = node.id;
let new_id = self.next_id;
self.next_id += 1;
id_mapping.insert(old_id, new_id);
node.id = new_id;
// Update dependencies with new IDs
// Only remap dependencies that were part of the branch (not from main graph)
node.dependencies = node
.dependencies
.iter()
.map(|&dep_id| {
if existing_ids.contains(&dep_id) {
// This dependency is from the main graph, keep it as-is
dep_id
} else {
// This dependency is from the branch, remap it
*id_mapping.get(&dep_id).unwrap_or(&dep_id)
}
})
.collect();
self.nodes.push(node);
}
// Recursively merge nested branches and collect their terminals as well
let mut terminals: Vec<NodeId> = terminal_old_ids
.iter()
.filter_map(|old| id_mapping.get(old).copied())
.collect();
for (_branch_id, nested_branch) in branch.branches {
let nested_terminals = self.merge_branch(nested_branch);
terminals.extend(nested_terminals);
}
terminals
}
}
impl Default for Graph {
fn default() -> Self {
Self::new()
}
}