use crate::{LocalWaker, SchedInfo, TaskImpl, TaskStat, Waker, HashMap, BTreeMap};
use core::cell::{Cell, RefCell};
use alloc::vec;
use alloc::vec::Vec;
use core::future::Future;
use core::pin::Pin;
use alloc::rc::{Rc, Weak};
use core::task::{Context, Poll};
pub(crate) trait Join<T> {
fn set_finished(&self, val: T, task_id: u64);
fn set_running(&self, task_id: u64);
fn set_aborted(&self, task_id: u64);
}
pub struct JoinHandle<T> {
inner: Rc<JoinHandleInner<T>>,
}
impl<T: 'static> JoinHandle<T> {
pub fn is_finished(&self) -> bool {
matches!(self.inner.stat(), TaskStat::End | TaskStat::Aborted)
}
pub fn is_running(&self) -> bool {
matches!(self.inner.stat(), TaskStat::Running)
}
pub fn is_aborted(&self) -> bool {
matches!(self.inner.stat(), TaskStat::Aborted)
}
pub fn join(self) -> Option<T> {
self.inner.output.take()
}
pub fn abort(self) {
if matches!(self.inner.stat(), TaskStat::Init) {
self.inner.info.borrow_mut().task_abort(self.inner.task_id);
}
}
pub fn force_abort(self) {
if matches!(self.inner.stat(), TaskStat::Init | TaskStat::Running) {
self.inner.info.borrow_mut().task_abort(self.inner.task_id);
}
}
pub fn task_id(&self) -> u64 {
self.inner.task_id()
}
fn get(&self) -> Option<T> {
self.inner.output.take()
}
pub(crate) fn new(info: Rc<RefCell<SchedInfo>>, task_id: u64) -> Self {
Self {
inner: Rc::new(JoinHandleInner::new(info, task_id)),
}
}
pub(crate) fn weak(&self) -> Weak<dyn Join<T>> {
Rc::<JoinHandleInner<T>>::downgrade(&self.inner)
}
}
impl<T: 'static> Future for JoinHandle<T> {
type Output = Option<T>;
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
match self.inner.stat() {
TaskStat::End | TaskStat::Aborted => {
return Poll::Ready(self.get());
}
_ => {}
}
self.inner
.waker
.replace(Some(LocalWaker::waker(ctx).clone()));
Poll::Pending
}
}
pub(crate) struct JoinHandleInner<T> {
output: RefCell<Option<T>>,
waker: RefCell<Option<Waker>>,
info: Rc<RefCell<SchedInfo>>,
task_id: u64,
stat: Cell<TaskStat>,
}
impl<T: 'static> JoinHandleInner<T> {
fn new(info: Rc<RefCell<SchedInfo>>, task_id: u64) -> Self {
Self {
output: RefCell::new(None),
waker: RefCell::new(None),
info,
task_id,
stat: Cell::new(TaskStat::Init),
}
}
fn set_finished(&self, val: T) {
self.stat.set(TaskStat::End);
self.output.replace(Some(val));
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
fn set_running(&self, _task_id: u64) {
self.stat.set(TaskStat::Running);
}
fn set_aborted(&self, _task_id: u64) {
self.stat.set(TaskStat::Aborted);
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
fn stat(&self) -> TaskStat {
self.stat.get()
}
fn task_id(&self) -> u64 {
self.task_id
}
}
impl<T: 'static> Join<T> for JoinHandleInner<T> {
fn set_finished(&self, val: T, _task_id: u64) {
JoinHandleInner::<T>::set_finished(self, val);
}
fn set_running(&self, task_id: u64) {
JoinHandleInner::<T>::set_running(self, task_id);
}
fn set_aborted(&self, task_id: u64) {
JoinHandleInner::<T>::set_aborted(self, task_id);
}
}
pub struct JoinSet<T> {
inner: Rc<JoinSetInner<T>>,
}
impl<T> Drop for JoinSet<T> {
fn drop(&mut self) {
let tasks = &mut self.inner.tasks.borrow_mut();
let mut info = self.inner.info.borrow_mut();
for (task_id, _) in tasks.iter() {
info.task_abort(*task_id);
}
}
}
impl<T> JoinSet<T> {
pub async fn new() -> Self {
struct Info;
impl Future for Info {
type Output = Rc<RefCell<SchedInfo>>;
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = LocalWaker::waker(ctx);
Poll::Ready(waker.info().clone())
}
}
let info = Info.await;
Self {
inner: Rc::new(JoinSetInner::new(info)),
}
}
}
impl<T: 'static> JoinSet<T> {
pub fn spawn<F>(&mut self, future: F) -> u64
where
F: Future<Output = T> + 'static,
{
let handle = self.weak();
let task = TaskImpl::with_join(future, &self.inner.info, handle);
let task_id = task.get_id();
self.inner.info.borrow_mut().task_push(task);
self.inner.tasks.borrow_mut().insert(task_id, false);
task_id
}
fn weak(&self) -> Weak<dyn Join<T>> {
Rc::<JoinSetInner<T>>::downgrade(&self.inner)
}
}
impl<T> JoinSet<T> {
pub fn abort(&mut self) {
let tasks = &mut self.inner.tasks.borrow_mut();
let mut aborted = Vec::with_capacity(tasks.len());
for (task_id, is_running) in tasks.iter() {
if !is_running {
self.inner.info.borrow_mut().task_abort(*task_id);
aborted.push(*task_id);
}
}
for task_id in aborted {
tasks.remove(&task_id);
}
}
pub async fn abort_wait(&mut self) {
self.abort();
if !self.inner.tasks.borrow().is_empty() {
self.inner.wait_all.set(true);
Wait { set: self }.await;
}
self.inner.outputs.borrow_mut().clear();
}
pub async fn wait_all(&mut self) -> impl IntoIterator<Item = (u64, Option<T>)> {
self.inner.wait_all.set(true);
Wait { set: self }.await;
self.inner.outputs.replace(vec![])
}
pub async fn wait_any(&mut self) -> Option<(u64, Option<T>)> {
self.inner.wait_all.set(false);
loop {
if let Some(val) = self.inner.outputs.borrow_mut().pop() {
return Some(val);
}
if self.inner.tasks.borrow().is_empty() {
return None;
}
Wait { set: self }.await;
}
}
}
struct Wait<'a, T> {
set: &'a mut JoinSet<T>,
}
impl<T> Future for Wait<'_, T> {
type Output = ();
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
match (
self.set.inner.wait_all.get(),
self.set.inner.tasks.borrow().is_empty(),
self.set.inner.outputs.borrow().is_empty(),
) {
(true, true, _) | (false, _, false) => return Poll::Ready(()),
_ => {}
}
self.set
.inner
.waker
.replace(Some(LocalWaker::waker(ctx).clone()));
Poll::Pending
}
}
pub(crate) struct JoinSetInner<T> {
tasks: RefCell<HashMap<u64, bool>>,
outputs: RefCell<Vec<(u64, Option<T>)>>,
waker: RefCell<Option<Waker>>,
info: Rc<RefCell<SchedInfo>>,
wait_all: Cell<bool>,
}
impl<T> Join<T> for JoinSetInner<T> {
fn set_finished(&self, val: T, task_id: u64) {
if self.tasks.borrow_mut().remove(&task_id).is_none() {
return;
};
self.outputs.borrow_mut().push((task_id, Some(val)));
self.notify();
}
fn set_running(&self, task_id: u64) {
if let Some(is_running) = self.tasks.borrow_mut().get_mut(&task_id) {
*is_running = true;
}
}
fn set_aborted(&self, task_id: u64) {
if self.tasks.borrow_mut().remove(&task_id).is_none() {
return;
};
self.outputs.borrow_mut().push((task_id, None));
self.notify();
}
}
impl<T> JoinSetInner<T> {
fn new(info: Rc<RefCell<SchedInfo>>) -> Self {
Self {
tasks: RefCell::new(BTreeMap::new()),
outputs: RefCell::new(vec![]),
waker: RefCell::new(None),
info,
wait_all: Cell::new(true),
}
}
fn notify(&self) {
if !self.wait_all.get() || self.tasks.borrow().is_empty() {
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
}
}