commonware_utils/
futures.rs

1//! Utilities for working with futures.
2
3use futures::{
4    future,
5    stream::{FuturesUnordered, SelectNextSome},
6    StreamExt,
7};
8use std::{future::Future, pin::Pin};
9
10/// A future type that can be used in `Pool`.
11type PooledFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
12
13/// An unordered pool of futures.
14///
15/// Futures can be added to the pool, and removed from the pool as they resolve.
16///
17/// **Note:** This pool is not thread-safe and should not be used across threads without external
18/// synchronization.
19pub struct Pool<T> {
20    pool: FuturesUnordered<PooledFuture<T>>,
21}
22
23impl<T: Send> Default for Pool<T> {
24    fn default() -> Self {
25        // Insert a dummy future (that never resolves) to prevent the stream from being empty.
26        // Else, the `select_next_some()` function returns `None` instantly.
27        let pool = FuturesUnordered::new();
28        pool.push(Self::create_dummy_future());
29        Self { pool }
30    }
31}
32
33impl<T: Send> Pool<T> {
34    /// Returns the number of futures in the pool.
35    pub fn len(&self) -> usize {
36        // Subtract the dummy future.
37        self.pool.len().checked_sub(1).unwrap()
38    }
39
40    /// Returns `true` if the pool is empty.
41    pub fn is_empty(&self) -> bool {
42        self.len() == 0
43    }
44
45    /// Adds a future to the pool.
46    ///
47    /// The future must be `'static` and `Send` to ensure it can be safely stored and executed.
48    pub fn push(&mut self, future: impl Future<Output = T> + Send + 'static) {
49        self.pool.push(Box::pin(future));
50    }
51
52    /// Returns a futures that resolves to the next future in the pool that resolves.
53    ///
54    /// If the pool is empty, the future will never resolve.
55    pub fn next_completed(&mut self) -> SelectNextSome<FuturesUnordered<PooledFuture<T>>> {
56        self.pool.select_next_some()
57    }
58
59    /// Cancels all futures in the pool.
60    ///
61    /// Excludes the dummy future.
62    pub fn cancel_all(&mut self) {
63        self.pool.clear();
64        self.pool.push(Self::create_dummy_future());
65    }
66
67    /// Creates a dummy future that never resolves.
68    fn create_dummy_future() -> PooledFuture<T> {
69        Box::pin(async { future::pending::<T>().await })
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76    use futures::{
77        channel::oneshot,
78        executor::block_on,
79        future::{self, select, Either},
80        pin_mut, FutureExt,
81    };
82    use std::{
83        sync::{
84            atomic::{AtomicBool, Ordering},
85            Arc,
86        },
87        thread,
88        time::Duration,
89    };
90
91    /// A future that resolves after a given duration.
92    fn delay(duration: Duration) -> impl Future<Output = ()> {
93        let (sender, receiver) = oneshot::channel();
94        thread::spawn(move || {
95            thread::sleep(duration);
96            sender.send(()).unwrap();
97        });
98        receiver.map(|_| ())
99    }
100
101    #[test]
102    fn test_initialization() {
103        let pool = Pool::<i32>::default();
104        assert_eq!(pool.len(), 0);
105        assert!(pool.is_empty());
106    }
107
108    #[test]
109    fn test_dummy_future_doesnt_resolve() {
110        block_on(async {
111            let mut pool = Pool::<i32>::default();
112            let stream_future = pool.next_completed();
113            let timeout_future = async {
114                delay(Duration::from_millis(100)).await;
115            };
116            pin_mut!(stream_future);
117            pin_mut!(timeout_future);
118            let result = select(stream_future, timeout_future).await;
119            match result {
120                Either::Left((_, _)) => panic!("Stream resolved unexpectedly"),
121                Either::Right((_, _)) => {
122                    // Timeout occurred, which is expected
123                }
124            }
125        });
126    }
127
128    #[test]
129    fn test_adding_futures() {
130        let mut pool = Pool::<i32>::default();
131        assert_eq!(pool.len(), 0);
132        assert!(pool.is_empty());
133
134        pool.push(async { 42 });
135        assert_eq!(pool.len(), 1);
136        assert!(!pool.is_empty(),);
137
138        pool.push(async { 43 });
139        assert_eq!(pool.len(), 2,);
140    }
141
142    #[test]
143    fn test_streaming_resolved_futures() {
144        block_on(async move {
145            let mut pool = Pool::<i32>::default();
146            pool.push(future::ready(42));
147            let result = pool.next_completed().await;
148            assert_eq!(result, 42,);
149            assert!(pool.is_empty(),);
150        });
151    }
152
153    #[test]
154    fn test_multiple_futures() {
155        block_on(async move {
156            let mut pool = Pool::<i32>::default();
157
158            // Futures resolve in order of completion time, not addition order
159            pool.push(async move {
160                delay(Duration::from_millis(100)).await;
161                1
162            });
163            pool.push(async move {
164                delay(Duration::from_millis(50)).await;
165                2
166            });
167            pool.push(async move {
168                delay(Duration::from_millis(150)).await;
169                3
170            });
171
172            let first = pool.next_completed().await;
173            assert_eq!(first, 2, "First resolved should be 2 (50ms)");
174            let second = pool.next_completed().await;
175            assert_eq!(second, 1, "Second resolved should be 1 (100ms)");
176            let third = pool.next_completed().await;
177            assert_eq!(third, 3, "Third resolved should be 3 (150ms)");
178            assert!(pool.is_empty(),);
179        });
180    }
181
182    #[test]
183    fn test_cancel_all() {
184        block_on(async move {
185            let flag = Arc::new(AtomicBool::new(false));
186            let flag_clone = flag.clone();
187            let mut pool = Pool::<i32>::default();
188
189            pool.push(async move {
190                delay(Duration::from_millis(100)).await;
191                flag_clone.store(true, Ordering::SeqCst);
192                42
193            });
194            assert_eq!(pool.len(), 1);
195
196            pool.cancel_all();
197            assert!(pool.is_empty());
198
199            delay(Duration::from_millis(150)).await; // Wait longer than future’s delay
200            assert!(!flag.load(Ordering::SeqCst));
201
202            // Stream should not resolve future after cancellation
203            let stream_future = pool.next_completed();
204            let timeout_future = async {
205                delay(Duration::from_millis(100)).await;
206            };
207            pin_mut!(stream_future);
208            pin_mut!(timeout_future);
209            let result = select(stream_future, timeout_future).await;
210            match result {
211                Either::Left((_, _)) => panic!("Stream resolved after cancellation"),
212                Either::Right((_, _)) => {
213                    // Timeout occurred, which is expected
214                }
215            }
216
217            // Push and await a new future
218            pool.push(future::ready(42));
219            assert_eq!(pool.len(), 1);
220            let result = pool.next_completed().await;
221            assert_eq!(result, 42);
222            assert!(pool.is_empty());
223        });
224    }
225
226    #[test]
227    fn test_many_futures() {
228        block_on(async move {
229            let mut pool = Pool::<i32>::default();
230            let num_futures = 1000;
231            for i in 0..num_futures {
232                pool.push(future::ready(i));
233            }
234            assert_eq!(pool.len(), num_futures as usize);
235
236            let mut sum = 0;
237            for _ in 0..num_futures {
238                let value = pool.next_completed().await;
239                sum += value;
240            }
241            let expected_sum = (0..num_futures).sum::<i32>();
242            assert_eq!(
243                sum, expected_sum,
244                "Sum of resolved values should match expected"
245            );
246            assert!(
247                pool.is_empty(),
248                "Pool should be empty after all futures resolve"
249            );
250        });
251    }
252}