mod task;
pub use task::JoinHandle;
use task::Task;
mod tasks_to_add;
use tasks_to_add::TasksToAdd;
use std::cell::RefCell;
use std::future::Future;
use std::mem;
use std::pin::Pin;
use std::task::{Poll, Waker};
pub struct LocalSpawnPool(RefCell<Pin<Box<LocalSpawnPoolInner>>>);
#[cfg(not(test))]
impl Default for LocalSpawnPool {
fn default() -> Self {
Self::new()
}
}
impl LocalSpawnPool {
pub fn new(#[cfg(test)] name: &'static str) -> Self {
Self(RefCell::new(Box::pin(LocalSpawnPoolInner::new(
#[cfg(test)]
name,
))))
}
pub async fn run_until<F>(&self, future: F) -> F::Output
where
F: Future + 'static,
{
let join_handle = self.spawn(future);
RunUntil::new(&self.0, join_handle).await
}
pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
where
F: Future + 'static,
{
self.0.borrow_mut().spawn(future)
}
}
impl Future for LocalSpawnPool {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
Future::poll(self.0.borrow_mut().as_mut(), cx)
}
}
struct LocalSpawnPoolInner {
#[cfg(test)]
name: &'static str,
tasks: Vec<Task>,
waker: Option<Waker>,
}
impl LocalSpawnPoolInner {
fn new(#[cfg(test)] name: &'static str) -> Self {
Self {
#[cfg(test)]
name,
tasks: Vec::new(),
waker: None,
}
}
fn spawn<F>(&mut self, future: F) -> JoinHandle<F::Output>
where
F: Future + 'static,
{
let (task, join_handle) = task::create_task(future);
self.tasks.push(task);
if let Some(waker) = &self.waker {
waker.wake_by_ref();
}
join_handle
}
}
impl Future for LocalSpawnPoolInner {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
self.waker = Some(cx.waker().clone());
let tasks_snapshot = mem::take::<Vec<_>>(&mut self.tasks);
if tasks_snapshot.is_empty() {
Poll::Ready(())
} else {
let tasks_to_add = TasksToAdd::new();
for mut task in tasks_snapshot {
tasks_to_add::set_thread_local(
&tasks_to_add,
#[cfg(test)]
self.name,
);
if Future::poll(task.as_mut(), cx).is_pending() {
self.tasks.push(task);
}
}
tasks_to_add::unset_thread_local();
tasks_to_add.access_mut(|tasks_to_add_vec| {
if !tasks_to_add_vec.is_empty() {
cx.waker().wake_by_ref();
}
self.tasks.append(tasks_to_add_vec);
});
if self.tasks.is_empty() {
Poll::Ready(())
} else {
Poll::Pending
}
}
}
}
#[track_caller]
pub fn spawn<F>(future: F) -> JoinHandle<F::Output>
where
F: Future + 'static,
{
let (task, join_handle) = task::create_task(future);
tasks_to_add::access_thread_local(|tasks_to_add| match tasks_to_add {
#[cfg(not(test))]
Some(tasks_to_add) => tasks_to_add.add(task),
#[cfg(test)]
Some((tasks_to_add, _)) => tasks_to_add.add(task),
None => {
panic!("`local_spawn_pool::spawn` was called outside the context of a `LocalSpawnPool`")
}
});
join_handle
}
struct RunUntil<'a, T> {
local_spawn_pool: Option<&'a RefCell<Pin<Box<LocalSpawnPoolInner>>>>,
join_handle: Pin<Box<JoinHandle<T>>>,
}
impl<'a, T> RunUntil<'a, T> {
fn new(
local_spawn_pool: &'a RefCell<Pin<Box<LocalSpawnPoolInner>>>,
join_handle: JoinHandle<T>,
) -> Self {
RunUntil {
local_spawn_pool: Some(local_spawn_pool),
join_handle: Box::pin(join_handle),
}
}
}
impl<'a, T> Future for RunUntil<'a, T> {
type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
if let Some(local_spawn_pool) = self.local_spawn_pool {
if let Poll::Ready(()) = Future::poll(local_spawn_pool.borrow_mut().as_mut(), cx) {
self.local_spawn_pool = None;
}
}
match Future::poll(self.join_handle.as_mut(), cx) {
Poll::Ready(output) => {
Poll::Ready(output.unwrap())
}
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(test)]
#[tokio::test]
async fn test() {
use std::rc::Rc;
use std::time::Duration;
use tokio::time;
let results: Rc<RefCell<Vec<(u8, &'static str)>>> = Rc::new(RefCell::new(Vec::new()));
#[track_caller]
fn push_result(results: &Rc<RefCell<Vec<(u8, &'static str)>>>, result: u8) {
results.borrow_mut().push((
result,
tasks_to_add::access_thread_local(|tasks_to_add_and_name| match tasks_to_add_and_name {
Some(&(_, name)) => name,
None => {
panic!("`spawn_pool_name()` was called outside the context of a `LocalSpawnPool`")
}
})
));
}
let local_spawn_pool_a = LocalSpawnPool::new("a");
let output = local_spawn_pool_a
.run_until({
let results = Rc::clone(&results);
async move {
spawn({
let results = Rc::clone(&results);
async move {
time::sleep(Duration::from_millis(500)).await;
push_result(&results, 3);
}
});
spawn({
let results = Rc::clone(&results);
async move {
let local_spawn_pool_b = LocalSpawnPool::new("b");
local_spawn_pool_b.spawn({
let results = Rc::clone(&results);
async move {
let join_handle = spawn({
let results = Rc::clone(&results);
async move {
time::sleep(Duration::from_millis(20)).await;
push_result(&results, 1);
"this is another output"
}
});
assert_eq!(join_handle.await, Some("this is another output"));
spawn({
let results = Rc::clone(&results);
async move {
time::sleep(Duration::from_millis(510)).await;
push_result(&results, 4);
}
});
let join_handle = spawn({
let results = Rc::clone(&results);
async move {
time::sleep(Duration::from_millis(515)).await;
push_result(&results, 100);
}
});
join_handle.abort();
assert_eq!(join_handle.await, None);
}
});
time::sleep(Duration::from_millis(50)).await;
push_result(&results, 0);
local_spawn_pool_b.await;
}
});
spawn({
let results = Rc::clone(&results);
async move {
time::sleep(Duration::from_millis(150)).await;
push_result(&results, 2);
}
});
"this is the output"
}
})
.await;
assert_eq!(output, "this is the output");
assert_eq!(&*results.borrow(), &[]);
local_spawn_pool_a.await;
assert_eq!(
&*results.borrow(),
&[(0, "a"), (1, "b"), (2, "a"), (3, "a"), (4, "b")]
);
}