graph_sp/
node.rs

1//! Node representation and execution
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6/// Unique identifier for a node
7pub type NodeId = usize;
8
9/// Type alias for node execution functions
10/// Takes broadcast variables and variant parameters as input, returns output variables
11pub type NodeFunction = Arc<dyn Fn(&HashMap<String, String>, &HashMap<String, String>) -> HashMap<String, String> + Send + Sync>;
12
13/// Represents a node in the graph
14#[derive(Clone)]
15pub struct Node {
16    /// Unique identifier
17    pub id: NodeId,
18    /// Optional label for visualization
19    pub label: Option<String>,
20    /// Function to execute
21    pub function: NodeFunction,
22    /// Input mapping: broadcast_var -> impl_var (what the function sees)
23    pub input_mapping: HashMap<String, String>,
24    /// Output mapping: impl_var -> broadcast_var (where function output goes in context)
25    pub output_mapping: HashMap<String, String>,
26    /// Branch ID for branch-specific variable resolution (None for main graph nodes)
27    pub branch_id: Option<usize>,
28    /// Nodes that this node depends on (connected from)
29    pub dependencies: Vec<NodeId>,
30    /// Whether this node is part of a branch
31    pub is_branch: bool,
32    /// Variant index if this is part of a variant sweep
33    pub variant_index: Option<usize>,
34    /// Variant parameters for this node (param_name -> value)
35    pub variant_params: HashMap<String, String>,
36}
37
38impl Node {
39    /// Create a new node
40    pub fn new(
41        id: NodeId,
42        function: NodeFunction,
43        label: Option<String>,
44        input_mapping: HashMap<String, String>,
45        output_mapping: HashMap<String, String>,
46    ) -> Self {
47        Self {
48            id,
49            label,
50            function,
51            input_mapping,
52            output_mapping,
53            branch_id: None,
54            dependencies: Vec::new(),
55            is_branch: false,
56            variant_index: None,
57            variant_params: HashMap::new(),
58        }
59    }
60
61    /// Execute this node with the given context
62    pub fn execute(&self, context: &HashMap<String, String>) -> HashMap<String, String> {
63        // Map broadcast context vars to impl vars using input_mapping
64        // input_mapping: broadcast_var -> impl_var
65        let inputs: HashMap<String, String> = self
66            .input_mapping
67            .iter()
68            .filter_map(|(broadcast_var, impl_var)| {
69                context.get(broadcast_var).map(|val| (impl_var.clone(), val.clone()))
70            })
71            .collect();
72
73        // Execute function with both inputs and variant parameters
74        let func_outputs = (self.function)(&inputs, &self.variant_params);
75        
76        // Map function outputs to broadcast vars using output_mapping
77        // output_mapping: impl_var -> broadcast_var
78        let mut context_outputs = HashMap::new();
79        for (impl_var, broadcast_var) in &self.output_mapping {
80            if let Some(value) = func_outputs.get(impl_var) {
81                context_outputs.insert(broadcast_var.clone(), value.clone());
82            }
83        }
84        
85        context_outputs
86    }
87
88    /// Get display name for this node
89    pub fn display_name(&self) -> String {
90        self.label
91            .as_ref()
92            .map(|l| l.clone())
93            .unwrap_or_else(|| format!("Node {}", self.id))
94    }
95}