fluxion_exec/subscribe_async.rs
1// Copyright 2025 Umberto Gotti <umberto.gotti@umbertogotti.dev>
2// Licensed under the Apache License, Version 2.0
3// http://www.apache.org/licenses/LICENSE-2.0
4
5use async_trait::async_trait;
6use futures::stream::Stream;
7use futures::stream::StreamExt;
8use std::future::Future;
9use tokio::sync::mpsc::unbounded_channel;
10use tokio_util::sync::CancellationToken;
11
12use fluxion_core::{FluxionError, Result};
13
14/// Extension trait providing async subscription capabilities for streams.
15///
16/// This trait enables processing stream items with async handlers in a sequential manner.
17#[async_trait]
18pub trait SubscribeAsyncExt<T>: Stream<Item = T> + Sized {
19 /// Subscribes to the stream with an async handler, processing items sequentially.
20 ///
21 /// This method consumes the stream and spawns async tasks to process each item.
22 /// Items are processed in the order they arrive, with each item's handler running
23 /// to completion before the next item is processed (though handlers run concurrently
24 /// via tokio spawn).
25 ///
26 /// # Behavior
27 ///
28 /// - Processes each stream item with the provided async handler
29 /// - Spawns a new task for each item (non-blocking)
30 /// - Continues until stream ends or cancellation token is triggered
31 /// - Errors from handlers are passed to the error callback if provided
32 /// - If no error callback provided, errors are collected and returned on completion
33 ///
34 /// # Arguments
35 ///
36 /// * `on_next_func` - Async function called for each stream item. Receives the item
37 /// and a cancellation token. Returns `Result<(), E>`.
38 /// * `cancellation_token` - Optional token to stop processing. If `None`, a default
39 /// token is created that never cancels.
40 /// * `on_error_callback` - Optional error handler called when `on_next_func` returns
41 /// an error. If `None`, errors are collected and returned.
42 ///
43 /// # Type Parameters
44 ///
45 /// * `F` - Function type for the item handler
46 /// * `Fut` - Future type returned by the handler
47 /// * `E` - Error type that implements `std::error::Error`
48 /// * `OnError` - Function type for error handling
49 ///
50 /// # Errors
51 ///
52 /// Returns `Err(FluxionError::MultipleErrors)` if any items failed to process and
53 /// no error callback was provided. If an error callback is provided, errors are
54 /// passed to it and the function returns `Ok(())` on stream completion.
55 ///
56 /// The subscription continues processing subsequent items even if individual items
57 /// fail, unless the cancellation token is triggered.
58 ///
59 /// # See Also
60 ///
61 /// - [`subscribe_latest_async`](crate::SubscribeLatestAsyncExt::subscribe_latest_async) - Cancels old work for new items
62 ///
63 /// # Examples
64 ///
65 /// ## Basic Usage
66 ///
67 /// Process all items sequentially:
68 ///
69 /// ```
70 /// use fluxion_exec::SubscribeAsyncExt;
71 /// use futures::stream;
72 /// use std::sync::Arc;
73 /// use tokio::sync::Mutex;
74 ///
75 /// # #[tokio::main]
76 /// # async fn main() {
77 /// let results = Arc::new(Mutex::new(Vec::new()));
78 /// let results_clone = results.clone();
79 ///
80 /// let stream = stream::iter(vec![1, 2, 3, 4, 5]);
81 ///
82 /// // Subscribe and process each item
83 /// stream.subscribe_async(
84 /// move |item, _token| {
85 /// let results = results_clone.clone();
86 /// async move {
87 /// results.lock().await.push(item * 2);
88 /// Ok::<(), std::io::Error>(())
89 /// }
90 /// },
91 /// None, // No cancellation
92 /// None::<fn(std::io::Error)> // No error callback
93 /// ).await.unwrap();
94 ///
95 /// // Wait a bit for spawned tasks to complete
96 /// tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
97 ///
98 /// let processed = results.lock().await;
99 /// assert!(processed.contains(&2));
100 /// assert!(processed.contains(&4));
101 /// # }
102 /// ```
103 ///
104 /// ## With Error Handling
105 ///
106 /// Use an error callback to handle errors without stopping the stream:
107 ///
108 /// ```
109 /// use fluxion_exec::SubscribeAsyncExt;
110 /// use futures::stream;
111 /// use std::sync::Arc;
112 /// use tokio::sync::Mutex;
113 ///
114 /// # #[tokio::main]
115 /// # async fn main() {
116 /// #[derive(Debug)]
117 /// struct MyError(String);
118 /// impl std::fmt::Display for MyError {
119 /// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 /// write!(f, "MyError: {}", self.0)
121 /// }
122 /// }
123 /// impl std::error::Error for MyError {}
124 ///
125 /// let error_count = Arc::new(Mutex::new(0));
126 /// let error_count_clone = error_count.clone();
127 ///
128 /// let stream = stream::iter(vec![1, 2, 3, 4, 5]);
129 ///
130 /// stream.subscribe_async(
131 /// |item, _token| async move {
132 /// if item % 2 == 0 {
133 /// Err(MyError(format!("Even number: {}", item)))
134 /// } else {
135 /// Ok(())
136 /// }
137 /// },
138 /// None,
139 /// Some(move |_err| {
140 /// let count = error_count_clone.clone();
141 /// tokio::spawn(async move {
142 /// *count.lock().await += 1;
143 /// });
144 /// })
145 /// ).await.unwrap();
146 ///
147 /// tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
148 /// assert_eq!(*error_count.lock().await, 2); // Items 2 and 4 errored
149 /// # }
150 /// ```
151 ///
152 /// ## With Cancellation
153 ///
154 /// Use a cancellation token to stop processing:
155 ///
156 /// ```
157 /// use fluxion_exec::SubscribeAsyncExt;
158 /// use tokio::sync::mpsc::unbounded_channel;
159 /// use tokio_stream::wrappers::UnboundedReceiverStream;
160 /// use futures::StreamExt;
161 /// use tokio_util::sync::CancellationToken;
162 /// use std::sync::Arc;
163 /// use tokio::sync::Mutex;
164 ///
165 /// # #[tokio::main]
166 /// # async fn main() {
167 /// let (tx, rx) = unbounded_channel();
168 /// let stream = UnboundedReceiverStream::new(rx);
169 ///
170 /// let cancel_token = CancellationToken::new();
171 /// let cancel_clone = cancel_token.clone();
172 ///
173 /// let processed = Arc::new(Mutex::new(Vec::new()));
174 /// let processed_clone = processed.clone();
175 ///
176 /// let handle = tokio::spawn(async move {
177 /// stream.subscribe_async(
178 /// move |item, token| {
179 /// let vec = processed_clone.clone();
180 /// async move {
181 /// if token.is_cancelled() {
182 /// return Ok(());
183 /// }
184 /// tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
185 /// vec.lock().await.push(item);
186 /// Ok::<(), std::io::Error>(())
187 /// }
188 /// },
189 /// Some(cancel_token),
190 /// None::<fn(std::io::Error)>
191 /// ).await
192 /// });
193 ///
194 /// // Send a few items
195 /// tx.send(1).unwrap();
196 /// tx.send(2).unwrap();
197 /// tx.send(3).unwrap();
198 ///
199 /// // Wait a bit then cancel
200 /// tokio::time::sleep(tokio::time::Duration::from_millis(15)).await;
201 /// cancel_clone.cancel();
202 /// drop(tx);
203 ///
204 /// handle.await.unwrap().unwrap();
205 ///
206 /// // At least one item should be processed before cancellation
207 /// assert!(!processed.lock().await.is_empty());
208 /// # }
209 /// ```
210 ///
211 /// ## Database Write Pattern
212 ///
213 /// Process events and persist to a database:
214 ///
215 /// ```
216 /// use fluxion_exec::SubscribeAsyncExt;
217 /// use futures::stream;
218 /// use std::sync::Arc;
219 /// use tokio::sync::Mutex;
220 ///
221 /// # #[tokio::main]
222 /// # async fn main() {
223 /// #[derive(Clone, Debug)]
224 /// struct Event { id: u32, data: String }
225 ///
226 /// // Simulated database
227 /// let db = Arc::new(Mutex::new(Vec::new()));
228 /// let db_clone = db.clone();
229 ///
230 /// let events = vec![
231 /// Event { id: 1, data: "event1".to_string() },
232 /// Event { id: 2, data: "event2".to_string() },
233 /// ];
234 ///
235 /// let stream = stream::iter(events);
236 ///
237 /// stream.subscribe_async(
238 /// move |event, _token| {
239 /// let db = db_clone.clone();
240 /// async move {
241 /// // Simulate database write
242 /// tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
243 /// db.lock().await.push(event);
244 /// Ok::<(), std::io::Error>(())
245 /// }
246 /// },
247 /// None,
248 /// Some(|err| eprintln!("DB Error: {}", err))
249 /// ).await.unwrap();
250 ///
251 /// tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
252 /// assert_eq!(db.lock().await.len(), 2);
253 /// # }
254 /// ```
255 ///
256 /// # Thread Safety
257 ///
258 /// All spawned tasks run on the tokio runtime. The subscription completes
259 /// when the stream ends, not when all spawned tasks complete.
260 async fn subscribe_async<F, Fut, E, OnError>(
261 self,
262 on_next_func: F,
263 cancellation_token: Option<CancellationToken>,
264 on_error_callback: Option<OnError>,
265 ) -> Result<()>
266 where
267 F: Fn(T, CancellationToken) -> Fut + Clone + Send + Sync + 'static,
268 Fut: Future<Output = std::result::Result<(), E>> + Send + 'static,
269 OnError: Fn(E) + Clone + Send + Sync + 'static,
270 T: std::fmt::Debug + Send + Clone + 'static,
271 E: std::error::Error + Send + Sync + 'static;
272}
273
274#[async_trait]
275impl<S, T> SubscribeAsyncExt<T> for S
276where
277 S: Stream<Item = T> + Send + Unpin + 'static,
278 T: Send + 'static,
279{
280 async fn subscribe_async<F, Fut, E, OnError>(
281 mut self,
282 on_next_func: F,
283 cancellation_token: Option<CancellationToken>,
284 on_error_callback: Option<OnError>,
285 ) -> Result<()>
286 where
287 F: Fn(T, CancellationToken) -> Fut + Clone + Send + Sync + 'static,
288 Fut: Future<Output = std::result::Result<(), E>> + Send + 'static,
289 OnError: Fn(E) + Clone + Send + Sync + 'static,
290 T: std::fmt::Debug + Send + Clone + 'static,
291 E: std::error::Error + Send + Sync + 'static,
292 {
293 let cancellation_token = cancellation_token.unwrap_or_default();
294 let (error_tx, mut error_rx) = unbounded_channel();
295
296 while let Some(item) = self.next().await {
297 if cancellation_token.is_cancelled() {
298 break;
299 }
300
301 let on_next_func = on_next_func.clone();
302 let cancellation_token = cancellation_token.clone();
303 let on_error_callback = on_error_callback.clone();
304 let error_tx = error_tx.clone();
305
306 tokio::spawn(async move {
307 let result = on_next_func(item.clone(), cancellation_token).await;
308
309 if let Err(error) = result {
310 if let Some(on_error_callback) = on_error_callback {
311 on_error_callback(error);
312 } else {
313 // Collect error for later aggregation
314 let _ = error_tx.send(error);
315 }
316 }
317 });
318 }
319
320 // Drop the original sender so the channel closes
321 drop(error_tx);
322
323 // Collect all errors from the channel
324 let mut collected_errors = Vec::new();
325 while let Some(error) = error_rx.recv().await {
326 collected_errors.push(error);
327 }
328
329 if !collected_errors.is_empty() {
330 Err(FluxionError::from_user_errors(collected_errors))
331 } else {
332 Ok(())
333 }
334 }
335}