commonware_utils/
futures.rs

1//! Utilities for working with futures.
2
3use futures::{
4    future,
5    stream::{FuturesUnordered, SelectNextSome},
6    StreamExt,
7};
8use std::{future::Future, pin::Pin};
9
10/// A future type that can be used in `Pool`.
11type PooledFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
12
13/// An unordered pool of futures.
14///
15/// Futures can be added to the pool, and removed from the pool as they resolve.
16///
17/// **Note:** This pool is not thread-safe and should not be used across threads without external
18/// synchronization.
19pub struct Pool<T> {
20    pool: FuturesUnordered<PooledFuture<T>>,
21}
22
23impl<T: Send> Default for Pool<T> {
24    fn default() -> Self {
25        // Insert a dummy future (that never resolves) to prevent the stream from being empty.
26        // Else, the `select_next_some()` function returns `None` instantly.
27        let pool = FuturesUnordered::new();
28        pool.push(Self::create_dummy_future());
29        Self { pool }
30    }
31}
32
33impl<T: Send> Pool<T> {
34    /// Returns the number of futures in the pool.
35    pub fn len(&self) -> usize {
36        // Subtract the dummy future.
37        self.pool.len().checked_sub(1).unwrap()
38    }
39
40    /// Returns `true` if the pool is empty.
41    pub fn is_empty(&self) -> bool {
42        self.len() == 0
43    }
44
45    /// Adds a future to the pool.
46    ///
47    /// The future must be `'static` and `Send` to ensure it can be safely stored and executed.
48    pub fn push(&mut self, future: impl Future<Output = T> + Send + 'static) {
49        self.pool.push(Box::pin(future));
50    }
51
52    /// Returns a futures that resolves to the next future in the pool that resolves.
53    ///
54    /// If the pool is empty, the future will never resolve.
55    pub fn next_completed(&mut self) -> SelectNextSome<'_, FuturesUnordered<PooledFuture<T>>> {
56        self.pool.select_next_some()
57    }
58
59    /// Cancels all futures in the pool.
60    ///
61    /// Excludes the dummy future.
62    pub fn cancel_all(&mut self) {
63        self.pool.clear();
64        self.pool.push(Self::create_dummy_future());
65    }
66
67    /// Creates a dummy future that never resolves.
68    fn create_dummy_future() -> PooledFuture<T> {
69        Box::pin(async { future::pending::<T>().await })
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76    use futures::{
77        channel::oneshot,
78        executor::block_on,
79        future::{self, select, Either},
80        pin_mut, FutureExt,
81    };
82    use std::{
83        sync::{
84            atomic::{AtomicBool, Ordering},
85            Arc,
86        },
87        thread,
88        time::Duration,
89    };
90
91    /// A future that resolves after a given duration.
92    fn delay(duration: Duration) -> impl Future<Output = ()> {
93        let (sender, receiver) = oneshot::channel();
94        thread::spawn(move || {
95            thread::sleep(duration);
96            sender.send(()).unwrap();
97        });
98        receiver.map(|_| ())
99    }
100
101    #[test]
102    fn test_initialization() {
103        let pool = Pool::<i32>::default();
104        assert_eq!(pool.len(), 0);
105        assert!(pool.is_empty());
106    }
107
108    #[test]
109    fn test_dummy_future_doesnt_resolve() {
110        block_on(async {
111            let mut pool = Pool::<i32>::default();
112            let stream_future = pool.next_completed();
113            let timeout_future = async {
114                delay(Duration::from_millis(100)).await;
115            };
116            pin_mut!(stream_future);
117            pin_mut!(timeout_future);
118            let result = select(stream_future, timeout_future).await;
119            match result {
120                Either::Left((_, _)) => panic!("Stream resolved unexpectedly"),
121                Either::Right((_, _)) => {
122                    // Timeout occurred, which is expected
123                }
124            }
125        });
126    }
127
128    #[test]
129    fn test_adding_futures() {
130        let mut pool = Pool::<i32>::default();
131        assert_eq!(pool.len(), 0);
132        assert!(pool.is_empty());
133
134        pool.push(async { 42 });
135        assert_eq!(pool.len(), 1);
136        assert!(!pool.is_empty(),);
137
138        pool.push(async { 43 });
139        assert_eq!(pool.len(), 2,);
140    }
141
142    #[test]
143    fn test_streaming_resolved_futures() {
144        block_on(async move {
145            let mut pool = Pool::<i32>::default();
146            pool.push(future::ready(42));
147            let result = pool.next_completed().await;
148            assert_eq!(result, 42,);
149            assert!(pool.is_empty(),);
150        });
151    }
152
153    #[test]
154    fn test_multiple_futures() {
155        block_on(async move {
156            let mut pool = Pool::<i32>::default();
157
158            // Futures resolve in order of completion, not addition order
159            let (finisher_1, finished_1) = oneshot::channel();
160            let (finisher_3, finished_3) = oneshot::channel();
161            pool.push(async move {
162                finished_1.await.unwrap();
163                finisher_3.send(()).unwrap();
164                1
165            });
166            pool.push(async move {
167                finisher_1.send(()).unwrap();
168                2
169            });
170            pool.push(async move {
171                finished_3.await.unwrap();
172                3
173            });
174
175            let first = pool.next_completed().await;
176            assert_eq!(first, 2, "First resolved should be 2");
177            let second = pool.next_completed().await;
178            assert_eq!(second, 1, "Second resolved should be 1");
179            let third = pool.next_completed().await;
180            assert_eq!(third, 3, "Third resolved should be 3");
181            assert!(pool.is_empty(),);
182        });
183    }
184
185    #[test]
186    fn test_cancel_all() {
187        block_on(async move {
188            let flag = Arc::new(AtomicBool::new(false));
189            let flag_clone = flag.clone();
190            let mut pool = Pool::<i32>::default();
191
192            // Push a future that will set the flag to true when it resolves.
193            let (finisher, finished) = oneshot::channel();
194            pool.push(async move {
195                finished.await.unwrap();
196                flag_clone.store(true, Ordering::SeqCst);
197                42
198            });
199            assert_eq!(pool.len(), 1);
200
201            // Cancel all futures.
202            pool.cancel_all();
203            assert!(pool.is_empty());
204            assert!(!flag.load(Ordering::SeqCst));
205
206            // Send the finisher signal (should be ignored).
207            let _ = finisher.send(());
208
209            // Stream should not resolve future after cancellation.
210            let stream_future = pool.next_completed();
211            let timeout_future = async {
212                delay(Duration::from_millis(100)).await;
213            };
214            pin_mut!(stream_future);
215            pin_mut!(timeout_future);
216            let result = select(stream_future, timeout_future).await;
217            match result {
218                Either::Left((_, _)) => panic!("Stream resolved after cancellation"),
219                Either::Right((_, _)) => {
220                    // Wait for the timeout to trigger.
221                }
222            }
223            assert!(!flag.load(Ordering::SeqCst));
224
225            // Push and await a new future.
226            pool.push(future::ready(42));
227            assert_eq!(pool.len(), 1);
228            let result = pool.next_completed().await;
229            assert_eq!(result, 42);
230            assert!(pool.is_empty());
231        });
232    }
233
234    #[test]
235    fn test_many_futures() {
236        block_on(async move {
237            let mut pool = Pool::<i32>::default();
238            let num_futures = 1000;
239            for i in 0..num_futures {
240                pool.push(future::ready(i));
241            }
242            assert_eq!(pool.len(), num_futures as usize);
243
244            let mut sum = 0;
245            for _ in 0..num_futures {
246                let value = pool.next_completed().await;
247                sum += value;
248            }
249            let expected_sum = (0..num_futures).sum::<i32>();
250            assert_eq!(
251                sum, expected_sum,
252                "Sum of resolved values should match expected"
253            );
254            assert!(
255                pool.is_empty(),
256                "Pool should be empty after all futures resolve"
257            );
258        });
259    }
260}