fluxion_exec/subscribe_async.rs
1// Copyright 2025 Umberto Gotti <umberto.gotti@umbertogotti.dev>
2// Licensed under the Apache License, Version 2.0
3// http://www.apache.org/licenses/LICENSE-2.0
4
5use async_trait::async_trait;
6use futures::stream::Stream;
7use futures::stream::StreamExt;
8use std::error::Error;
9use std::fmt::Debug;
10use std::future::Future;
11use tokio::sync::mpsc::unbounded_channel;
12use tokio_util::sync::CancellationToken;
13
14use fluxion_core::{FluxionError, Result};
15
16/// Extension trait providing async subscription capabilities for streams.
17///
18/// This trait enables processing stream items with async handlers in a sequential manner.
19#[async_trait]
20pub trait SubscribeAsyncExt<T>: Stream<Item = T> + Sized {
21 /// Subscribes to the stream with an async handler, processing items sequentially.
22 ///
23 /// This method consumes the stream and spawns async tasks to process each item.
24 /// Items are processed in the order they arrive, with each item's handler running
25 /// to completion before the next item is processed (though handlers run concurrently
26 /// via tokio spawn).
27 ///
28 /// # Behavior
29 ///
30 /// - Processes each stream item with the provided async handler
31 /// - Spawns a new task for each item (non-blocking)
32 /// - Continues until stream ends or cancellation token is triggered
33 /// - Errors from handlers are passed to the error callback if provided
34 /// - If no error callback provided, errors are collected and returned on completion
35 ///
36 /// # Arguments
37 ///
38 /// * `on_next_func` - Async function called for each stream item. Receives the item
39 /// and a cancellation token. Returns `Result<(), E>`.
40 /// * `cancellation_token` - Optional token to stop processing. If `None`, a default
41 /// token is created that never cancels.
42 /// * `on_error_callback` - Optional error handler called when `on_next_func` returns
43 /// an error. If `None`, errors are collected and returned.
44 ///
45 /// # Type Parameters
46 ///
47 /// * `F` - Function type for the item handler
48 /// * `Fut` - Future type returned by the handler
49 /// * `E` - Error type that implements `std::error::Error`
50 /// * `OnError` - Function type for error handling
51 ///
52 /// # Errors
53 ///
54 /// Returns `Err(FluxionError::MultipleErrors)` if any items failed to process and
55 /// no error callback was provided. If an error callback is provided, errors are
56 /// passed to it and the function returns `Ok(())` on stream completion.
57 ///
58 /// The subscription continues processing subsequent items even if individual items
59 /// fail, unless the cancellation token is triggered.
60 ///
61 /// # See Also
62 ///
63 /// - [`subscribe_latest_async`](crate::SubscribeLatestAsyncExt::subscribe_latest_async) - Cancels old work for new items
64 ///
65 /// # Examples
66 ///
67 /// ## Basic Usage
68 ///
69 /// Process all items sequentially:
70 ///
71 /// ```
72 /// use fluxion_exec::SubscribeAsyncExt;
73 /// use futures::stream;
74 /// use std::sync::Arc;
75 /// use tokio::sync::Mutex;
76 ///
77 /// # #[tokio::main]
78 /// # async fn main() {
79 /// let results = Arc::new(Mutex::new(Vec::new()));
80 /// let results_clone = results.clone();
81 ///
82 /// let stream = stream::iter(vec![1, 2, 3, 4, 5]);
83 ///
84 /// // Subscribe and process each item
85 /// stream.subscribe_async(
86 /// move |item, _token| {
87 /// let results = results_clone.clone();
88 /// async move {
89 /// results.lock().await.push(item * 2);
90 /// Ok::<(), std::io::Error>(())
91 /// }
92 /// },
93 /// None, // No cancellation
94 /// None::<fn(std::io::Error)> // No error callback
95 /// ).await.unwrap();
96 ///
97 /// // Wait a bit for spawned tasks to complete
98 /// tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
99 ///
100 /// let processed = results.lock().await;
101 /// assert!(processed.contains(&2));
102 /// assert!(processed.contains(&4));
103 /// # }
104 /// ```
105 ///
106 /// ## With Error Handling
107 ///
108 /// Use an error callback to handle errors without stopping the stream:
109 ///
110 /// ```
111 /// use fluxion_exec::SubscribeAsyncExt;
112 /// use futures::stream;
113 /// use std::sync::Arc;
114 /// use tokio::sync::Mutex;
115 ///
116 /// # #[tokio::main]
117 /// # async fn main() {
118 /// #[derive(Debug)]
119 /// struct MyError(String);
120 /// impl std::fmt::Display for MyError {
121 /// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 /// write!(f, "MyError: {}", self.0)
123 /// }
124 /// }
125 /// impl std::error::Error for MyError {}
126 ///
127 /// let error_count = Arc::new(Mutex::new(0));
128 /// let error_count_clone = error_count.clone();
129 ///
130 /// let stream = stream::iter(vec![1, 2, 3, 4, 5]);
131 ///
132 /// stream.subscribe_async(
133 /// |item, _token| async move {
134 /// if item % 2 == 0 {
135 /// Err(MyError(format!("Even number: {}", item)))
136 /// } else {
137 /// Ok(())
138 /// }
139 /// },
140 /// None,
141 /// Some(move |_err| {
142 /// let count = error_count_clone.clone();
143 /// tokio::spawn(async move {
144 /// *count.lock().await += 1;
145 /// });
146 /// })
147 /// ).await.unwrap();
148 ///
149 /// tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
150 /// assert_eq!(*error_count.lock().await, 2); // Items 2 and 4 errored
151 /// # }
152 /// ```
153 ///
154 /// ## With Cancellation
155 ///
156 /// Use a cancellation token to stop processing:
157 ///
158 /// ```
159 /// use fluxion_exec::SubscribeAsyncExt;
160 /// use tokio::sync::mpsc::unbounded_channel;
161 /// use tokio_stream::wrappers::UnboundedReceiverStream;
162 /// use futures::StreamExt;
163 /// use tokio_util::sync::CancellationToken;
164 /// use std::sync::Arc;
165 /// use tokio::sync::Mutex;
166 ///
167 /// # #[tokio::main]
168 /// # async fn main() {
169 /// let (tx, rx) = unbounded_channel();
170 /// let stream = UnboundedReceiverStream::new(rx);
171 ///
172 /// let cancel_token = CancellationToken::new();
173 /// let cancel_clone = cancel_token.clone();
174 ///
175 /// let processed = Arc::new(Mutex::new(Vec::new()));
176 /// let processed_clone = processed.clone();
177 ///
178 /// let handle = tokio::spawn(async move {
179 /// stream.subscribe_async(
180 /// move |item, token| {
181 /// let vec = processed_clone.clone();
182 /// async move {
183 /// if token.is_cancelled() {
184 /// return Ok(());
185 /// }
186 /// tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
187 /// vec.lock().await.push(item);
188 /// Ok::<(), std::io::Error>(())
189 /// }
190 /// },
191 /// Some(cancel_token),
192 /// None::<fn(std::io::Error)>
193 /// ).await
194 /// });
195 ///
196 /// // Send a few items
197 /// tx.send(1).unwrap();
198 /// tx.send(2).unwrap();
199 /// tx.send(3).unwrap();
200 ///
201 /// // Wait a bit then cancel
202 /// tokio::time::sleep(tokio::time::Duration::from_millis(15)).await;
203 /// cancel_clone.cancel();
204 /// drop(tx);
205 ///
206 /// handle.await.unwrap().unwrap();
207 ///
208 /// // At least one item should be processed before cancellation
209 /// assert!(!processed.lock().await.is_empty());
210 /// # }
211 /// ```
212 ///
213 /// ## Database Write Pattern
214 ///
215 /// Process events and persist to a database:
216 ///
217 /// ```
218 /// use fluxion_exec::SubscribeAsyncExt;
219 /// use futures::stream;
220 /// use std::sync::Arc;
221 /// use tokio::sync::Mutex;
222 ///
223 /// # #[tokio::main]
224 /// # async fn main() {
225 /// #[derive(Clone, Debug)]
226 /// struct Event { id: u32, data: String }
227 ///
228 /// // Simulated database
229 /// let db = Arc::new(Mutex::new(Vec::new()));
230 /// let db_clone = db.clone();
231 ///
232 /// let events = vec![
233 /// Event { id: 1, data: "event1".to_string() },
234 /// Event { id: 2, data: "event2".to_string() },
235 /// ];
236 ///
237 /// let stream = stream::iter(events);
238 ///
239 /// stream.subscribe_async(
240 /// move |event, _token| {
241 /// let db = db_clone.clone();
242 /// async move {
243 /// // Simulate database write
244 /// tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
245 /// db.lock().await.push(event);
246 /// Ok::<(), std::io::Error>(())
247 /// }
248 /// },
249 /// None,
250 /// Some(|err| eprintln!("DB Error: {}", err))
251 /// ).await.unwrap();
252 ///
253 /// tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
254 /// assert_eq!(db.lock().await.len(), 2);
255 /// # }
256 /// ```
257 ///
258 /// # Thread Safety
259 ///
260 /// All spawned tasks run on the tokio runtime. The subscription completes
261 /// when the stream ends, not when all spawned tasks complete.
262 async fn subscribe_async<F, Fut, E, OnError>(
263 self,
264 on_next_func: F,
265 cancellation_token: Option<CancellationToken>,
266 on_error_callback: Option<OnError>,
267 ) -> Result<()>
268 where
269 F: Fn(T, CancellationToken) -> Fut + Clone + Send + Sync + 'static,
270 Fut: Future<Output = std::result::Result<(), E>> + Send + 'static,
271 OnError: Fn(E) + Clone + Send + Sync + 'static,
272 T: Debug + Send + Clone + 'static,
273 E: Error + Send + Sync + 'static;
274}
275
276#[async_trait]
277impl<S, T> SubscribeAsyncExt<T> for S
278where
279 S: Stream<Item = T> + Send + Unpin + 'static,
280 T: Send + 'static,
281{
282 async fn subscribe_async<F, Fut, E, OnError>(
283 mut self,
284 on_next_func: F,
285 cancellation_token: Option<CancellationToken>,
286 on_error_callback: Option<OnError>,
287 ) -> Result<()>
288 where
289 F: Fn(T, CancellationToken) -> Fut + Clone + Send + Sync + 'static,
290 Fut: Future<Output = std::result::Result<(), E>> + Send + 'static,
291 OnError: Fn(E) + Clone + Send + Sync + 'static,
292 T: Debug + Send + Clone + 'static,
293 E: Error + Send + Sync + 'static,
294 {
295 let cancellation_token = cancellation_token.unwrap_or_default();
296 let (error_tx, mut error_rx) = unbounded_channel();
297
298 while let Some(item) = self.next().await {
299 if cancellation_token.is_cancelled() {
300 break;
301 }
302
303 let on_next_func = on_next_func.clone();
304 let cancellation_token = cancellation_token.clone();
305 let on_error_callback = on_error_callback.clone();
306 let error_tx = error_tx.clone();
307
308 tokio::spawn(async move {
309 let result = on_next_func(item.clone(), cancellation_token).await;
310
311 if let Err(error) = result {
312 if let Some(on_error_callback) = on_error_callback {
313 on_error_callback(error);
314 } else {
315 // Collect error for later aggregation
316 let _ = error_tx.send(error);
317 }
318 }
319 });
320 }
321
322 // Drop the original sender so the channel closes
323 drop(error_tx);
324
325 // Collect all errors from the channel
326 let mut collected_errors = Vec::new();
327 while let Some(error) = error_rx.recv().await {
328 collected_errors.push(error);
329 }
330
331 if !collected_errors.is_empty() {
332 Err(FluxionError::from_user_errors(collected_errors))
333 } else {
334 Ok(())
335 }
336 }
337}