fluxion_exec/
subscribe_async.rs

1// Copyright 2025 Umberto Gotti <umberto.gotti@umbertogotti.dev>
2// Licensed under the Apache License, Version 2.0
3// http://www.apache.org/licenses/LICENSE-2.0
4
5use async_trait::async_trait;
6use futures::stream::Stream;
7use futures::stream::StreamExt;
8use std::error::Error;
9use std::fmt::Debug;
10use std::future::Future;
11use tokio::sync::mpsc::unbounded_channel;
12use tokio_util::sync::CancellationToken;
13
14use fluxion_core::{FluxionError, Result};
15
16/// Extension trait providing async subscription capabilities for streams.
17///
18/// This trait enables processing stream items with async handlers in a sequential manner.
19#[async_trait]
20pub trait SubscribeAsyncExt<T>: Stream<Item = T> + Sized {
21    /// Subscribes to the stream with an async handler, processing items sequentially.
22    ///
23    /// This method consumes the stream and spawns async tasks to process each item.
24    /// Items are processed in the order they arrive, with each item's handler running
25    /// to completion before the next item is processed (though handlers run concurrently
26    /// via tokio spawn).
27    ///
28    /// # Behavior
29    ///
30    /// - Processes each stream item with the provided async handler
31    /// - Spawns a new task for each item (non-blocking)
32    /// - Continues until stream ends or cancellation token is triggered
33    /// - Errors from handlers are passed to the error callback if provided
34    /// - If no error callback provided, errors are collected and returned on completion
35    ///
36    /// # Arguments
37    ///
38    /// * `on_next_func` - Async function called for each stream item. Receives the item
39    ///                    and a cancellation token. Returns `Result<(), E>`.
40    /// * `cancellation_token` - Optional token to stop processing. If `None`, a default
41    ///                          token is created that never cancels.
42    /// * `on_error_callback` - Optional error handler called when `on_next_func` returns
43    ///                         an error. If `None`, errors are collected and returned.
44    ///
45    /// # Type Parameters
46    ///
47    /// * `F` - Function type for the item handler
48    /// * `Fut` - Future type returned by the handler
49    /// * `E` - Error type that implements `std::error::Error`
50    /// * `OnError` - Function type for error handling
51    ///
52    /// # Errors
53    ///
54    /// Returns `Err(FluxionError::MultipleErrors)` if any items failed to process and
55    /// no error callback was provided. If an error callback is provided, errors are
56    /// passed to it and the function returns `Ok(())` on stream completion.
57    ///
58    /// The subscription continues processing subsequent items even if individual items
59    /// fail, unless the cancellation token is triggered.
60    ///
61    /// # See Also
62    ///
63    /// - [`subscribe_latest_async`](crate::SubscribeLatestAsyncExt::subscribe_latest_async) - Cancels old work for new items
64    ///
65    /// # Examples
66    ///
67    /// ## Basic Usage
68    ///
69    /// Process all items sequentially:
70    ///
71    /// ```
72    /// use fluxion_exec::SubscribeAsyncExt;
73    /// use futures::stream;
74    /// use std::sync::Arc;
75    /// use tokio::sync::Mutex;
76    ///
77    /// # #[tokio::main]
78    /// # async fn main() {
79    /// let results = Arc::new(Mutex::new(Vec::new()));
80    /// let results_clone = results.clone();
81    ///
82    /// let stream = stream::iter(vec![1, 2, 3, 4, 5]);
83    ///
84    /// // Subscribe and process each item
85    /// stream.subscribe_async(
86    ///     move |item, _token| {
87    ///         let results = results_clone.clone();
88    ///         async move {
89    ///             results.lock().await.push(item * 2);
90    ///             Ok::<(), std::io::Error>(())
91    ///         }
92    ///     },
93    ///     None, // No cancellation
94    ///     None::<fn(std::io::Error)>  // No error callback
95    /// ).await.unwrap();
96    ///
97    /// // Wait a bit for spawned tasks to complete
98    /// tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
99    ///
100    /// let processed = results.lock().await;
101    /// assert!(processed.contains(&2));
102    /// assert!(processed.contains(&4));
103    /// # }
104    /// ```
105    ///
106    /// ## With Error Handling
107    ///
108    /// Use an error callback to handle errors without stopping the stream:
109    ///
110    /// ```
111    /// use fluxion_exec::SubscribeAsyncExt;
112    /// use futures::stream;
113    /// use std::sync::Arc;
114    /// use tokio::sync::Mutex;
115    ///
116    /// # #[tokio::main]
117    /// # async fn main() {
118    /// #[derive(Debug)]
119    /// struct MyError(String);
120    /// impl std::fmt::Display for MyError {
121    ///     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122    ///         write!(f, "MyError: {}", self.0)
123    ///     }
124    /// }
125    /// impl std::error::Error for MyError {}
126    ///
127    /// let error_count = Arc::new(Mutex::new(0));
128    /// let error_count_clone = error_count.clone();
129    ///
130    /// let stream = stream::iter(vec![1, 2, 3, 4, 5]);
131    ///
132    /// stream.subscribe_async(
133    ///     |item, _token| async move {
134    ///         if item % 2 == 0 {
135    ///             Err(MyError(format!("Even number: {}", item)))
136    ///         } else {
137    ///             Ok(())
138    ///         }
139    ///     },
140    ///     None,
141    ///     Some(move |_err| {
142    ///         let count = error_count_clone.clone();
143    ///         tokio::spawn(async move {
144    ///             *count.lock().await += 1;
145    ///         });
146    ///     })
147    /// ).await.unwrap();
148    ///
149    /// tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
150    /// assert_eq!(*error_count.lock().await, 2); // Items 2 and 4 errored
151    /// # }
152    /// ```
153    ///
154    /// ## With Cancellation
155    ///
156    /// Use a cancellation token to stop processing:
157    ///
158    /// ```
159    /// use fluxion_exec::SubscribeAsyncExt;
160    /// use tokio::sync::mpsc::unbounded_channel;
161    /// use tokio_stream::wrappers::UnboundedReceiverStream;
162    /// use futures::StreamExt;
163    /// use tokio_util::sync::CancellationToken;
164    /// use std::sync::Arc;
165    /// use tokio::sync::Mutex;
166    ///
167    /// # #[tokio::main]
168    /// # async fn main() {
169    /// let (tx, rx) = unbounded_channel();
170    /// let stream = UnboundedReceiverStream::new(rx);
171    ///
172    /// let cancel_token = CancellationToken::new();
173    /// let cancel_clone = cancel_token.clone();
174    ///
175    /// let processed = Arc::new(Mutex::new(Vec::new()));
176    /// let processed_clone = processed.clone();
177    ///
178    /// let handle = tokio::spawn(async move {
179    ///     stream.subscribe_async(
180    ///         move |item, token| {
181    ///             let vec = processed_clone.clone();
182    ///             async move {
183    ///                 if token.is_cancelled() {
184    ///                     return Ok(());
185    ///                 }
186    ///                 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
187    ///                 vec.lock().await.push(item);
188    ///                 Ok::<(), std::io::Error>(())
189    ///             }
190    ///         },
191    ///         Some(cancel_token),
192    ///         None::<fn(std::io::Error)>
193    ///     ).await
194    /// });
195    ///
196    /// // Send a few items
197    /// tx.send(1).unwrap();
198    /// tx.send(2).unwrap();
199    /// tx.send(3).unwrap();
200    ///
201    /// // Wait a bit then cancel
202    /// tokio::time::sleep(tokio::time::Duration::from_millis(15)).await;
203    /// cancel_clone.cancel();
204    /// drop(tx);
205    ///
206    /// handle.await.unwrap().unwrap();
207    ///
208    /// // At least one item should be processed before cancellation
209    /// assert!(!processed.lock().await.is_empty());
210    /// # }
211    /// ```
212    ///
213    /// ## Database Write Pattern
214    ///
215    /// Process events and persist to a database:
216    ///
217    /// ```
218    /// use fluxion_exec::SubscribeAsyncExt;
219    /// use futures::stream;
220    /// use std::sync::Arc;
221    /// use tokio::sync::Mutex;
222    ///
223    /// # #[tokio::main]
224    /// # async fn main() {
225    /// #[derive(Clone, Debug)]
226    /// struct Event { id: u32, data: String }
227    ///
228    /// // Simulated database
229    /// let db = Arc::new(Mutex::new(Vec::new()));
230    /// let db_clone = db.clone();
231    ///
232    /// let events = vec![
233    ///     Event { id: 1, data: "event1".to_string() },
234    ///     Event { id: 2, data: "event2".to_string() },
235    /// ];
236    ///
237    /// let stream = stream::iter(events);
238    ///
239    /// stream.subscribe_async(
240    ///     move |event, _token| {
241    ///         let db = db_clone.clone();
242    ///         async move {
243    ///             // Simulate database write
244    ///             tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
245    ///             db.lock().await.push(event);
246    ///             Ok::<(), std::io::Error>(())
247    ///         }
248    ///     },
249    ///     None,
250    ///     Some(|err| eprintln!("DB Error: {}", err))
251    /// ).await.unwrap();
252    ///
253    /// tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
254    /// assert_eq!(db.lock().await.len(), 2);
255    /// # }
256    /// ```
257    ///
258    /// # Thread Safety
259    ///
260    /// All spawned tasks run on the tokio runtime. The subscription completes
261    /// when the stream ends, not when all spawned tasks complete.
262    async fn subscribe_async<F, Fut, E, OnError>(
263        self,
264        on_next_func: F,
265        cancellation_token: Option<CancellationToken>,
266        on_error_callback: Option<OnError>,
267    ) -> Result<()>
268    where
269        F: Fn(T, CancellationToken) -> Fut + Clone + Send + Sync + 'static,
270        Fut: Future<Output = std::result::Result<(), E>> + Send + 'static,
271        OnError: Fn(E) + Clone + Send + Sync + 'static,
272        T: Debug + Send + Clone + 'static,
273        E: Error + Send + Sync + 'static;
274}
275
276#[async_trait]
277impl<S, T> SubscribeAsyncExt<T> for S
278where
279    S: Stream<Item = T> + Send + Unpin + 'static,
280    T: Send + 'static,
281{
282    async fn subscribe_async<F, Fut, E, OnError>(
283        mut self,
284        on_next_func: F,
285        cancellation_token: Option<CancellationToken>,
286        on_error_callback: Option<OnError>,
287    ) -> Result<()>
288    where
289        F: Fn(T, CancellationToken) -> Fut + Clone + Send + Sync + 'static,
290        Fut: Future<Output = std::result::Result<(), E>> + Send + 'static,
291        OnError: Fn(E) + Clone + Send + Sync + 'static,
292        T: Debug + Send + Clone + 'static,
293        E: Error + Send + Sync + 'static,
294    {
295        let cancellation_token = cancellation_token.unwrap_or_default();
296        let (error_tx, mut error_rx) = unbounded_channel();
297
298        while let Some(item) = self.next().await {
299            if cancellation_token.is_cancelled() {
300                break;
301            }
302
303            let on_next_func = on_next_func.clone();
304            let cancellation_token = cancellation_token.clone();
305            let on_error_callback = on_error_callback.clone();
306            let error_tx = error_tx.clone();
307
308            tokio::spawn(async move {
309                let result = on_next_func(item.clone(), cancellation_token).await;
310
311                if let Err(error) = result {
312                    if let Some(on_error_callback) = on_error_callback {
313                        on_error_callback(error);
314                    } else {
315                        // Collect error for later aggregation
316                        let _ = error_tx.send(error);
317                    }
318                }
319            });
320        }
321
322        // Drop the original sender so the channel closes
323        drop(error_tx);
324
325        // Collect all errors from the channel
326        let mut collected_errors = Vec::new();
327        while let Some(error) = error_rx.recv().await {
328            collected_errors.push(error);
329        }
330
331        if !collected_errors.is_empty() {
332            Err(FluxionError::from_user_errors(collected_errors))
333        } else {
334            Ok(())
335        }
336    }
337}