1use futures::{
4 channel::oneshot,
5 future::{self, AbortHandle, Abortable, Aborted},
6 stream::{FuturesUnordered, SelectNextSome},
7 StreamExt,
8};
9use std::{future::Future, pin::Pin, task::Poll};
10
11type PooledFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
13
14pub struct Pool<T> {
21 pool: FuturesUnordered<PooledFuture<T>>,
22}
23
24impl<T: Send> Default for Pool<T> {
25 fn default() -> Self {
26 let pool = FuturesUnordered::new();
29 pool.push(Self::create_dummy_future());
30 Self { pool }
31 }
32}
33
34impl<T: Send> Pool<T> {
35 pub fn len(&self) -> usize {
37 self.pool.len().checked_sub(1).unwrap()
39 }
40
41 pub fn is_empty(&self) -> bool {
43 self.len() == 0
44 }
45
46 pub fn push(&mut self, future: impl Future<Output = T> + Send + 'static) {
50 self.pool.push(Box::pin(future));
51 }
52
53 pub fn next_completed(&mut self) -> SelectNextSome<'_, FuturesUnordered<PooledFuture<T>>> {
57 self.pool.select_next_some()
58 }
59
60 pub fn cancel_all(&mut self) {
64 self.pool.clear();
65 self.pool.push(Self::create_dummy_future());
66 }
67
68 fn create_dummy_future() -> PooledFuture<T> {
70 Box::pin(async { future::pending::<T>().await })
71 }
72}
73
74pub struct Aborter {
78 inner: AbortHandle,
79}
80
81impl Drop for Aborter {
82 fn drop(&mut self) {
83 self.inner.abort();
84 }
85}
86
87type AbortablePooledFuture<T> = Pin<Box<dyn Future<Output = Result<T, Aborted>> + Send>>;
89
90pub struct AbortablePool<T> {
98 pool: FuturesUnordered<AbortablePooledFuture<T>>,
99}
100
101impl<T: Send> Default for AbortablePool<T> {
102 fn default() -> Self {
103 let pool = FuturesUnordered::new();
106 pool.push(Self::create_dummy_future());
107 Self { pool }
108 }
109}
110
111impl<T: Send> AbortablePool<T> {
112 pub fn len(&self) -> usize {
114 self.pool.len().checked_sub(1).unwrap()
116 }
117
118 pub fn is_empty(&self) -> bool {
120 self.len() == 0
121 }
122
123 pub fn push(&mut self, future: impl Future<Output = T> + Send + 'static) -> Aborter {
128 let (handle, registration) = AbortHandle::new_pair();
129 let abortable_future = Abortable::new(future, registration);
130 self.pool.push(Box::pin(abortable_future));
131 Aborter { inner: handle }
132 }
133
134 pub fn next_completed(
139 &mut self,
140 ) -> SelectNextSome<'_, FuturesUnordered<AbortablePooledFuture<T>>> {
141 self.pool.select_next_some()
142 }
143
144 fn create_dummy_future() -> AbortablePooledFuture<T> {
146 Box::pin(async { Ok(future::pending::<T>().await) })
147 }
148}
149
150pub struct Closed<'a, T> {
156 sender: &'a mut oneshot::Sender<T>,
157}
158
159impl<'a, T> Closed<'a, T> {
160 pub fn new(sender: &'a mut oneshot::Sender<T>) -> Self {
162 Self { sender }
163 }
164}
165
166impl<T> Future for Closed<'_, T> {
167 type Output = ();
168
169 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
170 match self.sender.poll_canceled(cx) {
171 Poll::Ready(()) => Poll::Ready(()),
172 Poll::Pending => Poll::Pending,
173 }
174 }
175}
176
177pub trait ClosedExt<T> {
179 fn closed(&mut self) -> Closed<'_, T>;
196}
197
198impl<T> ClosedExt<T> for oneshot::Sender<T> {
199 fn closed(&mut self) -> Closed<'_, T> {
200 Closed::new(self)
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use futures::{
208 channel::oneshot,
209 executor::block_on,
210 future::{self, select, Either},
211 pin_mut, FutureExt,
212 };
213 use std::{
214 sync::{
215 atomic::{AtomicBool, Ordering},
216 Arc,
217 },
218 thread,
219 time::Duration,
220 };
221
222 fn delay(duration: Duration) -> impl Future<Output = ()> {
224 let (sender, receiver) = oneshot::channel();
225 thread::spawn(move || {
226 thread::sleep(duration);
227 sender.send(()).unwrap();
228 });
229 receiver.map(|_| ())
230 }
231
232 #[test]
233 fn test_initialization() {
234 let pool = Pool::<i32>::default();
235 assert_eq!(pool.len(), 0);
236 assert!(pool.is_empty());
237 }
238
239 #[test]
240 fn test_dummy_future_doesnt_resolve() {
241 block_on(async {
242 let mut pool = Pool::<i32>::default();
243 let stream_future = pool.next_completed();
244 let timeout_future = async {
245 delay(Duration::from_millis(100)).await;
246 };
247 pin_mut!(stream_future);
248 pin_mut!(timeout_future);
249 let result = select(stream_future, timeout_future).await;
250 match result {
251 Either::Left((_, _)) => panic!("Stream resolved unexpectedly"),
252 Either::Right((_, _)) => {
253 }
255 }
256 });
257 }
258
259 #[test]
260 fn test_adding_futures() {
261 let mut pool = Pool::<i32>::default();
262 assert_eq!(pool.len(), 0);
263 assert!(pool.is_empty());
264
265 pool.push(async { 42 });
266 assert_eq!(pool.len(), 1);
267 assert!(!pool.is_empty(),);
268
269 pool.push(async { 43 });
270 assert_eq!(pool.len(), 2,);
271 }
272
273 #[test]
274 fn test_streaming_resolved_futures() {
275 block_on(async move {
276 let mut pool = Pool::<i32>::default();
277 pool.push(future::ready(42));
278 let result = pool.next_completed().await;
279 assert_eq!(result, 42,);
280 assert!(pool.is_empty(),);
281 });
282 }
283
284 #[test]
285 fn test_multiple_futures() {
286 block_on(async move {
287 let mut pool = Pool::<i32>::default();
288
289 let (finisher_1, finished_1) = oneshot::channel();
291 let (finisher_3, finished_3) = oneshot::channel();
292 pool.push(async move {
293 finished_1.await.unwrap();
294 finisher_3.send(()).unwrap();
295 1
296 });
297 pool.push(async move {
298 finisher_1.send(()).unwrap();
299 2
300 });
301 pool.push(async move {
302 finished_3.await.unwrap();
303 3
304 });
305
306 let first = pool.next_completed().await;
307 assert_eq!(first, 2, "First resolved should be 2");
308 let second = pool.next_completed().await;
309 assert_eq!(second, 1, "Second resolved should be 1");
310 let third = pool.next_completed().await;
311 assert_eq!(third, 3, "Third resolved should be 3");
312 assert!(pool.is_empty(),);
313 });
314 }
315
316 #[test]
317 fn test_cancel_all() {
318 block_on(async move {
319 let flag = Arc::new(AtomicBool::new(false));
320 let flag_clone = flag.clone();
321 let mut pool = Pool::<i32>::default();
322
323 let (finisher, finished) = oneshot::channel();
325 pool.push(async move {
326 finished.await.unwrap();
327 flag_clone.store(true, Ordering::SeqCst);
328 42
329 });
330 assert_eq!(pool.len(), 1);
331
332 pool.cancel_all();
334 assert!(pool.is_empty());
335 assert!(!flag.load(Ordering::SeqCst));
336
337 let _ = finisher.send(());
339
340 let stream_future = pool.next_completed();
342 let timeout_future = async {
343 delay(Duration::from_millis(100)).await;
344 };
345 pin_mut!(stream_future);
346 pin_mut!(timeout_future);
347 let result = select(stream_future, timeout_future).await;
348 match result {
349 Either::Left((_, _)) => panic!("Stream resolved after cancellation"),
350 Either::Right((_, _)) => {
351 }
353 }
354 assert!(!flag.load(Ordering::SeqCst));
355
356 pool.push(future::ready(42));
358 assert_eq!(pool.len(), 1);
359 let result = pool.next_completed().await;
360 assert_eq!(result, 42);
361 assert!(pool.is_empty());
362 });
363 }
364
365 #[test]
366 fn test_many_futures() {
367 block_on(async move {
368 let mut pool = Pool::<i32>::default();
369 let num_futures = 1000;
370 for i in 0..num_futures {
371 pool.push(future::ready(i));
372 }
373 assert_eq!(pool.len(), num_futures as usize);
374
375 let mut sum = 0;
376 for _ in 0..num_futures {
377 let value = pool.next_completed().await;
378 sum += value;
379 }
380 let expected_sum = (0..num_futures).sum::<i32>();
381 assert_eq!(
382 sum, expected_sum,
383 "Sum of resolved values should match expected"
384 );
385 assert!(
386 pool.is_empty(),
387 "Pool should be empty after all futures resolve"
388 );
389 });
390 }
391
392 #[test]
393 fn test_abortable_pool_initialization() {
394 let pool = AbortablePool::<i32>::default();
395 assert_eq!(pool.len(), 0);
396 assert!(pool.is_empty());
397 }
398
399 #[test]
400 fn test_abortable_pool_adding_futures() {
401 let mut pool = AbortablePool::<i32>::default();
402 assert_eq!(pool.len(), 0);
403 assert!(pool.is_empty());
404
405 let _hook1 = pool.push(async { 42 });
406 assert_eq!(pool.len(), 1);
407 assert!(!pool.is_empty());
408
409 let _hook2 = pool.push(async { 43 });
410 assert_eq!(pool.len(), 2);
411 }
412
413 #[test]
414 fn test_abortable_pool_successful_completion() {
415 block_on(async move {
416 let mut pool = AbortablePool::<i32>::default();
417 let _hook = pool.push(future::ready(42));
418 let result = pool.next_completed().await;
419 assert_eq!(result, Ok(42));
420 assert!(pool.is_empty());
421 });
422 }
423
424 #[test]
425 fn test_abortable_pool_drop_abort() {
426 block_on(async move {
427 let mut pool = AbortablePool::<i32>::default();
428
429 let (sender, receiver) = oneshot::channel();
430 let hook = pool.push(async move {
431 receiver.await.unwrap();
432 42
433 });
434
435 drop(hook);
436
437 let result = pool.next_completed().await;
438 assert!(result.is_err());
439 assert!(pool.is_empty());
440
441 let _ = sender.send(());
442 });
443 }
444
445 #[test]
446 fn test_abortable_pool_partial_abort() {
447 block_on(async move {
448 let mut pool = AbortablePool::<i32>::default();
449
450 let _hook1 = pool.push(future::ready(1));
451 let (sender, receiver) = oneshot::channel();
452 let hook2 = pool.push(async move {
453 receiver.await.unwrap();
454 2
455 });
456 let _hook3 = pool.push(future::ready(3));
457
458 assert_eq!(pool.len(), 3);
459
460 drop(hook2);
461
462 let mut results = Vec::new();
463 for _ in 0..3 {
464 let result = pool.next_completed().await;
465 results.push(result);
466 }
467
468 let successful: Vec<_> = results.iter().filter_map(|r| r.as_ref().ok()).collect();
469 let aborted: Vec<_> = results.iter().filter(|r| r.is_err()).collect();
470
471 assert_eq!(successful.len(), 2);
472 assert_eq!(aborted.len(), 1);
473 assert!(successful.contains(&&1));
474 assert!(successful.contains(&&3));
475 assert!(pool.is_empty());
476
477 let _ = sender.send(());
478 });
479 }
480
481 #[test]
482 fn test_closed_on_receiver_drop() {
483 block_on(async {
484 let (mut tx, rx) = oneshot::channel::<i32>();
485
486 let closed = tx.closed();
487 drop(rx);
488
489 closed.await;
490 });
491 }
492
493 #[test]
494 fn test_closed_pending_when_receiver_alive() {
495 block_on(async {
496 let (mut tx, rx) = oneshot::channel::<i32>();
497
498 let closed = tx.closed();
499 let timeout = delay(Duration::from_millis(500));
500
501 pin_mut!(closed);
502 pin_mut!(timeout);
503
504 match select(closed, timeout).await {
505 Either::Left(_) => panic!("Closed resolved while receiver still alive"),
506 Either::Right(_) => {}
507 }
508
509 drop(rx);
510 });
511 }
512
513 #[test]
514 fn test_closed_multiple_polls() {
515 block_on(async {
516 let (mut tx, rx) = oneshot::channel::<i32>();
517
518 let closed = tx.closed();
520 pin_mut!(closed);
521
522 let waker = futures::task::noop_waker();
524 let mut cx = std::task::Context::from_waker(&waker);
525 assert!(closed.as_mut().poll(&mut cx).is_pending());
526
527 drop(rx);
529
530 assert!(closed.as_mut().poll(&mut cx).is_ready());
532 });
533 }
534}