use crate::guard::{OptSlotIdx, PollThread, SlotIdx, StateNodeIdx};
use crate::{EngineActivity, SharedStorage, SharedStorageHeader, Slot, StateNode};
use std::fmt::Debug;
use std::hint::unreachable_unchecked;
use std::marker::PhantomData;
use std::mem::ManuallyDrop;
use std::ops::Deref;
use std::pin::Pin;
use std::ptr;
use std::ptr::NonNull;
use std::sync::atomic::Ordering;
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use triomphe::Arc;
fn poll_queue_pop_head(
_: PollThread,
head_ref: &mut StateNodeIdx,
storage: &SharedStorage,
) -> OptSlotIdx {
unsafe {
let mut head = *head_ref;
let mut head_next = head.get_poll_queue_link(storage).load(Ordering::Acquire);
if head == StateNodeIdx::STUB {
if head_next == StateNodeIdx::UNSET {
return OptSlotIdx::UNSET;
}
*head_ref = head_next;
head = head_next;
head_next = head.get_poll_queue_link(storage).load(Ordering::Acquire);
}
if head_next != StateNodeIdx::UNSET {
*head_ref = head_next;
debug_assert_ne!(StateNodeIdx::STUB, head, "cannot return stub");
return head.into_slot_idx().into();
}
let tail = storage.header.poll_queue_tail.load(Ordering::Acquire);
if head != tail {
return OptSlotIdx::UNSET;
}
let stub_state_node = StateNodeIdx::STUB.get_state_node(storage);
poll_queue_push_node_at_index(storage, stub_state_node, StateNodeIdx::STUB);
head_next = head.get_poll_queue_link(storage).load(Ordering::Acquire);
if head_next != StateNodeIdx::UNSET {
*head_ref = head_next;
debug_assert_ne!(StateNodeIdx::STUB, head, "cannot return stub");
return head.into_slot_idx().into();
}
OptSlotIdx::UNSET
}
}
pub(super) unsafe fn poll_queue_reinsert_head(
_: PollThread,
activity: &mut EngineActivity,
state_node: &StateNode,
state_node_idx: StateNodeIdx,
) {
let head = &mut activity.poll_queue_head;
let prev_head = *head;
state_node
.poll_queue_link
.store(prev_head, Ordering::Relaxed);
*head = state_node_idx;
}
pub(super) unsafe fn poll_queue_push_node_at_index(
storage: &SharedStorage,
state_node: &StateNode,
state_node_idx: StateNodeIdx,
) {
state_node
.poll_queue_link
.store(StateNodeIdx::UNSET, Ordering::Relaxed);
let previous = storage
.header
.poll_queue_tail
.swap(state_node_idx, Ordering::AcqRel);
let previous_link = unsafe {
previous.get_poll_queue_link(storage)
};
previous_link.store(state_node_idx, Ordering::Release);
}
pub(super) fn pop_pollable<'s>(
token: PollThread,
activity: &mut EngineActivity,
shared_storage: &'s Arc<SharedStorage>,
) -> PollableSlotQuery<'s> {
let pop_head = poll_queue_pop_head(token, &mut activity.poll_queue_head, shared_storage);
if pop_head.is_set() {
let slot_idx = pop_head.into_slot_idx();
let pollable_slot = PollableSlot::new(shared_storage, slot_idx);
let status = &pollable_slot.state_node().status;
match status.load(Ordering::Relaxed) {
SlotStatus::Init | SlotStatus::Woken => {
(pop_head, Some(pollable_slot))
}
SlotStatus::UninitButEnqueued => {
status.store(SlotStatus::Uninit, Ordering::Relaxed); (pop_head, None) }
SlotStatus::Uninit | SlotStatus::Waiting => {
unsafe {
debug_assert!(false, "Slot with Uninit/Waiting status popped off queue");
unreachable_unchecked()
}
}
}
} else {
(pop_head, None)
}
}
fn wake_up_node(storage: &SharedStorage, state_node: &StateNode) {
let success = state_node.status.compare_and_swap(
SlotStatus::Waiting,
SlotStatus::Woken,
Ordering::Acquire,
Ordering::Relaxed,
);
if success {
unsafe {
let state_node_idx = StateNodeIdx::from(state_node.slot_idx);
poll_queue_push_node_at_index(storage, state_node, state_node_idx);
}
storage.header.main_waker.notify();
}
}
pub(super) type PollableSlotQuery<'s> = (OptSlotIdx, Option<PollableSlot<'s>>);
pub(super) struct PollableSlot<'s> {
state_node_ptr: NonNull<StateNode>,
lifetime: PhantomData<&'s ()>,
}
impl<'s> PollableSlot<'s> {
pub(super) fn state_node(&self) -> &'s StateNode {
unsafe {
self.state_node_ptr.as_ref()
}
}
fn new(shared_storage: &'s Arc<SharedStorage>, slot_idx: SlotIdx) -> Self {
unsafe {
let storage_ptr = Arc::as_ptr(shared_storage);
let ptr = storage_ptr.byte_add(SharedStorageHeader::BYTE_OFFSET);
let ptr = ptr as *const StateNode;
let ptr = ptr.add(1 + slot_idx.value() as usize);
Self {
state_node_ptr: NonNull::new_unchecked(ptr as *mut StateNode),
lifetime: PhantomData,
}
}
}
}
unsafe fn recover_shared_storage(state_node_ptr: *const StateNode) -> *const SharedStorage {
unsafe {
let slot_idx = (*state_node_ptr).slot_idx;
let ptr = state_node_ptr.sub(1 + slot_idx.value() as usize);
let ptr = ptr.byte_sub(SharedStorageHeader::BYTE_OFFSET);
let nodes_len = {
let header = &*(ptr as *const SharedStorageHeader);
header.nodes_len as usize
};
ptr::slice_from_raw_parts(ptr, nodes_len) as *const SharedStorage
}
}
impl PollableSlot<'_> {
pub(super) fn call_poll<F: Future>(
self,
token: PollThread,
slot: Pin<&mut Slot<F>>,
) -> Poll<(F::Output, Droppable)> {
let our_waker = unsafe {
waker_from_state_node(self.state_node_ptr.as_ptr())
};
struct FutOwner<'poll, F: Future> {
token: PollThread,
slot: Pin<&'poll mut Slot<F>>,
status: &'poll AtomicStatus,
}
impl<'poll, F: Future> Drop for FutOwner<'poll, F> {
fn drop(&mut self) {
unsafe {
call_drop(self.token, self.slot.as_mut(), self.status);
}
}
}
let mut fut_owner = FutOwner {
token,
slot,
status: &self.state_node().status,
};
fut_owner
.status
.store(SlotStatus::Waiting, Ordering::Release);
let poll_res = {
let future = fut_owner.slot.as_mut().project().future;
unsafe {
let future = future.map_unchecked_mut(|fut| fut.assume_init_mut());
future.poll(&mut Context::from_waker(&our_waker))
}
};
std::mem::forget(fut_owner);
poll_res.map(|val| (val, Droppable(token)))
}
}
pub(super) unsafe fn call_drop<F: Future>(
_: PollThread,
slot: Pin<&mut Slot<F>>,
status: &AtomicStatus,
) {
let previous_status = status.swap(SlotStatus::Uninit, Ordering::Relaxed);
if previous_status == SlotStatus::Woken {
status.store(SlotStatus::UninitButEnqueued, Ordering::Relaxed);
}
unsafe {
let slot_future = slot.project().future.get_unchecked_mut();
ptr::drop_in_place(slot_future.as_mut_ptr()); }
}
#[must_use]
pub(super) struct Droppable(PollThread);
impl Droppable {
pub(super) fn drop_future<F: Future>(self, slot: Pin<&mut Slot<F>>, state_node: &StateNode) {
unsafe {
call_drop(self.0, slot, &state_node.status);
}
}
}
impl<F> Slot<F> {
pub(super) unsafe fn init_future(
mut self: Pin<&mut Self>,
_: PollThread,
slot_idx: SlotIdx,
state_node: &StateNode,
shared_storage: &SharedStorage,
new_future: F,
) {
let previous_status = state_node.status.swap(SlotStatus::Init, Ordering::Relaxed);
unsafe {
let fut_ptr = self.as_mut().get_unchecked_mut().future.as_mut_ptr();
ptr::write(fut_ptr, new_future);
match previous_status {
SlotStatus::Uninit => {
poll_queue_push_node_at_index(
shared_storage,
state_node,
StateNodeIdx::from(slot_idx),
);
}
SlotStatus::UninitButEnqueued => {
}
SlotStatus::Waiting | SlotStatus::Woken | SlotStatus::Init => {
debug_assert!(false, "Pinned future overwritten before drop");
unreachable_unchecked()
}
}
}
}
}
unsafe fn waker_from_state_node(state_node: *const StateNode) -> ManuallyDrop<Waker> {
fn clone_impl(data: *const ()) -> RawWaker {
unsafe {
let shared_storage = recover_shared_storage(data as *const StateNode);
let arc = ManuallyDrop::new(Arc::from_raw(shared_storage));
let _cloned = ManuallyDrop::new(Arc::clone(ManuallyDrop::deref(&arc)));
}
RawWaker::new(data, VTABLE)
}
fn wake_impl(data: *const ()) {
let state_node = data as *const StateNode;
unsafe {
let shared_storage = recover_shared_storage(state_node);
{
let state_node = &*state_node;
let shared_storage = &*shared_storage;
wake_up_node(shared_storage, state_node);
}
let _arc = Arc::from_raw(shared_storage);
}
}
fn wake_by_ref_impl(data: *const ()) {
let state_node = data as *const StateNode;
unsafe {
let shared_storage = recover_shared_storage(state_node);
let state_node = &*state_node;
let shared_storage = &*shared_storage;
wake_up_node(shared_storage, state_node);
}
}
fn drop_impl(data: *const ()) {
let state_node = data as *const StateNode;
unsafe {
let shared_storage = recover_shared_storage(state_node);
let _arc = Arc::from_raw(shared_storage);
}
}
debug_assert!((|| unsafe {
let slot_idx = (*state_node).slot_idx;
let shared_storage = &*recover_shared_storage(state_node);
let recovered_state_node = slot_idx.get_state_node(shared_storage);
ptr::eq(state_node, recovered_state_node)
})());
const VTABLE: &RawWakerVTable =
&RawWakerVTable::new(clone_impl, wake_impl, wake_by_ref_impl, drop_impl);
unsafe {
let data = state_node as *const ();
ManuallyDrop::new(Waker::from_raw(RawWaker::new(data, VTABLE)))
}
}
pub(super) use slot_status::AtomicStatus;
#[derive(PartialEq, Eq, Debug)]
#[repr(u8)]
pub(super) enum SlotStatus {
Uninit,
UninitButEnqueued,
Init,
Waiting,
Woken,
}
mod slot_status {
use super::SlotStatus;
use std::fmt::{Debug, Formatter};
use std::hint::unreachable_unchecked;
use std::sync::atomic::{AtomicU8, Ordering};
pub(crate) struct AtomicStatus(AtomicU8);
impl Debug for AtomicStatus {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.load(Ordering::Relaxed), f)
}
}
impl Default for AtomicStatus {
fn default() -> Self {
Self(AtomicU8::new(SlotStatus::Uninit as u8))
}
}
impl AtomicStatus {
fn extract(val: u8) -> SlotStatus {
if val == SlotStatus::Uninit as u8 {
SlotStatus::Uninit
} else if val == SlotStatus::UninitButEnqueued as u8 {
SlotStatus::UninitButEnqueued
} else if val == SlotStatus::Init as u8 {
SlotStatus::Init
} else if val == SlotStatus::Waiting as u8 {
SlotStatus::Waiting
} else if val == SlotStatus::Woken as u8 {
SlotStatus::Woken
} else {
unsafe {
unreachable_unchecked()
}
}
}
pub(crate) fn load(&self, ordering: Ordering) -> SlotStatus {
Self::extract(self.0.load(ordering))
}
pub(super) fn store(&self, val: SlotStatus, ordering: Ordering) {
self.0.store(val as u8, ordering);
}
pub(super) fn compare_and_swap(
&self,
current: SlotStatus,
new: SlotStatus,
success: Ordering,
failure: Ordering,
) -> bool {
self.0
.compare_exchange(current as u8, new as u8, success, failure)
.is_ok()
}
pub(super) fn swap(&self, new: SlotStatus, ordering: Ordering) -> SlotStatus {
Self::extract(self.0.swap(new as u8, ordering))
}
}
}