fluxion_exec/subscribe.rs
1// Copyright 2025 Umberto Gotti <umberto.gotti@umbertogotti.dev>
2// Licensed under the Apache License, Version 2.0
3// http://www.apache.org/licenses/LICENSE-2.0
4
5use async_trait::async_trait;
6use futures::stream::Stream;
7use futures::stream::StreamExt;
8use std::error::Error;
9use std::fmt::Debug;
10use std::future::Future;
11use tokio::sync::mpsc::unbounded_channel;
12use tokio_util::sync::CancellationToken;
13
14use fluxion_core::{FluxionError, Result};
15
16/// Extension trait providing async subscription capabilities for streams.
17///
18/// This trait enables processing stream items with async handlers in a sequential manner.
19#[async_trait]
20pub trait SubscribeExt<T>: Stream<Item = T> + Sized {
21 /// Subscribes to the stream with an async handler, processing items sequentially.
22 ///
23 /// This method consumes the stream and spawns async tasks to process each item.
24 /// Items are processed in the order they arrive, with each item's handler running
25 /// to completion before the next item is processed (though handlers run concurrently
26 /// via tokio spawn).
27 ///
28 /// # Behavior
29 ///
30 /// - Processes each stream item with the provided async handler
31 /// - Spawns a new task for each item (non-blocking)
32 /// - Continues until stream ends or cancellation token is triggered
33 /// - Errors from handlers are passed to the error callback if provided
34 /// - If no error callback provided, errors are collected and returned on completion
35 ///
36 /// # Arguments
37 ///
38 /// * `on_next_func` - Async function called for each stream item. Receives the item
39 /// and a cancellation token. Returns `Result<(), E>`.
40 /// * `cancellation_token` - Optional token to stop processing. If `None`, a default
41 /// token is created that never cancels.
42 /// * `on_error_callback` - Optional error handler called when `on_next_func` returns
43 /// an error. If `None`, errors are collected and returned.
44 ///
45 /// # Type Parameters
46 ///
47 /// * `F` - Function type for the item handler
48 /// * `Fut` - Future type returned by the handler
49 /// * `E` - Error type that implements `std::error::Error`
50 /// * `OnError` - Function type for error handling
51 ///
52 /// # Errors
53 ///
54 /// Returns `Err(FluxionError::MultipleErrors)` if any items failed to process and
55 /// no error callback was provided. If an error callback is provided, errors are
56 /// passed to it and the function returns `Ok(())` on stream completion.
57 ///
58 /// The subscription continues processing subsequent items even if individual items
59 /// fail, unless the cancellation token is triggered.
60 ///
61 /// # See Also
62 ///
63 /// - [`subscribe_latest`](crate::SubscribeLatestExt::subscribe_latest) - Cancels old work for new items
64 ///
65 /// # Examples
66 ///
67 /// ## Basic Usage
68 ///
69 /// Process all items sequentially:
70 ///
71 /// ```
72 /// use fluxion_exec::SubscribeExt;
73 /// use futures::stream;
74 /// use std::sync::Arc;
75 /// use tokio::sync::Mutex;
76 /// use tokio::sync::mpsc::unbounded_channel;
77 ///
78 /// # #[tokio::main]
79 /// # async fn main() {
80 /// let results = Arc::new(Mutex::new(Vec::new()));
81 /// let results_clone = results.clone();
82 /// let (notify_tx, mut notify_rx) = unbounded_channel();
83 ///
84 /// let stream = stream::iter(vec![1, 2, 3, 4, 5]);
85 ///
86 /// // Subscribe and process each item
87 /// stream.subscribe(
88 /// move |item, _token| {
89 /// let results = results_clone.clone();
90 /// let notify_tx = notify_tx.clone();
91 /// async move {
92 /// results.lock().await.push(item * 2);
93 /// let _ = notify_tx.send(());
94 /// Ok::<(), std::io::Error>(())
95 /// }
96 /// },
97 /// None, // No cancellation
98 /// None::<fn(std::io::Error)> // No error callback
99 /// ).await.unwrap();
100 ///
101 /// // Wait for all 5 items to be processed
102 /// for _ in 0..5 {
103 /// notify_rx.recv().await.unwrap();
104 /// }
105 ///
106 /// let processed = results.lock().await;
107 /// assert!(processed.contains(&2));
108 /// assert!(processed.contains(&4));
109 /// # }
110 /// ```
111 ///
112 /// ## With Error Handling
113 ///
114 /// Use an error callback to handle errors without stopping the stream:
115 ///
116 /// ```
117 /// use fluxion_exec::SubscribeExt;
118 /// use futures::stream;
119 /// use std::sync::Arc;
120 /// use tokio::sync::Mutex;
121 /// use tokio::sync::mpsc::unbounded_channel;
122 ///
123 /// # #[tokio::main]
124 /// # async fn main() {
125 /// #[derive(Debug)]
126 /// struct MyError(String);
127 /// impl std::fmt::Display for MyError {
128 /// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 /// write!(f, "MyError: {}", self.0)
130 /// }
131 /// }
132 /// impl std::error::Error for MyError {}
133 ///
134 /// let error_count = Arc::new(Mutex::new(0));
135 /// let error_count_clone = error_count.clone();
136 /// let (notify_tx, mut notify_rx) = unbounded_channel();
137 ///
138 /// let stream = stream::iter(vec![1, 2, 3, 4, 5]);
139 ///
140 /// stream.subscribe(
141 /// move |item, _token| {
142 /// let notify_tx = notify_tx.clone();
143 /// async move {
144 /// let res = if item % 2 == 0 {
145 /// Err(MyError(format!("Even number: {}", item)))
146 /// } else {
147 /// Ok(())
148 /// };
149 /// // Signal completion regardless of success/failure
150 /// // Note: In real code, you might signal in the error callback too
151 /// // but here we just want to know the handler finished.
152 /// // However, subscribe spawns the handler. If it errors,
153 /// // the error callback is called.
154 /// // We need to signal completion in both paths.
155 /// // Since the handler returns the error, we can't signal *after* returning Err.
156 /// // So we signal before returning.
157 /// let _ = notify_tx.send(());
158 /// res
159 /// }
160 /// },
161 /// None,
162 /// Some(move |_err| {
163 /// let count = error_count_clone.clone();
164 /// tokio::spawn(async move {
165 /// *count.lock().await += 1;
166 /// });
167 /// })
168 /// ).await.unwrap();
169 ///
170 /// // Wait for 5 items
171 /// for _ in 0..5 {
172 /// notify_rx.recv().await.unwrap();
173 /// }
174 ///
175 /// // Give a tiny bit of time for the error callback spawn to finish updating the count
176 /// // (Since the callback spawns another task)
177 /// // Alternatively, we could use a channel in the error callback too.
178 /// tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
179 ///
180 /// assert_eq!(*error_count.lock().await, 2); // Items 2 and 4 errored
181 /// # }
182 /// ```
183 ///
184 /// ## With Cancellation
185 ///
186 /// Use a cancellation token to stop processing:
187 ///
188 /// ```
189 /// use fluxion_exec::SubscribeExt;
190 /// use tokio::sync::mpsc::unbounded_channel;
191 /// use tokio_stream::wrappers::UnboundedReceiverStream;
192 /// use futures::StreamExt;
193 /// use tokio_util::sync::CancellationToken;
194 /// use std::sync::Arc;
195 /// use tokio::sync::Mutex;
196 ///
197 /// # #[tokio::main]
198 /// # async fn main() {
199 /// let (tx, rx) = unbounded_channel();
200 /// let stream = UnboundedReceiverStream::new(rx);
201 ///
202 /// let cancel_token = CancellationToken::new();
203 /// let cancel_clone = cancel_token.clone();
204 ///
205 /// let processed = Arc::new(Mutex::new(Vec::new()));
206 /// let processed_clone = processed.clone();
207 /// let (notify_tx, mut notify_rx) = unbounded_channel();
208 ///
209 /// let handle = tokio::spawn(async move {
210 /// stream.subscribe(
211 /// move |item, token| {
212 /// let vec = processed_clone.clone();
213 /// let notify_tx = notify_tx.clone();
214 /// async move {
215 /// if token.is_cancelled() {
216 /// return Ok(());
217 /// }
218 /// vec.lock().await.push(item);
219 /// let _ = notify_tx.send(());
220 /// Ok::<(), std::io::Error>(())
221 /// }
222 /// },
223 /// Some(cancel_token),
224 /// None::<fn(std::io::Error)>
225 /// ).await
226 /// });
227 ///
228 /// // Send items
229 /// tx.send(1).unwrap();
230 /// tx.send(2).unwrap();
231 /// tx.send(3).unwrap();
232 ///
233 /// // Wait for first item to be processed
234 /// notify_rx.recv().await.unwrap();
235 ///
236 /// // Cancel now
237 /// cancel_clone.cancel();
238 /// drop(tx);
239 ///
240 /// handle.await.unwrap().unwrap();
241 ///
242 /// // At least one item should be processed before cancellation
243 /// assert!(!processed.lock().await.is_empty());
244 /// # }
245 /// ```
246 ///
247 /// ## Database Write Pattern
248 ///
249 /// Process events and persist to a database:
250 ///
251 /// ```
252 /// use fluxion_exec::SubscribeExt;
253 /// use futures::stream;
254 /// use std::sync::Arc;
255 /// use tokio::sync::Mutex;
256 /// use tokio::sync::mpsc::unbounded_channel;
257 ///
258 /// # #[tokio::main]
259 /// # async fn main() {
260 /// #[derive(Clone, Debug)]
261 /// struct Event { id: u32, data: String }
262 ///
263 /// // Simulated database
264 /// let db = Arc::new(Mutex::new(Vec::new()));
265 /// let db_clone = db.clone();
266 /// let (notify_tx, mut notify_rx) = unbounded_channel();
267 ///
268 /// let events = vec![
269 /// Event { id: 1, data: "event1".to_string() },
270 /// Event { id: 2, data: "event2".to_string() },
271 /// ];
272 ///
273 /// let stream = stream::iter(events);
274 ///
275 /// stream.subscribe(
276 /// move |event, _token| {
277 /// let db = db_clone.clone();
278 /// let notify_tx = notify_tx.clone();
279 /// async move {
280 /// // Simulate database write
281 /// db.lock().await.push(event);
282 /// let _ = notify_tx.send(());
283 /// Ok::<(), std::io::Error>(())
284 /// }
285 /// },
286 /// None,
287 /// Some(|err| eprintln!("DB Error: {}", err))
288 /// ).await.unwrap();
289 ///
290 /// // Wait for 2 events
291 /// notify_rx.recv().await.unwrap();
292 /// notify_rx.recv().await.unwrap();
293 ///
294 /// assert_eq!(db.lock().await.len(), 2);
295 /// # }
296 /// ```
297 ///
298 /// # Thread Safety
299 ///
300 /// All spawned tasks run on the tokio runtime. The subscription completes
301 /// when the stream ends, not when all spawned tasks complete.
302 async fn subscribe<F, Fut, E, OnError>(
303 self,
304 on_next_func: F,
305 cancellation_token: Option<CancellationToken>,
306 on_error_callback: Option<OnError>,
307 ) -> Result<()>
308 where
309 F: Fn(T, CancellationToken) -> Fut + Clone + Send + Sync + 'static,
310 Fut: Future<Output = std::result::Result<(), E>> + Send + 'static,
311 OnError: Fn(E) + Clone + Send + Sync + 'static,
312 T: Debug + Send + Clone + 'static,
313 E: Error + Send + Sync + 'static;
314}
315
316#[async_trait]
317impl<S, T> SubscribeExt<T> for S
318where
319 S: Stream<Item = T> + Send + Unpin + 'static,
320 T: Send + 'static,
321{
322 async fn subscribe<F, Fut, E, OnError>(
323 mut self,
324 on_next_func: F,
325 cancellation_token: Option<CancellationToken>,
326 on_error_callback: Option<OnError>,
327 ) -> Result<()>
328 where
329 F: Fn(T, CancellationToken) -> Fut + Clone + Send + Sync + 'static,
330 Fut: Future<Output = std::result::Result<(), E>> + Send + 'static,
331 OnError: Fn(E) + Clone + Send + Sync + 'static,
332 T: Debug + Send + Clone + 'static,
333 E: Error + Send + Sync + 'static,
334 {
335 let cancellation_token = cancellation_token.unwrap_or_default();
336 let (error_tx, mut error_rx) = unbounded_channel();
337
338 while let Some(item) = self.next().await {
339 if cancellation_token.is_cancelled() {
340 break;
341 }
342
343 let on_next_func = on_next_func.clone();
344 let cancellation_token = cancellation_token.clone();
345 let on_error_callback = on_error_callback.clone();
346 let error_tx = error_tx.clone();
347
348 tokio::spawn(async move {
349 let result = on_next_func(item.clone(), cancellation_token).await;
350
351 if let Err(error) = result {
352 if let Some(on_error_callback) = on_error_callback {
353 on_error_callback(error);
354 } else {
355 // Collect error for later aggregation
356 let _ = error_tx.send(error);
357 }
358 }
359 });
360 }
361
362 // Drop the original sender so the channel closes
363 drop(error_tx);
364
365 // Collect all errors from the channel
366 let mut collected_errors = Vec::new();
367 while let Some(error) = error_rx.recv().await {
368 collected_errors.push(error);
369 }
370
371 if !collected_errors.is_empty() {
372 Err(FluxionError::from_user_errors(collected_errors))
373 } else {
374 Ok(())
375 }
376 }
377}