commonware_utils/
futures.rs1use core::ops::{Deref, DerefMut};
4use futures::{
5 future::{self, AbortHandle, Abortable, Aborted},
6 stream::{FuturesUnordered, SelectNextSome},
7 StreamExt,
8};
9use pin_project::pin_project;
10use std::{future::Future, pin::Pin, task::Poll};
11
12type PooledFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
14
15pub struct Pool<T> {
22 pool: FuturesUnordered<PooledFuture<T>>,
23}
24
25impl<T: Send> Default for Pool<T> {
26 fn default() -> Self {
27 let pool = FuturesUnordered::new();
30 pool.push(Self::create_dummy_future());
31 Self { pool }
32 }
33}
34
35impl<T: Send> Pool<T> {
36 pub fn len(&self) -> usize {
38 self.pool.len().checked_sub(1).unwrap()
40 }
41
42 pub fn is_empty(&self) -> bool {
44 self.len() == 0
45 }
46
47 pub fn push(&mut self, future: impl Future<Output = T> + Send + 'static) {
51 self.pool.push(Box::pin(future));
52 }
53
54 pub fn next_completed(&mut self) -> SelectNextSome<'_, FuturesUnordered<PooledFuture<T>>> {
58 self.pool.select_next_some()
59 }
60
61 pub fn cancel_all(&mut self) {
65 self.pool.clear();
66 self.pool.push(Self::create_dummy_future());
67 }
68
69 fn create_dummy_future() -> PooledFuture<T> {
71 Box::pin(async { future::pending::<T>().await })
72 }
73}
74
75pub struct Aborter {
79 inner: AbortHandle,
80}
81
82impl Drop for Aborter {
83 fn drop(&mut self) {
84 self.inner.abort();
85 }
86}
87
88type AbortablePooledFuture<T> = Pin<Box<dyn Future<Output = Result<T, Aborted>> + Send>>;
90
91pub struct AbortablePool<T> {
99 pool: FuturesUnordered<AbortablePooledFuture<T>>,
100}
101
102impl<T: Send> Default for AbortablePool<T> {
103 fn default() -> Self {
104 let pool = FuturesUnordered::new();
107 pool.push(Self::create_dummy_future());
108 Self { pool }
109 }
110}
111
112impl<T: Send> AbortablePool<T> {
113 pub fn len(&self) -> usize {
115 self.pool.len().checked_sub(1).unwrap()
117 }
118
119 pub fn is_empty(&self) -> bool {
121 self.len() == 0
122 }
123
124 pub fn push(&mut self, future: impl Future<Output = T> + Send + 'static) -> Aborter {
129 let (handle, registration) = AbortHandle::new_pair();
130 let abortable_future = Abortable::new(future, registration);
131 self.pool.push(Box::pin(abortable_future));
132 Aborter { inner: handle }
133 }
134
135 pub fn next_completed(
140 &mut self,
141 ) -> SelectNextSome<'_, FuturesUnordered<AbortablePooledFuture<T>>> {
142 self.pool.select_next_some()
143 }
144
145 fn create_dummy_future() -> AbortablePooledFuture<T> {
147 Box::pin(async { Ok(future::pending::<T>().await) })
148 }
149}
150
151#[pin_project]
157pub struct OptionFuture<F: Future>(#[pin] Option<F>);
158
159impl<F: Future> Default for OptionFuture<F> {
160 fn default() -> Self {
161 Self(None)
162 }
163}
164
165impl<F: Future> From<Option<F>> for OptionFuture<F> {
166 fn from(opt: Option<F>) -> Self {
167 Self(opt)
168 }
169}
170
171impl<F: Future> Deref for OptionFuture<F> {
172 type Target = Option<F>;
173
174 fn deref(&self) -> &Self::Target {
175 &self.0
176 }
177}
178
179impl<F: Future> DerefMut for OptionFuture<F> {
180 fn deref_mut(&mut self) -> &mut Self::Target {
181 &mut self.0
182 }
183}
184
185impl<F: Future> Future for OptionFuture<F> {
186 type Output = F::Output;
187
188 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
189 let this = self.project();
190 this.0
191 .as_pin_mut()
192 .map_or_else(|| Poll::Pending, |fut| fut.poll(cx))
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use crate::channel::oneshot;
200 use futures::{
201 executor::block_on,
202 future::{self, select, Either},
203 pin_mut,
204 };
205 use std::{
206 sync::{
207 atomic::{AtomicBool, Ordering},
208 Arc,
209 },
210 thread,
211 time::Duration,
212 };
213
214 fn delay(duration: Duration) -> impl Future<Output = ()> {
216 let (sender, receiver) = oneshot::channel();
217 thread::spawn(move || {
218 thread::sleep(duration);
219 sender.send(()).unwrap();
220 });
221 async move {
222 let _ = receiver.await;
223 }
224 }
225
226 #[test]
227 fn test_initialization() {
228 let pool = Pool::<i32>::default();
229 assert_eq!(pool.len(), 0);
230 assert!(pool.is_empty());
231 }
232
233 #[test]
234 fn test_dummy_future_doesnt_resolve() {
235 block_on(async {
236 let mut pool = Pool::<i32>::default();
237 let stream_future = pool.next_completed();
238 let timeout_future = async {
239 delay(Duration::from_millis(100)).await;
240 };
241 pin_mut!(stream_future);
242 pin_mut!(timeout_future);
243 let result = select(stream_future, timeout_future).await;
244 match result {
245 Either::Left((_, _)) => panic!("Stream resolved unexpectedly"),
246 Either::Right((_, _)) => {
247 }
249 }
250 });
251 }
252
253 #[test]
254 fn test_adding_futures() {
255 let mut pool = Pool::<i32>::default();
256 assert_eq!(pool.len(), 0);
257 assert!(pool.is_empty());
258
259 pool.push(async { 42 });
260 assert_eq!(pool.len(), 1);
261 assert!(!pool.is_empty(),);
262
263 pool.push(async { 43 });
264 assert_eq!(pool.len(), 2,);
265 }
266
267 #[test]
268 fn test_streaming_resolved_futures() {
269 block_on(async move {
270 let mut pool = Pool::<i32>::default();
271 pool.push(future::ready(42));
272 let result = pool.next_completed().await;
273 assert_eq!(result, 42,);
274 assert!(pool.is_empty(),);
275 });
276 }
277
278 #[test]
279 fn test_multiple_futures() {
280 block_on(async move {
281 let mut pool = Pool::<i32>::default();
282
283 let (finisher_1, finished_1) = oneshot::channel();
285 let (finisher_3, finished_3) = oneshot::channel();
286 pool.push(async move {
287 finished_1.await.unwrap();
288 finisher_3.send(()).unwrap();
289 1
290 });
291 pool.push(async move {
292 finisher_1.send(()).unwrap();
293 2
294 });
295 pool.push(async move {
296 finished_3.await.unwrap();
297 3
298 });
299
300 let first = pool.next_completed().await;
301 assert_eq!(first, 2, "First resolved should be 2");
302 let second = pool.next_completed().await;
303 assert_eq!(second, 1, "Second resolved should be 1");
304 let third = pool.next_completed().await;
305 assert_eq!(third, 3, "Third resolved should be 3");
306 assert!(pool.is_empty(),);
307 });
308 }
309
310 #[test]
311 fn test_cancel_all() {
312 block_on(async move {
313 let flag = Arc::new(AtomicBool::new(false));
314 let flag_clone = flag.clone();
315 let mut pool = Pool::<i32>::default();
316
317 let (finisher, finished) = oneshot::channel();
319 pool.push(async move {
320 finished.await.unwrap();
321 flag_clone.store(true, Ordering::SeqCst);
322 42
323 });
324 assert_eq!(pool.len(), 1);
325
326 pool.cancel_all();
328 assert!(pool.is_empty());
329 assert!(!flag.load(Ordering::SeqCst));
330
331 let _ = finisher.send(());
333
334 let stream_future = pool.next_completed();
336 let timeout_future = async {
337 delay(Duration::from_millis(100)).await;
338 };
339 pin_mut!(stream_future);
340 pin_mut!(timeout_future);
341 let result = select(stream_future, timeout_future).await;
342 match result {
343 Either::Left((_, _)) => panic!("Stream resolved after cancellation"),
344 Either::Right((_, _)) => {
345 }
347 }
348 assert!(!flag.load(Ordering::SeqCst));
349
350 pool.push(future::ready(42));
352 assert_eq!(pool.len(), 1);
353 let result = pool.next_completed().await;
354 assert_eq!(result, 42);
355 assert!(pool.is_empty());
356 });
357 }
358
359 #[test]
360 fn test_many_futures() {
361 block_on(async move {
362 let mut pool = Pool::<i32>::default();
363 let num_futures = 1000;
364 for i in 0..num_futures {
365 pool.push(future::ready(i));
366 }
367 assert_eq!(pool.len(), num_futures as usize);
368
369 let mut sum = 0;
370 for _ in 0..num_futures {
371 let value = pool.next_completed().await;
372 sum += value;
373 }
374 let expected_sum = (0..num_futures).sum::<i32>();
375 assert_eq!(
376 sum, expected_sum,
377 "Sum of resolved values should match expected"
378 );
379 assert!(
380 pool.is_empty(),
381 "Pool should be empty after all futures resolve"
382 );
383 });
384 }
385
386 #[test]
387 fn test_abortable_pool_initialization() {
388 let pool = AbortablePool::<i32>::default();
389 assert_eq!(pool.len(), 0);
390 assert!(pool.is_empty());
391 }
392
393 #[test]
394 fn test_abortable_pool_adding_futures() {
395 let mut pool = AbortablePool::<i32>::default();
396 assert_eq!(pool.len(), 0);
397 assert!(pool.is_empty());
398
399 let _hook1 = pool.push(async { 42 });
400 assert_eq!(pool.len(), 1);
401 assert!(!pool.is_empty());
402
403 let _hook2 = pool.push(async { 43 });
404 assert_eq!(pool.len(), 2);
405 }
406
407 #[test]
408 fn test_abortable_pool_successful_completion() {
409 block_on(async move {
410 let mut pool = AbortablePool::<i32>::default();
411 let _hook = pool.push(future::ready(42));
412 let result = pool.next_completed().await;
413 assert_eq!(result, Ok(42));
414 assert!(pool.is_empty());
415 });
416 }
417
418 #[test]
419 fn test_abortable_pool_drop_abort() {
420 block_on(async move {
421 let mut pool = AbortablePool::<i32>::default();
422
423 let (sender, receiver) = oneshot::channel();
424 let hook = pool.push(async move {
425 receiver.await.unwrap();
426 42
427 });
428
429 drop(hook);
430
431 let result = pool.next_completed().await;
432 assert!(result.is_err());
433 assert!(pool.is_empty());
434
435 let _ = sender.send(());
436 });
437 }
438
439 #[test]
440 fn test_abortable_pool_partial_abort() {
441 block_on(async move {
442 let mut pool = AbortablePool::<i32>::default();
443
444 let _hook1 = pool.push(future::ready(1));
445 let (sender, receiver) = oneshot::channel();
446 let hook2 = pool.push(async move {
447 receiver.await.unwrap();
448 2
449 });
450 let _hook3 = pool.push(future::ready(3));
451
452 assert_eq!(pool.len(), 3);
453
454 drop(hook2);
455
456 let mut results = Vec::new();
457 for _ in 0..3 {
458 let result = pool.next_completed().await;
459 results.push(result);
460 }
461
462 let successful: Vec<_> = results.iter().filter_map(|r| r.as_ref().ok()).collect();
463 let aborted: Vec<_> = results.iter().filter(|r| r.is_err()).collect();
464
465 assert_eq!(successful.len(), 2);
466 assert_eq!(aborted.len(), 1);
467 assert!(successful.contains(&&1));
468 assert!(successful.contains(&&3));
469 assert!(pool.is_empty());
470
471 let _ = sender.send(());
472 });
473 }
474
475 #[test]
476 fn test_option_future() {
477 block_on(async {
478 let option_future = OptionFuture::<oneshot::Receiver<()>>::from(None);
479 pin_mut!(option_future);
480
481 let waker = futures::task::noop_waker();
482 let mut cx = std::task::Context::from_waker(&waker);
483 assert!(option_future.poll(&mut cx).is_pending());
484
485 let (tx, rx) = oneshot::channel();
486 let option_future: OptionFuture<_> = Some(rx).into();
487 pin_mut!(option_future);
488
489 tx.send(1usize).unwrap();
490 assert_eq!(option_future.poll(&mut cx), Poll::Ready(Ok(1)));
491 });
492 }
493}