node_flow/flows/parallel_flow/
flow.rs

1use std::fmt::Debug;
2
3use super::Builder;
4use super::chain_run::ChainRunParallel as ChainRun;
5use crate::{
6    context::{Fork, Join},
7    describe::{Description, DescriptionBase, Edge, Type, remove_generics_from_name},
8    flows::{
9        NodeResult, chain_debug::ChainDebug, chain_describe::ChainDescribe, parallel_flow::Joiner,
10    },
11    node::{Node, NodeOutput as NodeOutputStruct},
12};
13
14/// `ParallelFlow` executes nodes (branches) **in parallel**.
15///
16/// Nodes (branches) are executed concurrently.
17/// The flow completes when **all** node succeed or **any** node "hard" fails.
18/// - If a node returns [`NodeOutput::Ok`](crate::node::NodeOutput::Ok) or [`NodeOutput::SoftFail`](crate::node::NodeOutput::SoftFail),
19///   the flow continues waiting for other nodes (branches).
20/// - If a node returns an **error**, then that error is returned.
21///
22/// The output of all nodes is then passed into a [`Joiner`],
23/// which decides what should happen and what should this flow return.
24///
25/// # Type Parameters
26/// - `Input`: The type of data accepted by this flow.
27/// - `Output`: The type of data produced by this flow.
28/// - `Error`: The type of error emitted by this flow.
29/// - `Context`: The type of context used during execution.
30///
31/// See also [`Joiner`].
32///
33/// # Examples
34/// ```
35/// use node_flow::node::{Node, NodeOutput};
36/// use node_flow::flows::ParallelFlow;
37/// use node_flow::context::{Fork, Join};
38///
39/// // Example nodes
40/// #[derive(Clone)]
41/// struct A;
42/// #[derive(Clone)]
43/// struct B;
44///
45/// struct ExampleCtx;
46/// impl Fork for ExampleCtx // ...
47/// # { fn fork(&self) -> Self { Self } }
48/// impl Join for ExampleCtx // ...
49/// # { fn join(&mut self, others: Box<[Self]>) {} }
50///
51/// impl<Ctx: Send> Node<(), NodeOutput<i32>, (), Ctx> for A {
52///     async fn run(&mut self, _: (), _: &mut Ctx) -> Result<NodeOutput<i32>, ()> {
53///         Ok(NodeOutput::SoftFail)
54///     }
55/// }
56///
57/// impl<Ctx: Send> Node<(), NodeOutput<i32>, (), Ctx> for B {
58///     async fn run(&mut self, _: (), _: &mut Ctx) -> Result<NodeOutput<i32>, ()> {
59///         Ok(NodeOutput::Ok(5))
60///     }
61/// }
62///
63/// # tokio::runtime::Builder::new_current_thread()
64/// #     .enable_all()
65/// #     .build()
66/// #     .unwrap()
67/// #     .block_on(async {
68/// async fn main() {
69///     let mut flow = ParallelFlow::<(), i32, (), _>::builder()
70///         .add_node(A)
71///         .add_node(B)
72///         .build(async |_input, context: &mut ExampleCtx| {
73///             Ok(NodeOutput::Ok(120))
74///         });
75///
76///     let mut ctx = ExampleCtx;
77///     let result = flow.run((), &mut ctx).await;
78///     assert_eq!(result, Ok(NodeOutput::Ok(120)));
79/// }
80/// # main().await;
81/// # });
82/// ```
83pub struct ParallelFlow<
84    Input,
85    Output,
86    Error,
87    Context,
88    ChainOutput = (),
89    Joiner = (),
90    NodeTypes = (),
91    NodeIOETypes = (),
92> {
93    #[expect(clippy::type_complexity)]
94    pub(super) _ioec: std::marker::PhantomData<fn() -> (Input, Output, Error, Context)>,
95    pub(super) _nodes_io: std::marker::PhantomData<fn() -> NodeIOETypes>,
96    pub(super) nodes: std::sync::Arc<NodeTypes>,
97    pub(super) _joiner_input: std::marker::PhantomData<fn() -> ChainOutput>,
98    pub(super) joiner: Joiner,
99}
100
101impl<Input, Output, Error, Context> ParallelFlow<Input, Output, Error, Context>
102where
103    // Trait bounds for better and nicer errors
104    Input: Send + Clone,
105    Error: Send,
106    Context: Fork + Join + Send,
107{
108    /// Creates a new [`Builder`] for constructing [`ParallelFlow`].
109    ///
110    /// See also [`ParallelFlow`].
111    ///
112    /// # Examples
113    /// ```
114    /// # use node_flow::context::{Fork, Join};
115    /// # struct Ctx;
116    /// # impl Fork for Ctx { fn fork(&self) -> Self { Self } }
117    /// # impl Join for Ctx { fn join(&mut self, other: Box<[Self]>) {} }
118    /// #
119    /// use node_flow::flows::ParallelFlow;
120    ///
121    /// let builder = ParallelFlow::<u8, u16, (), Ctx>::builder();
122    /// ```
123    #[must_use]
124    pub fn builder() -> Builder<Input, Output, Error, Context> {
125        Builder::new()
126    }
127}
128
129impl<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes> Clone
130    for ParallelFlow<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes>
131where
132    J: Clone,
133{
134    fn clone(&self) -> Self {
135        Self {
136            _ioec: std::marker::PhantomData,
137            _nodes_io: std::marker::PhantomData,
138            nodes: self.nodes.clone(),
139            _joiner_input: std::marker::PhantomData,
140            joiner: self.joiner.clone(),
141        }
142    }
143}
144
145impl<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes> Debug
146    for ParallelFlow<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes>
147where
148    NodeTypes: ChainDebug,
149{
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        f.debug_struct("ParallelFlow")
152            .field("nodes", &self.nodes.as_list())
153            .finish_non_exhaustive()
154    }
155}
156
157// workaround for https://github.com/rust-lang/rust/issues/100013
158#[inline(always)]
159#[expect(clippy::inline_always)]
160fn call_joiner<'a, J, I, O, E, Ctx>(
161    j: &J,
162    i: I,
163    s: &'a mut Ctx,
164) -> impl Future<Output = NodeResult<O, E>>
165where
166    J: Joiner<'a, I, O, E, Ctx> + 'a,
167{
168    j.join(i, s)
169}
170
171impl<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes>
172    Node<Input, NodeOutputStruct<Output>, Error, Context>
173    for ParallelFlow<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes>
174where
175    Input: Send,
176    Context: Send,
177    for<'a> J: Joiner<'a, ChainRunOutput, Output, Error, Context>,
178    NodeTypes: ChainRun<Input, Result<ChainRunOutput, Error>, Context, NodeIOETypes>
179        + ChainDescribe<Context, NodeIOETypes>
180        + Send
181        + Sync,
182{
183    fn run(
184        &mut self,
185        input: Input,
186        context: &mut Context,
187    ) -> impl Future<Output = NodeResult<Output, Error>> + Send {
188        let nodes = self.nodes.as_ref();
189        let joiner = &self.joiner;
190        async move {
191            let fut = nodes.run(input, context);
192            let res = fut.await?;
193            // workaround for https://github.com/rust-lang/rust/issues/100013
194            call_joiner::<J, ChainRunOutput, Output, Error, Context>(joiner, res, context).await
195        }
196    }
197
198    fn describe(&self) -> Description {
199        let node_count = <NodeTypes as ChainDescribe<Context, NodeIOETypes>>::COUNT;
200        let mut node_descriptions = Vec::with_capacity(node_count + 1);
201        self.nodes.describe(&mut node_descriptions);
202
203        node_descriptions.push(Description::Node {
204            base: DescriptionBase {
205                r#type: Type {
206                    name: "Joiner".to_owned(),
207                },
208                input: Type {
209                    name: String::new(),
210                },
211                output: Type {
212                    name: String::new(),
213                },
214                error: Type {
215                    name: String::new(),
216                },
217                context: Type {
218                    name: String::new(),
219                },
220                description: None,
221                externals: None,
222            },
223        });
224
225        let mut edges = Vec::with_capacity(node_count * 2 + 1);
226        for i in 0..node_count {
227            edges.push(Edge::flow_to_node(i));
228            edges.push(Edge::node_to_node(i, node_count));
229        }
230        edges.push(Edge::node_to_flow(node_count));
231
232        Description::new_flow(self, node_descriptions, edges).modify_name(remove_generics_from_name)
233    }
234}
235
236#[cfg(test)]
237mod test {
238    use super::{ChainRun, ParallelFlow as Flow};
239    use crate::{
240        context::storage::local_storage::{LocalStorage, LocalStorageImpl, tests::MyVal},
241        flows::tests::{InsertIntoStorageAssertWasNotInStorage, Passer, SoftFailNode},
242        node::{Node, NodeOutput},
243    };
244
245    #[tokio::test]
246    async fn test_flow() {
247        let mut st = LocalStorageImpl::new();
248        let mut flow = Flow::<u8, u64, (), _>::builder()
249            .add_node(Passer::<u16, u64, ()>::new())
250            .add_node(SoftFailNode::<u16, u32, ()>::new())
251            .add_node(Passer::<u16, u32, ()>::new())
252            .build(async |input, context: &mut LocalStorageImpl| {
253                context.insert(MyVal::default());
254                assert_eq!(
255                    input,
256                    (
257                        ((NodeOutput::Ok(0u64),), NodeOutput::SoftFail),
258                        NodeOutput::Ok(0u32)
259                    )
260                );
261                Ok(NodeOutput::Ok(120))
262            });
263        let res = flow.run(0, &mut st).await;
264
265        assert_eq!(res, Result::Ok(NodeOutput::Ok(120)));
266    }
267
268    #[tokio::test]
269    async fn test_chain() {
270        let mut st = LocalStorageImpl::new();
271        let node = (
272            (
273                (Passer::<u16, u64, ()>::new(),),
274                SoftFailNode::<u16, u32, ()>::new(),
275            ),
276            Passer::<u16, u32, ()>::new(),
277        );
278        let res: Result<_, ()> = ChainRun::<u8, _, _, _>::run(&node, 0u8, &mut st).await;
279        assert_eq!(
280            res,
281            Ok((
282                ((NodeOutput::Ok(0u64),), NodeOutput::SoftFail),
283                NodeOutput::Ok(0u32)
284            ))
285        );
286    }
287
288    #[tokio::test]
289    async fn test_flow_storage() {
290        let mut st = LocalStorageImpl::new();
291        let mut flow = Flow::<u8, u64, (), _>::builder()
292            .add_node(InsertIntoStorageAssertWasNotInStorage::<u16, u32, (), MyVal>::new())
293            .add_node(Passer::<u16, u64, ()>::new())
294            .add_node(InsertIntoStorageAssertWasNotInStorage::<u8, u16, (), MyVal>::new())
295            .add_node(InsertIntoStorageAssertWasNotInStorage::<u32, u64, (), MyVal>::new())
296            .add_node(Passer::<u16, u32, ()>::new())
297            .build(async |input, context: &mut LocalStorageImpl| {
298                let merged_orig = context.insert(MyVal::default());
299                assert_eq!(merged_orig, Some(MyVal("|||".to_owned())));
300                assert_eq!(
301                    input,
302                    (
303                        (
304                            (
305                                ((NodeOutput::SoftFail,), NodeOutput::Ok(5u64)),
306                                NodeOutput::SoftFail
307                            ),
308                            NodeOutput::SoftFail
309                        ),
310                        NodeOutput::Ok(5u32)
311                    )
312                );
313                Ok(NodeOutput::Ok(120))
314            });
315
316        let res = flow.run(5, &mut st).await;
317        assert_eq!(res, Result::Ok(NodeOutput::Ok(120)));
318
319        assert_eq!(st.remove::<MyVal>(), Some(MyVal::default()));
320    }
321}