1use std::fmt::Debug;
2
3use super::Builder;
4use super::chain_run::ChainRunParallel as ChainRun;
5use crate::{
6 context::{Fork, Join},
7 describe::{Description, DescriptionBase, Edge, Type, remove_generics_from_name},
8 flows::{
9 NodeResult, chain_debug::ChainDebug, chain_describe::ChainDescribe, parallel_flow::Joiner,
10 },
11 node::{Node, NodeOutput as NodeOutputStruct},
12};
13
14pub struct ParallelFlow<
84 Input,
85 Output,
86 Error,
87 Context,
88 ChainOutput = (),
89 Joiner = (),
90 NodeTypes = (),
91 NodeIOETypes = (),
92> {
93 #[expect(clippy::type_complexity)]
94 pub(super) _ioec: std::marker::PhantomData<fn() -> (Input, Output, Error, Context)>,
95 pub(super) _nodes_io: std::marker::PhantomData<fn() -> NodeIOETypes>,
96 pub(super) nodes: std::sync::Arc<NodeTypes>,
97 pub(super) _joiner_input: std::marker::PhantomData<fn() -> ChainOutput>,
98 pub(super) joiner: Joiner,
99}
100
101impl<Input, Output, Error, Context> ParallelFlow<Input, Output, Error, Context>
102where
103 Input: Send + Clone,
105 Error: Send,
106 Context: Fork + Join + Send,
107{
108 #[must_use]
124 pub fn builder() -> Builder<Input, Output, Error, Context> {
125 Builder::new()
126 }
127}
128
129impl<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes> Clone
130 for ParallelFlow<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes>
131where
132 J: Clone,
133{
134 fn clone(&self) -> Self {
135 Self {
136 _ioec: std::marker::PhantomData,
137 _nodes_io: std::marker::PhantomData,
138 nodes: self.nodes.clone(),
139 _joiner_input: std::marker::PhantomData,
140 joiner: self.joiner.clone(),
141 }
142 }
143}
144
145impl<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes> Debug
146 for ParallelFlow<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes>
147where
148 NodeTypes: ChainDebug,
149{
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 f.debug_struct("ParallelFlow")
152 .field("nodes", &self.nodes.as_list())
153 .finish_non_exhaustive()
154 }
155}
156
157#[inline(always)]
159#[expect(clippy::inline_always)]
160fn call_joiner<'a, J, I, O, E, Ctx>(
161 j: &J,
162 i: I,
163 s: &'a mut Ctx,
164) -> impl Future<Output = NodeResult<O, E>>
165where
166 J: Joiner<'a, I, O, E, Ctx> + 'a,
167{
168 j.join(i, s)
169}
170
171impl<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes>
172 Node<Input, NodeOutputStruct<Output>, Error, Context>
173 for ParallelFlow<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes>
174where
175 Input: Send,
176 Context: Send,
177 for<'a> J: Joiner<'a, ChainRunOutput, Output, Error, Context>,
178 NodeTypes: ChainRun<Input, Result<ChainRunOutput, Error>, Context, NodeIOETypes>
179 + ChainDescribe<Context, NodeIOETypes>
180 + Send
181 + Sync,
182{
183 fn run(
184 &mut self,
185 input: Input,
186 context: &mut Context,
187 ) -> impl Future<Output = NodeResult<Output, Error>> + Send {
188 let nodes = self.nodes.as_ref();
189 let joiner = &self.joiner;
190 async move {
191 let fut = nodes.run(input, context);
192 let res = fut.await?;
193 call_joiner::<J, ChainRunOutput, Output, Error, Context>(joiner, res, context).await
195 }
196 }
197
198 fn describe(&self) -> Description {
199 let node_count = <NodeTypes as ChainDescribe<Context, NodeIOETypes>>::COUNT;
200 let mut node_descriptions = Vec::with_capacity(node_count + 1);
201 self.nodes.describe(&mut node_descriptions);
202
203 node_descriptions.push(Description::Node {
204 base: DescriptionBase {
205 r#type: Type {
206 name: "Joiner".to_owned(),
207 },
208 input: Type {
209 name: String::new(),
210 },
211 output: Type {
212 name: String::new(),
213 },
214 error: Type {
215 name: String::new(),
216 },
217 context: Type {
218 name: String::new(),
219 },
220 description: None,
221 externals: None,
222 },
223 });
224
225 let mut edges = Vec::with_capacity(node_count * 2 + 1);
226 for i in 0..node_count {
227 edges.push(Edge::flow_to_node(i));
228 edges.push(Edge::node_to_node(i, node_count));
229 }
230 edges.push(Edge::node_to_flow(node_count));
231
232 Description::new_flow(self, node_descriptions, edges).modify_name(remove_generics_from_name)
233 }
234}
235
236#[cfg(test)]
237mod test {
238 use super::{ChainRun, ParallelFlow as Flow};
239 use crate::{
240 context::storage::local_storage::{LocalStorage, LocalStorageImpl, tests::MyVal},
241 flows::tests::{InsertIntoStorageAssertWasNotInStorage, Passer, SoftFailNode},
242 node::{Node, NodeOutput},
243 };
244
245 #[tokio::test]
246 async fn test_flow() {
247 let mut st = LocalStorageImpl::new();
248 let mut flow = Flow::<u8, u64, (), _>::builder()
249 .add_node(Passer::<u16, u64, ()>::new())
250 .add_node(SoftFailNode::<u16, u32, ()>::new())
251 .add_node(Passer::<u16, u32, ()>::new())
252 .build(async |input, context: &mut LocalStorageImpl| {
253 context.insert(MyVal::default());
254 assert_eq!(
255 input,
256 (
257 ((NodeOutput::Ok(0u64),), NodeOutput::SoftFail),
258 NodeOutput::Ok(0u32)
259 )
260 );
261 Ok(NodeOutput::Ok(120))
262 });
263 let res = flow.run(0, &mut st).await;
264
265 assert_eq!(res, Result::Ok(NodeOutput::Ok(120)));
266 }
267
268 #[tokio::test]
269 async fn test_chain() {
270 let mut st = LocalStorageImpl::new();
271 let node = (
272 (
273 (Passer::<u16, u64, ()>::new(),),
274 SoftFailNode::<u16, u32, ()>::new(),
275 ),
276 Passer::<u16, u32, ()>::new(),
277 );
278 let res: Result<_, ()> = ChainRun::<u8, _, _, _>::run(&node, 0u8, &mut st).await;
279 assert_eq!(
280 res,
281 Ok((
282 ((NodeOutput::Ok(0u64),), NodeOutput::SoftFail),
283 NodeOutput::Ok(0u32)
284 ))
285 );
286 }
287
288 #[tokio::test]
289 async fn test_flow_storage() {
290 let mut st = LocalStorageImpl::new();
291 let mut flow = Flow::<u8, u64, (), _>::builder()
292 .add_node(InsertIntoStorageAssertWasNotInStorage::<u16, u32, (), MyVal>::new())
293 .add_node(Passer::<u16, u64, ()>::new())
294 .add_node(InsertIntoStorageAssertWasNotInStorage::<u8, u16, (), MyVal>::new())
295 .add_node(InsertIntoStorageAssertWasNotInStorage::<u32, u64, (), MyVal>::new())
296 .add_node(Passer::<u16, u32, ()>::new())
297 .build(async |input, context: &mut LocalStorageImpl| {
298 let merged_orig = context.insert(MyVal::default());
299 assert_eq!(merged_orig, Some(MyVal("|||".to_owned())));
300 assert_eq!(
301 input,
302 (
303 (
304 (
305 ((NodeOutput::SoftFail,), NodeOutput::Ok(5u64)),
306 NodeOutput::SoftFail
307 ),
308 NodeOutput::SoftFail
309 ),
310 NodeOutput::Ok(5u32)
311 )
312 );
313 Ok(NodeOutput::Ok(120))
314 });
315
316 let res = flow.run(5, &mut st).await;
317 assert_eq!(res, Result::Ok(NodeOutput::Ok(120)));
318
319 assert_eq!(st.remove::<MyVal>(), Some(MyVal::default()));
320 }
321}