#![allow(clippy::needless_lifetimes)]
use std::{
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use futures::{
future::{BoxFuture, Future, FutureExt},
stream::{FuturesUnordered, Stream},
};
use pin_project::{pin_project, pinned_drop};
use tokio::task::JoinHandle;
#[pin_project(PinnedDrop)]
pub struct Scope<'a, T> {
done: bool,
len: usize,
remaining: usize,
#[pin]
futs: FuturesUnordered<JoinHandle<T>>,
_marker: PhantomData<&'a mut &'a ()>,
}
impl<'a, T: Send + 'static> Scope<'a, T> {
pub unsafe fn create() -> Self {
Scope {
done: false,
len: 0,
remaining: 0,
futs: FuturesUnordered::new(),
_marker: PhantomData,
}
}
pub fn spawn<F: Future<Output = T> + Send + 'a>(&mut self, f: F) {
let handle =
tokio::spawn(unsafe { std::mem::transmute::<_, BoxFuture<'static, T>>(f.boxed()) });
self.futs.push(handle);
self.len += 1;
self.remaining += 1;
}
}
impl<'a, T> Scope<'a, T> {
#[inline]
#[allow(dead_code)]
pub fn len(&self) -> usize {
self.len
}
#[inline]
#[allow(dead_code)]
pub fn remaining(&self) -> usize {
self.remaining
}
pub async fn collect(&mut self) -> Vec<T> {
let mut proc_outputs = Vec::with_capacity(self.remaining);
use futures::stream::StreamExt;
while let Some(item) = self.next().await {
proc_outputs.push(item);
}
proc_outputs
}
}
impl<'a, T> Stream for Scope<'a, T> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
let poll = this.futs.poll_next(cx);
if let Poll::Ready(None) = poll {
*this.done = true;
} else if poll.is_ready() {
*this.remaining -= 1;
}
poll.map(|t| t.map(|t| t.expect("task not driven to completion")))
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining, Some(self.remaining))
}
}
#[pinned_drop]
impl<T> PinnedDrop for Scope<'_, T> {
fn drop(mut self: Pin<&mut Self>) {
if !self.done {
futures::executor::block_on(async {
self.collect().await;
});
}
}
}
pub unsafe fn scope<'a, T: Send + 'static, R, F: FnOnce(&mut Scope<'a, T>) -> R>(
f: F,
) -> (Scope<'a, T>, R) {
let mut scope = Scope::create();
let op = f(&mut scope);
(scope, op)
}
#[allow(dead_code)]
pub fn scope_and_block<'a, T: Send + 'static, R, F: FnOnce(&mut Scope<'a, T>) -> R>(
f: F,
) -> (R, Vec<T>) {
let (mut stream, block_output) = unsafe { scope(f) };
let proc_outputs = futures::executor::block_on(stream.collect());
(block_output, proc_outputs)
}
#[allow(dead_code)]
pub async unsafe fn scope_and_collect<
'a,
T: Send + 'static,
R,
F: FnOnce(&mut Scope<'a, T>) -> R,
>(
f: F,
) -> (R, Vec<T>) {
let (mut stream, block_output) = scope(f);
let proc_outputs = stream.collect().await;
(block_output, proc_outputs)
}