commonware_utils/
futures.rs

1//! Utilities for working with futures.
2
3use futures::{
4    future::{self, AbortHandle, Abortable, Aborted},
5    stream::{FuturesUnordered, SelectNextSome},
6    StreamExt,
7};
8use std::{future::Future, pin::Pin};
9
10/// A future type that can be used in `Pool`.
11type PooledFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
12
13/// An unordered pool of futures.
14///
15/// Futures can be added to the pool, and removed from the pool as they resolve.
16///
17/// **Note:** This pool is not thread-safe and should not be used across threads without external
18/// synchronization.
19pub struct Pool<T> {
20    pool: FuturesUnordered<PooledFuture<T>>,
21}
22
23impl<T: Send> Default for Pool<T> {
24    fn default() -> Self {
25        // Insert a dummy future (that never resolves) to prevent the stream from being empty.
26        // Else, the `select_next_some()` function returns `None` instantly.
27        let pool = FuturesUnordered::new();
28        pool.push(Self::create_dummy_future());
29        Self { pool }
30    }
31}
32
33impl<T: Send> Pool<T> {
34    /// Returns the number of futures in the pool.
35    pub fn len(&self) -> usize {
36        // Subtract the dummy future.
37        self.pool.len().checked_sub(1).unwrap()
38    }
39
40    /// Returns `true` if the pool is empty.
41    pub fn is_empty(&self) -> bool {
42        self.len() == 0
43    }
44
45    /// Adds a future to the pool.
46    ///
47    /// The future must be `'static` and `Send` to ensure it can be safely stored and executed.
48    pub fn push(&mut self, future: impl Future<Output = T> + Send + 'static) {
49        self.pool.push(Box::pin(future));
50    }
51
52    /// Returns a futures that resolves to the next future in the pool that resolves.
53    ///
54    /// If the pool is empty, the future will never resolve.
55    pub fn next_completed(&mut self) -> SelectNextSome<'_, FuturesUnordered<PooledFuture<T>>> {
56        self.pool.select_next_some()
57    }
58
59    /// Cancels all futures in the pool.
60    ///
61    /// Excludes the dummy future.
62    pub fn cancel_all(&mut self) {
63        self.pool.clear();
64        self.pool.push(Self::create_dummy_future());
65    }
66
67    /// Creates a dummy future that never resolves.
68    fn create_dummy_future() -> PooledFuture<T> {
69        Box::pin(async { future::pending::<T>().await })
70    }
71}
72
73/// A handle that can be used to abort a specific future in an [AbortablePool].
74///
75/// When the aborter is dropped, the associated future is aborted.
76pub struct Aborter {
77    inner: AbortHandle,
78}
79
80impl Drop for Aborter {
81    fn drop(&mut self) {
82        self.inner.abort();
83    }
84}
85
86/// A future type that can be used in [AbortablePool].
87type AbortablePooledFuture<T> = Pin<Box<dyn Future<Output = Result<T, Aborted>> + Send>>;
88
89/// An unordered pool of futures that can be individually aborted.
90///
91/// Each future added to the pool returns an [Aborter]. When the aborter is dropped,
92/// the associated future is aborted.
93///
94/// **Note:** This pool is not thread-safe and should not be used across threads without external
95/// synchronization.
96pub struct AbortablePool<T> {
97    pool: FuturesUnordered<AbortablePooledFuture<T>>,
98}
99
100impl<T: Send> Default for AbortablePool<T> {
101    fn default() -> Self {
102        // Insert a dummy future (that never resolves) to prevent the stream from being empty.
103        // Else, the `select_next_some()` function returns `None` instantly.
104        let pool = FuturesUnordered::new();
105        pool.push(Self::create_dummy_future());
106        Self { pool }
107    }
108}
109
110impl<T: Send> AbortablePool<T> {
111    /// Returns the number of futures in the pool.
112    pub fn len(&self) -> usize {
113        // Subtract the dummy future.
114        self.pool.len().checked_sub(1).unwrap()
115    }
116
117    /// Returns `true` if the pool is empty.
118    pub fn is_empty(&self) -> bool {
119        self.len() == 0
120    }
121
122    /// Adds a future to the pool and returns an [Aborter] that can be used to abort it.
123    ///
124    /// The future must be `'static` and `Send` to ensure it can be safely stored and executed.
125    /// When the returned [Aborter] is dropped, the future will be aborted.
126    pub fn push(&mut self, future: impl Future<Output = T> + Send + 'static) -> Aborter {
127        let (handle, registration) = AbortHandle::new_pair();
128        let abortable_future = Abortable::new(future, registration);
129        self.pool.push(Box::pin(abortable_future));
130        Aborter { inner: handle }
131    }
132
133    /// Returns a future that resolves to the next future in the pool that resolves.
134    ///
135    /// If the pool is empty, the future will never resolve.
136    /// Returns `Ok(T)` for successful completion or `Err(Aborted)` for aborted futures.
137    pub fn next_completed(
138        &mut self,
139    ) -> SelectNextSome<'_, FuturesUnordered<AbortablePooledFuture<T>>> {
140        self.pool.select_next_some()
141    }
142
143    /// Creates a dummy future that never resolves.
144    fn create_dummy_future() -> AbortablePooledFuture<T> {
145        Box::pin(async { Ok(future::pending::<T>().await) })
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use futures::{
153        channel::oneshot,
154        executor::block_on,
155        future::{self, select, Either},
156        pin_mut, FutureExt,
157    };
158    use std::{
159        sync::{
160            atomic::{AtomicBool, Ordering},
161            Arc,
162        },
163        thread,
164        time::Duration,
165    };
166
167    /// A future that resolves after a given duration.
168    fn delay(duration: Duration) -> impl Future<Output = ()> {
169        let (sender, receiver) = oneshot::channel();
170        thread::spawn(move || {
171            thread::sleep(duration);
172            sender.send(()).unwrap();
173        });
174        receiver.map(|_| ())
175    }
176
177    #[test]
178    fn test_initialization() {
179        let pool = Pool::<i32>::default();
180        assert_eq!(pool.len(), 0);
181        assert!(pool.is_empty());
182    }
183
184    #[test]
185    fn test_dummy_future_doesnt_resolve() {
186        block_on(async {
187            let mut pool = Pool::<i32>::default();
188            let stream_future = pool.next_completed();
189            let timeout_future = async {
190                delay(Duration::from_millis(100)).await;
191            };
192            pin_mut!(stream_future);
193            pin_mut!(timeout_future);
194            let result = select(stream_future, timeout_future).await;
195            match result {
196                Either::Left((_, _)) => panic!("Stream resolved unexpectedly"),
197                Either::Right((_, _)) => {
198                    // Timeout occurred, which is expected
199                }
200            }
201        });
202    }
203
204    #[test]
205    fn test_adding_futures() {
206        let mut pool = Pool::<i32>::default();
207        assert_eq!(pool.len(), 0);
208        assert!(pool.is_empty());
209
210        pool.push(async { 42 });
211        assert_eq!(pool.len(), 1);
212        assert!(!pool.is_empty(),);
213
214        pool.push(async { 43 });
215        assert_eq!(pool.len(), 2,);
216    }
217
218    #[test]
219    fn test_streaming_resolved_futures() {
220        block_on(async move {
221            let mut pool = Pool::<i32>::default();
222            pool.push(future::ready(42));
223            let result = pool.next_completed().await;
224            assert_eq!(result, 42,);
225            assert!(pool.is_empty(),);
226        });
227    }
228
229    #[test]
230    fn test_multiple_futures() {
231        block_on(async move {
232            let mut pool = Pool::<i32>::default();
233
234            // Futures resolve in order of completion, not addition order
235            let (finisher_1, finished_1) = oneshot::channel();
236            let (finisher_3, finished_3) = oneshot::channel();
237            pool.push(async move {
238                finished_1.await.unwrap();
239                finisher_3.send(()).unwrap();
240                1
241            });
242            pool.push(async move {
243                finisher_1.send(()).unwrap();
244                2
245            });
246            pool.push(async move {
247                finished_3.await.unwrap();
248                3
249            });
250
251            let first = pool.next_completed().await;
252            assert_eq!(first, 2, "First resolved should be 2");
253            let second = pool.next_completed().await;
254            assert_eq!(second, 1, "Second resolved should be 1");
255            let third = pool.next_completed().await;
256            assert_eq!(third, 3, "Third resolved should be 3");
257            assert!(pool.is_empty(),);
258        });
259    }
260
261    #[test]
262    fn test_cancel_all() {
263        block_on(async move {
264            let flag = Arc::new(AtomicBool::new(false));
265            let flag_clone = flag.clone();
266            let mut pool = Pool::<i32>::default();
267
268            // Push a future that will set the flag to true when it resolves.
269            let (finisher, finished) = oneshot::channel();
270            pool.push(async move {
271                finished.await.unwrap();
272                flag_clone.store(true, Ordering::SeqCst);
273                42
274            });
275            assert_eq!(pool.len(), 1);
276
277            // Cancel all futures.
278            pool.cancel_all();
279            assert!(pool.is_empty());
280            assert!(!flag.load(Ordering::SeqCst));
281
282            // Send the finisher signal (should be ignored).
283            let _ = finisher.send(());
284
285            // Stream should not resolve future after cancellation.
286            let stream_future = pool.next_completed();
287            let timeout_future = async {
288                delay(Duration::from_millis(100)).await;
289            };
290            pin_mut!(stream_future);
291            pin_mut!(timeout_future);
292            let result = select(stream_future, timeout_future).await;
293            match result {
294                Either::Left((_, _)) => panic!("Stream resolved after cancellation"),
295                Either::Right((_, _)) => {
296                    // Wait for the timeout to trigger.
297                }
298            }
299            assert!(!flag.load(Ordering::SeqCst));
300
301            // Push and await a new future.
302            pool.push(future::ready(42));
303            assert_eq!(pool.len(), 1);
304            let result = pool.next_completed().await;
305            assert_eq!(result, 42);
306            assert!(pool.is_empty());
307        });
308    }
309
310    #[test]
311    fn test_many_futures() {
312        block_on(async move {
313            let mut pool = Pool::<i32>::default();
314            let num_futures = 1000;
315            for i in 0..num_futures {
316                pool.push(future::ready(i));
317            }
318            assert_eq!(pool.len(), num_futures as usize);
319
320            let mut sum = 0;
321            for _ in 0..num_futures {
322                let value = pool.next_completed().await;
323                sum += value;
324            }
325            let expected_sum = (0..num_futures).sum::<i32>();
326            assert_eq!(
327                sum, expected_sum,
328                "Sum of resolved values should match expected"
329            );
330            assert!(
331                pool.is_empty(),
332                "Pool should be empty after all futures resolve"
333            );
334        });
335    }
336
337    #[test]
338    fn test_abortable_pool_initialization() {
339        let pool = AbortablePool::<i32>::default();
340        assert_eq!(pool.len(), 0);
341        assert!(pool.is_empty());
342    }
343
344    #[test]
345    fn test_abortable_pool_adding_futures() {
346        let mut pool = AbortablePool::<i32>::default();
347        assert_eq!(pool.len(), 0);
348        assert!(pool.is_empty());
349
350        let _hook1 = pool.push(async { 42 });
351        assert_eq!(pool.len(), 1);
352        assert!(!pool.is_empty());
353
354        let _hook2 = pool.push(async { 43 });
355        assert_eq!(pool.len(), 2);
356    }
357
358    #[test]
359    fn test_abortable_pool_successful_completion() {
360        block_on(async move {
361            let mut pool = AbortablePool::<i32>::default();
362            let _hook = pool.push(future::ready(42));
363            let result = pool.next_completed().await;
364            assert_eq!(result, Ok(42));
365            assert!(pool.is_empty());
366        });
367    }
368
369    #[test]
370    fn test_abortable_pool_drop_abort() {
371        block_on(async move {
372            let mut pool = AbortablePool::<i32>::default();
373
374            let (sender, receiver) = oneshot::channel();
375            let hook = pool.push(async move {
376                receiver.await.unwrap();
377                42
378            });
379
380            drop(hook);
381
382            let result = pool.next_completed().await;
383            assert!(result.is_err());
384            assert!(pool.is_empty());
385
386            let _ = sender.send(());
387        });
388    }
389
390    #[test]
391    fn test_abortable_pool_partial_abort() {
392        block_on(async move {
393            let mut pool = AbortablePool::<i32>::default();
394
395            let _hook1 = pool.push(future::ready(1));
396            let (sender, receiver) = oneshot::channel();
397            let hook2 = pool.push(async move {
398                receiver.await.unwrap();
399                2
400            });
401            let _hook3 = pool.push(future::ready(3));
402
403            assert_eq!(pool.len(), 3);
404
405            drop(hook2);
406
407            let mut results = Vec::new();
408            for _ in 0..3 {
409                let result = pool.next_completed().await;
410                results.push(result);
411            }
412
413            let successful: Vec<_> = results.iter().filter_map(|r| r.as_ref().ok()).collect();
414            let aborted: Vec<_> = results.iter().filter(|r| r.is_err()).collect();
415
416            assert_eq!(successful.len(), 2);
417            assert_eq!(aborted.len(), 1);
418            assert!(successful.contains(&&1));
419            assert!(successful.contains(&&3));
420            assert!(pool.is_empty());
421
422            let _ = sender.send(());
423        });
424    }
425}