use std::sync::Arc;
use std::time::Instant;
use crate::error::GraphError;
use crate::event::FlowEvent;
use crate::ids::SpanId;
use crate::node::FlowNode;
use crate::node_context::NodeContext;
use crate::state::{State, StateMerge};
use crate::workflow_state::{MergeStrategy, WorkflowState};
pub struct ParallelNode<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
label: Option<String>,
branches: Vec<(String, Arc<dyn FlowNode<S>>)>,
error_strategy: ParallelErrorStrategy,
_merge_strategy: std::marker::PhantomData<M>,
}
impl<S: WorkflowState, M: MergeStrategy<S>> Clone for ParallelNode<S, M> {
fn clone(&self) -> Self {
Self {
label: self.label.clone(),
branches: self.branches.clone(),
error_strategy: self.error_strategy,
_merge_strategy: std::marker::PhantomData,
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum ParallelErrorStrategy {
#[default]
FailFast,
CollectAll,
}
impl ParallelNode {
pub fn builder() -> ParallelNodeBuilder {
ParallelNodeBuilder::new()
}
}
impl<S: WorkflowState, M: MergeStrategy<S>> ParallelNode<S, M> {
pub fn with_label(mut self, label: impl Into<String>) -> Self {
self.label = Some(label.into());
self
}
pub fn branch_count(&self) -> usize {
self.branches.len()
}
pub fn branch_names(&self) -> Vec<&str> {
self.branches
.iter()
.map(|(name, _)| name.as_str())
.collect()
}
pub fn branches_iter(&self) -> impl Iterator<Item = (&str, &Arc<dyn FlowNode<S>>)> {
self.branches
.iter()
.map(|(name, node)| (name.as_str(), node))
}
pub fn error_strategy(&self) -> ParallelErrorStrategy {
self.error_strategy
}
pub fn label(&self) -> Option<&str> {
self.label.as_deref()
}
fn display_name(&self) -> String {
self.label.clone().unwrap_or_else(|| "parallel".to_string())
}
}
pub struct ParallelNodeBuilder<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
label: Option<String>,
branches: Vec<(String, Arc<dyn FlowNode<S>>)>,
error_strategy: ParallelErrorStrategy,
_phantom: std::marker::PhantomData<M>,
}
impl<S: WorkflowState, M: MergeStrategy<S>> ParallelNodeBuilder<S, M> {
fn new() -> Self {
Self {
label: None,
branches: Vec::new(),
error_strategy: ParallelErrorStrategy::default(),
_phantom: std::marker::PhantomData,
}
}
pub fn label(mut self, label: impl Into<String>) -> Self {
self.label = Some(label.into());
self
}
pub fn branch(mut self, name: impl Into<String>, node: Arc<dyn FlowNode<S>>) -> Self {
self.branches.push((name.into(), node));
self
}
pub fn error_strategy(mut self, strategy: ParallelErrorStrategy) -> Self {
self.error_strategy = strategy;
self
}
pub fn build(self) -> ParallelNode<S, M> {
if self.branches.is_empty() {
panic!("ParallelNode must have at least one branch");
}
ParallelNode {
label: self.label,
branches: self.branches,
error_strategy: self.error_strategy,
_merge_strategy: std::marker::PhantomData,
}
}
pub fn merge_strategy<NM>(self) -> ParallelNodeBuilder<S, NM>
where
NM: MergeStrategy<S>,
{
ParallelNodeBuilder {
label: self.label,
branches: self.branches,
error_strategy: self.error_strategy,
_phantom: std::marker::PhantomData,
}
}
}
pub struct ParallelNodeBuilderWithMerge<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge>(
pub ParallelNodeBuilder<S, M>,
);
impl<S: WorkflowState, M: MergeStrategy<S>> std::fmt::Debug for ParallelNode<S, M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ParallelNode")
.field("label", &self.label)
.field(
"branches",
&self
.branches
.iter()
.map(|(n, _)| n.as_str())
.collect::<Vec<_>>(),
)
.field("error_strategy", &self.error_strategy)
.finish()
}
}
#[async_trait::async_trait]
impl<S: WorkflowState, M: MergeStrategy<S>> FlowNode<S> for ParallelNode<S, M> {
async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError> {
let start_time = Instant::now();
let span_id = SpanId::new();
let branch_count = self.branches.len();
ctx.emit_flow_event(FlowEvent::ParallelStarted {
node_id: self.display_name(),
branch_count,
span_id,
});
let base_state = ctx.state().clone();
let mut branch_results: Vec<S> = Vec::with_capacity(self.branches.len());
for (name, node) in &self.branches {
let branch_start = Instant::now();
let branch_span = SpanId::new();
let mut branch_state = base_state.clone();
let mut branch_bs = ctx.branch().fork();
let mut branch_ctx = NodeContext::new(&mut branch_state, &mut branch_bs, None);
let result = node.execute(&mut branch_ctx).await.map_err(|e| {
GraphError::Terminal(crate::error::TerminalError::NodeExecutionFailed {
node: format!("{}/{}", self.display_name(), name),
source: e.into(),
})
});
let effects = branch_ctx.consume_effects();
branch_state.apply_batch(effects);
let branch_duration = branch_start.elapsed();
let success = result.is_ok();
ctx.emit_flow_event(FlowEvent::BranchCompleted {
branch_name: name.clone(),
node_id: self.display_name(),
span_id: branch_span,
success,
duration: branch_duration,
});
if !success {
return result;
}
branch_results.push(branch_state);
}
let merged = M::merge(branch_results).map_err(|e| {
GraphError::Terminal(crate::error::TerminalError::StateError(format!(
"parallel merge conflict: {e}",
)))
})?;
*ctx.state_mut() = merged;
ctx.emit_flow_event(FlowEvent::ParallelCompleted {
node_id: self.display_name(),
span_id,
duration: start_time.elapsed(),
});
Ok(())
}
}