use std::sync::Arc;
use std::time::Instant;
use crate::error::GraphError;
use crate::event::FlowEvent;
use crate::execution_engine::{ExecutionEngine, ExecutorState};
use crate::ids::SpanId;
use crate::node::{ExecutorOperation, FlowNode};
use crate::state::{State, StateMerge};
use crate::workflow_state::{MergeStrategy, WorkflowState};
pub struct ParallelNode<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
label: Option<String>,
branches: Vec<(String, Arc<dyn FlowNode<S>>)>,
error_strategy: ParallelErrorStrategy,
_merge_strategy: std::marker::PhantomData<M>,
}
impl<S: WorkflowState, M: MergeStrategy<S>> Clone for ParallelNode<S, M> {
fn clone(&self) -> Self {
Self {
label: self.label.clone(),
branches: self.branches.clone(),
error_strategy: self.error_strategy,
_merge_strategy: std::marker::PhantomData,
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum ParallelErrorStrategy {
#[default]
FailFast,
CollectAll,
}
impl ParallelNode {
pub fn builder() -> ParallelNodeBuilder {
ParallelNodeBuilder::new()
}
}
impl<S: WorkflowState, M: MergeStrategy<S>> ParallelNode<S, M> {
pub fn with_label(mut self, label: impl Into<String>) -> Self {
self.label = Some(label.into());
self
}
pub fn branch_count(&self) -> usize {
self.branches.len()
}
pub fn branch_names(&self) -> Vec<&str> {
self.branches
.iter()
.map(|(name, _)| name.as_str())
.collect()
}
pub fn branches_iter(&self) -> impl Iterator<Item = (&str, &Arc<dyn FlowNode<S>>)> {
self.branches
.iter()
.map(|(name, node)| (name.as_str(), node))
}
pub fn error_strategy(&self) -> ParallelErrorStrategy {
self.error_strategy
}
pub fn label(&self) -> Option<&str> {
self.label.as_deref()
}
fn display_name(&self) -> String {
self.label.clone().unwrap_or_else(|| "parallel".to_string())
}
}
pub struct ParallelNodeBuilder<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
label: Option<String>,
branches: Vec<(String, Arc<dyn FlowNode<S>>)>,
error_strategy: ParallelErrorStrategy,
_phantom: std::marker::PhantomData<M>,
}
impl<S: WorkflowState, M: MergeStrategy<S>> ParallelNodeBuilder<S, M> {
fn new() -> Self {
Self {
label: None,
branches: Vec::new(),
error_strategy: ParallelErrorStrategy::default(),
_phantom: std::marker::PhantomData,
}
}
pub fn label(mut self, label: impl Into<String>) -> Self {
self.label = Some(label.into());
self
}
pub fn branch(mut self, name: impl Into<String>, node: Arc<dyn FlowNode<S>>) -> Self {
self.branches.push((name.into(), node));
self
}
pub fn error_strategy(mut self, strategy: ParallelErrorStrategy) -> Self {
self.error_strategy = strategy;
self
}
pub fn build(self) -> ParallelNode<S, M> {
if self.branches.is_empty() {
panic!("ParallelNode must have at least one branch");
}
ParallelNode {
label: self.label,
branches: self.branches,
error_strategy: self.error_strategy,
_merge_strategy: std::marker::PhantomData,
}
}
pub fn merge_strategy<NM>(self) -> ParallelNodeBuilder<S, NM>
where
NM: MergeStrategy<S>,
{
ParallelNodeBuilder {
label: self.label,
branches: self.branches,
error_strategy: self.error_strategy,
_phantom: std::marker::PhantomData,
}
}
}
impl<S: WorkflowState, M: MergeStrategy<S>> std::fmt::Debug for ParallelNode<S, M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ParallelNode")
.field("label", &self.label)
.field(
"branches",
&self
.branches
.iter()
.map(|(n, _)| n.as_str())
.collect::<Vec<_>>(),
)
.field("error_strategy", &self.error_strategy)
.finish()
}
}
#[async_trait::async_trait]
impl<S: WorkflowState + Clone + Send + Sync, M: MergeStrategy<S>> ExecutorOperation<S>
for ParallelNode<S, M>
{
async fn execute(&self, engine: &mut ExecutionEngine<S>) -> Result<(), GraphError> {
let start_time = Instant::now();
let span_id = SpanId::new();
let branch_count = self.branches.len();
let display_name = self.display_name();
engine.emit_flow_event(FlowEvent::ParallelStarted {
node_id: display_name.clone(),
branch_count,
span_id,
});
let base_state = engine.clone_state();
let parent_cancel = engine.cancel_token().clone();
let parent_stream = engine.stream_sink();
let branches: Vec<(String, Arc<dyn crate::node::FlowNode<S>>)> = self
.branches
.iter()
.map(|(n, nd)| (n.clone(), nd.clone()))
.collect();
let branch_futures: Vec<_> = branches
.into_iter()
.map(|(branch_name, node)| {
let state = base_state.clone();
let child_cancel = parent_cancel.child_token();
let child_stream = parent_stream.clone();
async move {
let branch_start = Instant::now();
let mut child_engine = ExecutionEngine::new(state, child_stream, child_cancel);
let mut branch_ctx = child_engine.build_node_context();
let ok = node.execute(&mut branch_ctx).await.is_ok();
drop(branch_ctx);
if !ok {
return (branch_name, Err("branch execution failed".into()));
}
child_engine.commit();
let duration = branch_start.elapsed();
(branch_name, Ok((child_engine.into_state(), duration)))
}
})
.collect();
let raw_results: Vec<(String, Result<(S, std::time::Duration), String>)> =
futures::future::join_all(branch_futures).await;
let mut branch_states: Vec<S> = Vec::with_capacity(branch_count);
let mut errors: Vec<(String, String)> = Vec::new();
for (branch_name, result) in raw_results {
match result {
Ok((state, branch_duration)) => {
engine.emit_flow_event(FlowEvent::BranchCompleted {
branch_name,
node_id: display_name.clone(),
span_id: SpanId::new(),
success: true,
duration: branch_duration,
});
branch_states.push(state);
}
Err(reason) => {
errors.push((branch_name, reason));
}
}
}
if !errors.is_empty() {
match self.error_strategy {
ParallelErrorStrategy::FailFast => {
let (name, reason) = &errors[0];
return Err(GraphError::Terminal(
crate::error::TerminalError::NodeExecutionFailed {
node: format!("{}/{}", display_name, name),
source: reason.clone().into(),
},
));
}
ParallelErrorStrategy::CollectAll => {
if !branch_states.is_empty() {
for (name, reason) in &errors {
tracing::warn!(
parallel = %display_name,
branch = %name,
error = %reason,
"branch failed (CollectAll strategy)"
);
}
}
let (name, reason) = &errors[0];
return Err(GraphError::Terminal(
crate::error::TerminalError::NodeExecutionFailed {
node: format!("{}/{}", display_name, name),
source: reason.clone().into(),
},
));
}
}
}
let merged = M::merge(branch_states).map_err(|e| {
GraphError::Terminal(crate::error::TerminalError::StateError(format!(
"parallel merge conflict: {e}",
)))
})?;
engine.replace_state(merged);
engine.emit_flow_event(FlowEvent::ParallelCompleted {
node_id: display_name,
span_id,
duration: start_time.elapsed(),
});
Ok(())
}
}