1use core::ops::{Deref, DerefMut};
4use futures::{
5 channel::oneshot,
6 future::{self, AbortHandle, Abortable, Aborted},
7 stream::{FuturesUnordered, SelectNextSome},
8 StreamExt,
9};
10use pin_project::pin_project;
11use std::{future::Future, pin::Pin, task::Poll};
12
13type PooledFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
15
16pub struct Pool<T> {
23 pool: FuturesUnordered<PooledFuture<T>>,
24}
25
26impl<T: Send> Default for Pool<T> {
27 fn default() -> Self {
28 let pool = FuturesUnordered::new();
31 pool.push(Self::create_dummy_future());
32 Self { pool }
33 }
34}
35
36impl<T: Send> Pool<T> {
37 pub fn len(&self) -> usize {
39 self.pool.len().checked_sub(1).unwrap()
41 }
42
43 pub fn is_empty(&self) -> bool {
45 self.len() == 0
46 }
47
48 pub fn push(&mut self, future: impl Future<Output = T> + Send + 'static) {
52 self.pool.push(Box::pin(future));
53 }
54
55 pub fn next_completed(&mut self) -> SelectNextSome<'_, FuturesUnordered<PooledFuture<T>>> {
59 self.pool.select_next_some()
60 }
61
62 pub fn cancel_all(&mut self) {
66 self.pool.clear();
67 self.pool.push(Self::create_dummy_future());
68 }
69
70 fn create_dummy_future() -> PooledFuture<T> {
72 Box::pin(async { future::pending::<T>().await })
73 }
74}
75
76pub struct Aborter {
80 inner: AbortHandle,
81}
82
83impl Drop for Aborter {
84 fn drop(&mut self) {
85 self.inner.abort();
86 }
87}
88
89type AbortablePooledFuture<T> = Pin<Box<dyn Future<Output = Result<T, Aborted>> + Send>>;
91
92pub struct AbortablePool<T> {
100 pool: FuturesUnordered<AbortablePooledFuture<T>>,
101}
102
103impl<T: Send> Default for AbortablePool<T> {
104 fn default() -> Self {
105 let pool = FuturesUnordered::new();
108 pool.push(Self::create_dummy_future());
109 Self { pool }
110 }
111}
112
113impl<T: Send> AbortablePool<T> {
114 pub fn len(&self) -> usize {
116 self.pool.len().checked_sub(1).unwrap()
118 }
119
120 pub fn is_empty(&self) -> bool {
122 self.len() == 0
123 }
124
125 pub fn push(&mut self, future: impl Future<Output = T> + Send + 'static) -> Aborter {
130 let (handle, registration) = AbortHandle::new_pair();
131 let abortable_future = Abortable::new(future, registration);
132 self.pool.push(Box::pin(abortable_future));
133 Aborter { inner: handle }
134 }
135
136 pub fn next_completed(
141 &mut self,
142 ) -> SelectNextSome<'_, FuturesUnordered<AbortablePooledFuture<T>>> {
143 self.pool.select_next_some()
144 }
145
146 fn create_dummy_future() -> AbortablePooledFuture<T> {
148 Box::pin(async { Ok(future::pending::<T>().await) })
149 }
150}
151
152pub struct Closed<'a, T> {
158 sender: &'a mut oneshot::Sender<T>,
159}
160
161impl<'a, T> Closed<'a, T> {
162 pub const fn new(sender: &'a mut oneshot::Sender<T>) -> Self {
164 Self { sender }
165 }
166}
167
168impl<T> Future for Closed<'_, T> {
169 type Output = ();
170
171 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
172 match self.sender.poll_canceled(cx) {
173 Poll::Ready(()) => Poll::Ready(()),
174 Poll::Pending => Poll::Pending,
175 }
176 }
177}
178
179pub trait ClosedExt<T> {
181 fn closed(&mut self) -> Closed<'_, T>;
198}
199
200impl<T> ClosedExt<T> for oneshot::Sender<T> {
201 fn closed(&mut self) -> Closed<'_, T> {
202 Closed::new(self)
203 }
204}
205
206#[pin_project]
212pub struct OptionFuture<F: Future>(#[pin] Option<F>);
213
214impl<F: Future> Default for OptionFuture<F> {
215 fn default() -> Self {
216 Self(None)
217 }
218}
219
220impl<F: Future> From<Option<F>> for OptionFuture<F> {
221 fn from(opt: Option<F>) -> Self {
222 Self(opt)
223 }
224}
225
226impl<F: Future> Deref for OptionFuture<F> {
227 type Target = Option<F>;
228
229 fn deref(&self) -> &Self::Target {
230 &self.0
231 }
232}
233
234impl<F: Future> DerefMut for OptionFuture<F> {
235 fn deref_mut(&mut self) -> &mut Self::Target {
236 &mut self.0
237 }
238}
239
240impl<F: Future> Future for OptionFuture<F> {
241 type Output = F::Output;
242
243 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
244 let this = self.project();
245 this.0
246 .as_pin_mut()
247 .map_or_else(|| Poll::Pending, |fut| fut.poll(cx))
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use futures::{
255 channel::oneshot,
256 executor::block_on,
257 future::{self, select, Either},
258 pin_mut, FutureExt,
259 };
260 use std::{
261 sync::{
262 atomic::{AtomicBool, Ordering},
263 Arc,
264 },
265 thread,
266 time::Duration,
267 };
268
269 fn delay(duration: Duration) -> impl Future<Output = ()> {
271 let (sender, receiver) = oneshot::channel();
272 thread::spawn(move || {
273 thread::sleep(duration);
274 sender.send(()).unwrap();
275 });
276 receiver.map(|_| ())
277 }
278
279 #[test]
280 fn test_initialization() {
281 let pool = Pool::<i32>::default();
282 assert_eq!(pool.len(), 0);
283 assert!(pool.is_empty());
284 }
285
286 #[test]
287 fn test_dummy_future_doesnt_resolve() {
288 block_on(async {
289 let mut pool = Pool::<i32>::default();
290 let stream_future = pool.next_completed();
291 let timeout_future = async {
292 delay(Duration::from_millis(100)).await;
293 };
294 pin_mut!(stream_future);
295 pin_mut!(timeout_future);
296 let result = select(stream_future, timeout_future).await;
297 match result {
298 Either::Left((_, _)) => panic!("Stream resolved unexpectedly"),
299 Either::Right((_, _)) => {
300 }
302 }
303 });
304 }
305
306 #[test]
307 fn test_adding_futures() {
308 let mut pool = Pool::<i32>::default();
309 assert_eq!(pool.len(), 0);
310 assert!(pool.is_empty());
311
312 pool.push(async { 42 });
313 assert_eq!(pool.len(), 1);
314 assert!(!pool.is_empty(),);
315
316 pool.push(async { 43 });
317 assert_eq!(pool.len(), 2,);
318 }
319
320 #[test]
321 fn test_streaming_resolved_futures() {
322 block_on(async move {
323 let mut pool = Pool::<i32>::default();
324 pool.push(future::ready(42));
325 let result = pool.next_completed().await;
326 assert_eq!(result, 42,);
327 assert!(pool.is_empty(),);
328 });
329 }
330
331 #[test]
332 fn test_multiple_futures() {
333 block_on(async move {
334 let mut pool = Pool::<i32>::default();
335
336 let (finisher_1, finished_1) = oneshot::channel();
338 let (finisher_3, finished_3) = oneshot::channel();
339 pool.push(async move {
340 finished_1.await.unwrap();
341 finisher_3.send(()).unwrap();
342 1
343 });
344 pool.push(async move {
345 finisher_1.send(()).unwrap();
346 2
347 });
348 pool.push(async move {
349 finished_3.await.unwrap();
350 3
351 });
352
353 let first = pool.next_completed().await;
354 assert_eq!(first, 2, "First resolved should be 2");
355 let second = pool.next_completed().await;
356 assert_eq!(second, 1, "Second resolved should be 1");
357 let third = pool.next_completed().await;
358 assert_eq!(third, 3, "Third resolved should be 3");
359 assert!(pool.is_empty(),);
360 });
361 }
362
363 #[test]
364 fn test_cancel_all() {
365 block_on(async move {
366 let flag = Arc::new(AtomicBool::new(false));
367 let flag_clone = flag.clone();
368 let mut pool = Pool::<i32>::default();
369
370 let (finisher, finished) = oneshot::channel();
372 pool.push(async move {
373 finished.await.unwrap();
374 flag_clone.store(true, Ordering::SeqCst);
375 42
376 });
377 assert_eq!(pool.len(), 1);
378
379 pool.cancel_all();
381 assert!(pool.is_empty());
382 assert!(!flag.load(Ordering::SeqCst));
383
384 let _ = finisher.send(());
386
387 let stream_future = pool.next_completed();
389 let timeout_future = async {
390 delay(Duration::from_millis(100)).await;
391 };
392 pin_mut!(stream_future);
393 pin_mut!(timeout_future);
394 let result = select(stream_future, timeout_future).await;
395 match result {
396 Either::Left((_, _)) => panic!("Stream resolved after cancellation"),
397 Either::Right((_, _)) => {
398 }
400 }
401 assert!(!flag.load(Ordering::SeqCst));
402
403 pool.push(future::ready(42));
405 assert_eq!(pool.len(), 1);
406 let result = pool.next_completed().await;
407 assert_eq!(result, 42);
408 assert!(pool.is_empty());
409 });
410 }
411
412 #[test]
413 fn test_many_futures() {
414 block_on(async move {
415 let mut pool = Pool::<i32>::default();
416 let num_futures = 1000;
417 for i in 0..num_futures {
418 pool.push(future::ready(i));
419 }
420 assert_eq!(pool.len(), num_futures as usize);
421
422 let mut sum = 0;
423 for _ in 0..num_futures {
424 let value = pool.next_completed().await;
425 sum += value;
426 }
427 let expected_sum = (0..num_futures).sum::<i32>();
428 assert_eq!(
429 sum, expected_sum,
430 "Sum of resolved values should match expected"
431 );
432 assert!(
433 pool.is_empty(),
434 "Pool should be empty after all futures resolve"
435 );
436 });
437 }
438
439 #[test]
440 fn test_abortable_pool_initialization() {
441 let pool = AbortablePool::<i32>::default();
442 assert_eq!(pool.len(), 0);
443 assert!(pool.is_empty());
444 }
445
446 #[test]
447 fn test_abortable_pool_adding_futures() {
448 let mut pool = AbortablePool::<i32>::default();
449 assert_eq!(pool.len(), 0);
450 assert!(pool.is_empty());
451
452 let _hook1 = pool.push(async { 42 });
453 assert_eq!(pool.len(), 1);
454 assert!(!pool.is_empty());
455
456 let _hook2 = pool.push(async { 43 });
457 assert_eq!(pool.len(), 2);
458 }
459
460 #[test]
461 fn test_abortable_pool_successful_completion() {
462 block_on(async move {
463 let mut pool = AbortablePool::<i32>::default();
464 let _hook = pool.push(future::ready(42));
465 let result = pool.next_completed().await;
466 assert_eq!(result, Ok(42));
467 assert!(pool.is_empty());
468 });
469 }
470
471 #[test]
472 fn test_abortable_pool_drop_abort() {
473 block_on(async move {
474 let mut pool = AbortablePool::<i32>::default();
475
476 let (sender, receiver) = oneshot::channel();
477 let hook = pool.push(async move {
478 receiver.await.unwrap();
479 42
480 });
481
482 drop(hook);
483
484 let result = pool.next_completed().await;
485 assert!(result.is_err());
486 assert!(pool.is_empty());
487
488 let _ = sender.send(());
489 });
490 }
491
492 #[test]
493 fn test_abortable_pool_partial_abort() {
494 block_on(async move {
495 let mut pool = AbortablePool::<i32>::default();
496
497 let _hook1 = pool.push(future::ready(1));
498 let (sender, receiver) = oneshot::channel();
499 let hook2 = pool.push(async move {
500 receiver.await.unwrap();
501 2
502 });
503 let _hook3 = pool.push(future::ready(3));
504
505 assert_eq!(pool.len(), 3);
506
507 drop(hook2);
508
509 let mut results = Vec::new();
510 for _ in 0..3 {
511 let result = pool.next_completed().await;
512 results.push(result);
513 }
514
515 let successful: Vec<_> = results.iter().filter_map(|r| r.as_ref().ok()).collect();
516 let aborted: Vec<_> = results.iter().filter(|r| r.is_err()).collect();
517
518 assert_eq!(successful.len(), 2);
519 assert_eq!(aborted.len(), 1);
520 assert!(successful.contains(&&1));
521 assert!(successful.contains(&&3));
522 assert!(pool.is_empty());
523
524 let _ = sender.send(());
525 });
526 }
527
528 #[test]
529 fn test_closed_on_receiver_drop() {
530 block_on(async {
531 let (mut tx, rx) = oneshot::channel::<i32>();
532
533 let closed = tx.closed();
534 drop(rx);
535
536 closed.await;
537 });
538 }
539
540 #[test]
541 fn test_closed_pending_when_receiver_alive() {
542 block_on(async {
543 let (mut tx, rx) = oneshot::channel::<i32>();
544
545 let closed = tx.closed();
546 let timeout = delay(Duration::from_millis(500));
547
548 pin_mut!(closed);
549 pin_mut!(timeout);
550
551 match select(closed, timeout).await {
552 Either::Left(_) => panic!("Closed resolved while receiver still alive"),
553 Either::Right(_) => {}
554 }
555
556 drop(rx);
557 });
558 }
559
560 #[test]
561 fn test_closed_multiple_polls() {
562 block_on(async {
563 let (mut tx, rx) = oneshot::channel::<i32>();
564
565 let closed = tx.closed();
567 pin_mut!(closed);
568
569 let waker = futures::task::noop_waker();
571 let mut cx = std::task::Context::from_waker(&waker);
572 assert!(closed.as_mut().poll(&mut cx).is_pending());
573
574 drop(rx);
576
577 assert!(closed.as_mut().poll(&mut cx).is_ready());
579 });
580 }
581
582 #[test]
583 fn test_option_future() {
584 block_on(async {
585 let option_future = OptionFuture::<oneshot::Receiver<()>>::from(None);
586 pin_mut!(option_future);
587
588 let waker = futures::task::noop_waker();
589 let mut cx = std::task::Context::from_waker(&waker);
590 assert!(option_future.poll(&mut cx).is_pending());
591
592 let (tx, rx) = oneshot::channel();
593 let option_future: OptionFuture<_> = Some(rx).into();
594 pin_mut!(option_future);
595
596 tx.send(1usize).unwrap();
597 assert_eq!(option_future.poll(&mut cx), Poll::Ready(Ok(1)));
598 });
599 }
600}