Skip to main content

node_flow/flows/parallel_flow/
flow.rs

1use std::fmt::Debug;
2
3use super::Builder;
4use super::chain_run::ChainRunParallel as ChainRun;
5use crate::{
6    context::{Fork, Join},
7    describe::{Description, DescriptionBase, Edge, Type, remove_generics_from_name},
8    flows::{
9        NodeResult, chain_debug::ChainDebug, chain_describe::ChainDescribe, parallel_flow::Joiner,
10    },
11    node::{Node, NodeOutput as NodeOutputStruct},
12};
13
14/// `ParallelFlow` executes nodes (branches) **in parallel**.
15///
16/// Nodes (branches) are executed concurrently.
17/// The flow completes when **all** node succeed or **any** node "hard" fails.
18/// - If a node returns [`NodeOutput::Ok`](crate::node::NodeOutput::Ok) or [`NodeOutput::SoftFail`](crate::node::NodeOutput::SoftFail),
19///   the flow continues waiting for other nodes (branches).
20/// - If a node returns an **error**, then that error is returned.
21///
22/// The output of all nodes is then passed into a [`Joiner`],
23/// which decides what should happen and what should this flow return.
24///
25/// # Type Parameters
26/// - `Input`: The type of data accepted by this flow.
27/// - `Output`: The type of data produced by this flow.
28/// - `Error`: The type of error emitted by this flow.
29/// - `Context`: The type of context used during execution.
30///
31/// See also [`Joiner`].
32///
33/// # Examples
34/// ```
35/// use node_flow::node::{Node, NodeOutput};
36/// use node_flow::flows::ParallelFlow;
37/// use node_flow::context::{Fork, Join};
38///
39/// // Example nodes
40/// #[derive(Clone)]
41/// struct A;
42/// #[derive(Clone)]
43/// struct B;
44///
45/// struct ExampleCtx;
46/// impl Fork for ExampleCtx // ...
47/// # { fn fork(&self) -> Self { Self } }
48/// impl Join for ExampleCtx // ...
49/// # { fn join(&mut self, others: Box<[Self]>) {} }
50///
51/// impl<Ctx: Send> Node<(), NodeOutput<i32>, (), Ctx> for A {
52///     async fn run(&mut self, _: (), _: &mut Ctx) -> Result<NodeOutput<i32>, ()> {
53///         Ok(NodeOutput::SoftFail)
54///     }
55/// }
56///
57/// impl<Ctx: Send> Node<(), NodeOutput<i32>, (), Ctx> for B {
58///     async fn run(&mut self, _: (), _: &mut Ctx) -> Result<NodeOutput<i32>, ()> {
59///         Ok(NodeOutput::Ok(5))
60///     }
61/// }
62///
63/// # tokio::runtime::Builder::new_current_thread()
64/// #     .enable_all()
65/// #     .build()
66/// #     .unwrap()
67/// #     .block_on(async {
68/// async fn main() {
69///     let mut flow = ParallelFlow::<(), i32, (), _>::builder()
70///         .add_node(A)
71///         .add_node(B)
72///         .build(async |_input, context: &mut ExampleCtx| {
73///             Ok(NodeOutput::Ok(120))
74///         });
75///
76///     let mut ctx = ExampleCtx;
77///     let result = flow.run((), &mut ctx).await;
78///     assert_eq!(result, Ok(NodeOutput::Ok(120)));
79/// }
80/// # main().await;
81/// # });
82/// ```
83pub struct ParallelFlow<
84    Input,
85    Output,
86    Error,
87    Context,
88    ChainOutput = (),
89    Joiner = (),
90    NodeTypes = (),
91    NodeIOETypes = (),
92> {
93    #[expect(clippy::type_complexity)]
94    pub(super) _ioec: std::marker::PhantomData<fn() -> (Input, Output, Error, Context)>,
95    pub(super) _nodes_io: std::marker::PhantomData<fn() -> NodeIOETypes>,
96    pub(super) nodes: NodeTypes,
97    pub(super) _joiner_input: std::marker::PhantomData<fn() -> ChainOutput>,
98    pub(super) joiner: Joiner,
99}
100
101impl<Input, Output, Error, Context> ParallelFlow<Input, Output, Error, Context>
102where
103    // Trait bounds for better and nicer errors
104    Input: Send + Clone,
105    Error: Send,
106    Context: Fork + Join + Send,
107{
108    /// Creates a new [`Builder`] for constructing [`ParallelFlow`].
109    ///
110    /// See also [`ParallelFlow`].
111    ///
112    /// # Examples
113    /// ```
114    /// # use node_flow::context::{Fork, Join};
115    /// # struct Ctx;
116    /// # impl Fork for Ctx { fn fork(&self) -> Self { Self } }
117    /// # impl Join for Ctx { fn join(&mut self, other: Box<[Self]>) {} }
118    /// #
119    /// use node_flow::flows::ParallelFlow;
120    ///
121    /// let builder = ParallelFlow::<u8, u16, (), Ctx>::builder();
122    /// ```
123    #[must_use]
124    pub fn builder() -> Builder<Input, Output, Error, Context> {
125        Builder::new()
126    }
127}
128
129impl<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes> Clone
130    for ParallelFlow<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes>
131where
132    J: Clone,
133    NodeTypes: Clone,
134{
135    fn clone(&self) -> Self {
136        Self {
137            _ioec: std::marker::PhantomData,
138            _nodes_io: std::marker::PhantomData,
139            nodes: self.nodes.clone(),
140            _joiner_input: std::marker::PhantomData,
141            joiner: self.joiner.clone(),
142        }
143    }
144}
145
146impl<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes> Debug
147    for ParallelFlow<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes>
148where
149    NodeTypes: ChainDebug,
150{
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        f.debug_struct("ParallelFlow")
153            .field("nodes", &self.nodes.as_list())
154            .finish_non_exhaustive()
155    }
156}
157
158// workaround for https://github.com/rust-lang/rust/issues/100013
159#[inline(always)]
160#[expect(clippy::inline_always)]
161fn call_joiner<'a, J, I, O, E, Ctx>(
162    j: &J,
163    i: I,
164    s: &'a mut Ctx,
165) -> impl Future<Output = NodeResult<O, E>>
166where
167    J: Joiner<'a, I, O, E, Ctx> + 'a,
168{
169    j.join(i, s)
170}
171
172impl<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes>
173    Node<Input, NodeOutputStruct<Output>, Error, Context>
174    for ParallelFlow<Input, Output, Error, Context, ChainRunOutput, J, NodeTypes, NodeIOETypes>
175where
176    Input: Send,
177    Context: Send,
178    for<'a> J: Joiner<'a, ChainRunOutput, Output, Error, Context>,
179    NodeTypes: ChainRun<Input, Result<ChainRunOutput, Error>, Context, NodeIOETypes>
180        + ChainDescribe<Context, NodeIOETypes>
181        + Send
182        + Sync,
183{
184    fn run(
185        &mut self,
186        input: Input,
187        context: &mut Context,
188    ) -> impl Future<Output = NodeResult<Output, Error>> + Send {
189        let nodes = &self.nodes;
190        let joiner = &self.joiner;
191        async move {
192            let fut = nodes.run(input, context);
193            let res = fut.await?;
194            // workaround for https://github.com/rust-lang/rust/issues/100013
195            call_joiner::<J, ChainRunOutput, Output, Error, Context>(joiner, res, context).await
196        }
197    }
198
199    fn describe(&self) -> Description {
200        let node_count = <NodeTypes as ChainDescribe<Context, NodeIOETypes>>::COUNT;
201        let mut node_descriptions = Vec::with_capacity(node_count + 1);
202        self.nodes.describe(&mut node_descriptions);
203
204        node_descriptions.push(Description::Node {
205            base: DescriptionBase {
206                r#type: Type {
207                    name: "Joiner".to_owned(),
208                },
209                input: Type {
210                    name: String::new(),
211                },
212                output: Type {
213                    name: String::new(),
214                },
215                error: Type {
216                    name: String::new(),
217                },
218                context: Type {
219                    name: String::new(),
220                },
221                description: None,
222                externals: None,
223            },
224        });
225
226        let mut edges = Vec::with_capacity(node_count * 2 + 1);
227        for i in 0..node_count {
228            edges.push(Edge::flow_to_node(i));
229            edges.push(Edge::node_to_node(i, node_count));
230        }
231        edges.push(Edge::node_to_flow(node_count));
232
233        Description::new_flow(self, node_descriptions, edges).modify_name(remove_generics_from_name)
234    }
235}
236
237#[cfg(test)]
238mod test {
239    use super::{ChainRun, ParallelFlow as Flow};
240    use crate::{
241        context::storage::local_storage::{LocalStorage, LocalStorageImpl, tests::MyVal},
242        flows::tests::{InsertIntoStorageAssertWasNotInStorage, Passer, SoftFailNode},
243        node::{Node, NodeOutput},
244    };
245
246    #[tokio::test]
247    async fn test_flow() {
248        let mut st = LocalStorageImpl::new();
249        let mut flow = Flow::<u8, u64, (), _>::builder()
250            .add_node(Passer::<u16, u64, ()>::new())
251            .add_node(SoftFailNode::<u16, u32, ()>::new())
252            .add_node(Passer::<u16, u32, ()>::new())
253            .build(async |input, context: &mut LocalStorageImpl| {
254                context.insert(MyVal::default());
255                assert_eq!(
256                    input,
257                    (
258                        ((NodeOutput::Ok(0u64),), NodeOutput::SoftFail),
259                        NodeOutput::Ok(0u32)
260                    )
261                );
262                Ok(NodeOutput::Ok(120))
263            });
264        let res = flow.run(0, &mut st).await;
265
266        assert_eq!(res, Result::Ok(NodeOutput::Ok(120)));
267    }
268
269    #[tokio::test]
270    async fn test_chain() {
271        let mut st = LocalStorageImpl::new();
272        let node = (
273            (
274                (Passer::<u16, u64, ()>::new(),),
275                SoftFailNode::<u16, u32, ()>::new(),
276            ),
277            Passer::<u16, u32, ()>::new(),
278        );
279        let res: Result<_, ()> = ChainRun::<u8, _, _, _>::run(&node, 0u8, &mut st).await;
280        assert_eq!(
281            res,
282            Ok((
283                ((NodeOutput::Ok(0u64),), NodeOutput::SoftFail),
284                NodeOutput::Ok(0u32)
285            ))
286        );
287    }
288
289    #[tokio::test]
290    async fn test_flow_storage() {
291        let mut st = LocalStorageImpl::new();
292        let mut flow = Flow::<u8, u64, (), _>::builder()
293            .add_node(InsertIntoStorageAssertWasNotInStorage::<u16, u32, (), MyVal>::new())
294            .add_node(Passer::<u16, u64, ()>::new())
295            .add_node(InsertIntoStorageAssertWasNotInStorage::<u8, u16, (), MyVal>::new())
296            .add_node(InsertIntoStorageAssertWasNotInStorage::<u32, u64, (), MyVal>::new())
297            .add_node(Passer::<u16, u32, ()>::new())
298            .build(async |input, context: &mut LocalStorageImpl| {
299                let merged_orig = context.insert(MyVal::default());
300                assert_eq!(merged_orig, Some(MyVal("|||".to_owned())));
301                assert_eq!(
302                    input,
303                    (
304                        (
305                            (
306                                ((NodeOutput::SoftFail,), NodeOutput::Ok(5u64)),
307                                NodeOutput::SoftFail
308                            ),
309                            NodeOutput::SoftFail
310                        ),
311                        NodeOutput::Ok(5u32)
312                    )
313                );
314                Ok(NodeOutput::Ok(120))
315            });
316
317        let res = flow.run(5, &mut st).await;
318        assert_eq!(res, Result::Ok(NodeOutput::Ok(120)));
319
320        assert_eq!(st.remove::<MyVal>(), Some(MyVal::default()));
321    }
322}