fluxion_exec/
subscribe.rs

1// Copyright 2025 Umberto Gotti <umberto.gotti@umbertogotti.dev>
2// Licensed under the Apache License, Version 2.0
3// http://www.apache.org/licenses/LICENSE-2.0
4
5use async_trait::async_trait;
6use futures::stream::Stream;
7use futures::stream::StreamExt;
8use std::error::Error;
9use std::fmt::Debug;
10use std::future::Future;
11use tokio::sync::mpsc::unbounded_channel;
12use tokio_util::sync::CancellationToken;
13
14use fluxion_core::{FluxionError, Result};
15
16/// Extension trait providing async subscription capabilities for streams.
17///
18/// This trait enables processing stream items with async handlers in a sequential manner.
19#[async_trait]
20pub trait SubscribeExt<T>: Stream<Item = T> + Sized {
21    /// Subscribes to the stream with an async handler, processing items sequentially.
22    ///
23    /// This method consumes the stream and spawns async tasks to process each item.
24    /// Items are processed in the order they arrive, with each item's handler running
25    /// to completion before the next item is processed (though handlers run concurrently
26    /// via tokio spawn).
27    ///
28    /// # Behavior
29    ///
30    /// - Processes each stream item with the provided async handler
31    /// - Spawns a new task for each item (non-blocking)
32    /// - Continues until stream ends or cancellation token is triggered
33    /// - Errors from handlers are passed to the error callback if provided
34    /// - If no error callback provided, errors are collected and returned on completion
35    ///
36    /// # Arguments
37    ///
38    /// * `on_next_func` - Async function called for each stream item. Receives the item
39    ///                    and a cancellation token. Returns `Result<(), E>`.
40    /// * `cancellation_token` - Optional token to stop processing. If `None`, a default
41    ///                          token is created that never cancels.
42    /// * `on_error_callback` - Optional error handler called when `on_next_func` returns
43    ///                         an error. If `None`, errors are collected and returned.
44    ///
45    /// # Type Parameters
46    ///
47    /// * `F` - Function type for the item handler
48    /// * `Fut` - Future type returned by the handler
49    /// * `E` - Error type that implements `std::error::Error`
50    /// * `OnError` - Function type for error handling
51    ///
52    /// # Errors
53    ///
54    /// Returns `Err(FluxionError::MultipleErrors)` if any items failed to process and
55    /// no error callback was provided. If an error callback is provided, errors are
56    /// passed to it and the function returns `Ok(())` on stream completion.
57    ///
58    /// The subscription continues processing subsequent items even if individual items
59    /// fail, unless the cancellation token is triggered.
60    ///
61    /// # See Also
62    ///
63    /// - [`subscribe_latest`](crate::SubscribeLatestExt::subscribe_latest) - Cancels old work for new items
64    ///
65    /// # Examples
66    ///
67    /// ## Basic Usage
68    ///
69    /// Process all items sequentially:
70    ///
71    /// ```
72    /// use fluxion_exec::SubscribeExt;
73    /// use futures::stream;
74    /// use std::sync::Arc;
75    /// use tokio::sync::Mutex;
76    /// use tokio::sync::mpsc::unbounded_channel;
77    ///
78    /// # #[tokio::main]
79    /// # async fn main() {
80    /// let results = Arc::new(Mutex::new(Vec::new()));
81    /// let results_clone = results.clone();
82    /// let (notify_tx, mut notify_rx) = unbounded_channel();
83    ///
84    /// let stream = stream::iter(vec![1, 2, 3, 4, 5]);
85    ///
86    /// // Subscribe and process each item
87    /// stream.subscribe(
88    ///     move |item, _token| {
89    ///         let results = results_clone.clone();
90    ///         let notify_tx = notify_tx.clone();
91    ///         async move {
92    ///             results.lock().await.push(item * 2);
93    ///             let _ = notify_tx.send(());
94    ///             Ok::<(), std::io::Error>(())
95    ///         }
96    ///     },
97    ///     None, // No cancellation
98    ///     None::<fn(std::io::Error)>  // No error callback
99    /// ).await.unwrap();
100    ///
101    /// // Wait for all 5 items to be processed
102    /// for _ in 0..5 {
103    ///     notify_rx.recv().await.unwrap();
104    /// }
105    ///
106    /// let processed = results.lock().await;
107    /// assert!(processed.contains(&2));
108    /// assert!(processed.contains(&4));
109    /// # }
110    /// ```
111    ///
112    /// ## With Error Handling
113    ///
114    /// Use an error callback to handle errors without stopping the stream:
115    ///
116    /// ```
117    /// use fluxion_exec::SubscribeExt;
118    /// use futures::stream;
119    /// use std::sync::Arc;
120    /// use tokio::sync::Mutex;
121    /// use tokio::sync::mpsc::unbounded_channel;
122    ///
123    /// # #[tokio::main]
124    /// # async fn main() {
125    /// #[derive(Debug)]
126    /// struct MyError(String);
127    /// impl std::fmt::Display for MyError {
128    ///     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129    ///         write!(f, "MyError: {}", self.0)
130    ///     }
131    /// }
132    /// impl std::error::Error for MyError {}
133    ///
134    /// let error_count = Arc::new(Mutex::new(0));
135    /// let error_count_clone = error_count.clone();
136    /// let (notify_tx, mut notify_rx) = unbounded_channel();
137    ///
138    /// let stream = stream::iter(vec![1, 2, 3, 4, 5]);
139    ///
140    /// stream.subscribe(
141    ///     move |item, _token| {
142    ///         let notify_tx = notify_tx.clone();
143    ///         async move {
144    ///             let res = if item % 2 == 0 {
145    ///                 Err(MyError(format!("Even number: {}", item)))
146    ///             } else {
147    ///                 Ok(())
148    ///             };
149    ///             // Signal completion regardless of success/failure
150    ///             // Note: In real code, you might signal in the error callback too
151    ///             // but here we just want to know the handler finished.
152    ///             // However, subscribe spawns the handler. If it errors,
153    ///             // the error callback is called.
154    ///             // We need to signal completion in both paths.
155    ///             // Since the handler returns the error, we can't signal *after* returning Err.
156    ///             // So we signal before returning.
157    ///             let _ = notify_tx.send(());
158    ///             res
159    ///         }
160    ///     },
161    ///     None,
162    ///     Some(move |_err| {
163    ///         let count = error_count_clone.clone();
164    ///         tokio::spawn(async move {
165    ///             *count.lock().await += 1;
166    ///         });
167    ///     })
168    /// ).await.unwrap();
169    ///
170    /// // Wait for 5 items
171    /// for _ in 0..5 {
172    ///     notify_rx.recv().await.unwrap();
173    /// }
174    ///
175    /// // Give a tiny bit of time for the error callback spawn to finish updating the count
176    /// // (Since the callback spawns another task)
177    /// // Alternatively, we could use a channel in the error callback too.
178    /// tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
179    ///
180    /// assert_eq!(*error_count.lock().await, 2); // Items 2 and 4 errored
181    /// # }
182    /// ```
183    ///
184    /// ## With Cancellation
185    ///
186    /// Use a cancellation token to stop processing:
187    ///
188    /// ```
189    /// use fluxion_exec::SubscribeExt;
190    /// use tokio::sync::mpsc::unbounded_channel;
191    /// use tokio_stream::wrappers::UnboundedReceiverStream;
192    /// use futures::StreamExt;
193    /// use tokio_util::sync::CancellationToken;
194    /// use std::sync::Arc;
195    /// use tokio::sync::Mutex;
196    ///
197    /// # #[tokio::main]
198    /// # async fn main() {
199    /// let (tx, rx) = unbounded_channel();
200    /// let stream = UnboundedReceiverStream::new(rx);
201    ///
202    /// let cancel_token = CancellationToken::new();
203    /// let cancel_clone = cancel_token.clone();
204    ///
205    /// let processed = Arc::new(Mutex::new(Vec::new()));
206    /// let processed_clone = processed.clone();
207    /// let (notify_tx, mut notify_rx) = unbounded_channel();
208    ///
209    /// let handle = tokio::spawn(async move {
210    ///     stream.subscribe(
211    ///         move |item, token| {
212    ///             let vec = processed_clone.clone();
213    ///             let notify_tx = notify_tx.clone();
214    ///             async move {
215    ///                 if token.is_cancelled() {
216    ///                     return Ok(());
217    ///                 }
218    ///                 vec.lock().await.push(item);
219    ///                 let _ = notify_tx.send(());
220    ///                 Ok::<(), std::io::Error>(())
221    ///             }
222    ///         },
223    ///         Some(cancel_token),
224    ///         None::<fn(std::io::Error)>
225    ///     ).await
226    /// });
227    ///
228    /// // Send items
229    /// tx.send(1).unwrap();
230    /// tx.send(2).unwrap();
231    /// tx.send(3).unwrap();
232    ///
233    /// // Wait for first item to be processed
234    /// notify_rx.recv().await.unwrap();
235    ///
236    /// // Cancel now
237    /// cancel_clone.cancel();
238    /// drop(tx);
239    ///
240    /// handle.await.unwrap().unwrap();
241    ///
242    /// // At least one item should be processed before cancellation
243    /// assert!(!processed.lock().await.is_empty());
244    /// # }
245    /// ```
246    ///
247    /// ## Database Write Pattern
248    ///
249    /// Process events and persist to a database:
250    ///
251    /// ```
252    /// use fluxion_exec::SubscribeExt;
253    /// use futures::stream;
254    /// use std::sync::Arc;
255    /// use tokio::sync::Mutex;
256    /// use tokio::sync::mpsc::unbounded_channel;
257    ///
258    /// # #[tokio::main]
259    /// # async fn main() {
260    /// #[derive(Clone, Debug)]
261    /// struct Event { id: u32, data: String }
262    ///
263    /// // Simulated database
264    /// let db = Arc::new(Mutex::new(Vec::new()));
265    /// let db_clone = db.clone();
266    /// let (notify_tx, mut notify_rx) = unbounded_channel();
267    ///
268    /// let events = vec![
269    ///     Event { id: 1, data: "event1".to_string() },
270    ///     Event { id: 2, data: "event2".to_string() },
271    /// ];
272    ///
273    /// let stream = stream::iter(events);
274    ///
275    /// stream.subscribe(
276    ///     move |event, _token| {
277    ///         let db = db_clone.clone();
278    ///         let notify_tx = notify_tx.clone();
279    ///         async move {
280    ///             // Simulate database write
281    ///             db.lock().await.push(event);
282    ///             let _ = notify_tx.send(());
283    ///             Ok::<(), std::io::Error>(())
284    ///         }
285    ///     },
286    ///     None,
287    ///     Some(|err| eprintln!("DB Error: {}", err))
288    /// ).await.unwrap();
289    ///
290    /// // Wait for 2 events
291    /// notify_rx.recv().await.unwrap();
292    /// notify_rx.recv().await.unwrap();
293    ///
294    /// assert_eq!(db.lock().await.len(), 2);
295    /// # }
296    /// ```
297    ///
298    /// # Thread Safety
299    ///
300    /// All spawned tasks run on the tokio runtime. The subscription completes
301    /// when the stream ends, not when all spawned tasks complete.
302    async fn subscribe<F, Fut, E, OnError>(
303        self,
304        on_next_func: F,
305        cancellation_token: Option<CancellationToken>,
306        on_error_callback: Option<OnError>,
307    ) -> Result<()>
308    where
309        F: Fn(T, CancellationToken) -> Fut + Clone + Send + Sync + 'static,
310        Fut: Future<Output = std::result::Result<(), E>> + Send + 'static,
311        OnError: Fn(E) + Clone + Send + Sync + 'static,
312        T: Debug + Send + Clone + 'static,
313        E: Error + Send + Sync + 'static;
314}
315
316#[async_trait]
317impl<S, T> SubscribeExt<T> for S
318where
319    S: Stream<Item = T> + Send + Unpin + 'static,
320    T: Send + 'static,
321{
322    async fn subscribe<F, Fut, E, OnError>(
323        mut self,
324        on_next_func: F,
325        cancellation_token: Option<CancellationToken>,
326        on_error_callback: Option<OnError>,
327    ) -> Result<()>
328    where
329        F: Fn(T, CancellationToken) -> Fut + Clone + Send + Sync + 'static,
330        Fut: Future<Output = std::result::Result<(), E>> + Send + 'static,
331        OnError: Fn(E) + Clone + Send + Sync + 'static,
332        T: Debug + Send + Clone + 'static,
333        E: Error + Send + Sync + 'static,
334    {
335        let cancellation_token = cancellation_token.unwrap_or_default();
336        let (error_tx, mut error_rx) = unbounded_channel();
337
338        while let Some(item) = self.next().await {
339            if cancellation_token.is_cancelled() {
340                break;
341            }
342
343            let on_next_func = on_next_func.clone();
344            let cancellation_token = cancellation_token.clone();
345            let on_error_callback = on_error_callback.clone();
346            let error_tx = error_tx.clone();
347
348            tokio::spawn(async move {
349                let result = on_next_func(item.clone(), cancellation_token).await;
350
351                if let Err(error) = result {
352                    if let Some(on_error_callback) = on_error_callback {
353                        on_error_callback(error);
354                    } else {
355                        // Collect error for later aggregation
356                        let _ = error_tx.send(error);
357                    }
358                }
359            });
360        }
361
362        // Drop the original sender so the channel closes
363        drop(error_tx);
364
365        // Collect all errors from the channel
366        let mut collected_errors = Vec::new();
367        while let Some(error) = error_rx.recv().await {
368            collected_errors.push(error);
369        }
370
371        if !collected_errors.is_empty() {
372            Err(FluxionError::from_user_errors(collected_errors))
373        } else {
374            Ok(())
375        }
376    }
377}