use alloc::{boxed::Box, vec::Vec};
use core::{
fmt::{Debug, Formatter, Result},
future::Future,
pin::Pin,
task::{Context, Poll},
};
use crate::SharedCell;
#[doc = include_str!("../examples/actor.rs")]
#[macro_export]
macro_rules! spawn {
($tasks: expr, $callback: ident ( $($args: expr),+ $(,)? ) $(,)?) => {{
let tasks: &mut $crate::TaskGroup<'_, _, _> = &mut $tasks;
let cb = $callback;
unsafe {
tasks.spawn(|data| async move {
let mut data = core::pin::pin!(data);
cb(&mut data, $($args),+).await
});
}
}};
($tasks: expr, $callback: ident ( ) $(,)?) => {{
let tasks: &mut $crate::TaskGroup<'_, _, _> = &mut $tasks;
let cb = $callback;
unsafe {
tasks.spawn(|data| async move {
let mut data = core::pin::pin!(data);
cb(&mut data).await
});
}
}};
}
#[doc = include_str!("../examples/task_group.rs")]
pub struct TaskGroup<'a, T, R>
where
T: ?Sized,
{
tasks: Vec<Pin<Box<dyn Future<Output = R> + 'a>>>,
shared_cell: SharedCell<'a, T>,
}
impl<T, R> Debug for TaskGroup<'_, T, R>
where
T: Debug + ?Sized,
{
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
f.debug_struct("TaskGroup")
.field("shared_cell", &self.shared_cell)
.field("tasks.len", &self.tasks.len())
.finish_non_exhaustive()
}
}
impl<'a, T, R> TaskGroup<'a, T, R>
where
T: ?Sized,
{
pub fn new(value: &'a mut T) -> Self {
let shared_cell = SharedCell::new(value);
let tasks = Vec::new();
Self { shared_cell, tasks }
}
pub async fn advance(&mut self) -> R {
Tasks(self, 0).await
}
pub fn is_empty(&self) -> bool {
self.tasks.is_empty()
}
pub async fn cancel(mut self) -> &'a mut T {
self.tasks.clear();
self.shared_cell.into_inner()
}
pub async fn finish(mut self) -> &'a mut T {
while !self.is_empty() {
drop(self.advance().await);
}
self.shared_cell.into_inner()
}
pub unsafe fn spawn<A>(&mut self, f: impl FnOnce(SharedCell<'a, T>) -> A)
where
A: Future<Output = R> + 'a,
{
self.tasks
.push(Box::pin(f(unsafe { self.shared_cell.duplicate() })));
}
}
struct Tasks<'a, 'b, T: ?Sized, R>(&'b mut TaskGroup<'a, T, R>, usize);
impl<T: ?Sized, R> Future for Tasks<'_, '_, T, R> {
type Output = R;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<R> {
let this = self.get_mut();
let list = &mut this.0.tasks;
let len = list.len();
let start = this.1;
for task in (start..len).chain(0..start) {
if let Poll::Ready(output) = Pin::new(&mut list[task]).poll(cx) {
list.swap_remove(task);
return Poll::Ready(output);
}
}
this.1 = if len == 0 { 0 } else { (this.1 + 1) % len };
Poll::Pending
}
}