run-loop 0.1.1

Run loop for thread, provide message post, deadline timer and future executor

use std::future::Future;
use std::task::{Waker, RawWakerVTable, RawWaker, Context, Poll};
use std::num::Wrapping;
use std::sync::{Arc, Weak};
use std::marker::PhantomData;
use std::cell::UnsafeCell;
use std::pin::Pin;

use crate::linked_list::{Node as ListNode, NodePtr as ListNodePtr, List};

#[derive(Eq, PartialEq, Copy, Clone)]
pub(crate) enum Exec {
    Pending,
    Ready,
}

impl Exec {
    pub(crate) fn is_pending(self) -> bool {
        self == Exec::Pending
    }
}

enum TaskResult<T> {
    Pending,
    Ready(T),
    Polled,
}

impl<T> TaskResult<T> {
    fn poll(&mut self) -> Poll<T> {
        match std::mem::replace(self, TaskResult::Polled) {
            TaskResult::Pending => Poll::Pending,
            TaskResult::Ready(t) => Poll::Ready(t),
            TaskResult::Polled => panic!(),
        }
    }
}

pub(crate) trait Executor {
    unsafe fn wake(&self, ticket: usize) -> bool;
    unsafe fn exec(&self, node: &Arc<ListNode<dyn Executor>>) -> Exec;
    unsafe fn list_mut(&self) -> &mut *mut List<dyn Executor>;
    unsafe fn run_loop_drop(&self);
}

trait TaskBase<T> {
    unsafe fn poll(&self, cx: &mut Context<'_>) -> Poll<T>;
    unsafe fn to_executor_node_ptr(&self) -> ListNodePtr<dyn Executor>;
}

struct RawTask<T: Future> {
    future: T,
    result: TaskResult<T::Output>,
    waker: Option<Waker>,
    ticket: Wrapping<usize>,
    list: *mut List<dyn Executor>,
}

impl<T: Future> Executor for UnsafeCell<RawTask<T>> {
    unsafe fn wake(&self, ticket: usize) -> bool {
        let task = &mut *self.get();
        if task.ticket.0 == ticket {
            task.ticket += Wrapping(1);
            true
        } else {
            false
        }
    }

    unsafe fn exec(&self, node: &Arc<ListNode<dyn Executor>>) -> Exec {
        let task = &mut *self.get();
        let dummy = DummyWaker { node, ticket: task.ticket.0 };
        let waker = Waker::from_raw(dummy.to_raw_waker());
        let fut = Pin::new_unchecked(&mut task.future);
        let mut cx = Context::from_waker(&waker);
        match fut.poll(&mut cx) {
            Poll::Pending => Exec::Pending,
            Poll::Ready(t) => {
                task.ticket += Wrapping(1);
                task.result = TaskResult::Ready(t);
                if let Some(w) = task.waker.take() {
                    w.wake();
                }
                Exec::Ready
            },
        }
    }

    unsafe fn list_mut(&self) -> &mut *mut List<dyn Executor> {
        let task = self.get();
        &mut (*task).list
    }

    unsafe fn run_loop_drop(&self) {
        let task = &mut *self.get();
        task.ticket += Wrapping(1);
        task.list = 0 as *mut _;
    }
}

impl<T: Future> TaskBase<T::Output> for UnsafeCell<RawTask<T>> {
    unsafe fn poll(&self, cx: &mut Context<'_>) -> Poll<T::Output> {
        let task = &mut *self.get();
        let poll = task.result.poll();
        if poll.is_pending() {
            if let Some(ref mut waker) = task.waker {
                if !cx.waker().will_wake(waker) {
                    *waker = cx.waker().clone();
                }
            } else {
                task.waker = Some(cx.waker().clone());
            }
        }
        poll
    }

    unsafe fn to_executor_node_ptr(&self) -> ListNodePtr<dyn Executor> {
        let ptr = self as *const Self;
        let ptr = ptr as *const dyn Executor;
        let node_ptr = ListNode::get_ptr_from_raw(ptr);
        ListNodePtr::new_unchecked(node_ptr as *mut _)
    }
}

struct DummyWaker<'a> {
    node: &'a Arc<ListNode<dyn Executor>>,
    ticket: usize,
}

impl<'a> DummyWaker<'a> {
    const VTABLE: RawWakerVTable = RawWakerVTable::new(
        Self::clone,
        Self::nop,
        Self::nop,
        Self::nop,
    );

    fn to_raw_waker(&self) -> RawWaker {
        RawWaker::new( self as *const Self as *const _, &Self::VTABLE)
    }

    unsafe fn clone(this: *const ()) -> RawWaker {
        let this = &*(this as *const DummyWaker);
        TaskWaker::new_raw_waker(this.node, this.ticket)
    }

    unsafe fn nop(_this: *const ()) {

    }
}

struct TaskWaker {
    handle: super::Handle,
    node: Weak<ListNode<dyn Executor>>,
    ticket: usize,
}

struct TWData (Weak<ListNode<dyn Executor>>, usize);
unsafe impl Send for TWData {}

impl TaskWaker {
    const VTABLE: RawWakerVTable = RawWakerVTable::new(
        Self::clone,
        Self::wake,
        Self::wake_by_ref,
        Self::drop,
    );

    fn new_raw_waker(node: &Arc<ListNode<dyn Executor>>, ticket: usize) -> RawWaker {
        let waker = Box::new(TaskWaker {
            handle: super::Handle::new(),
            node: Arc::downgrade(node),
            ticket,
        });
        RawWaker::new(Box::into_raw(waker) as *const _, &Self::VTABLE)
    }

    unsafe fn clone(this: *const ()) -> RawWaker {
        let waker = &*(this as *const TaskWaker);
        let cloned = Box::new(TaskWaker {
            handle: waker.handle.clone(),
            node: waker.node.clone(),
            ticket: waker.ticket,
        });
        RawWaker::new(Box::into_raw(cloned) as *const _, &Self::VTABLE)
    }

    unsafe fn wake_by_ref(this: *const ()) {
        let waker = &*(this as *const TaskWaker);
        let data = TWData(waker.node.clone(), waker.ticket);
        let _ = waker.handle.post(move || {
            if let Some(node) = data.0.upgrade() {
                node.wake(data.1);
            }
        });
    }

    unsafe fn wake(this: *const ()) {
        let waker = Box::from_raw(this as *mut TaskWaker);

        let data = TWData(waker.node, waker.ticket);
        if waker.handle.is_local() {
            if let Some(node) = data.0.upgrade() {
                if node.wake(data.1) {
                    pending_to_waken(node);
                }
            }
        } else {
            let _ = waker.handle.post(move || {
                if let Some(node) = data.0.upgrade() {
                    if node.wake(data.1) {
                        pending_to_waken(node);
                    }
                }
            });
        }
    }

    unsafe fn drop(this: *const ()) {
        Box::from_raw(this as *mut TaskWaker);
    }
}

pub struct Task<'a, T> {
    node: Arc<ListNode<dyn TaskBase<T>>>,
    phantom: PhantomData<&'a ()>,
}

impl<'a, T> Task<'a, T> {
    pub fn new<F: 'a + Future<Output = T>>(future: F) -> Task<'a, T> {
        let task = Arc::new(ListNode::new(UnsafeCell::new(RawTask {
            future,
            result: TaskResult::Pending,
            waker: None,
            ticket: Wrapping(0),
            list: 0 as *mut _,
        })));

        unsafe {
            std::mem::forget(task.clone());

            let task_ptr = Arc::into_raw(task);
            let base_ptr = task_ptr as *const ListNode<dyn TaskBase<T>>;
            let exec_ptr = task_ptr as *const ListNode<dyn Executor>;
            super::RUN_LOOP.with(|rl| {
                let list = rl.waken_tasks.get();
                (*(*task_ptr).get()).list = list;
                (*list).push_back(ListNodePtr::new_unchecked(exec_ptr as *mut _))
            });

            Task {
                node: Arc::from_raw(base_ptr as *mut ListNode<dyn TaskBase<T>>),
                phantom: PhantomData,
            }
        }
    }
}

impl<T> Task<'static, T> {
    pub fn detach(self) {
        unsafe {
            std::ptr::read(&self.node);
        }
        std::mem::forget(self);
    }
}

impl<'a, T> Drop for Task<'a, T> {
    fn drop(&mut self) {
        unsafe {
            let node_ptr = self.node.to_executor_node_ptr();
            let list = node_ptr.as_ref().list_mut();
            if !list.is_null() {
                (**list).remove(node_ptr);
                *list = 0 as *mut _;
                Arc::from_raw(node_ptr.as_ptr());
            }
        }
    }
}

impl<'a, T> Future for Task<'a, T> {
    type Output = T;
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
        unsafe { self.node.poll(cx) }
    }
}

unsafe fn pending_to_waken(node: Arc<ListNode<dyn Executor>>) {
    super::RUN_LOOP.with(|rl| {
        let pending_list = rl.pending_tasks.get();
        let waken_list = rl.waken_tasks.get();
        let ptr = ListNodePtr::new_unchecked(
            node.as_ref() as *const _ as *mut ListNode<dyn Executor>);
        (*pending_list).remove(ptr);
        (*waken_list).push_back(ptr);
        *node.list_mut() = waken_list;
    })
}

pub struct Yield (bool);

impl Future for Yield {
    type Output = ();
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
        if self.0 {
            Poll::Ready(())
        } else {
            self.0 = true;
            let waker = cx.waker().clone();
            super::post(move || waker.wake());
            Poll::Pending
        }
    }
}

pub fn yield_now() -> Yield {
    Yield (false)
}