use crate::loom::alloc::Track;
use crate::loom::cell::CausalCell;
use crate::task::raw::{self, Vtable};
use crate::task::state::State;
use crate::task::waker::waker_ref;
use crate::task::Schedule;
use std::cell::UnsafeCell;
use std::future::Future;
use std::mem::MaybeUninit;
use std::pin::Pin;
use std::ptr::{self, NonNull};
use std::task::{Context, Poll, Waker};
#[repr(C)]
pub(super) struct Cell<T: Future> {
pub(super) header: Header,
pub(super) core: Core<T>,
pub(super) trailer: Trailer,
}
pub(super) struct Core<T: Future> {
stage: Stage<T>,
}
#[repr(C)]
pub(crate) struct Header {
pub(super) state: State,
pub(super) executor: CausalCell<Option<NonNull<()>>>,
pub(crate) queue_next: UnsafeCell<*const Header>,
pub(crate) owned_next: UnsafeCell<Option<NonNull<Header>>>,
pub(crate) owned_prev: UnsafeCell<Option<NonNull<Header>>>,
pub(super) vtable: &'static Vtable,
pub(super) future_causality: CausalCell<()>,
}
pub(super) struct Trailer {
pub(super) waker: CausalCell<MaybeUninit<Option<Waker>>>,
}
enum Stage<T: Future> {
Running(Track<T>),
Finished(Track<super::Result<T::Output>>),
Consumed,
}
impl<T: Future> Cell<T> {
pub(super) fn new<S>(future: T, state: State) -> Box<Cell<T>>
where
S: Schedule,
{
Box::new(Cell {
header: Header {
state,
executor: CausalCell::new(None),
queue_next: UnsafeCell::new(ptr::null()),
owned_next: UnsafeCell::new(None),
owned_prev: UnsafeCell::new(None),
vtable: raw::vtable::<T, S>(),
future_causality: CausalCell::new(()),
},
core: Core {
stage: Stage::Running(Track::new(future)),
},
trailer: Trailer {
waker: CausalCell::new(MaybeUninit::new(None)),
},
})
}
}
impl<T: Future> Core<T> {
pub(super) fn transition_to_consumed(&mut self) {
self.stage = Stage::Consumed
}
pub(super) fn poll<S>(&mut self, header: &Header) -> Poll<T::Output>
where
S: Schedule,
{
let res = {
let future = match &mut self.stage {
Stage::Running(tracked) => tracked.get_mut(),
_ => unreachable!("unexpected stage"),
};
let future = unsafe { Pin::new_unchecked(future) };
let waker_ref = waker_ref::<T, S>(header);
let mut cx = Context::from_waker(&*waker_ref);
future.poll(&mut cx)
};
if res.is_ready() {
self.stage = Stage::Consumed;
}
res
}
pub(super) fn store_output(&mut self, output: super::Result<T::Output>) {
self.stage = Stage::Finished(Track::new(output));
}
pub(super) unsafe fn read_output(&mut self, dst: *mut Track<super::Result<T::Output>>) {
use std::mem;
dst.write(match mem::replace(&mut self.stage, Stage::Consumed) {
Stage::Finished(output) => output,
_ => unreachable!("unexpected state"),
});
}
}
impl Header {
pub(super) fn executor(&self) -> Option<NonNull<()>> {
unsafe { self.executor.with(|ptr| *ptr) }
}
}