use crate::distribution::DistTransferFn;
use crate::graph_data::GraphData;
use std::collections::HashMap;
use std::sync::Arc;
pub type NodeId = usize;
pub type NodeFunction = Arc<
dyn Fn(&HashMap<String, GraphData>) -> HashMap<String, GraphData>
+ Send
+ Sync,
>;
#[derive(Clone)]
pub struct Node {
pub id: NodeId,
pub label: Option<String>,
pub function: NodeFunction,
pub input_mapping: HashMap<String, String>,
pub output_mapping: HashMap<String, String>,
pub branch_id: Option<usize>,
pub dependencies: Vec<NodeId>,
pub is_branch: bool,
pub variant_index: Option<usize>,
pub variant_params: HashMap<String, GraphData>,
pub dist_transfer: Option<DistTransferFn>,
}
impl Node {
pub fn new(
id: NodeId,
function: NodeFunction,
label: Option<String>,
input_mapping: HashMap<String, String>,
output_mapping: HashMap<String, String>,
) -> Self {
Self {
id,
label,
function,
input_mapping,
output_mapping,
branch_id: None,
dependencies: Vec::new(),
is_branch: false,
variant_index: None,
variant_params: HashMap::new(),
dist_transfer: None,
}
}
pub fn execute(&self, context: &HashMap<String, GraphData>) -> HashMap<String, GraphData> {
let inputs: HashMap<String, GraphData> = self
.input_mapping
.iter()
.filter_map(|(broadcast_key, impl_var)| {
if broadcast_key.contains(':') {
let parts: Vec<&str> = broadcast_key.split(':').collect();
if parts.len() == 2 {
let prefixed_key = format!("__branch_{}__{}", parts[0], parts[1]);
context
.get(&prefixed_key)
.map(|val| (impl_var.clone(), val.clone()))
} else {
None
}
} else {
context
.get(broadcast_key)
.map(|val| (impl_var.clone(), val.clone()))
}
})
.collect();
let func_outputs = (self.function)(&inputs);
let mut context_outputs = HashMap::new();
for (impl_var, broadcast_var) in &self.output_mapping {
if let Some(value) = func_outputs.get(impl_var) {
context_outputs.insert(broadcast_var.clone(), value.clone());
}
}
context_outputs
}
pub fn display_name(&self) -> String {
self.label
.clone()
.unwrap_or_else(|| format!("Node {}", self.id))
}
}