1use std::{any::Any, marker::PhantomData, sync::Arc};
2
3use async_trait::async_trait;
4
5use crate::{
6 internal::internal_node::{InternalNode, InternalNodeStruct},
7 node::{Node, NodeOutput},
8 storage::Storage,
9};
10
11pub struct SequentialFlow<Input, Output, Error> {
12 _input: PhantomData<Input>,
13 _output: PhantomData<Output>,
14 last_node_output_converter: Arc<Box<dyn ConvertTo<Output>>>,
15 #[cfg(not(all(doc, not(doctest))))]
16 nodes: Arc<Vec<Box<dyn InternalNode<Error> + Sync>>>,
17 #[cfg(all(doc, not(doctest)))]
18 __: PhantomData<Error>,
19}
20
21impl<Input, Output, Error> Clone for SequentialFlow<Input, Output, Error> {
22 fn clone(&self) -> Self {
23 Self {
24 _input: PhantomData,
25 _output: PhantomData,
26 last_node_output_converter: Arc::clone(&self.last_node_output_converter),
27 nodes: Arc::clone(&self.nodes),
28 }
29 }
30}
31
32impl<Input, Output, Error> SequentialFlow<Input, Output, Error>
33where
34 Input: Send + 'static,
35 Output: Send + 'static,
36 Error: Send + 'static,
37{
38 #[must_use]
40 pub fn builder() -> SequentialFlowBuilder<Input, Output, Error, Input> {
41 SequentialFlowBuilder::new()
42 }
43}
44
45pub struct SequentialFlowBuilder<Input, Output, Error, NextNodeInput> {
47 _input: PhantomData<Input>,
48 _output: PhantomData<Output>,
49 _next_node_input: PhantomData<NextNodeInput>,
50 #[cfg(not(all(doc, not(doctest))))]
51 nodes: Vec<Box<dyn InternalNode<Error> + Sync>>,
52 #[cfg(all(doc, not(doctest)))]
53 __: PhantomData<Error>,
54}
55
56#[allow(clippy::mismatching_type_param_order)]
57impl<Input, Output, Error> SequentialFlowBuilder<Input, Output, Error, Input> {
58 #[must_use]
60 pub fn new() -> Self {
61 Self {
62 _input: PhantomData,
63 _output: PhantomData,
64 _next_node_input: PhantomData,
65 nodes: Vec::new(),
66 }
67 }
68}
69
70#[allow(clippy::mismatching_type_param_order)]
71impl<Input, Output, Error> Default for SequentialFlowBuilder<Input, Output, Error, Input> {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77#[cfg_attr(not(all(doc, not(doctest))), async_trait)]
78impl<Input, Output, Error> Node<Input, NodeOutput<Output>, Error>
79 for SequentialFlow<Input, Output, Error>
80where
81 Input: Send + 'static,
82 Output: Send,
83 Error: Send,
84{
85 async fn run_with_storage<'a>(
86 &mut self,
87 input: Input,
88 storage: &mut Storage,
89 ) -> Result<NodeOutput<Output>, Error> {
90 let mut data: Box<dyn Any + Send> = Box::new(input);
91 for mut node in self.nodes.iter().map(|node| node.duplicate()) {
92 match node.run_with_storage(data, storage).await? {
93 NodeOutput::Ok(output) => data = output,
94 NodeOutput::SoftFail => return Ok(NodeOutput::SoftFail),
95 }
96 }
97 let output = self
98 .last_node_output_converter
99 .convert(data)
100 .expect("Converting data to sequence output type failed");
101 return Ok(NodeOutput::Ok(output));
102 }
103}
104
105impl<Input, Output, Error, LastNodeOutput>
106 SequentialFlowBuilder<Input, Output, Error, LastNodeOutput>
107where
108 Input: Send + 'static,
109 Output: Send + 'static,
110 Error: Send + 'static,
111{
112 pub fn add_node<NodeType, NodeInput, NodeOutput_, NodeError>(
114 mut self,
115 node: NodeType,
116 ) -> SequentialFlowBuilder<Input, Output, Error, NodeOutput_>
117 where
118 LastNodeOutput: Send + Sync + Into<NodeInput> + 'static,
119 NodeInput: Send + Sync + 'static,
120 NodeOutput_: Send + Sync + 'static,
121 NodeError: Send + Sync + Into<Error> + 'static,
122 NodeType:
123 Node<NodeInput, NodeOutput<NodeOutput_>, NodeError> + Clone + Send + Sync + 'static,
124 {
125 self.nodes.push(Box::new(InternalNodeStruct::<
126 NodeInput,
127 NodeOutput_,
128 NodeError,
129 NodeType,
130 LastNodeOutput,
131 >::new(node)));
132 SequentialFlowBuilder {
133 _input: PhantomData,
134 _output: PhantomData,
135 _next_node_input: PhantomData,
136 nodes: self.nodes,
137 }
138 }
139}
140
141impl<Input, Output, Error, LastNodeOutput>
142 SequentialFlowBuilder<Input, Output, Error, LastNodeOutput>
143where
144 Output: Send + Sync + 'static,
145 LastNodeOutput: Into<Output> + Send + Sync + 'static,
146{
147 #[must_use]
149 pub fn build(self) -> SequentialFlow<Input, Output, Error> {
150 SequentialFlow {
151 _input: PhantomData,
152 _output: PhantomData,
153 last_node_output_converter: Arc::new(Box::new(DowncastConverter::<
154 LastNodeOutput,
155 Output,
156 >::new())),
157 nodes: Arc::new(self.nodes),
158 }
159 }
160}
161
162trait ConvertTo<T>: Send + Sync {
163 fn convert(&self, data: Box<dyn Any>) -> Option<T>;
164}
165
166struct DowncastConverter<Input, Output>
167where
168 Input: Into<Output>,
169{
170 _node_output_type: PhantomData<Input>,
171 _output_type: PhantomData<Output>,
172}
173
174impl<Input, Output> DowncastConverter<Input, Output>
175where
176 Input: Into<Output>,
177{
178 fn new() -> Self {
179 Self {
180 _node_output_type: PhantomData,
181 _output_type: PhantomData,
182 }
183 }
184}
185
186impl<FromType, IntoType> ConvertTo<IntoType> for DowncastConverter<FromType, IntoType>
187where
188 FromType: Into<IntoType> + Send + Sync + 'static,
189 IntoType: Send + Sync,
190{
191 fn convert(&self, data: Box<dyn Any>) -> Option<IntoType> {
192 let box_from = data.downcast::<FromType>().ok()?;
193 let from = *box_from;
194 let into = from.into();
195 Some(into)
196 }
197}