lellm_graph/
parallel_node.rs1use std::sync::Arc;
23
24use crate::error::GraphError;
25use crate::node::{FlowNode, NextStep, NodeOutput};
26use crate::state::State;
27
28#[derive(Clone)]
44pub struct ParallelNode {
45 label: Option<String>,
47 branches: Vec<(String, Arc<dyn FlowNode>)>,
49 error_strategy: ParallelErrorStrategy,
51}
52
53#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
55pub enum ParallelErrorStrategy {
56 #[default]
58 FailFast,
59 CollectAll,
61}
62
63impl ParallelNode {
64 pub fn builder() -> ParallelNodeBuilder {
66 ParallelNodeBuilder::new()
67 }
68
69 pub fn with_label(mut self, label: impl Into<String>) -> Self {
71 self.label = Some(label.into());
72 self
73 }
74
75 pub fn branch_count(&self) -> usize {
77 self.branches.len()
78 }
79
80 pub fn branch_names(&self) -> Vec<&str> {
82 self.branches
83 .iter()
84 .map(|(name, _)| name.as_str())
85 .collect()
86 }
87
88 pub fn branches_iter(&self) -> impl Iterator<Item = (&str, &Arc<dyn FlowNode>)> {
90 self.branches
91 .iter()
92 .map(|(name, node)| (name.as_str(), node))
93 }
94
95 pub fn error_strategy(&self) -> ParallelErrorStrategy {
97 self.error_strategy
98 }
99
100 pub fn label(&self) -> Option<&str> {
102 self.label.as_deref()
103 }
104
105 pub async fn execute_sequential(&self, state: &State) -> Result<NodeOutput, GraphError> {
110 let mut all_deltas = Vec::new();
111
112 for (name, branch) in &self.branches {
113 let output = branch.execute(state).await.map_err(|e| {
114 GraphError::Terminal(crate::error::TerminalError::NodeExecutionFailed {
115 node: format!("{}/{}", self.display_name(), name),
116 source: e.into(),
117 })
118 })?;
119 all_deltas.extend(output.deltas);
120 }
121
122 Ok(NodeOutput {
123 deltas: all_deltas,
124 next: NextStep::GoToNext,
125 metadata: None,
126 })
127 }
128
129 fn display_name(&self) -> String {
130 self.label.clone().unwrap_or_else(|| "parallel".to_string())
131 }
132}
133
134pub struct ParallelNodeBuilder {
136 label: Option<String>,
137 branches: Vec<(String, Arc<dyn FlowNode>)>,
138 error_strategy: ParallelErrorStrategy,
139}
140
141impl ParallelNodeBuilder {
142 fn new() -> Self {
143 Self {
144 label: None,
145 branches: Vec::new(),
146 error_strategy: ParallelErrorStrategy::default(),
147 }
148 }
149
150 pub fn label(mut self, label: impl Into<String>) -> Self {
152 self.label = Some(label.into());
153 self
154 }
155
156 pub fn branch(mut self, name: impl Into<String>, node: Arc<dyn FlowNode>) -> Self {
161 self.branches.push((name.into(), node));
162 self
163 }
164
165 pub fn error_strategy(mut self, strategy: ParallelErrorStrategy) -> Self {
167 self.error_strategy = strategy;
168 self
169 }
170
171 pub fn build(self) -> ParallelNode {
177 if self.branches.is_empty() {
178 panic!("ParallelNode must have at least one branch");
179 }
180 ParallelNode {
181 label: self.label,
182 branches: self.branches,
183 error_strategy: self.error_strategy,
184 }
185 }
186}
187
188impl std::fmt::Debug for ParallelNode {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 f.debug_struct("ParallelNode")
191 .field("label", &self.label)
192 .field(
193 "branches",
194 &self
195 .branches
196 .iter()
197 .map(|(n, _)| n.as_str())
198 .collect::<Vec<_>>(),
199 )
200 .field("error_strategy", &self.error_strategy)
201 .finish()
202 }
203}