hiasync 0.2.2

Supports only single-threaded asynchronous runtime
Documentation
use crate::{LocalWaker, SchedInfo, TaskImpl, TaskStat, Waker, HashMap, BTreeMap};
use core::cell::{Cell, RefCell};
use alloc::vec;
use alloc::vec::Vec;
use core::future::Future;
use core::pin::Pin;
use alloc::rc::{Rc, Weak};
use core::task::{Context, Poll};

pub(crate) trait Join<T> {
    fn set_finished(&self, val: T, task_id: u64);
    fn set_running(&self, task_id: u64);
    fn set_aborted(&self, task_id: u64);
}

/// 获取异步任务的返回结果.
pub struct JoinHandle<T> {
    inner: Rc<JoinHandleInner<T>>,
}

impl<T: 'static> JoinHandle<T> {
    /// 异步任务是否结束
    pub fn is_finished(&self) -> bool {
        matches!(self.inner.stat(), TaskStat::End | TaskStat::Aborted)
    }

    /// 异步任务已完成首次调度.
    pub fn is_running(&self) -> bool {
        matches!(self.inner.stat(), TaskStat::Running)
    }

    /// 异步任务是否被强制终止.
    pub fn is_aborted(&self) -> bool {
        matches!(self.inner.stat(), TaskStat::Aborted)
    }

    /// 获取结果,不会阻塞,如果未结束或者`aborted`,返回None.
    /// 可先调用is_finished判断是否结束.
    pub fn join(self) -> Option<T> {
        self.inner.output.take()
    }

    /// 放弃调度. 只有在task首次调度之前生效. task的资源会在sched结束时释放.
    pub fn abort(self) {
        if matches!(self.inner.stat(), TaskStat::Init) {
            self.inner.info.borrow_mut().task_abort(self.inner.task_id);
        }
    }

    /// 强制结束. task的资源会在sched结束时释放.
    ///
    /// 可能导致系统资源泄露. 参见`Runtime::task_force_abort`.
    pub fn force_abort(self) {
        if matches!(self.inner.stat(), TaskStat::Init | TaskStat::Running) {
            self.inner.info.borrow_mut().task_abort(self.inner.task_id);
        }
    }

    /// 获取task标识.
    pub fn task_id(&self) -> u64 {
        self.inner.task_id()
    }

    fn get(&self) -> Option<T> {
        self.inner.output.take()
    }

    pub(crate) fn new(info: Rc<RefCell<SchedInfo>>, task_id: u64) -> Self {
        Self {
            inner: Rc::new(JoinHandleInner::new(info, task_id)),
        }
    }

    pub(crate) fn weak(&self) -> Weak<dyn Join<T>> {
        Rc::<JoinHandleInner<T>>::downgrade(&self.inner)
    }
}

impl<T: 'static> Future for JoinHandle<T> {
    type Output = Option<T>;
    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
        match self.inner.stat() {
            TaskStat::End | TaskStat::Aborted => {
                return Poll::Ready(self.get());
            }
            _ => {}
        }
        self.inner
            .waker
            .replace(Some(LocalWaker::waker(ctx).clone()));
        Poll::Pending
    }
}

pub(crate) struct JoinHandleInner<T> {
    //Waker don't support Send and Sync
    output: RefCell<Option<T>>,
    waker: RefCell<Option<Waker>>,
    info: Rc<RefCell<SchedInfo>>,
    task_id: u64,
    stat: Cell<TaskStat>,
}

impl<T: 'static> JoinHandleInner<T> {
    fn new(info: Rc<RefCell<SchedInfo>>, task_id: u64) -> Self {
        Self {
            output: RefCell::new(None),
            waker: RefCell::new(None),
            info,
            task_id,
            stat: Cell::new(TaskStat::Init),
        }
    }

    fn set_finished(&self, val: T) {
        self.stat.set(TaskStat::End);
        self.output.replace(Some(val));
        if let Some(waker) = self.waker.take() {
            waker.wake();
        }
    }

    fn set_running(&self, _task_id: u64) {
        self.stat.set(TaskStat::Running);
    }

    fn set_aborted(&self, _task_id: u64) {
        self.stat.set(TaskStat::Aborted);
        if let Some(waker) = self.waker.take() {
            waker.wake();
        }
    }

    fn stat(&self) -> TaskStat {
        self.stat.get()
    }

    fn task_id(&self) -> u64 {
        self.task_id
    }
}

impl<T: 'static> Join<T> for JoinHandleInner<T> {
    fn set_finished(&self, val: T, _task_id: u64) {
        JoinHandleInner::<T>::set_finished(self, val);
    }

    fn set_running(&self, task_id: u64) {
        JoinHandleInner::<T>::set_running(self, task_id);
    }

    fn set_aborted(&self, task_id: u64) {
        JoinHandleInner::<T>::set_aborted(self, task_id);
    }
}

/// 用于管理多个可并发的子异步任务.
pub struct JoinSet<T> {
    inner: Rc<JoinSetInner<T>>,
}

impl<T> Drop for JoinSet<T> {
    fn drop(&mut self) {
        let tasks = &mut self.inner.tasks.borrow_mut();
        let mut info = self.inner.info.borrow_mut();
        for (task_id, _) in tasks.iter() {
            info.task_abort(*task_id);
        }
    }
}

impl<T> JoinSet<T> {
    pub async fn new() -> Self {
        struct Info;
        impl Future for Info {
            type Output = Rc<RefCell<SchedInfo>>;
            fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
                let waker = LocalWaker::waker(ctx);
                Poll::Ready(waker.info().clone())
            }
        }
        let info = Info.await;
        Self {
            inner: Rc::new(JoinSetInner::new(info)),
        }
    }
}

impl<T: 'static> JoinSet<T> {
    /// 创建子异步任务,返回新异步任务的id. 后续根据此id获取结果.
    pub fn spawn<F>(&mut self, future: F) -> u64
    where
        F: Future<Output = T> + 'static,
    {
        let handle = self.weak();
        let task = TaskImpl::with_join(future, &self.inner.info, handle);
        let task_id = task.get_id();
        self.inner.info.borrow_mut().task_push(task);
        self.inner.tasks.borrow_mut().insert(task_id, false);
        task_id
    }

    fn weak(&self) -> Weak<dyn Join<T>> {
        Rc::<JoinSetInner<T>>::downgrade(&self.inner)
    }
}

impl<T> JoinSet<T> {
    /// 只对未被首次调度的子异步任务生效.
    pub fn abort(&mut self) {
        let tasks = &mut self.inner.tasks.borrow_mut();
        let mut aborted = Vec::with_capacity(tasks.len());
        for (task_id, is_running) in tasks.iter() {
            if !is_running {
                self.inner.info.borrow_mut().task_abort(*task_id);
                aborted.push(*task_id);
            }
        }
        for task_id in aborted {
            tasks.remove(&task_id);
        }
    }

    /// 如果子任务未被调度则放弃执行,并等待已调度的子任务全部结束后返回.
    pub async fn abort_wait(&mut self) {
        self.abort();
        if !self.inner.tasks.borrow().is_empty() {
            self.inner.wait_all.set(true);
            Wait { set: self }.await;
        }
        self.inner.outputs.borrow_mut().clear();
    }

    /// 等待所有子异步任务完成后返回结果.
    pub async fn wait_all(&mut self) -> impl IntoIterator<Item = (u64, Option<T>)> {
        self.inner.wait_all.set(true);
        Wait { set: self }.await;
        self.inner.outputs.replace(vec![])
    }

    /// 只要有任何一个子异步任务结束后就返回
    pub async fn wait_any(&mut self) -> Option<(u64, Option<T>)> {
        self.inner.wait_all.set(false);
        loop {
            if let Some(val) = self.inner.outputs.borrow_mut().pop() {
                return Some(val);
            }
            if self.inner.tasks.borrow().is_empty() {
                return None;
            }
            Wait { set: self }.await;
        }
    }
}

struct Wait<'a, T> {
    set: &'a mut JoinSet<T>,
}

impl<T> Future for Wait<'_, T> {
    type Output = ();
    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
        match (
            self.set.inner.wait_all.get(),
            self.set.inner.tasks.borrow().is_empty(),
            self.set.inner.outputs.borrow().is_empty(),
        ) {
            (true, true, _) | (false, _, false) => return Poll::Ready(()),
            _ => {}
        }
        self.set
            .inner
            .waker
            .replace(Some(LocalWaker::waker(ctx).clone()));
        Poll::Pending
    }
}

pub(crate) struct JoinSetInner<T> {
    tasks: RefCell<HashMap<u64, bool>>,
    outputs: RefCell<Vec<(u64, Option<T>)>>,
    waker: RefCell<Option<Waker>>,
    info: Rc<RefCell<SchedInfo>>,
    wait_all: Cell<bool>,
}

impl<T> Join<T> for JoinSetInner<T> {
    fn set_finished(&self, val: T, task_id: u64) {
        if self.tasks.borrow_mut().remove(&task_id).is_none() {
            return;
        };
        self.outputs.borrow_mut().push((task_id, Some(val)));
        self.notify();
    }

    fn set_running(&self, task_id: u64) {
        if let Some(is_running) = self.tasks.borrow_mut().get_mut(&task_id) {
            *is_running = true;
        }
    }

    fn set_aborted(&self, task_id: u64) {
        if self.tasks.borrow_mut().remove(&task_id).is_none() {
            return;
        };
        self.outputs.borrow_mut().push((task_id, None));
        self.notify();
    }
}

impl<T> JoinSetInner<T> {
    fn new(info: Rc<RefCell<SchedInfo>>) -> Self {
        Self {
            tasks: RefCell::new(BTreeMap::new()),
            outputs: RefCell::new(vec![]),
            waker: RefCell::new(None),
            info,
            wait_all: Cell::new(true),
        }
    }

    fn notify(&self) {
        if !self.wait_all.get() || self.tasks.borrow().is_empty() {
            if let Some(waker) = self.waker.take() {
                waker.wake();
            }
        }
    }
}