fluxion_exec/
subscribe_async.rs

1// Copyright 2025 Umberto Gotti <umberto.gotti@umbertogotti.dev>
2// Licensed under the Apache License, Version 2.0
3// http://www.apache.org/licenses/LICENSE-2.0
4
5use async_trait::async_trait;
6use futures::stream::Stream;
7use futures::stream::StreamExt;
8use std::future::Future;
9use tokio::sync::mpsc::unbounded_channel;
10use tokio_util::sync::CancellationToken;
11
12use fluxion_core::{FluxionError, Result};
13
14/// Extension trait providing async subscription capabilities for streams.
15///
16/// This trait enables processing stream items with async handlers in a sequential manner.
17#[async_trait]
18pub trait SubscribeAsyncExt<T>: Stream<Item = T> + Sized {
19    /// Subscribes to the stream with an async handler, processing items sequentially.
20    ///
21    /// This method consumes the stream and spawns async tasks to process each item.
22    /// Items are processed in the order they arrive, with each item's handler running
23    /// to completion before the next item is processed (though handlers run concurrently
24    /// via tokio spawn).
25    ///
26    /// # Behavior
27    ///
28    /// - Processes each stream item with the provided async handler
29    /// - Spawns a new task for each item (non-blocking)
30    /// - Continues until stream ends or cancellation token is triggered
31    /// - Errors from handlers are passed to the error callback if provided
32    /// - If no error callback provided, errors are collected and returned on completion
33    ///
34    /// # Arguments
35    ///
36    /// * `on_next_func` - Async function called for each stream item. Receives the item
37    ///                    and a cancellation token. Returns `Result<(), E>`.
38    /// * `cancellation_token` - Optional token to stop processing. If `None`, a default
39    ///                          token is created that never cancels.
40    /// * `on_error_callback` - Optional error handler called when `on_next_func` returns
41    ///                         an error. If `None`, errors are collected and returned.
42    ///
43    /// # Type Parameters
44    ///
45    /// * `F` - Function type for the item handler
46    /// * `Fut` - Future type returned by the handler
47    /// * `E` - Error type that implements `std::error::Error`
48    /// * `OnError` - Function type for error handling
49    ///
50    /// # Errors
51    ///
52    /// Returns `Err(FluxionError::MultipleErrors)` if any items failed to process and
53    /// no error callback was provided. If an error callback is provided, errors are
54    /// passed to it and the function returns `Ok(())` on stream completion.
55    ///
56    /// The subscription continues processing subsequent items even if individual items
57    /// fail, unless the cancellation token is triggered.
58    ///
59    /// # See Also
60    ///
61    /// - [`subscribe_latest_async`](crate::SubscribeLatestAsyncExt::subscribe_latest_async) - Cancels old work for new items
62    ///
63    /// # Examples
64    ///
65    /// ## Basic Usage
66    ///
67    /// Process all items sequentially:
68    ///
69    /// ```
70    /// use fluxion_exec::SubscribeAsyncExt;
71    /// use futures::stream;
72    /// use std::sync::Arc;
73    /// use tokio::sync::Mutex;
74    ///
75    /// # #[tokio::main]
76    /// # async fn main() {
77    /// let results = Arc::new(Mutex::new(Vec::new()));
78    /// let results_clone = results.clone();
79    ///
80    /// let stream = stream::iter(vec![1, 2, 3, 4, 5]);
81    ///
82    /// // Subscribe and process each item
83    /// stream.subscribe_async(
84    ///     move |item, _token| {
85    ///         let results = results_clone.clone();
86    ///         async move {
87    ///             results.lock().await.push(item * 2);
88    ///             Ok::<(), std::io::Error>(())
89    ///         }
90    ///     },
91    ///     None, // No cancellation
92    ///     None::<fn(std::io::Error)>  // No error callback
93    /// ).await.unwrap();
94    ///
95    /// // Wait a bit for spawned tasks to complete
96    /// tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
97    ///
98    /// let processed = results.lock().await;
99    /// assert!(processed.contains(&2));
100    /// assert!(processed.contains(&4));
101    /// # }
102    /// ```
103    ///
104    /// ## With Error Handling
105    ///
106    /// Use an error callback to handle errors without stopping the stream:
107    ///
108    /// ```
109    /// use fluxion_exec::SubscribeAsyncExt;
110    /// use futures::stream;
111    /// use std::sync::Arc;
112    /// use tokio::sync::Mutex;
113    ///
114    /// # #[tokio::main]
115    /// # async fn main() {
116    /// #[derive(Debug)]
117    /// struct MyError(String);
118    /// impl std::fmt::Display for MyError {
119    ///     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120    ///         write!(f, "MyError: {}", self.0)
121    ///     }
122    /// }
123    /// impl std::error::Error for MyError {}
124    ///
125    /// let error_count = Arc::new(Mutex::new(0));
126    /// let error_count_clone = error_count.clone();
127    ///
128    /// let stream = stream::iter(vec![1, 2, 3, 4, 5]);
129    ///
130    /// stream.subscribe_async(
131    ///     |item, _token| async move {
132    ///         if item % 2 == 0 {
133    ///             Err(MyError(format!("Even number: {}", item)))
134    ///         } else {
135    ///             Ok(())
136    ///         }
137    ///     },
138    ///     None,
139    ///     Some(move |_err| {
140    ///         let count = error_count_clone.clone();
141    ///         tokio::spawn(async move {
142    ///             *count.lock().await += 1;
143    ///         });
144    ///     })
145    /// ).await.unwrap();
146    ///
147    /// tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
148    /// assert_eq!(*error_count.lock().await, 2); // Items 2 and 4 errored
149    /// # }
150    /// ```
151    ///
152    /// ## With Cancellation
153    ///
154    /// Use a cancellation token to stop processing:
155    ///
156    /// ```
157    /// use fluxion_exec::SubscribeAsyncExt;
158    /// use tokio::sync::mpsc::unbounded_channel;
159    /// use tokio_stream::wrappers::UnboundedReceiverStream;
160    /// use futures::StreamExt;
161    /// use tokio_util::sync::CancellationToken;
162    /// use std::sync::Arc;
163    /// use tokio::sync::Mutex;
164    ///
165    /// # #[tokio::main]
166    /// # async fn main() {
167    /// let (tx, rx) = unbounded_channel();
168    /// let stream = UnboundedReceiverStream::new(rx);
169    ///
170    /// let cancel_token = CancellationToken::new();
171    /// let cancel_clone = cancel_token.clone();
172    ///
173    /// let processed = Arc::new(Mutex::new(Vec::new()));
174    /// let processed_clone = processed.clone();
175    ///
176    /// let handle = tokio::spawn(async move {
177    ///     stream.subscribe_async(
178    ///         move |item, token| {
179    ///             let vec = processed_clone.clone();
180    ///             async move {
181    ///                 if token.is_cancelled() {
182    ///                     return Ok(());
183    ///                 }
184    ///                 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
185    ///                 vec.lock().await.push(item);
186    ///                 Ok::<(), std::io::Error>(())
187    ///             }
188    ///         },
189    ///         Some(cancel_token),
190    ///         None::<fn(std::io::Error)>
191    ///     ).await
192    /// });
193    ///
194    /// // Send a few items
195    /// tx.send(1).unwrap();
196    /// tx.send(2).unwrap();
197    /// tx.send(3).unwrap();
198    ///
199    /// // Wait a bit then cancel
200    /// tokio::time::sleep(tokio::time::Duration::from_millis(15)).await;
201    /// cancel_clone.cancel();
202    /// drop(tx);
203    ///
204    /// handle.await.unwrap().unwrap();
205    ///
206    /// // At least one item should be processed before cancellation
207    /// assert!(!processed.lock().await.is_empty());
208    /// # }
209    /// ```
210    ///
211    /// ## Database Write Pattern
212    ///
213    /// Process events and persist to a database:
214    ///
215    /// ```
216    /// use fluxion_exec::SubscribeAsyncExt;
217    /// use futures::stream;
218    /// use std::sync::Arc;
219    /// use tokio::sync::Mutex;
220    ///
221    /// # #[tokio::main]
222    /// # async fn main() {
223    /// #[derive(Clone, Debug)]
224    /// struct Event { id: u32, data: String }
225    ///
226    /// // Simulated database
227    /// let db = Arc::new(Mutex::new(Vec::new()));
228    /// let db_clone = db.clone();
229    ///
230    /// let events = vec![
231    ///     Event { id: 1, data: "event1".to_string() },
232    ///     Event { id: 2, data: "event2".to_string() },
233    /// ];
234    ///
235    /// let stream = stream::iter(events);
236    ///
237    /// stream.subscribe_async(
238    ///     move |event, _token| {
239    ///         let db = db_clone.clone();
240    ///         async move {
241    ///             // Simulate database write
242    ///             tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
243    ///             db.lock().await.push(event);
244    ///             Ok::<(), std::io::Error>(())
245    ///         }
246    ///     },
247    ///     None,
248    ///     Some(|err| eprintln!("DB Error: {}", err))
249    /// ).await.unwrap();
250    ///
251    /// tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
252    /// assert_eq!(db.lock().await.len(), 2);
253    /// # }
254    /// ```
255    ///
256    /// # Thread Safety
257    ///
258    /// All spawned tasks run on the tokio runtime. The subscription completes
259    /// when the stream ends, not when all spawned tasks complete.
260    async fn subscribe_async<F, Fut, E, OnError>(
261        self,
262        on_next_func: F,
263        cancellation_token: Option<CancellationToken>,
264        on_error_callback: Option<OnError>,
265    ) -> Result<()>
266    where
267        F: Fn(T, CancellationToken) -> Fut + Clone + Send + Sync + 'static,
268        Fut: Future<Output = std::result::Result<(), E>> + Send + 'static,
269        OnError: Fn(E) + Clone + Send + Sync + 'static,
270        T: std::fmt::Debug + Send + Clone + 'static,
271        E: std::error::Error + Send + Sync + 'static;
272}
273
274#[async_trait]
275impl<S, T> SubscribeAsyncExt<T> for S
276where
277    S: Stream<Item = T> + Send + Unpin + 'static,
278    T: Send + 'static,
279{
280    async fn subscribe_async<F, Fut, E, OnError>(
281        mut self,
282        on_next_func: F,
283        cancellation_token: Option<CancellationToken>,
284        on_error_callback: Option<OnError>,
285    ) -> Result<()>
286    where
287        F: Fn(T, CancellationToken) -> Fut + Clone + Send + Sync + 'static,
288        Fut: Future<Output = std::result::Result<(), E>> + Send + 'static,
289        OnError: Fn(E) + Clone + Send + Sync + 'static,
290        T: std::fmt::Debug + Send + Clone + 'static,
291        E: std::error::Error + Send + Sync + 'static,
292    {
293        let cancellation_token = cancellation_token.unwrap_or_default();
294        let (error_tx, mut error_rx) = unbounded_channel();
295
296        while let Some(item) = self.next().await {
297            if cancellation_token.is_cancelled() {
298                break;
299            }
300
301            let on_next_func = on_next_func.clone();
302            let cancellation_token = cancellation_token.clone();
303            let on_error_callback = on_error_callback.clone();
304            let error_tx = error_tx.clone();
305
306            tokio::spawn(async move {
307                let result = on_next_func(item.clone(), cancellation_token).await;
308
309                if let Err(error) = result {
310                    if let Some(on_error_callback) = on_error_callback {
311                        on_error_callback(error);
312                    } else {
313                        // Collect error for later aggregation
314                        let _ = error_tx.send(error);
315                    }
316                }
317            });
318        }
319
320        // Drop the original sender so the channel closes
321        drop(error_tx);
322
323        // Collect all errors from the channel
324        let mut collected_errors = Vec::new();
325        while let Some(error) = error_rx.recv().await {
326            collected_errors.push(error);
327        }
328
329        if !collected_errors.is_empty() {
330            Err(FluxionError::from_user_errors(collected_errors))
331        } else {
332            Ok(())
333        }
334    }
335}