use crate::polling::AtomicStatus;
use crate::{
EngineActivity, HeapSlots, SharedStorage, SharedStorageHeader, SharedStorageHeaderInner, Slot,
StackSlots, StateNode,
};
use cache_padded::CachePadded;
use diatomic_waker::DiatomicWaker;
use std::cell::Cell;
use std::fmt::{Debug, Formatter};
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::pin::Pin;
use std::sync::MutexGuard;
use std::sync::atomic::{AtomicU32, Ordering};
use triomphe::Arc;
#[derive(Copy, Clone)]
pub(super) struct PollThread(PhantomData<(Cell<usize>, MutexGuard<'static, ()>)>);
impl Debug for PollThread {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PollThread").finish()
}
}
pub(super) unsafe trait AnchorPollThread {
fn poll_thread(&self) -> PollThread;
}
unsafe impl<'pin, const N: usize, F: Future> AnchorPollThread for Pin<&'pin mut StackSlots<N, F>> {
fn poll_thread(&self) -> PollThread {
PollThread(PhantomData)
}
}
unsafe impl<F: Future> AnchorPollThread for HeapSlots<F> {
fn poll_thread(&self) -> PollThread {
PollThread(PhantomData)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub(super) struct SlotIdx(u16);
impl SlotIdx {
pub(super) fn value(self) -> u16 {
self.0
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub(super) struct OptSlotIdx(u16);
impl Debug for OptSlotIdx {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
if self.is_set() {
f.debug_tuple("OptSlotIdx").field(&self.0).finish()
} else {
f.write_str("OptSlotIdx::UNSET")
}
}
}
impl OptSlotIdx {
pub(super) const UNSET: Self = Self(u16::MAX);
pub(super) fn is_set(&self) -> bool {
*self != Self::UNSET
}
pub(super) fn into_slot_idx(self) -> SlotIdx {
assert_ne!(self, Self::UNSET, "Self is unset!");
SlotIdx(self.0)
}
}
impl From<SlotIdx> for OptSlotIdx {
fn from(value: SlotIdx) -> Self {
Self(value.0)
}
}
impl SlotIdx {
pub(super) fn get_slot<F>(self, slots: Pin<&mut [Slot<F>]>) -> Pin<&mut Slot<F>> {
let idx = self.0 as usize;
unsafe {
let slots = slots.get_unchecked_mut();
debug_assert!(idx < slots.len());
let slot = slots.get_unchecked_mut(idx);
Pin::new_unchecked(slot)
}
}
}
impl Default for EngineActivity {
fn default() -> Self {
Self {
empty_head: OptSlotIdx::from(SlotIdx(0)),
slots_active: 0,
poll_queue_head: StateNodeIdx::STUB,
poll_loop_idx: 1,
}
}
}
pub(super) fn slot_init<F>(slots: &mut [MaybeUninit<Slot<F>>]) {
let slots_len = slots.len();
debug_assert_ne!(0, slots_len, "Length must be non-zero");
debug_assert!(
slots_len <= u16::MAX as usize,
"Length can be at most u16::MAX"
);
for (idx, slot) in slots.iter_mut().enumerate() {
let next_idx = idx + 1;
let empty_link = if next_idx == slots_len {
OptSlotIdx::UNSET
} else {
OptSlotIdx::from(SlotIdx(next_idx as u16))
};
slot.write(Slot {
future: MaybeUninit::uninit(),
empty_link,
last_poll_loop_idx: 0,
});
}
}
pub(super) fn shared_storage_init(len: u16) -> Arc<SharedStorage> {
assert_ne!(len, 0, "Length cannot be 0");
let state_node_iter = (0..=len).into_iter().map(|idx| StateNode {
status: AtomicStatus::default(),
poll_queue_link: AtomicStateNodeIdx::new(StateNodeIdx::UNSET),
slot_idx: if idx == 0 {
SlotIdx(0) } else {
SlotIdx(idx - 1)
},
});
Arc::from_header_and_iter(
SharedStorageHeader(CachePadded::new(SharedStorageHeaderInner {
poll_queue_tail: AtomicStateNodeIdx::new(StateNodeIdx::STUB),
nodes_len: (len as u32) + 1,
main_waker: DiatomicWaker::new(),
})),
state_node_iter,
)
}
pub(super) fn state_node_iter(
shared_storage: &SharedStorage,
) -> impl IntoIterator<Item = (SlotIdx, &StateNode)> {
shared_storage.slice[1..]
.iter()
.enumerate()
.map(|(idx, state_node)| (SlotIdx(idx as u16), state_node))
}
impl SlotIdx {
pub(super) fn get_state_node(self, shared_storage: &SharedStorage) -> &StateNode {
let idx = (self.0 + 1) as usize;
unsafe {
debug_assert!(idx < shared_storage.slice.len());
shared_storage.slice.get_unchecked(idx)
}
}
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub(super) struct StateNodeIdx(u32);
impl Debug for StateNodeIdx {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
if *self == Self::STUB {
f.write_str("StateNodeIdx::STUB")
} else if *self == Self::UNSET {
f.write_str("StateNodeIdx::UNSET")
} else {
f.debug_tuple("StateNodeIdx").field(&self.0).finish()
}
}
}
impl StateNodeIdx {
pub(super) const UNSET: Self = Self(u32::MAX);
pub(super) const STUB: Self = Self(0);
pub(super) unsafe fn get_state_node(self, shared_storage: &SharedStorage) -> &StateNode {
let idx = self.0 as usize;
unsafe {
debug_assert!(idx < shared_storage.slice.len());
shared_storage.slice.get_unchecked(idx)
}
}
pub(super) unsafe fn get_poll_queue_link(
self,
shared_storage: &SharedStorage,
) -> &AtomicStateNodeIdx {
unsafe {
&self.get_state_node(shared_storage).poll_queue_link
}
}
pub(super) unsafe fn into_slot_idx(self) -> SlotIdx {
unsafe {
let val = self.0;
std::hint::assert_unchecked(val != 0);
SlotIdx((val as u16) - 1)
}
}
}
impl From<SlotIdx> for StateNodeIdx {
fn from(value: SlotIdx) -> Self {
Self((value.0 + 1) as u32)
}
}
pub(super) struct AtomicStateNodeIdx(AtomicU32);
impl Debug for AtomicStateNodeIdx {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
StateNodeIdx::fmt(&StateNodeIdx(self.0.load(Ordering::Relaxed)), f)
}
}
impl AtomicStateNodeIdx {
pub(super) const fn new(value: StateNodeIdx) -> Self {
Self(AtomicU32::new(value.0))
}
}
impl AtomicStateNodeIdx {
pub(super) fn load(&self, ordering: Ordering) -> StateNodeIdx {
StateNodeIdx(self.0.load(ordering))
}
pub(super) fn store(&self, new_value: StateNodeIdx, ordering: Ordering) {
self.0.store(new_value.0, ordering);
}
pub(super) fn swap(&self, new_value: StateNodeIdx, ordering: Ordering) -> StateNodeIdx {
let previous = self.0.swap(new_value.0, ordering);
StateNodeIdx(previous)
}
}