cortex_ai/composer/flow.rs
1use super::BranchBuilder;
2use crate::{
3 flow::{condition::Condition, processor::Processor, source::Source, stage::Stage},
4 FlowError,
5};
6use std::{error::Error, sync::Arc};
7use tokio::sync::broadcast;
8use tracing::{debug, error, info, instrument, warn};
9
10/// A builder for constructing and executing data processing flows.
11///
12/// `Flow` represents a sequence of processing stages that data flows through. It supports
13/// various operations including data transformation, conditional branching, and async execution.
14///
15/// # Type Parameters
16///
17/// * `DataType` - The type of data flowing through the pipeline
18/// * `ErrorType` - The error type that can be produced during processing
19/// * `OutputType` - The output type produced by conditions in branches
20///
21/// # Examples
22///
23/// ```
24/// use cortex_ai::composer::Flow;
25/// use cortex_ai::flow::source::Source;
26/// use cortex_ai::flow::types::SourceOutput;
27/// use cortex_ai::flow::condition::Condition;
28/// use cortex_ai::flow::processor::Processor;
29/// use cortex_ai::FlowComponent;
30/// use cortex_ai::FlowError;
31/// use std::error::Error;
32/// use std::fmt;
33/// use std::pin::Pin;
34/// use std::future::Future;
35/// use flume::{Sender, Receiver};
36///
37/// #[derive(Clone, Debug)]
38/// struct MyData(String);
39///
40/// #[derive(Clone, Debug)]
41/// struct MyError;
42///
43/// impl fmt::Display for MyError {
44/// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45/// write!(f, "MyError")
46/// }
47/// }
48///
49/// impl Error for MyError {}
50///
51/// impl From<FlowError> for MyError {
52/// fn from(e: FlowError) -> Self { MyError }
53/// }
54///
55/// struct MySource;
56///
57/// impl FlowComponent for MySource {
58/// type Input = ();
59/// type Output = MyData;
60/// type Error = MyError;
61/// }
62///
63/// impl Source for MySource {
64/// fn stream(&self) -> Pin<Box<dyn Future<Output = Result<SourceOutput<Self::Output, Self::Error>, Self::Error>> + Send>> {
65/// Box::pin(async move {
66/// let (tx, rx) = flume::bounded(10);
67/// let (feedback_tx, _) = flume::bounded(10);
68/// Ok(SourceOutput { receiver: rx, feedback: feedback_tx })
69/// })
70/// }
71/// }
72///
73/// struct MyProcessor;
74/// impl FlowComponent for MyProcessor {
75/// type Input = MyData;
76/// type Output = MyData;
77/// type Error = MyError;
78/// }
79///
80/// impl Processor for MyProcessor {
81/// fn process(&self, input: Self::Input) -> Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>> {
82/// Box::pin(async move { Ok(input) })
83/// }
84/// }
85///
86/// struct MyCondition;
87/// impl FlowComponent for MyCondition {
88/// type Input = MyData;
89/// type Output = bool;
90/// type Error = MyError;
91/// }
92///
93/// impl Condition for MyCondition {
94/// fn evaluate(&self, input: Self::Input) -> Pin<Box<dyn Future<Output = Result<(bool, Option<Self::Output>), Self::Error>> + Send>> {
95/// Box::pin(async move { Ok((true, Some(false))) })
96/// }
97/// }
98///
99/// #[tokio::main]
100/// async fn main() {
101/// let (_, shutdown_rx) = tokio::sync::broadcast::channel::<()>(1);
102///
103/// let flow = Flow::<MyData, MyError, bool>::new()
104/// .source(MySource)
105/// .process(MyProcessor)
106/// .when(MyCondition)
107/// .process(MyProcessor)
108/// .otherwise()
109/// .process(MyProcessor)
110/// .end();
111///
112/// let _ = flow.run_stream(shutdown_rx).await;
113/// }
114/// ```
115pub struct Flow<DataType, ErrorType, OutputType> {
116 pub(crate) source:
117 Option<Box<dyn Source<Input = (), Output = DataType, Error = ErrorType> + Send + Sync>>,
118 pub(crate) stages: Vec<Stage<DataType, ErrorType, OutputType>>,
119}
120
121impl<DataType, ErrorType, OutputType> Flow<DataType, ErrorType, OutputType>
122where
123 DataType: Clone + Send + Sync + 'static,
124 OutputType: Send + Sync + 'static,
125 ErrorType: Error + Send + Sync + Clone + 'static + From<FlowError>,
126{
127 /// Creates a new empty Flow.
128 ///
129 /// # Returns
130 ///
131 /// A new instance of `Flow` with no source or stages configured
132 #[must_use]
133 pub fn new() -> Self {
134 Self {
135 source: None,
136 stages: Vec::new(),
137 }
138 }
139
140 /// Sets the data source for the flow.
141 ///
142 /// # Arguments
143 ///
144 /// * `source` - The source that will produce data for the flow
145 ///
146 /// # Returns
147 ///
148 /// The flow builder for method chaining
149 #[must_use]
150 pub fn source<SourceType>(mut self, source: SourceType) -> Self
151 where
152 SourceType:
153 Source<Input = (), Output = DataType, Error = ErrorType> + Send + Sync + 'static,
154 {
155 self.source = Some(Box::new(source));
156 self
157 }
158
159 /// Adds a processor stage to the flow.
160 ///
161 /// # Arguments
162 ///
163 /// * `processor` - The processor to add to the flow
164 ///
165 /// # Returns
166 ///
167 /// The flow builder for method chaining
168 #[must_use]
169 pub fn process<ProcessorType>(mut self, processor: ProcessorType) -> Self
170 where
171 ProcessorType: Processor<Input = DataType, Output = DataType, Error = ErrorType>
172 + Send
173 + Sync
174 + 'static,
175 {
176 self.stages.push(Stage::Process(Box::new(processor)));
177 self
178 }
179
180 /// Starts building a conditional branch in the flow.
181 ///
182 /// # Arguments
183 ///
184 /// * `condition` - The condition that determines which branch to take
185 ///
186 /// # Returns
187 ///
188 /// A `BranchBuilder` for constructing the conditional branches
189 #[must_use]
190 pub fn when<ConditionType>(
191 self,
192 condition: ConditionType,
193 ) -> BranchBuilder<DataType, OutputType, ErrorType>
194 where
195 ConditionType: Condition<Input = DataType, Output = OutputType, Error = ErrorType>
196 + Send
197 + Sync
198 + 'static,
199 {
200 BranchBuilder::new(Box::new(condition), self)
201 }
202
203 /// Executes the flow asynchronously, processing data from the source through all stages.
204 ///
205 /// # Arguments
206 ///
207 /// * `shutdown` - A broadcast receiver for graceful shutdown signaling
208 ///
209 /// # Returns
210 ///
211 /// A Result containing either a vector of processed data items or an error
212 ///
213 /// # Errors
214 ///
215 /// Returns an error if:
216 /// * The flow source is not set
217 /// * Any stage in the flow returns an error during processing
218 /// * The task execution fails
219 #[instrument(skip(self))]
220 pub async fn run_stream(
221 mut self,
222 shutdown: broadcast::Receiver<()>,
223 ) -> Result<Vec<DataType>, ErrorType> {
224 info!("Starting flow execution");
225 let source = self.source.take().ok_or_else(|| {
226 error!("Flow source not set");
227 ErrorType::from(FlowError::NoSource)
228 })?;
229
230 debug!("Initializing source stream");
231 let source_output = source.stream().await?;
232 let receiver = source_output.receiver;
233 let feedback = source_output.feedback;
234 let mut results = Vec::new();
235 let mut shutdown_rx = shutdown;
236 let mut handles = Vec::new();
237
238 let stages = Arc::new(self.stages);
239 info!("Starting message processing loop");
240
241 loop {
242 tokio::select! {
243 _ = shutdown_rx.recv() => {
244 warn!("Received shutdown signal");
245 break;
246 }
247 item = receiver.recv_async() => {
248 if let Ok(item) = item {
249 let feedback = feedback.clone();
250 let stages = Arc::clone(&stages);
251
252 debug!("Spawning task for data processing");
253 let handle = tokio::spawn(async move {
254 let mut current_item = match item {
255 Ok(data) => {
256 debug!("Processing new item");
257 data
258 },
259 Err(e) => {
260 error!("Source error: {:?}", e);
261 let _ = feedback.send(Err(e.clone()));
262 return Err(e);
263 }
264 };
265
266 for stage in stages.iter() {
267 match stage {
268 Stage::Process(processor) => {
269 debug!("Executing processor stage");
270 current_item = match processor.process(current_item).await {
271 Ok(data) => data,
272 Err(e) => {
273 error!("Processor error: {:?}", e);
274 let _ = feedback.send(Err(e.clone()));
275 return Err(e);
276 }
277 };
278 }
279 Stage::Branch(branch) => {
280 debug!("Evaluating branch condition");
281 let (condition_met, _) = match branch.condition.evaluate(current_item.clone()).await {
282 Ok(result) => result,
283 Err(e) => {
284 error!("Branch condition error: {:?}", e);
285 let _ = feedback.send(Err(e.clone()));
286 return Err(e);
287 }
288 };
289
290 let stages = if condition_met {
291 debug!("Taking then branch");
292 &branch.then_branch
293 } else {
294 debug!("Taking else branch");
295 &branch.else_branch
296 };
297
298 for stage in stages {
299 if let Stage::Process(processor) = stage {
300 current_item = match processor.process(current_item).await {
301 Ok(data) => data,
302 Err(e) => {
303 error!("Branch processor error: {:?}", e);
304 let _ = feedback.send(Err(e.clone()));
305 return Err(e);
306 }
307 };
308 }
309 }
310 }
311 }
312 }
313 debug!("Data processing completed successfully");
314 let _ = feedback.send(Ok(current_item.clone()));
315 Ok(current_item)
316 });
317 handles.push(handle);
318 } else {
319 debug!("Source channel closed");
320 break;
321 }
322 }
323 }
324 }
325
326 debug!("Collecting results from all tasks");
327 for handle in handles {
328 match handle.await {
329 Ok(Ok(result)) => results.push(result),
330 Ok(Err(e)) => {
331 error!("Task error: {:?}", e);
332 return Err(e);
333 }
334 Err(e) => {
335 error!("Task join error: {:?}", e);
336 return Err(ErrorType::from(FlowError::Custom(e.to_string())));
337 }
338 }
339 }
340
341 debug!("Flow execution completed successfully");
342 Ok(results)
343 }
344}
345
346impl<DataType, ErrorType, OutputType> Default for Flow<DataType, ErrorType, OutputType>
347where
348 DataType: Clone + Send + Sync + 'static,
349 OutputType: Send + Sync + 'static,
350 ErrorType: Error + Send + Sync + Clone + 'static + From<FlowError>,
351{
352 fn default() -> Self {
353 Self::new()
354 }
355}