commonware_utils/
futures.rs1use futures::{
4 future::{self, AbortHandle, Abortable, Aborted},
5 stream::{FuturesUnordered, SelectNextSome},
6 StreamExt,
7};
8use std::{future::Future, pin::Pin};
9
10type PooledFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
12
13pub struct Pool<T> {
20 pool: FuturesUnordered<PooledFuture<T>>,
21}
22
23impl<T: Send> Default for Pool<T> {
24 fn default() -> Self {
25 let pool = FuturesUnordered::new();
28 pool.push(Self::create_dummy_future());
29 Self { pool }
30 }
31}
32
33impl<T: Send> Pool<T> {
34 pub fn len(&self) -> usize {
36 self.pool.len().checked_sub(1).unwrap()
38 }
39
40 pub fn is_empty(&self) -> bool {
42 self.len() == 0
43 }
44
45 pub fn push(&mut self, future: impl Future<Output = T> + Send + 'static) {
49 self.pool.push(Box::pin(future));
50 }
51
52 pub fn next_completed(&mut self) -> SelectNextSome<'_, FuturesUnordered<PooledFuture<T>>> {
56 self.pool.select_next_some()
57 }
58
59 pub fn cancel_all(&mut self) {
63 self.pool.clear();
64 self.pool.push(Self::create_dummy_future());
65 }
66
67 fn create_dummy_future() -> PooledFuture<T> {
69 Box::pin(async { future::pending::<T>().await })
70 }
71}
72
73pub struct Aborter {
77 inner: AbortHandle,
78}
79
80impl Drop for Aborter {
81 fn drop(&mut self) {
82 self.inner.abort();
83 }
84}
85
86type AbortablePooledFuture<T> = Pin<Box<dyn Future<Output = Result<T, Aborted>> + Send>>;
88
89pub struct AbortablePool<T> {
97 pool: FuturesUnordered<AbortablePooledFuture<T>>,
98}
99
100impl<T: Send> Default for AbortablePool<T> {
101 fn default() -> Self {
102 let pool = FuturesUnordered::new();
105 pool.push(Self::create_dummy_future());
106 Self { pool }
107 }
108}
109
110impl<T: Send> AbortablePool<T> {
111 pub fn len(&self) -> usize {
113 self.pool.len().checked_sub(1).unwrap()
115 }
116
117 pub fn is_empty(&self) -> bool {
119 self.len() == 0
120 }
121
122 pub fn push(&mut self, future: impl Future<Output = T> + Send + 'static) -> Aborter {
127 let (handle, registration) = AbortHandle::new_pair();
128 let abortable_future = Abortable::new(future, registration);
129 self.pool.push(Box::pin(abortable_future));
130 Aborter { inner: handle }
131 }
132
133 pub fn next_completed(
138 &mut self,
139 ) -> SelectNextSome<'_, FuturesUnordered<AbortablePooledFuture<T>>> {
140 self.pool.select_next_some()
141 }
142
143 fn create_dummy_future() -> AbortablePooledFuture<T> {
145 Box::pin(async { Ok(future::pending::<T>().await) })
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 use futures::{
153 channel::oneshot,
154 executor::block_on,
155 future::{self, select, Either},
156 pin_mut, FutureExt,
157 };
158 use std::{
159 sync::{
160 atomic::{AtomicBool, Ordering},
161 Arc,
162 },
163 thread,
164 time::Duration,
165 };
166
167 fn delay(duration: Duration) -> impl Future<Output = ()> {
169 let (sender, receiver) = oneshot::channel();
170 thread::spawn(move || {
171 thread::sleep(duration);
172 sender.send(()).unwrap();
173 });
174 receiver.map(|_| ())
175 }
176
177 #[test]
178 fn test_initialization() {
179 let pool = Pool::<i32>::default();
180 assert_eq!(pool.len(), 0);
181 assert!(pool.is_empty());
182 }
183
184 #[test]
185 fn test_dummy_future_doesnt_resolve() {
186 block_on(async {
187 let mut pool = Pool::<i32>::default();
188 let stream_future = pool.next_completed();
189 let timeout_future = async {
190 delay(Duration::from_millis(100)).await;
191 };
192 pin_mut!(stream_future);
193 pin_mut!(timeout_future);
194 let result = select(stream_future, timeout_future).await;
195 match result {
196 Either::Left((_, _)) => panic!("Stream resolved unexpectedly"),
197 Either::Right((_, _)) => {
198 }
200 }
201 });
202 }
203
204 #[test]
205 fn test_adding_futures() {
206 let mut pool = Pool::<i32>::default();
207 assert_eq!(pool.len(), 0);
208 assert!(pool.is_empty());
209
210 pool.push(async { 42 });
211 assert_eq!(pool.len(), 1);
212 assert!(!pool.is_empty(),);
213
214 pool.push(async { 43 });
215 assert_eq!(pool.len(), 2,);
216 }
217
218 #[test]
219 fn test_streaming_resolved_futures() {
220 block_on(async move {
221 let mut pool = Pool::<i32>::default();
222 pool.push(future::ready(42));
223 let result = pool.next_completed().await;
224 assert_eq!(result, 42,);
225 assert!(pool.is_empty(),);
226 });
227 }
228
229 #[test]
230 fn test_multiple_futures() {
231 block_on(async move {
232 let mut pool = Pool::<i32>::default();
233
234 let (finisher_1, finished_1) = oneshot::channel();
236 let (finisher_3, finished_3) = oneshot::channel();
237 pool.push(async move {
238 finished_1.await.unwrap();
239 finisher_3.send(()).unwrap();
240 1
241 });
242 pool.push(async move {
243 finisher_1.send(()).unwrap();
244 2
245 });
246 pool.push(async move {
247 finished_3.await.unwrap();
248 3
249 });
250
251 let first = pool.next_completed().await;
252 assert_eq!(first, 2, "First resolved should be 2");
253 let second = pool.next_completed().await;
254 assert_eq!(second, 1, "Second resolved should be 1");
255 let third = pool.next_completed().await;
256 assert_eq!(third, 3, "Third resolved should be 3");
257 assert!(pool.is_empty(),);
258 });
259 }
260
261 #[test]
262 fn test_cancel_all() {
263 block_on(async move {
264 let flag = Arc::new(AtomicBool::new(false));
265 let flag_clone = flag.clone();
266 let mut pool = Pool::<i32>::default();
267
268 let (finisher, finished) = oneshot::channel();
270 pool.push(async move {
271 finished.await.unwrap();
272 flag_clone.store(true, Ordering::SeqCst);
273 42
274 });
275 assert_eq!(pool.len(), 1);
276
277 pool.cancel_all();
279 assert!(pool.is_empty());
280 assert!(!flag.load(Ordering::SeqCst));
281
282 let _ = finisher.send(());
284
285 let stream_future = pool.next_completed();
287 let timeout_future = async {
288 delay(Duration::from_millis(100)).await;
289 };
290 pin_mut!(stream_future);
291 pin_mut!(timeout_future);
292 let result = select(stream_future, timeout_future).await;
293 match result {
294 Either::Left((_, _)) => panic!("Stream resolved after cancellation"),
295 Either::Right((_, _)) => {
296 }
298 }
299 assert!(!flag.load(Ordering::SeqCst));
300
301 pool.push(future::ready(42));
303 assert_eq!(pool.len(), 1);
304 let result = pool.next_completed().await;
305 assert_eq!(result, 42);
306 assert!(pool.is_empty());
307 });
308 }
309
310 #[test]
311 fn test_many_futures() {
312 block_on(async move {
313 let mut pool = Pool::<i32>::default();
314 let num_futures = 1000;
315 for i in 0..num_futures {
316 pool.push(future::ready(i));
317 }
318 assert_eq!(pool.len(), num_futures as usize);
319
320 let mut sum = 0;
321 for _ in 0..num_futures {
322 let value = pool.next_completed().await;
323 sum += value;
324 }
325 let expected_sum = (0..num_futures).sum::<i32>();
326 assert_eq!(
327 sum, expected_sum,
328 "Sum of resolved values should match expected"
329 );
330 assert!(
331 pool.is_empty(),
332 "Pool should be empty after all futures resolve"
333 );
334 });
335 }
336
337 #[test]
338 fn test_abortable_pool_initialization() {
339 let pool = AbortablePool::<i32>::default();
340 assert_eq!(pool.len(), 0);
341 assert!(pool.is_empty());
342 }
343
344 #[test]
345 fn test_abortable_pool_adding_futures() {
346 let mut pool = AbortablePool::<i32>::default();
347 assert_eq!(pool.len(), 0);
348 assert!(pool.is_empty());
349
350 let _hook1 = pool.push(async { 42 });
351 assert_eq!(pool.len(), 1);
352 assert!(!pool.is_empty());
353
354 let _hook2 = pool.push(async { 43 });
355 assert_eq!(pool.len(), 2);
356 }
357
358 #[test]
359 fn test_abortable_pool_successful_completion() {
360 block_on(async move {
361 let mut pool = AbortablePool::<i32>::default();
362 let _hook = pool.push(future::ready(42));
363 let result = pool.next_completed().await;
364 assert_eq!(result, Ok(42));
365 assert!(pool.is_empty());
366 });
367 }
368
369 #[test]
370 fn test_abortable_pool_drop_abort() {
371 block_on(async move {
372 let mut pool = AbortablePool::<i32>::default();
373
374 let (sender, receiver) = oneshot::channel();
375 let hook = pool.push(async move {
376 receiver.await.unwrap();
377 42
378 });
379
380 drop(hook);
381
382 let result = pool.next_completed().await;
383 assert!(result.is_err());
384 assert!(pool.is_empty());
385
386 let _ = sender.send(());
387 });
388 }
389
390 #[test]
391 fn test_abortable_pool_partial_abort() {
392 block_on(async move {
393 let mut pool = AbortablePool::<i32>::default();
394
395 let _hook1 = pool.push(future::ready(1));
396 let (sender, receiver) = oneshot::channel();
397 let hook2 = pool.push(async move {
398 receiver.await.unwrap();
399 2
400 });
401 let _hook3 = pool.push(future::ready(3));
402
403 assert_eq!(pool.len(), 3);
404
405 drop(hook2);
406
407 let mut results = Vec::new();
408 for _ in 0..3 {
409 let result = pool.next_completed().await;
410 results.push(result);
411 }
412
413 let successful: Vec<_> = results.iter().filter_map(|r| r.as_ref().ok()).collect();
414 let aborted: Vec<_> = results.iter().filter(|r| r.is_err()).collect();
415
416 assert_eq!(successful.len(), 2);
417 assert_eq!(aborted.len(), 1);
418 assert!(successful.contains(&&1));
419 assert!(successful.contains(&&3));
420 assert!(pool.is_empty());
421
422 let _ = sender.send(());
423 });
424 }
425}