commonware_utils/
futures.rs

1//! Utilities for working with futures.
2
3use core::ops::{Deref, DerefMut};
4use futures::{
5    channel::oneshot,
6    future::{self, AbortHandle, Abortable, Aborted},
7    stream::{FuturesUnordered, SelectNextSome},
8    StreamExt,
9};
10use pin_project::pin_project;
11use std::{future::Future, pin::Pin, task::Poll};
12
13/// A future type that can be used in `Pool`.
14type PooledFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
15
16/// An unordered pool of futures.
17///
18/// Futures can be added to the pool, and removed from the pool as they resolve.
19///
20/// **Note:** This pool is not thread-safe and should not be used across threads without external
21/// synchronization.
22pub struct Pool<T> {
23    pool: FuturesUnordered<PooledFuture<T>>,
24}
25
26impl<T: Send> Default for Pool<T> {
27    fn default() -> Self {
28        // Insert a dummy future (that never resolves) to prevent the stream from being empty.
29        // Else, the `select_next_some()` function returns `None` instantly.
30        let pool = FuturesUnordered::new();
31        pool.push(Self::create_dummy_future());
32        Self { pool }
33    }
34}
35
36impl<T: Send> Pool<T> {
37    /// Returns the number of futures in the pool.
38    pub fn len(&self) -> usize {
39        // Subtract the dummy future.
40        self.pool.len().checked_sub(1).unwrap()
41    }
42
43    /// Returns `true` if the pool is empty.
44    pub fn is_empty(&self) -> bool {
45        self.len() == 0
46    }
47
48    /// Adds a future to the pool.
49    ///
50    /// The future must be `'static` and `Send` to ensure it can be safely stored and executed.
51    pub fn push(&mut self, future: impl Future<Output = T> + Send + 'static) {
52        self.pool.push(Box::pin(future));
53    }
54
55    /// Returns a futures that resolves to the next future in the pool that resolves.
56    ///
57    /// If the pool is empty, the future will never resolve.
58    pub fn next_completed(&mut self) -> SelectNextSome<'_, FuturesUnordered<PooledFuture<T>>> {
59        self.pool.select_next_some()
60    }
61
62    /// Cancels all futures in the pool.
63    ///
64    /// Excludes the dummy future.
65    pub fn cancel_all(&mut self) {
66        self.pool.clear();
67        self.pool.push(Self::create_dummy_future());
68    }
69
70    /// Creates a dummy future that never resolves.
71    fn create_dummy_future() -> PooledFuture<T> {
72        Box::pin(async { future::pending::<T>().await })
73    }
74}
75
76/// A handle that can be used to abort a specific future in an [AbortablePool].
77///
78/// When the aborter is dropped, the associated future is aborted.
79pub struct Aborter {
80    inner: AbortHandle,
81}
82
83impl Drop for Aborter {
84    fn drop(&mut self) {
85        self.inner.abort();
86    }
87}
88
89/// A future type that can be used in [AbortablePool].
90type AbortablePooledFuture<T> = Pin<Box<dyn Future<Output = Result<T, Aborted>> + Send>>;
91
92/// An unordered pool of futures that can be individually aborted.
93///
94/// Each future added to the pool returns an [Aborter]. When the aborter is dropped,
95/// the associated future is aborted.
96///
97/// **Note:** This pool is not thread-safe and should not be used across threads without external
98/// synchronization.
99pub struct AbortablePool<T> {
100    pool: FuturesUnordered<AbortablePooledFuture<T>>,
101}
102
103impl<T: Send> Default for AbortablePool<T> {
104    fn default() -> Self {
105        // Insert a dummy future (that never resolves) to prevent the stream from being empty.
106        // Else, the `select_next_some()` function returns `None` instantly.
107        let pool = FuturesUnordered::new();
108        pool.push(Self::create_dummy_future());
109        Self { pool }
110    }
111}
112
113impl<T: Send> AbortablePool<T> {
114    /// Returns the number of futures in the pool.
115    pub fn len(&self) -> usize {
116        // Subtract the dummy future.
117        self.pool.len().checked_sub(1).unwrap()
118    }
119
120    /// Returns `true` if the pool is empty.
121    pub fn is_empty(&self) -> bool {
122        self.len() == 0
123    }
124
125    /// Adds a future to the pool and returns an [Aborter] that can be used to abort it.
126    ///
127    /// The future must be `'static` and `Send` to ensure it can be safely stored and executed.
128    /// When the returned [Aborter] is dropped, the future will be aborted.
129    pub fn push(&mut self, future: impl Future<Output = T> + Send + 'static) -> Aborter {
130        let (handle, registration) = AbortHandle::new_pair();
131        let abortable_future = Abortable::new(future, registration);
132        self.pool.push(Box::pin(abortable_future));
133        Aborter { inner: handle }
134    }
135
136    /// Returns a future that resolves to the next future in the pool that resolves.
137    ///
138    /// If the pool is empty, the future will never resolve.
139    /// Returns `Ok(T)` for successful completion or `Err(Aborted)` for aborted futures.
140    pub fn next_completed(
141        &mut self,
142    ) -> SelectNextSome<'_, FuturesUnordered<AbortablePooledFuture<T>>> {
143        self.pool.select_next_some()
144    }
145
146    /// Creates a dummy future that never resolves.
147    fn create_dummy_future() -> AbortablePooledFuture<T> {
148        Box::pin(async { Ok(future::pending::<T>().await) })
149    }
150}
151
152/// A future that resolves when a [oneshot::Receiver] is dropped.
153///
154/// This future completes when the receiver end of the channel is dropped,
155/// allowing the caller to detect when the other side is no longer interested
156/// in the result.
157pub struct Closed<'a, T> {
158    sender: &'a mut oneshot::Sender<T>,
159}
160
161impl<'a, T> Closed<'a, T> {
162    /// Creates a new future that resolves when the receiver is dropped.
163    pub const fn new(sender: &'a mut oneshot::Sender<T>) -> Self {
164        Self { sender }
165    }
166}
167
168impl<T> Future for Closed<'_, T> {
169    type Output = ();
170
171    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
172        match self.sender.poll_canceled(cx) {
173            Poll::Ready(()) => Poll::Ready(()),
174            Poll::Pending => Poll::Pending,
175        }
176    }
177}
178
179/// Extension trait to detect when a [oneshot::Receiver] is dropped.
180pub trait ClosedExt<T> {
181    /// Returns a future that resolves when the receiver is dropped.
182    ///
183    /// # Examples
184    ///
185    /// ```
186    /// use futures::channel::oneshot;
187    /// use commonware_utils::futures::ClosedExt;
188    ///
189    /// # futures::executor::block_on(async {
190    /// let (mut tx, rx) = oneshot::channel::<i32>();
191    ///
192    /// let closed = tx.closed();
193    /// drop(rx);
194    /// closed.await;
195    /// # });
196    /// ```
197    fn closed(&mut self) -> Closed<'_, T>;
198}
199
200impl<T> ClosedExt<T> for oneshot::Sender<T> {
201    fn closed(&mut self) -> Closed<'_, T> {
202        Closed::new(self)
203    }
204}
205
206/// An optional future that yields [Poll::Pending] when [None]. Useful within `select!` macros,
207/// where a future may be conditionally present.
208///
209/// Not to be confused with [futures::future::OptionFuture], which resolves to [None] immediately
210/// when the inner future is `None`.
211#[pin_project]
212pub struct OptionFuture<F: Future>(#[pin] Option<F>);
213
214impl<F: Future> Default for OptionFuture<F> {
215    fn default() -> Self {
216        Self(None)
217    }
218}
219
220impl<F: Future> From<Option<F>> for OptionFuture<F> {
221    fn from(opt: Option<F>) -> Self {
222        Self(opt)
223    }
224}
225
226impl<F: Future> Deref for OptionFuture<F> {
227    type Target = Option<F>;
228
229    fn deref(&self) -> &Self::Target {
230        &self.0
231    }
232}
233
234impl<F: Future> DerefMut for OptionFuture<F> {
235    fn deref_mut(&mut self) -> &mut Self::Target {
236        &mut self.0
237    }
238}
239
240impl<F: Future> Future for OptionFuture<F> {
241    type Output = F::Output;
242
243    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
244        let this = self.project();
245        this.0
246            .as_pin_mut()
247            .map_or_else(|| Poll::Pending, |fut| fut.poll(cx))
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use futures::{
255        channel::oneshot,
256        executor::block_on,
257        future::{self, select, Either},
258        pin_mut, FutureExt,
259    };
260    use std::{
261        sync::{
262            atomic::{AtomicBool, Ordering},
263            Arc,
264        },
265        thread,
266        time::Duration,
267    };
268
269    /// A future that resolves after a given duration.
270    fn delay(duration: Duration) -> impl Future<Output = ()> {
271        let (sender, receiver) = oneshot::channel();
272        thread::spawn(move || {
273            thread::sleep(duration);
274            sender.send(()).unwrap();
275        });
276        receiver.map(|_| ())
277    }
278
279    #[test]
280    fn test_initialization() {
281        let pool = Pool::<i32>::default();
282        assert_eq!(pool.len(), 0);
283        assert!(pool.is_empty());
284    }
285
286    #[test]
287    fn test_dummy_future_doesnt_resolve() {
288        block_on(async {
289            let mut pool = Pool::<i32>::default();
290            let stream_future = pool.next_completed();
291            let timeout_future = async {
292                delay(Duration::from_millis(100)).await;
293            };
294            pin_mut!(stream_future);
295            pin_mut!(timeout_future);
296            let result = select(stream_future, timeout_future).await;
297            match result {
298                Either::Left((_, _)) => panic!("Stream resolved unexpectedly"),
299                Either::Right((_, _)) => {
300                    // Timeout occurred, which is expected
301                }
302            }
303        });
304    }
305
306    #[test]
307    fn test_adding_futures() {
308        let mut pool = Pool::<i32>::default();
309        assert_eq!(pool.len(), 0);
310        assert!(pool.is_empty());
311
312        pool.push(async { 42 });
313        assert_eq!(pool.len(), 1);
314        assert!(!pool.is_empty(),);
315
316        pool.push(async { 43 });
317        assert_eq!(pool.len(), 2,);
318    }
319
320    #[test]
321    fn test_streaming_resolved_futures() {
322        block_on(async move {
323            let mut pool = Pool::<i32>::default();
324            pool.push(future::ready(42));
325            let result = pool.next_completed().await;
326            assert_eq!(result, 42,);
327            assert!(pool.is_empty(),);
328        });
329    }
330
331    #[test]
332    fn test_multiple_futures() {
333        block_on(async move {
334            let mut pool = Pool::<i32>::default();
335
336            // Futures resolve in order of completion, not addition order
337            let (finisher_1, finished_1) = oneshot::channel();
338            let (finisher_3, finished_3) = oneshot::channel();
339            pool.push(async move {
340                finished_1.await.unwrap();
341                finisher_3.send(()).unwrap();
342                1
343            });
344            pool.push(async move {
345                finisher_1.send(()).unwrap();
346                2
347            });
348            pool.push(async move {
349                finished_3.await.unwrap();
350                3
351            });
352
353            let first = pool.next_completed().await;
354            assert_eq!(first, 2, "First resolved should be 2");
355            let second = pool.next_completed().await;
356            assert_eq!(second, 1, "Second resolved should be 1");
357            let third = pool.next_completed().await;
358            assert_eq!(third, 3, "Third resolved should be 3");
359            assert!(pool.is_empty(),);
360        });
361    }
362
363    #[test]
364    fn test_cancel_all() {
365        block_on(async move {
366            let flag = Arc::new(AtomicBool::new(false));
367            let flag_clone = flag.clone();
368            let mut pool = Pool::<i32>::default();
369
370            // Push a future that will set the flag to true when it resolves.
371            let (finisher, finished) = oneshot::channel();
372            pool.push(async move {
373                finished.await.unwrap();
374                flag_clone.store(true, Ordering::SeqCst);
375                42
376            });
377            assert_eq!(pool.len(), 1);
378
379            // Cancel all futures.
380            pool.cancel_all();
381            assert!(pool.is_empty());
382            assert!(!flag.load(Ordering::SeqCst));
383
384            // Send the finisher signal (should be ignored).
385            let _ = finisher.send(());
386
387            // Stream should not resolve future after cancellation.
388            let stream_future = pool.next_completed();
389            let timeout_future = async {
390                delay(Duration::from_millis(100)).await;
391            };
392            pin_mut!(stream_future);
393            pin_mut!(timeout_future);
394            let result = select(stream_future, timeout_future).await;
395            match result {
396                Either::Left((_, _)) => panic!("Stream resolved after cancellation"),
397                Either::Right((_, _)) => {
398                    // Wait for the timeout to trigger.
399                }
400            }
401            assert!(!flag.load(Ordering::SeqCst));
402
403            // Push and await a new future.
404            pool.push(future::ready(42));
405            assert_eq!(pool.len(), 1);
406            let result = pool.next_completed().await;
407            assert_eq!(result, 42);
408            assert!(pool.is_empty());
409        });
410    }
411
412    #[test]
413    fn test_many_futures() {
414        block_on(async move {
415            let mut pool = Pool::<i32>::default();
416            let num_futures = 1000;
417            for i in 0..num_futures {
418                pool.push(future::ready(i));
419            }
420            assert_eq!(pool.len(), num_futures as usize);
421
422            let mut sum = 0;
423            for _ in 0..num_futures {
424                let value = pool.next_completed().await;
425                sum += value;
426            }
427            let expected_sum = (0..num_futures).sum::<i32>();
428            assert_eq!(
429                sum, expected_sum,
430                "Sum of resolved values should match expected"
431            );
432            assert!(
433                pool.is_empty(),
434                "Pool should be empty after all futures resolve"
435            );
436        });
437    }
438
439    #[test]
440    fn test_abortable_pool_initialization() {
441        let pool = AbortablePool::<i32>::default();
442        assert_eq!(pool.len(), 0);
443        assert!(pool.is_empty());
444    }
445
446    #[test]
447    fn test_abortable_pool_adding_futures() {
448        let mut pool = AbortablePool::<i32>::default();
449        assert_eq!(pool.len(), 0);
450        assert!(pool.is_empty());
451
452        let _hook1 = pool.push(async { 42 });
453        assert_eq!(pool.len(), 1);
454        assert!(!pool.is_empty());
455
456        let _hook2 = pool.push(async { 43 });
457        assert_eq!(pool.len(), 2);
458    }
459
460    #[test]
461    fn test_abortable_pool_successful_completion() {
462        block_on(async move {
463            let mut pool = AbortablePool::<i32>::default();
464            let _hook = pool.push(future::ready(42));
465            let result = pool.next_completed().await;
466            assert_eq!(result, Ok(42));
467            assert!(pool.is_empty());
468        });
469    }
470
471    #[test]
472    fn test_abortable_pool_drop_abort() {
473        block_on(async move {
474            let mut pool = AbortablePool::<i32>::default();
475
476            let (sender, receiver) = oneshot::channel();
477            let hook = pool.push(async move {
478                receiver.await.unwrap();
479                42
480            });
481
482            drop(hook);
483
484            let result = pool.next_completed().await;
485            assert!(result.is_err());
486            assert!(pool.is_empty());
487
488            let _ = sender.send(());
489        });
490    }
491
492    #[test]
493    fn test_abortable_pool_partial_abort() {
494        block_on(async move {
495            let mut pool = AbortablePool::<i32>::default();
496
497            let _hook1 = pool.push(future::ready(1));
498            let (sender, receiver) = oneshot::channel();
499            let hook2 = pool.push(async move {
500                receiver.await.unwrap();
501                2
502            });
503            let _hook3 = pool.push(future::ready(3));
504
505            assert_eq!(pool.len(), 3);
506
507            drop(hook2);
508
509            let mut results = Vec::new();
510            for _ in 0..3 {
511                let result = pool.next_completed().await;
512                results.push(result);
513            }
514
515            let successful: Vec<_> = results.iter().filter_map(|r| r.as_ref().ok()).collect();
516            let aborted: Vec<_> = results.iter().filter(|r| r.is_err()).collect();
517
518            assert_eq!(successful.len(), 2);
519            assert_eq!(aborted.len(), 1);
520            assert!(successful.contains(&&1));
521            assert!(successful.contains(&&3));
522            assert!(pool.is_empty());
523
524            let _ = sender.send(());
525        });
526    }
527
528    #[test]
529    fn test_closed_on_receiver_drop() {
530        block_on(async {
531            let (mut tx, rx) = oneshot::channel::<i32>();
532
533            let closed = tx.closed();
534            drop(rx);
535
536            closed.await;
537        });
538    }
539
540    #[test]
541    fn test_closed_pending_when_receiver_alive() {
542        block_on(async {
543            let (mut tx, rx) = oneshot::channel::<i32>();
544
545            let closed = tx.closed();
546            let timeout = delay(Duration::from_millis(500));
547
548            pin_mut!(closed);
549            pin_mut!(timeout);
550
551            match select(closed, timeout).await {
552                Either::Left(_) => panic!("Closed resolved while receiver still alive"),
553                Either::Right(_) => {}
554            }
555
556            drop(rx);
557        });
558    }
559
560    #[test]
561    fn test_closed_multiple_polls() {
562        block_on(async {
563            let (mut tx, rx) = oneshot::channel::<i32>();
564
565            // Setup the closed future
566            let closed = tx.closed();
567            pin_mut!(closed);
568
569            // Poll the closed future
570            let waker = futures::task::noop_waker();
571            let mut cx = std::task::Context::from_waker(&waker);
572            assert!(closed.as_mut().poll(&mut cx).is_pending());
573
574            // Drop receiver
575            drop(rx);
576
577            // Now poll should be ready
578            assert!(closed.as_mut().poll(&mut cx).is_ready());
579        });
580    }
581
582    #[test]
583    fn test_option_future() {
584        block_on(async {
585            let option_future = OptionFuture::<oneshot::Receiver<()>>::from(None);
586            pin_mut!(option_future);
587
588            let waker = futures::task::noop_waker();
589            let mut cx = std::task::Context::from_waker(&waker);
590            assert!(option_future.poll(&mut cx).is_pending());
591
592            let (tx, rx) = oneshot::channel();
593            let option_future: OptionFuture<_> = Some(rx).into();
594            pin_mut!(option_future);
595
596            tx.send(1usize).unwrap();
597            assert_eq!(option_future.poll(&mut cx), Poll::Ready(Ok(1)));
598        });
599    }
600}