commonware_utils/
futures.rs

1//! Utilities for working with futures.
2
3use futures::{
4    channel::oneshot,
5    future::{self, AbortHandle, Abortable, Aborted},
6    stream::{FuturesUnordered, SelectNextSome},
7    StreamExt,
8};
9use std::{future::Future, pin::Pin, task::Poll};
10
11/// A future type that can be used in `Pool`.
12type PooledFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
13
14/// An unordered pool of futures.
15///
16/// Futures can be added to the pool, and removed from the pool as they resolve.
17///
18/// **Note:** This pool is not thread-safe and should not be used across threads without external
19/// synchronization.
20pub struct Pool<T> {
21    pool: FuturesUnordered<PooledFuture<T>>,
22}
23
24impl<T: Send> Default for Pool<T> {
25    fn default() -> Self {
26        // Insert a dummy future (that never resolves) to prevent the stream from being empty.
27        // Else, the `select_next_some()` function returns `None` instantly.
28        let pool = FuturesUnordered::new();
29        pool.push(Self::create_dummy_future());
30        Self { pool }
31    }
32}
33
34impl<T: Send> Pool<T> {
35    /// Returns the number of futures in the pool.
36    pub fn len(&self) -> usize {
37        // Subtract the dummy future.
38        self.pool.len().checked_sub(1).unwrap()
39    }
40
41    /// Returns `true` if the pool is empty.
42    pub fn is_empty(&self) -> bool {
43        self.len() == 0
44    }
45
46    /// Adds a future to the pool.
47    ///
48    /// The future must be `'static` and `Send` to ensure it can be safely stored and executed.
49    pub fn push(&mut self, future: impl Future<Output = T> + Send + 'static) {
50        self.pool.push(Box::pin(future));
51    }
52
53    /// Returns a futures that resolves to the next future in the pool that resolves.
54    ///
55    /// If the pool is empty, the future will never resolve.
56    pub fn next_completed(&mut self) -> SelectNextSome<'_, FuturesUnordered<PooledFuture<T>>> {
57        self.pool.select_next_some()
58    }
59
60    /// Cancels all futures in the pool.
61    ///
62    /// Excludes the dummy future.
63    pub fn cancel_all(&mut self) {
64        self.pool.clear();
65        self.pool.push(Self::create_dummy_future());
66    }
67
68    /// Creates a dummy future that never resolves.
69    fn create_dummy_future() -> PooledFuture<T> {
70        Box::pin(async { future::pending::<T>().await })
71    }
72}
73
74/// A handle that can be used to abort a specific future in an [AbortablePool].
75///
76/// When the aborter is dropped, the associated future is aborted.
77pub struct Aborter {
78    inner: AbortHandle,
79}
80
81impl Drop for Aborter {
82    fn drop(&mut self) {
83        self.inner.abort();
84    }
85}
86
87/// A future type that can be used in [AbortablePool].
88type AbortablePooledFuture<T> = Pin<Box<dyn Future<Output = Result<T, Aborted>> + Send>>;
89
90/// An unordered pool of futures that can be individually aborted.
91///
92/// Each future added to the pool returns an [Aborter]. When the aborter is dropped,
93/// the associated future is aborted.
94///
95/// **Note:** This pool is not thread-safe and should not be used across threads without external
96/// synchronization.
97pub struct AbortablePool<T> {
98    pool: FuturesUnordered<AbortablePooledFuture<T>>,
99}
100
101impl<T: Send> Default for AbortablePool<T> {
102    fn default() -> Self {
103        // Insert a dummy future (that never resolves) to prevent the stream from being empty.
104        // Else, the `select_next_some()` function returns `None` instantly.
105        let pool = FuturesUnordered::new();
106        pool.push(Self::create_dummy_future());
107        Self { pool }
108    }
109}
110
111impl<T: Send> AbortablePool<T> {
112    /// Returns the number of futures in the pool.
113    pub fn len(&self) -> usize {
114        // Subtract the dummy future.
115        self.pool.len().checked_sub(1).unwrap()
116    }
117
118    /// Returns `true` if the pool is empty.
119    pub fn is_empty(&self) -> bool {
120        self.len() == 0
121    }
122
123    /// Adds a future to the pool and returns an [Aborter] that can be used to abort it.
124    ///
125    /// The future must be `'static` and `Send` to ensure it can be safely stored and executed.
126    /// When the returned [Aborter] is dropped, the future will be aborted.
127    pub fn push(&mut self, future: impl Future<Output = T> + Send + 'static) -> Aborter {
128        let (handle, registration) = AbortHandle::new_pair();
129        let abortable_future = Abortable::new(future, registration);
130        self.pool.push(Box::pin(abortable_future));
131        Aborter { inner: handle }
132    }
133
134    /// Returns a future that resolves to the next future in the pool that resolves.
135    ///
136    /// If the pool is empty, the future will never resolve.
137    /// Returns `Ok(T)` for successful completion or `Err(Aborted)` for aborted futures.
138    pub fn next_completed(
139        &mut self,
140    ) -> SelectNextSome<'_, FuturesUnordered<AbortablePooledFuture<T>>> {
141        self.pool.select_next_some()
142    }
143
144    /// Creates a dummy future that never resolves.
145    fn create_dummy_future() -> AbortablePooledFuture<T> {
146        Box::pin(async { Ok(future::pending::<T>().await) })
147    }
148}
149
150/// A future that resolves when a [oneshot::Receiver] is dropped.
151///
152/// This future completes when the receiver end of the channel is dropped,
153/// allowing the caller to detect when the other side is no longer interested
154/// in the result.
155pub struct Closed<'a, T> {
156    sender: &'a mut oneshot::Sender<T>,
157}
158
159impl<'a, T> Closed<'a, T> {
160    /// Creates a new future that resolves when the receiver is dropped.
161    pub fn new(sender: &'a mut oneshot::Sender<T>) -> Self {
162        Self { sender }
163    }
164}
165
166impl<T> Future for Closed<'_, T> {
167    type Output = ();
168
169    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
170        match self.sender.poll_canceled(cx) {
171            Poll::Ready(()) => Poll::Ready(()),
172            Poll::Pending => Poll::Pending,
173        }
174    }
175}
176
177/// Extension trait to detect when a [oneshot::Receiver] is dropped.
178pub trait ClosedExt<T> {
179    /// Returns a future that resolves when the receiver is dropped.
180    ///
181    /// # Examples
182    ///
183    /// ```
184    /// use futures::channel::oneshot;
185    /// use commonware_utils::futures::ClosedExt;
186    ///
187    /// # futures::executor::block_on(async {
188    /// let (mut tx, rx) = oneshot::channel::<i32>();
189    ///
190    /// let closed = tx.closed();
191    /// drop(rx);
192    /// closed.await;
193    /// # });
194    /// ```
195    fn closed(&mut self) -> Closed<'_, T>;
196}
197
198impl<T> ClosedExt<T> for oneshot::Sender<T> {
199    fn closed(&mut self) -> Closed<'_, T> {
200        Closed::new(self)
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use futures::{
208        channel::oneshot,
209        executor::block_on,
210        future::{self, select, Either},
211        pin_mut, FutureExt,
212    };
213    use std::{
214        sync::{
215            atomic::{AtomicBool, Ordering},
216            Arc,
217        },
218        thread,
219        time::Duration,
220    };
221
222    /// A future that resolves after a given duration.
223    fn delay(duration: Duration) -> impl Future<Output = ()> {
224        let (sender, receiver) = oneshot::channel();
225        thread::spawn(move || {
226            thread::sleep(duration);
227            sender.send(()).unwrap();
228        });
229        receiver.map(|_| ())
230    }
231
232    #[test]
233    fn test_initialization() {
234        let pool = Pool::<i32>::default();
235        assert_eq!(pool.len(), 0);
236        assert!(pool.is_empty());
237    }
238
239    #[test]
240    fn test_dummy_future_doesnt_resolve() {
241        block_on(async {
242            let mut pool = Pool::<i32>::default();
243            let stream_future = pool.next_completed();
244            let timeout_future = async {
245                delay(Duration::from_millis(100)).await;
246            };
247            pin_mut!(stream_future);
248            pin_mut!(timeout_future);
249            let result = select(stream_future, timeout_future).await;
250            match result {
251                Either::Left((_, _)) => panic!("Stream resolved unexpectedly"),
252                Either::Right((_, _)) => {
253                    // Timeout occurred, which is expected
254                }
255            }
256        });
257    }
258
259    #[test]
260    fn test_adding_futures() {
261        let mut pool = Pool::<i32>::default();
262        assert_eq!(pool.len(), 0);
263        assert!(pool.is_empty());
264
265        pool.push(async { 42 });
266        assert_eq!(pool.len(), 1);
267        assert!(!pool.is_empty(),);
268
269        pool.push(async { 43 });
270        assert_eq!(pool.len(), 2,);
271    }
272
273    #[test]
274    fn test_streaming_resolved_futures() {
275        block_on(async move {
276            let mut pool = Pool::<i32>::default();
277            pool.push(future::ready(42));
278            let result = pool.next_completed().await;
279            assert_eq!(result, 42,);
280            assert!(pool.is_empty(),);
281        });
282    }
283
284    #[test]
285    fn test_multiple_futures() {
286        block_on(async move {
287            let mut pool = Pool::<i32>::default();
288
289            // Futures resolve in order of completion, not addition order
290            let (finisher_1, finished_1) = oneshot::channel();
291            let (finisher_3, finished_3) = oneshot::channel();
292            pool.push(async move {
293                finished_1.await.unwrap();
294                finisher_3.send(()).unwrap();
295                1
296            });
297            pool.push(async move {
298                finisher_1.send(()).unwrap();
299                2
300            });
301            pool.push(async move {
302                finished_3.await.unwrap();
303                3
304            });
305
306            let first = pool.next_completed().await;
307            assert_eq!(first, 2, "First resolved should be 2");
308            let second = pool.next_completed().await;
309            assert_eq!(second, 1, "Second resolved should be 1");
310            let third = pool.next_completed().await;
311            assert_eq!(third, 3, "Third resolved should be 3");
312            assert!(pool.is_empty(),);
313        });
314    }
315
316    #[test]
317    fn test_cancel_all() {
318        block_on(async move {
319            let flag = Arc::new(AtomicBool::new(false));
320            let flag_clone = flag.clone();
321            let mut pool = Pool::<i32>::default();
322
323            // Push a future that will set the flag to true when it resolves.
324            let (finisher, finished) = oneshot::channel();
325            pool.push(async move {
326                finished.await.unwrap();
327                flag_clone.store(true, Ordering::SeqCst);
328                42
329            });
330            assert_eq!(pool.len(), 1);
331
332            // Cancel all futures.
333            pool.cancel_all();
334            assert!(pool.is_empty());
335            assert!(!flag.load(Ordering::SeqCst));
336
337            // Send the finisher signal (should be ignored).
338            let _ = finisher.send(());
339
340            // Stream should not resolve future after cancellation.
341            let stream_future = pool.next_completed();
342            let timeout_future = async {
343                delay(Duration::from_millis(100)).await;
344            };
345            pin_mut!(stream_future);
346            pin_mut!(timeout_future);
347            let result = select(stream_future, timeout_future).await;
348            match result {
349                Either::Left((_, _)) => panic!("Stream resolved after cancellation"),
350                Either::Right((_, _)) => {
351                    // Wait for the timeout to trigger.
352                }
353            }
354            assert!(!flag.load(Ordering::SeqCst));
355
356            // Push and await a new future.
357            pool.push(future::ready(42));
358            assert_eq!(pool.len(), 1);
359            let result = pool.next_completed().await;
360            assert_eq!(result, 42);
361            assert!(pool.is_empty());
362        });
363    }
364
365    #[test]
366    fn test_many_futures() {
367        block_on(async move {
368            let mut pool = Pool::<i32>::default();
369            let num_futures = 1000;
370            for i in 0..num_futures {
371                pool.push(future::ready(i));
372            }
373            assert_eq!(pool.len(), num_futures as usize);
374
375            let mut sum = 0;
376            for _ in 0..num_futures {
377                let value = pool.next_completed().await;
378                sum += value;
379            }
380            let expected_sum = (0..num_futures).sum::<i32>();
381            assert_eq!(
382                sum, expected_sum,
383                "Sum of resolved values should match expected"
384            );
385            assert!(
386                pool.is_empty(),
387                "Pool should be empty after all futures resolve"
388            );
389        });
390    }
391
392    #[test]
393    fn test_abortable_pool_initialization() {
394        let pool = AbortablePool::<i32>::default();
395        assert_eq!(pool.len(), 0);
396        assert!(pool.is_empty());
397    }
398
399    #[test]
400    fn test_abortable_pool_adding_futures() {
401        let mut pool = AbortablePool::<i32>::default();
402        assert_eq!(pool.len(), 0);
403        assert!(pool.is_empty());
404
405        let _hook1 = pool.push(async { 42 });
406        assert_eq!(pool.len(), 1);
407        assert!(!pool.is_empty());
408
409        let _hook2 = pool.push(async { 43 });
410        assert_eq!(pool.len(), 2);
411    }
412
413    #[test]
414    fn test_abortable_pool_successful_completion() {
415        block_on(async move {
416            let mut pool = AbortablePool::<i32>::default();
417            let _hook = pool.push(future::ready(42));
418            let result = pool.next_completed().await;
419            assert_eq!(result, Ok(42));
420            assert!(pool.is_empty());
421        });
422    }
423
424    #[test]
425    fn test_abortable_pool_drop_abort() {
426        block_on(async move {
427            let mut pool = AbortablePool::<i32>::default();
428
429            let (sender, receiver) = oneshot::channel();
430            let hook = pool.push(async move {
431                receiver.await.unwrap();
432                42
433            });
434
435            drop(hook);
436
437            let result = pool.next_completed().await;
438            assert!(result.is_err());
439            assert!(pool.is_empty());
440
441            let _ = sender.send(());
442        });
443    }
444
445    #[test]
446    fn test_abortable_pool_partial_abort() {
447        block_on(async move {
448            let mut pool = AbortablePool::<i32>::default();
449
450            let _hook1 = pool.push(future::ready(1));
451            let (sender, receiver) = oneshot::channel();
452            let hook2 = pool.push(async move {
453                receiver.await.unwrap();
454                2
455            });
456            let _hook3 = pool.push(future::ready(3));
457
458            assert_eq!(pool.len(), 3);
459
460            drop(hook2);
461
462            let mut results = Vec::new();
463            for _ in 0..3 {
464                let result = pool.next_completed().await;
465                results.push(result);
466            }
467
468            let successful: Vec<_> = results.iter().filter_map(|r| r.as_ref().ok()).collect();
469            let aborted: Vec<_> = results.iter().filter(|r| r.is_err()).collect();
470
471            assert_eq!(successful.len(), 2);
472            assert_eq!(aborted.len(), 1);
473            assert!(successful.contains(&&1));
474            assert!(successful.contains(&&3));
475            assert!(pool.is_empty());
476
477            let _ = sender.send(());
478        });
479    }
480
481    #[test]
482    fn test_closed_on_receiver_drop() {
483        block_on(async {
484            let (mut tx, rx) = oneshot::channel::<i32>();
485
486            let closed = tx.closed();
487            drop(rx);
488
489            closed.await;
490        });
491    }
492
493    #[test]
494    fn test_closed_pending_when_receiver_alive() {
495        block_on(async {
496            let (mut tx, rx) = oneshot::channel::<i32>();
497
498            let closed = tx.closed();
499            let timeout = delay(Duration::from_millis(500));
500
501            pin_mut!(closed);
502            pin_mut!(timeout);
503
504            match select(closed, timeout).await {
505                Either::Left(_) => panic!("Closed resolved while receiver still alive"),
506                Either::Right(_) => {}
507            }
508
509            drop(rx);
510        });
511    }
512
513    #[test]
514    fn test_closed_multiple_polls() {
515        block_on(async {
516            let (mut tx, rx) = oneshot::channel::<i32>();
517
518            // Setup the closed future
519            let closed = tx.closed();
520            pin_mut!(closed);
521
522            // Poll the closed future
523            let waker = futures::task::noop_waker();
524            let mut cx = std::task::Context::from_waker(&waker);
525            assert!(closed.as_mut().poll(&mut cx).is_pending());
526
527            // Drop receiver
528            drop(rx);
529
530            // Now poll should be ready
531            assert!(closed.as_mut().poll(&mut cx).is_ready());
532        });
533    }
534}