use std::sync::Arc;
use crate::error::GraphError;
use crate::node::{FlowNode, NextStep, NodeOutput};
use crate::state::State;
#[derive(Clone)]
pub struct ParallelNode {
label: Option<String>,
branches: Vec<(String, Arc<dyn FlowNode>)>,
error_strategy: ParallelErrorStrategy,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum ParallelErrorStrategy {
#[default]
FailFast,
CollectAll,
}
impl ParallelNode {
pub fn builder() -> ParallelNodeBuilder {
ParallelNodeBuilder::new()
}
pub fn with_label(mut self, label: impl Into<String>) -> Self {
self.label = Some(label.into());
self
}
pub fn branch_count(&self) -> usize {
self.branches.len()
}
pub fn branch_names(&self) -> Vec<&str> {
self.branches
.iter()
.map(|(name, _)| name.as_str())
.collect()
}
pub fn branches_iter(&self) -> impl Iterator<Item = (&str, &Arc<dyn FlowNode>)> {
self.branches
.iter()
.map(|(name, node)| (name.as_str(), node))
}
pub fn error_strategy(&self) -> ParallelErrorStrategy {
self.error_strategy
}
pub fn label(&self) -> Option<&str> {
self.label.as_deref()
}
pub async fn execute_sequential(&self, state: &State) -> Result<NodeOutput, GraphError> {
let mut all_deltas = Vec::new();
for (name, branch) in &self.branches {
let output = branch.execute(state).await.map_err(|e| {
GraphError::Terminal(crate::error::TerminalError::NodeExecutionFailed {
node: format!("{}/{}", self.display_name(), name),
source: e.into(),
})
})?;
all_deltas.extend(output.deltas);
}
Ok(NodeOutput {
deltas: all_deltas,
next: NextStep::GoToNext,
metadata: None,
})
}
fn display_name(&self) -> String {
self.label.clone().unwrap_or_else(|| "parallel".to_string())
}
}
pub struct ParallelNodeBuilder {
label: Option<String>,
branches: Vec<(String, Arc<dyn FlowNode>)>,
error_strategy: ParallelErrorStrategy,
}
impl ParallelNodeBuilder {
fn new() -> Self {
Self {
label: None,
branches: Vec::new(),
error_strategy: ParallelErrorStrategy::default(),
}
}
pub fn label(mut self, label: impl Into<String>) -> Self {
self.label = Some(label.into());
self
}
pub fn branch(mut self, name: impl Into<String>, node: Arc<dyn FlowNode>) -> Self {
self.branches.push((name.into(), node));
self
}
pub fn error_strategy(mut self, strategy: ParallelErrorStrategy) -> Self {
self.error_strategy = strategy;
self
}
pub fn build(self) -> ParallelNode {
if self.branches.is_empty() {
panic!("ParallelNode must have at least one branch");
}
ParallelNode {
label: self.label,
branches: self.branches,
error_strategy: self.error_strategy,
}
}
}
impl std::fmt::Debug for ParallelNode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ParallelNode")
.field("label", &self.label)
.field(
"branches",
&self
.branches
.iter()
.map(|(n, _)| n.as_str())
.collect::<Vec<_>>(),
)
.field("error_strategy", &self.error_strategy)
.finish()
}
}