1use std::fmt::Debug;
2
3use super::Builder;
4use super::chain_run::ChainRunParallel as ChainRun;
5use crate::{
6 context::{Fork, Join},
7 describe::{Description, DescriptionBase, Edge, Type, remove_generics_from_name},
8 flows::{
9 NodeResult, chain_debug::ChainDebug, chain_describe::ChainDescribe, parallel_flow::Joiner,
10 },
11 node::{Node, NodeOutput as NodeOutputStruct},
12};
13
14pub struct ParallelFlow<
84 Input,
85 Output,
86 Error,
87 Context,
88 ChainOutput = (),
89 Joiner = (),
90 NodeTypes = (),
91 NodeIOETypes = (),
92> {
93 #[expect(clippy::type_complexity)]
94 pub(super) _ioec: std::marker::PhantomData<fn() -> (Input, Output, Error, Context)>,
95 pub(super) _nodes_io: std::marker::PhantomData<fn() -> NodeIOETypes>,
96 pub(super) nodes: NodeTypes,
97 pub(super) _joiner_input: std::marker::PhantomData<fn() -> ChainOutput>,
98 pub(super) joiner: Joiner,
99}
100
101impl<Input, Output, Error, Context> ParallelFlow<Input, Output, Error, Context>
102where
103 Input: Send + Clone,
105 Error: Send,
106 Context: Fork + Join + Send,
107{
108 #[must_use]
124 pub fn builder() -> Builder<Input, Output, Error, Context> {
125 Builder::new()
126 }
127}
128
129impl<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes> Clone
130 for ParallelFlow<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes>
131where
132 J: Clone,
133 NodeTypes: Clone,
134{
135 fn clone(&self) -> Self {
136 Self {
137 _ioec: std::marker::PhantomData,
138 _nodes_io: std::marker::PhantomData,
139 nodes: self.nodes.clone(),
140 _joiner_input: std::marker::PhantomData,
141 joiner: self.joiner.clone(),
142 }
143 }
144}
145
146impl<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes> Debug
147 for ParallelFlow<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes>
148where
149 NodeTypes: ChainDebug,
150{
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 f.debug_struct("ParallelFlow")
153 .field("nodes", &self.nodes.as_list())
154 .finish_non_exhaustive()
155 }
156}
157
158#[inline(always)]
160#[expect(clippy::inline_always)]
161fn call_joiner<'a, J, I, O, E, Ctx>(
162 j: &J,
163 i: I,
164 s: &'a mut Ctx,
165) -> impl Future<Output = NodeResult<O, E>>
166where
167 J: Joiner<'a, I, O, E, Ctx> + 'a,
168{
169 j.join(i, s)
170}
171
172impl<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes>
173 Node<Input, NodeOutputStruct<Output>, Error, Context>
174 for ParallelFlow<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes>
175where
176 Input: Send,
177 Context: Send,
178 for<'a> J: Joiner<'a, ChainRunOutput, Output, Error, Context>,
179 NodeTypes: ChainRun<Input, Result<ChainRunOutput, Error>, Context, NodeIOETypes>
180 + ChainDescribe<Context, NodeIOETypes>
181 + Send
182 + Sync,
183{
184 fn run(
185 &mut self,
186 input: Input,
187 context: &mut Context,
188 ) -> impl Future<Output = NodeResult<Output, Error>> + Send {
189 let nodes = &self.nodes;
190 let joiner = &self.joiner;
191 async move {
192 let fut = nodes.run(input, context);
193 let res = fut.await?;
194 call_joiner::<J, ChainRunOutput, Output, Error, Context>(joiner, res, context).await
196 }
197 }
198
199 fn describe(&self) -> Description {
200 let node_count = <NodeTypes as ChainDescribe<Context, NodeIOETypes>>::COUNT;
201 let mut node_descriptions = Vec::with_capacity(node_count + 1);
202 self.nodes.describe(&mut node_descriptions);
203
204 node_descriptions.push(Description::Node {
205 base: DescriptionBase {
206 r#type: Type {
207 name: "Joiner".to_owned(),
208 },
209 input: Type {
210 name: String::new(),
211 },
212 output: Type {
213 name: String::new(),
214 },
215 error: Type {
216 name: String::new(),
217 },
218 context: Type {
219 name: String::new(),
220 },
221 description: None,
222 externals: None,
223 },
224 });
225
226 let mut edges = Vec::with_capacity(node_count * 2 + 1);
227 for i in 0..node_count {
228 edges.push(Edge::flow_to_node(i));
229 edges.push(Edge::node_to_node(i, node_count));
230 }
231 edges.push(Edge::node_to_flow(node_count));
232
233 Description::new_flow(self, node_descriptions, edges).modify_name(remove_generics_from_name)
234 }
235}
236
237#[cfg(test)]
238mod test {
239 use super::{ChainRun, ParallelFlow as Flow};
240 use crate::{
241 context::storage::local_storage::{LocalStorage, LocalStorageImpl, tests::MyVal},
242 flows::tests::{InsertIntoStorageAssertWasNotInStorage, Passer, SoftFailNode},
243 node::{Node, NodeOutput},
244 };
245
246 #[tokio::test]
247 async fn test_flow() {
248 let mut st = LocalStorageImpl::new();
249 let mut flow = Flow::<u8, u64, (), _>::builder()
250 .add_node(Passer::<u16, u64, ()>::new())
251 .add_node(SoftFailNode::<u16, u32, ()>::new())
252 .add_node(Passer::<u16, u32, ()>::new())
253 .build(async |input, context: &mut LocalStorageImpl| {
254 context.insert(MyVal::default());
255 assert_eq!(
256 input,
257 (
258 ((NodeOutput::Ok(0u64),), NodeOutput::SoftFail),
259 NodeOutput::Ok(0u32)
260 )
261 );
262 Ok(NodeOutput::Ok(120))
263 });
264 let res = flow.run(0, &mut st).await;
265
266 assert_eq!(res, Result::Ok(NodeOutput::Ok(120)));
267 }
268
269 #[tokio::test]
270 async fn test_chain() {
271 let mut st = LocalStorageImpl::new();
272 let node = (
273 (
274 (Passer::<u16, u64, ()>::new(),),
275 SoftFailNode::<u16, u32, ()>::new(),
276 ),
277 Passer::<u16, u32, ()>::new(),
278 );
279 let res: Result<_, ()> = ChainRun::<u8, _, _, _>::run(&node, 0u8, &mut st).await;
280 assert_eq!(
281 res,
282 Ok((
283 ((NodeOutput::Ok(0u64),), NodeOutput::SoftFail),
284 NodeOutput::Ok(0u32)
285 ))
286 );
287 }
288
289 #[tokio::test]
290 async fn test_flow_storage() {
291 let mut st = LocalStorageImpl::new();
292 let mut flow = Flow::<u8, u64, (), _>::builder()
293 .add_node(InsertIntoStorageAssertWasNotInStorage::<u16, u32, (), MyVal>::new())
294 .add_node(Passer::<u16, u64, ()>::new())
295 .add_node(InsertIntoStorageAssertWasNotInStorage::<u8, u16, (), MyVal>::new())
296 .add_node(InsertIntoStorageAssertWasNotInStorage::<u32, u64, (), MyVal>::new())
297 .add_node(Passer::<u16, u32, ()>::new())
298 .build(async |input, context: &mut LocalStorageImpl| {
299 let merged_orig = context.insert(MyVal::default());
300 assert_eq!(merged_orig, Some(MyVal("|||".to_owned())));
301 assert_eq!(
302 input,
303 (
304 (
305 (
306 ((NodeOutput::SoftFail,), NodeOutput::Ok(5u64)),
307 NodeOutput::SoftFail
308 ),
309 NodeOutput::SoftFail
310 ),
311 NodeOutput::Ok(5u32)
312 )
313 );
314 Ok(NodeOutput::Ok(120))
315 });
316
317 let res = flow.run(5, &mut st).await;
318 assert_eq!(res, Result::Ok(NodeOutput::Ok(120)));
319
320 assert_eq!(st.remove::<MyVal>(), Some(MyVal::default()));
321 }
322}