cortex_ai/composer/
flow.rs

1use super::BranchBuilder;
2use crate::{
3    flow::{condition::Condition, processor::Processor, source::Source, stage::Stage},
4    FlowError,
5};
6use std::{error::Error, sync::Arc};
7use tokio::sync::broadcast;
8use tracing::{debug, error, info, instrument, warn};
9
10/// A builder for constructing and executing data processing flows.
11///
12/// `Flow` represents a sequence of processing stages that data flows through. It supports
13/// various operations including data transformation, conditional branching, and async execution.
14///
15/// # Type Parameters
16///
17/// * `DataType` - The type of data flowing through the pipeline
18/// * `ErrorType` - The error type that can be produced during processing
19/// * `OutputType` - The output type produced by conditions in branches
20///
21/// # Examples
22///
23/// ```
24/// use cortex_ai::composer::Flow;
25/// use cortex_ai::flow::source::Source;
26/// use cortex_ai::flow::types::SourceOutput;
27/// use cortex_ai::flow::condition::Condition;
28/// use cortex_ai::flow::processor::Processor;
29/// use cortex_ai::FlowComponent;
30/// use cortex_ai::FlowError;
31/// use std::error::Error;
32/// use std::fmt;
33/// use std::pin::Pin;
34/// use std::future::Future;
35/// use flume::{Sender, Receiver};
36///
37/// #[derive(Clone, Debug)]
38/// struct MyData(String);
39///
40/// #[derive(Clone, Debug)]
41/// struct MyError;
42///
43/// impl fmt::Display for MyError {
44///     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45///         write!(f, "MyError")
46///     }
47/// }
48///
49/// impl Error for MyError {}
50///
51/// impl From<FlowError> for MyError {
52///     fn from(e: FlowError) -> Self { MyError }
53/// }
54///
55/// struct MySource;
56///
57/// impl FlowComponent for MySource {
58///     type Input = ();
59///     type Output = MyData;
60///     type Error = MyError;
61/// }
62///
63/// impl Source for MySource {
64///     fn stream(&self) -> Pin<Box<dyn Future<Output = Result<SourceOutput<Self::Output, Self::Error>, Self::Error>> + Send>> {
65///         Box::pin(async move {
66///             let (tx, rx) = flume::bounded(10);
67///             let (feedback_tx, _) = flume::bounded(10);
68///             Ok(SourceOutput { receiver: rx, feedback: feedback_tx })
69///         })
70///     }
71/// }
72///
73/// struct MyProcessor;
74/// impl FlowComponent for MyProcessor {
75///     type Input = MyData;
76///     type Output = MyData;
77///     type Error = MyError;
78/// }
79///
80/// impl Processor for MyProcessor {
81///     fn process(&self, input: Self::Input) -> Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>> {
82///         Box::pin(async move { Ok(input) })
83///     }
84/// }
85///
86/// struct MyCondition;
87/// impl FlowComponent for MyCondition {
88///     type Input = MyData;
89///     type Output = bool;
90///     type Error = MyError;
91/// }
92///
93/// impl Condition for MyCondition {
94///     fn evaluate(&self, input: Self::Input) -> Pin<Box<dyn Future<Output = Result<(bool, Option<Self::Output>), Self::Error>> + Send>> {
95///         Box::pin(async move { Ok((true, Some(false))) })
96///     }
97/// }
98///
99/// #[tokio::main]
100/// async fn main() {
101///     let (_, shutdown_rx) = tokio::sync::broadcast::channel::<()>(1);
102///     
103///     let flow = Flow::<MyData, MyError, bool>::new()
104///         .source(MySource)
105///         .process(MyProcessor)
106///         .when(MyCondition)
107///         .process(MyProcessor)
108///         .otherwise()
109///         .process(MyProcessor)
110///         .end();
111///
112///     let _ = flow.run_stream(shutdown_rx).await;
113/// }
114/// ```
115pub struct Flow<DataType, ErrorType, OutputType> {
116    pub(crate) source:
117        Option<Box<dyn Source<Input = (), Output = DataType, Error = ErrorType> + Send + Sync>>,
118    pub(crate) stages: Vec<Stage<DataType, ErrorType, OutputType>>,
119}
120
121impl<DataType, ErrorType, OutputType> Flow<DataType, ErrorType, OutputType>
122where
123    DataType: Clone + Send + Sync + 'static,
124    OutputType: Send + Sync + 'static,
125    ErrorType: Error + Send + Sync + Clone + 'static + From<FlowError>,
126{
127    /// Creates a new empty Flow.
128    ///
129    /// # Returns
130    ///
131    /// A new instance of `Flow` with no source or stages configured
132    #[must_use]
133    pub fn new() -> Self {
134        Self {
135            source: None,
136            stages: Vec::new(),
137        }
138    }
139
140    /// Sets the data source for the flow.
141    ///
142    /// # Arguments
143    ///
144    /// * `source` - The source that will produce data for the flow
145    ///
146    /// # Returns
147    ///
148    /// The flow builder for method chaining
149    #[must_use]
150    pub fn source<SourceType>(mut self, source: SourceType) -> Self
151    where
152        SourceType:
153            Source<Input = (), Output = DataType, Error = ErrorType> + Send + Sync + 'static,
154    {
155        self.source = Some(Box::new(source));
156        self
157    }
158
159    /// Adds a processor stage to the flow.
160    ///
161    /// # Arguments
162    ///
163    /// * `processor` - The processor to add to the flow
164    ///
165    /// # Returns
166    ///
167    /// The flow builder for method chaining
168    #[must_use]
169    pub fn process<ProcessorType>(mut self, processor: ProcessorType) -> Self
170    where
171        ProcessorType: Processor<Input = DataType, Output = DataType, Error = ErrorType>
172            + Send
173            + Sync
174            + 'static,
175    {
176        self.stages.push(Stage::Process(Box::new(processor)));
177        self
178    }
179
180    /// Starts building a conditional branch in the flow.
181    ///
182    /// # Arguments
183    ///
184    /// * `condition` - The condition that determines which branch to take
185    ///
186    /// # Returns
187    ///
188    /// A `BranchBuilder` for constructing the conditional branches
189    #[must_use]
190    pub fn when<ConditionType>(
191        self,
192        condition: ConditionType,
193    ) -> BranchBuilder<DataType, OutputType, ErrorType>
194    where
195        ConditionType: Condition<Input = DataType, Output = OutputType, Error = ErrorType>
196            + Send
197            + Sync
198            + 'static,
199    {
200        BranchBuilder::new(Box::new(condition), self)
201    }
202
203    /// Executes the flow asynchronously, processing data from the source through all stages.
204    ///
205    /// # Arguments
206    ///
207    /// * `shutdown` - A broadcast receiver for graceful shutdown signaling
208    ///
209    /// # Returns
210    ///
211    /// A Result containing either a vector of processed data items or an error
212    ///
213    /// # Errors
214    ///
215    /// Returns an error if:
216    /// * The flow source is not set
217    /// * Any stage in the flow returns an error during processing
218    /// * The task execution fails
219    #[instrument(skip(self))]
220    pub async fn run_stream(
221        mut self,
222        shutdown: broadcast::Receiver<()>,
223    ) -> Result<Vec<DataType>, ErrorType> {
224        info!("Starting flow execution");
225        let source = self.source.take().ok_or_else(|| {
226            error!("Flow source not set");
227            ErrorType::from(FlowError::NoSource)
228        })?;
229
230        debug!("Initializing source stream");
231        let source_output = source.stream().await?;
232        let receiver = source_output.receiver;
233        let feedback = source_output.feedback;
234        let mut results = Vec::new();
235        let mut shutdown_rx = shutdown;
236        let mut handles = Vec::new();
237
238        let stages = Arc::new(self.stages);
239        info!("Starting message processing loop");
240
241        loop {
242            tokio::select! {
243                _ = shutdown_rx.recv() => {
244                    warn!("Received shutdown signal");
245                    break;
246                }
247                item = receiver.recv_async() => {
248                    if let Ok(item) = item {
249                        let feedback = feedback.clone();
250                        let stages = Arc::clone(&stages);
251
252                        debug!("Spawning task for data processing");
253                        let handle = tokio::spawn(async move {
254                            let mut current_item = match item {
255                                Ok(data) => {
256                                    debug!("Processing new item");
257                                    data
258                                },
259                                Err(e) => {
260                                    error!("Source error: {:?}", e);
261                                    let _ = feedback.send(Err(e.clone()));
262                                    return Err(e);
263                                }
264                            };
265
266                            for stage in stages.iter() {
267                                match stage {
268                                    Stage::Process(processor) => {
269                                        debug!("Executing processor stage");
270                                        current_item = match processor.process(current_item).await {
271                                            Ok(data) => data,
272                                            Err(e) => {
273                                                error!("Processor error: {:?}", e);
274                                                let _ = feedback.send(Err(e.clone()));
275                                                return Err(e);
276                                            }
277                                        };
278                                    }
279                                    Stage::Branch(branch) => {
280                                        debug!("Evaluating branch condition");
281                                        let (condition_met, _) = match branch.condition.evaluate(current_item.clone()).await {
282                                            Ok(result) => result,
283                                            Err(e) => {
284                                                error!("Branch condition error: {:?}", e);
285                                                let _ = feedback.send(Err(e.clone()));
286                                                return Err(e);
287                                            }
288                                        };
289
290                                        let stages = if condition_met {
291                                            debug!("Taking then branch");
292                                            &branch.then_branch
293                                        } else {
294                                            debug!("Taking else branch");
295                                            &branch.else_branch
296                                        };
297
298                                        for stage in stages {
299                                            if let Stage::Process(processor) = stage {
300                                                current_item = match processor.process(current_item).await {
301                                                    Ok(data) => data,
302                                                    Err(e) => {
303                                                        error!("Branch processor error: {:?}", e);
304                                                        let _ = feedback.send(Err(e.clone()));
305                                                        return Err(e);
306                                                    }
307                                                };
308                                            }
309                                        }
310                                    }
311                                }
312                            }
313                            debug!("Data processing completed successfully");
314                            let _ = feedback.send(Ok(current_item.clone()));
315                            Ok(current_item)
316                        });
317                        handles.push(handle);
318                    } else {
319                        debug!("Source channel closed");
320                        break;
321                    }
322                }
323            }
324        }
325
326        debug!("Collecting results from all tasks");
327        for handle in handles {
328            match handle.await {
329                Ok(Ok(result)) => results.push(result),
330                Ok(Err(e)) => {
331                    error!("Task error: {:?}", e);
332                    return Err(e);
333                }
334                Err(e) => {
335                    error!("Task join error: {:?}", e);
336                    return Err(ErrorType::from(FlowError::Custom(e.to_string())));
337                }
338            }
339        }
340
341        debug!("Flow execution completed successfully");
342        Ok(results)
343    }
344}
345
346impl<DataType, ErrorType, OutputType> Default for Flow<DataType, ErrorType, OutputType>
347where
348    DataType: Clone + Send + Sync + 'static,
349    OutputType: Send + Sync + 'static,
350    ErrorType: Error + Send + Sync + Clone + 'static + From<FlowError>,
351{
352    fn default() -> Self {
353        Self::new()
354    }
355}