commonware_utils/
futures.rs1use futures::{
4 future,
5 stream::{FuturesUnordered, SelectNextSome},
6 StreamExt,
7};
8use std::{future::Future, pin::Pin};
9
10type PooledFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
12
13pub struct Pool<T> {
20 pool: FuturesUnordered<PooledFuture<T>>,
21}
22
23impl<T: Send> Default for Pool<T> {
24 fn default() -> Self {
25 let pool = FuturesUnordered::new();
28 pool.push(Self::create_dummy_future());
29 Self { pool }
30 }
31}
32
33impl<T: Send> Pool<T> {
34 pub fn len(&self) -> usize {
36 self.pool.len().checked_sub(1).unwrap()
38 }
39
40 pub fn is_empty(&self) -> bool {
42 self.len() == 0
43 }
44
45 pub fn push(&mut self, future: impl Future<Output = T> + Send + 'static) {
49 self.pool.push(Box::pin(future));
50 }
51
52 pub fn next_completed(&mut self) -> SelectNextSome<'_, FuturesUnordered<PooledFuture<T>>> {
56 self.pool.select_next_some()
57 }
58
59 pub fn cancel_all(&mut self) {
63 self.pool.clear();
64 self.pool.push(Self::create_dummy_future());
65 }
66
67 fn create_dummy_future() -> PooledFuture<T> {
69 Box::pin(async { future::pending::<T>().await })
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76 use futures::{
77 channel::oneshot,
78 executor::block_on,
79 future::{self, select, Either},
80 pin_mut, FutureExt,
81 };
82 use std::{
83 sync::{
84 atomic::{AtomicBool, Ordering},
85 Arc,
86 },
87 thread,
88 time::Duration,
89 };
90
91 fn delay(duration: Duration) -> impl Future<Output = ()> {
93 let (sender, receiver) = oneshot::channel();
94 thread::spawn(move || {
95 thread::sleep(duration);
96 sender.send(()).unwrap();
97 });
98 receiver.map(|_| ())
99 }
100
101 #[test]
102 fn test_initialization() {
103 let pool = Pool::<i32>::default();
104 assert_eq!(pool.len(), 0);
105 assert!(pool.is_empty());
106 }
107
108 #[test]
109 fn test_dummy_future_doesnt_resolve() {
110 block_on(async {
111 let mut pool = Pool::<i32>::default();
112 let stream_future = pool.next_completed();
113 let timeout_future = async {
114 delay(Duration::from_millis(100)).await;
115 };
116 pin_mut!(stream_future);
117 pin_mut!(timeout_future);
118 let result = select(stream_future, timeout_future).await;
119 match result {
120 Either::Left((_, _)) => panic!("Stream resolved unexpectedly"),
121 Either::Right((_, _)) => {
122 }
124 }
125 });
126 }
127
128 #[test]
129 fn test_adding_futures() {
130 let mut pool = Pool::<i32>::default();
131 assert_eq!(pool.len(), 0);
132 assert!(pool.is_empty());
133
134 pool.push(async { 42 });
135 assert_eq!(pool.len(), 1);
136 assert!(!pool.is_empty(),);
137
138 pool.push(async { 43 });
139 assert_eq!(pool.len(), 2,);
140 }
141
142 #[test]
143 fn test_streaming_resolved_futures() {
144 block_on(async move {
145 let mut pool = Pool::<i32>::default();
146 pool.push(future::ready(42));
147 let result = pool.next_completed().await;
148 assert_eq!(result, 42,);
149 assert!(pool.is_empty(),);
150 });
151 }
152
153 #[test]
154 fn test_multiple_futures() {
155 block_on(async move {
156 let mut pool = Pool::<i32>::default();
157
158 let (finisher_1, finished_1) = oneshot::channel();
160 let (finisher_3, finished_3) = oneshot::channel();
161 pool.push(async move {
162 finished_1.await.unwrap();
163 finisher_3.send(()).unwrap();
164 1
165 });
166 pool.push(async move {
167 finisher_1.send(()).unwrap();
168 2
169 });
170 pool.push(async move {
171 finished_3.await.unwrap();
172 3
173 });
174
175 let first = pool.next_completed().await;
176 assert_eq!(first, 2, "First resolved should be 2");
177 let second = pool.next_completed().await;
178 assert_eq!(second, 1, "Second resolved should be 1");
179 let third = pool.next_completed().await;
180 assert_eq!(third, 3, "Third resolved should be 3");
181 assert!(pool.is_empty(),);
182 });
183 }
184
185 #[test]
186 fn test_cancel_all() {
187 block_on(async move {
188 let flag = Arc::new(AtomicBool::new(false));
189 let flag_clone = flag.clone();
190 let mut pool = Pool::<i32>::default();
191
192 let (finisher, finished) = oneshot::channel();
194 pool.push(async move {
195 finished.await.unwrap();
196 flag_clone.store(true, Ordering::SeqCst);
197 42
198 });
199 assert_eq!(pool.len(), 1);
200
201 pool.cancel_all();
203 assert!(pool.is_empty());
204 assert!(!flag.load(Ordering::SeqCst));
205
206 let _ = finisher.send(());
208
209 let stream_future = pool.next_completed();
211 let timeout_future = async {
212 delay(Duration::from_millis(100)).await;
213 };
214 pin_mut!(stream_future);
215 pin_mut!(timeout_future);
216 let result = select(stream_future, timeout_future).await;
217 match result {
218 Either::Left((_, _)) => panic!("Stream resolved after cancellation"),
219 Either::Right((_, _)) => {
220 }
222 }
223 assert!(!flag.load(Ordering::SeqCst));
224
225 pool.push(future::ready(42));
227 assert_eq!(pool.len(), 1);
228 let result = pool.next_completed().await;
229 assert_eq!(result, 42);
230 assert!(pool.is_empty());
231 });
232 }
233
234 #[test]
235 fn test_many_futures() {
236 block_on(async move {
237 let mut pool = Pool::<i32>::default();
238 let num_futures = 1000;
239 for i in 0..num_futures {
240 pool.push(future::ready(i));
241 }
242 assert_eq!(pool.len(), num_futures as usize);
243
244 let mut sum = 0;
245 for _ in 0..num_futures {
246 let value = pool.next_completed().await;
247 sum += value;
248 }
249 let expected_sum = (0..num_futures).sum::<i32>();
250 assert_eq!(
251 sum, expected_sum,
252 "Sum of resolved values should match expected"
253 );
254 assert!(
255 pool.is_empty(),
256 "Pool should be empty after all futures resolve"
257 );
258 });
259 }
260}