Skip to main content

commonware_utils/
futures.rs

1//! Utilities for working with futures.
2
3use core::ops::{Deref, DerefMut};
4use futures::{
5    future::{self, AbortHandle, Abortable, Aborted},
6    stream::{FuturesUnordered, SelectNextSome},
7    StreamExt,
8};
9use pin_project::pin_project;
10use std::{future::Future, pin::Pin, task::Poll};
11
12/// A future type that can be used in `Pool`.
13type PooledFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
14
15/// An unordered pool of futures.
16///
17/// Futures can be added to the pool, and removed from the pool as they resolve.
18///
19/// **Note:** This pool is not thread-safe and should not be used across threads without external
20/// synchronization.
21pub struct Pool<T> {
22    pool: FuturesUnordered<PooledFuture<T>>,
23}
24
25impl<T: Send> Default for Pool<T> {
26    fn default() -> Self {
27        // Insert a dummy future (that never resolves) to prevent the stream from being empty.
28        // Else, the `select_next_some()` function returns `None` instantly.
29        let pool = FuturesUnordered::new();
30        pool.push(Self::create_dummy_future());
31        Self { pool }
32    }
33}
34
35impl<T: Send> Pool<T> {
36    /// Returns the number of futures in the pool.
37    pub fn len(&self) -> usize {
38        // Subtract the dummy future.
39        self.pool.len().checked_sub(1).unwrap()
40    }
41
42    /// Returns `true` if the pool is empty.
43    pub fn is_empty(&self) -> bool {
44        self.len() == 0
45    }
46
47    /// Adds a future to the pool.
48    ///
49    /// The future must be `'static` and `Send` to ensure it can be safely stored and executed.
50    pub fn push(&mut self, future: impl Future<Output = T> + Send + 'static) {
51        self.pool.push(Box::pin(future));
52    }
53
54    /// Returns a futures that resolves to the next future in the pool that resolves.
55    ///
56    /// If the pool is empty, the future will never resolve.
57    pub fn next_completed(&mut self) -> SelectNextSome<'_, FuturesUnordered<PooledFuture<T>>> {
58        self.pool.select_next_some()
59    }
60
61    /// Cancels all futures in the pool.
62    ///
63    /// Excludes the dummy future.
64    pub fn cancel_all(&mut self) {
65        self.pool.clear();
66        self.pool.push(Self::create_dummy_future());
67    }
68
69    /// Creates a dummy future that never resolves.
70    fn create_dummy_future() -> PooledFuture<T> {
71        Box::pin(async { future::pending::<T>().await })
72    }
73}
74
75/// A handle that can be used to abort a specific future in an [AbortablePool].
76///
77/// When the aborter is dropped, the associated future is aborted.
78pub struct Aborter {
79    inner: AbortHandle,
80}
81
82impl Drop for Aborter {
83    fn drop(&mut self) {
84        self.inner.abort();
85    }
86}
87
88/// A future type that can be used in [AbortablePool].
89type AbortablePooledFuture<T> = Pin<Box<dyn Future<Output = Result<T, Aborted>> + Send>>;
90
91/// An unordered pool of futures that can be individually aborted.
92///
93/// Each future added to the pool returns an [Aborter]. When the aborter is dropped,
94/// the associated future is aborted.
95///
96/// **Note:** This pool is not thread-safe and should not be used across threads without external
97/// synchronization.
98pub struct AbortablePool<T> {
99    pool: FuturesUnordered<AbortablePooledFuture<T>>,
100}
101
102impl<T: Send> Default for AbortablePool<T> {
103    fn default() -> Self {
104        // Insert a dummy future (that never resolves) to prevent the stream from being empty.
105        // Else, the `select_next_some()` function returns `None` instantly.
106        let pool = FuturesUnordered::new();
107        pool.push(Self::create_dummy_future());
108        Self { pool }
109    }
110}
111
112impl<T: Send> AbortablePool<T> {
113    /// Returns the number of futures in the pool.
114    pub fn len(&self) -> usize {
115        // Subtract the dummy future.
116        self.pool.len().checked_sub(1).unwrap()
117    }
118
119    /// Returns `true` if the pool is empty.
120    pub fn is_empty(&self) -> bool {
121        self.len() == 0
122    }
123
124    /// Adds a future to the pool and returns an [Aborter] that can be used to abort it.
125    ///
126    /// The future must be `'static` and `Send` to ensure it can be safely stored and executed.
127    /// When the returned [Aborter] is dropped, the future will be aborted.
128    pub fn push(&mut self, future: impl Future<Output = T> + Send + 'static) -> Aborter {
129        let (handle, registration) = AbortHandle::new_pair();
130        let abortable_future = Abortable::new(future, registration);
131        self.pool.push(Box::pin(abortable_future));
132        Aborter { inner: handle }
133    }
134
135    /// Returns a future that resolves to the next future in the pool that resolves.
136    ///
137    /// If the pool is empty, the future will never resolve.
138    /// Returns `Ok(T)` for successful completion or `Err(Aborted)` for aborted futures.
139    pub fn next_completed(
140        &mut self,
141    ) -> SelectNextSome<'_, FuturesUnordered<AbortablePooledFuture<T>>> {
142        self.pool.select_next_some()
143    }
144
145    /// Creates a dummy future that never resolves.
146    fn create_dummy_future() -> AbortablePooledFuture<T> {
147        Box::pin(async { Ok(future::pending::<T>().await) })
148    }
149}
150
151/// An optional future that yields [Poll::Pending] when [None]. Useful within `select!` macros,
152/// where a future may be conditionally present.
153///
154/// Not to be confused with [futures::future::OptionFuture], which resolves to [None] immediately
155/// when the inner future is `None`.
156#[pin_project]
157pub struct OptionFuture<F: Future>(#[pin] Option<F>);
158
159impl<F: Future> Default for OptionFuture<F> {
160    fn default() -> Self {
161        Self(None)
162    }
163}
164
165impl<F: Future> From<Option<F>> for OptionFuture<F> {
166    fn from(opt: Option<F>) -> Self {
167        Self(opt)
168    }
169}
170
171impl<F: Future> Deref for OptionFuture<F> {
172    type Target = Option<F>;
173
174    fn deref(&self) -> &Self::Target {
175        &self.0
176    }
177}
178
179impl<F: Future> DerefMut for OptionFuture<F> {
180    fn deref_mut(&mut self) -> &mut Self::Target {
181        &mut self.0
182    }
183}
184
185impl<F: Future> Future for OptionFuture<F> {
186    type Output = F::Output;
187
188    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
189        let this = self.project();
190        this.0
191            .as_pin_mut()
192            .map_or_else(|| Poll::Pending, |fut| fut.poll(cx))
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use crate::channel::oneshot;
200    use futures::{
201        executor::block_on,
202        future::{self, select, Either},
203        pin_mut,
204    };
205    use std::{
206        sync::{
207            atomic::{AtomicBool, Ordering},
208            Arc,
209        },
210        thread,
211        time::Duration,
212    };
213
214    /// A future that resolves after a given duration.
215    fn delay(duration: Duration) -> impl Future<Output = ()> {
216        let (sender, receiver) = oneshot::channel();
217        thread::spawn(move || {
218            thread::sleep(duration);
219            sender.send(()).unwrap();
220        });
221        async move {
222            let _ = receiver.await;
223        }
224    }
225
226    #[test]
227    fn test_initialization() {
228        let pool = Pool::<i32>::default();
229        assert_eq!(pool.len(), 0);
230        assert!(pool.is_empty());
231    }
232
233    #[test]
234    fn test_dummy_future_doesnt_resolve() {
235        block_on(async {
236            let mut pool = Pool::<i32>::default();
237            let stream_future = pool.next_completed();
238            let timeout_future = async {
239                delay(Duration::from_millis(100)).await;
240            };
241            pin_mut!(stream_future);
242            pin_mut!(timeout_future);
243            let result = select(stream_future, timeout_future).await;
244            match result {
245                Either::Left((_, _)) => panic!("Stream resolved unexpectedly"),
246                Either::Right((_, _)) => {
247                    // Timeout occurred, which is expected
248                }
249            }
250        });
251    }
252
253    #[test]
254    fn test_adding_futures() {
255        let mut pool = Pool::<i32>::default();
256        assert_eq!(pool.len(), 0);
257        assert!(pool.is_empty());
258
259        pool.push(async { 42 });
260        assert_eq!(pool.len(), 1);
261        assert!(!pool.is_empty(),);
262
263        pool.push(async { 43 });
264        assert_eq!(pool.len(), 2,);
265    }
266
267    #[test]
268    fn test_streaming_resolved_futures() {
269        block_on(async move {
270            let mut pool = Pool::<i32>::default();
271            pool.push(future::ready(42));
272            let result = pool.next_completed().await;
273            assert_eq!(result, 42,);
274            assert!(pool.is_empty(),);
275        });
276    }
277
278    #[test]
279    fn test_multiple_futures() {
280        block_on(async move {
281            let mut pool = Pool::<i32>::default();
282
283            // Futures resolve in order of completion, not addition order
284            let (finisher_1, finished_1) = oneshot::channel();
285            let (finisher_3, finished_3) = oneshot::channel();
286            pool.push(async move {
287                finished_1.await.unwrap();
288                finisher_3.send(()).unwrap();
289                1
290            });
291            pool.push(async move {
292                finisher_1.send(()).unwrap();
293                2
294            });
295            pool.push(async move {
296                finished_3.await.unwrap();
297                3
298            });
299
300            let first = pool.next_completed().await;
301            assert_eq!(first, 2, "First resolved should be 2");
302            let second = pool.next_completed().await;
303            assert_eq!(second, 1, "Second resolved should be 1");
304            let third = pool.next_completed().await;
305            assert_eq!(third, 3, "Third resolved should be 3");
306            assert!(pool.is_empty(),);
307        });
308    }
309
310    #[test]
311    fn test_cancel_all() {
312        block_on(async move {
313            let flag = Arc::new(AtomicBool::new(false));
314            let flag_clone = flag.clone();
315            let mut pool = Pool::<i32>::default();
316
317            // Push a future that will set the flag to true when it resolves.
318            let (finisher, finished) = oneshot::channel();
319            pool.push(async move {
320                finished.await.unwrap();
321                flag_clone.store(true, Ordering::SeqCst);
322                42
323            });
324            assert_eq!(pool.len(), 1);
325
326            // Cancel all futures.
327            pool.cancel_all();
328            assert!(pool.is_empty());
329            assert!(!flag.load(Ordering::SeqCst));
330
331            // Send the finisher signal (should be ignored).
332            let _ = finisher.send(());
333
334            // Stream should not resolve future after cancellation.
335            let stream_future = pool.next_completed();
336            let timeout_future = async {
337                delay(Duration::from_millis(100)).await;
338            };
339            pin_mut!(stream_future);
340            pin_mut!(timeout_future);
341            let result = select(stream_future, timeout_future).await;
342            match result {
343                Either::Left((_, _)) => panic!("Stream resolved after cancellation"),
344                Either::Right((_, _)) => {
345                    // Wait for the timeout to trigger.
346                }
347            }
348            assert!(!flag.load(Ordering::SeqCst));
349
350            // Push and await a new future.
351            pool.push(future::ready(42));
352            assert_eq!(pool.len(), 1);
353            let result = pool.next_completed().await;
354            assert_eq!(result, 42);
355            assert!(pool.is_empty());
356        });
357    }
358
359    #[test]
360    fn test_many_futures() {
361        block_on(async move {
362            let mut pool = Pool::<i32>::default();
363            let num_futures = 1000;
364            for i in 0..num_futures {
365                pool.push(future::ready(i));
366            }
367            assert_eq!(pool.len(), num_futures as usize);
368
369            let mut sum = 0;
370            for _ in 0..num_futures {
371                let value = pool.next_completed().await;
372                sum += value;
373            }
374            let expected_sum = (0..num_futures).sum::<i32>();
375            assert_eq!(
376                sum, expected_sum,
377                "Sum of resolved values should match expected"
378            );
379            assert!(
380                pool.is_empty(),
381                "Pool should be empty after all futures resolve"
382            );
383        });
384    }
385
386    #[test]
387    fn test_abortable_pool_initialization() {
388        let pool = AbortablePool::<i32>::default();
389        assert_eq!(pool.len(), 0);
390        assert!(pool.is_empty());
391    }
392
393    #[test]
394    fn test_abortable_pool_adding_futures() {
395        let mut pool = AbortablePool::<i32>::default();
396        assert_eq!(pool.len(), 0);
397        assert!(pool.is_empty());
398
399        let _hook1 = pool.push(async { 42 });
400        assert_eq!(pool.len(), 1);
401        assert!(!pool.is_empty());
402
403        let _hook2 = pool.push(async { 43 });
404        assert_eq!(pool.len(), 2);
405    }
406
407    #[test]
408    fn test_abortable_pool_successful_completion() {
409        block_on(async move {
410            let mut pool = AbortablePool::<i32>::default();
411            let _hook = pool.push(future::ready(42));
412            let result = pool.next_completed().await;
413            assert_eq!(result, Ok(42));
414            assert!(pool.is_empty());
415        });
416    }
417
418    #[test]
419    fn test_abortable_pool_drop_abort() {
420        block_on(async move {
421            let mut pool = AbortablePool::<i32>::default();
422
423            let (sender, receiver) = oneshot::channel();
424            let hook = pool.push(async move {
425                receiver.await.unwrap();
426                42
427            });
428
429            drop(hook);
430
431            let result = pool.next_completed().await;
432            assert!(result.is_err());
433            assert!(pool.is_empty());
434
435            let _ = sender.send(());
436        });
437    }
438
439    #[test]
440    fn test_abortable_pool_partial_abort() {
441        block_on(async move {
442            let mut pool = AbortablePool::<i32>::default();
443
444            let _hook1 = pool.push(future::ready(1));
445            let (sender, receiver) = oneshot::channel();
446            let hook2 = pool.push(async move {
447                receiver.await.unwrap();
448                2
449            });
450            let _hook3 = pool.push(future::ready(3));
451
452            assert_eq!(pool.len(), 3);
453
454            drop(hook2);
455
456            let mut results = Vec::new();
457            for _ in 0..3 {
458                let result = pool.next_completed().await;
459                results.push(result);
460            }
461
462            let successful: Vec<_> = results.iter().filter_map(|r| r.as_ref().ok()).collect();
463            let aborted: Vec<_> = results.iter().filter(|r| r.is_err()).collect();
464
465            assert_eq!(successful.len(), 2);
466            assert_eq!(aborted.len(), 1);
467            assert!(successful.contains(&&1));
468            assert!(successful.contains(&&3));
469            assert!(pool.is_empty());
470
471            let _ = sender.send(());
472        });
473    }
474
475    #[test]
476    fn test_option_future() {
477        block_on(async {
478            let option_future = OptionFuture::<oneshot::Receiver<()>>::from(None);
479            pin_mut!(option_future);
480
481            let waker = futures::task::noop_waker();
482            let mut cx = std::task::Context::from_waker(&waker);
483            assert!(option_future.poll(&mut cx).is_pending());
484
485            let (tx, rx) = oneshot::channel();
486            let option_future: OptionFuture<_> = Some(rx).into();
487            pin_mut!(option_future);
488
489            tx.send(1usize).unwrap();
490            assert_eq!(option_future.poll(&mut cx), Poll::Ready(Ok(1)));
491        });
492    }
493}