use crate::collections::ArcCell;
use crate::flavor::FlavorImpl;
use std::cell::UnsafeCell;
use std::fmt;
use std::ops::Deref;
use std::sync::{
atomic::{AtomicU32, AtomicU8, Ordering},
Arc, Weak,
};
use std::task::*;
use std::thread;
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(u8)]
pub enum WakerState {
Init = 0, Waiting = 1,
Woken = 3,
Closed = 4, Done = 5,
}
#[derive(PartialEq, Debug, Clone, Copy)]
#[repr(u8)]
pub enum WakeResult {
Woken = 0x1, Sent = 0x3, Next = 0x2, Skip = 0x4, }
impl WakeResult {
#[inline(always)]
pub fn is_done(&self) -> bool {
(*self as u8) & (WakeResult::Woken as u8) > 0
}
}
pub struct ArcWaker<P>(Arc<WakerInner<P>>);
impl<P> fmt::Debug for ArcWaker<P> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(f)
}
}
impl<P> fmt::Debug for WakerInner<P> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "waker({})", self.get_seq())
}
}
impl<P> Deref for ArcWaker<P> {
type Target = WakerInner<P>;
#[inline]
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
impl<P> ArcWaker<P> {
#[inline(always)]
pub fn new_async(ctx: &Context, payload: P) -> Self {
Self(Arc::new(WakerInner {
seq: AtomicU32::new(0),
state: AtomicU8::new(WakerState::Init as u8),
waker: UnsafeCell::new(ThinWaker::Async(ctx.waker().clone())),
payload: UnsafeCell::new(payload),
}))
}
#[inline(always)]
pub fn new_blocking(payload: P) -> Self {
Self(Arc::new(WakerInner {
seq: AtomicU32::new(0),
state: AtomicU8::new(WakerState::Init as u8),
waker: UnsafeCell::new(ThinWaker::Blocking(thread::current())),
payload: UnsafeCell::new(payload),
}))
}
}
impl<P> ArcWaker<P> {
#[inline(always)]
pub fn from_arc(inner: Arc<WakerInner<P>>) -> Self {
Self(inner)
}
#[allow(clippy::wrong_self_convention)]
#[inline(always)]
pub fn to_arc(self) -> Arc<WakerInner<P>> {
self.0
}
#[inline(always)]
pub fn weak(&self) -> Weak<WakerInner<P>> {
Arc::downgrade(&self.0)
}
}
#[derive(Debug)]
pub(crate) enum ThinWaker {
Async(Waker),
Blocking(thread::Thread),
}
impl ThinWaker {
#[inline(always)]
pub fn wake_by_ref(&self) {
match self {
Self::Async(w) => w.wake_by_ref(),
Self::Blocking(th) => th.unpark(),
}
}
#[allow(dead_code)]
#[inline(always)]
pub fn wake(self) {
match self {
Self::Async(w) => w.wake(),
Self::Blocking(th) => th.unpark(),
}
}
#[inline(always)]
pub fn will_wake(&self, ctx: &mut Context) -> bool {
if let Self::Async(_waker) = self {
_waker.will_wake(ctx.waker())
} else {
unreachable!();
}
}
}
pub struct WakerInner<P> {
state: AtomicU8,
seq: AtomicU32,
waker: UnsafeCell<ThinWaker>,
#[allow(dead_code)]
payload: UnsafeCell<P>,
}
unsafe impl<P> Send for WakerInner<P> {}
unsafe impl<P> Sync for WakerInner<P> {}
impl<P> WakerInner<P> {
#[inline(always)]
fn get_waker(&self) -> &ThinWaker {
unsafe { &*self.waker.get() }
}
#[inline(always)]
fn get_waker_mut(&self) -> &mut ThinWaker {
unsafe { &mut *self.waker.get() }
}
#[inline(always)]
fn get_payload_mut(&self) -> &mut P {
unsafe { &mut *self.payload.get() }
}
#[inline(always)]
pub fn reset(&self, payload: P) {
*self.get_payload_mut() = payload;
self.reset_init();
}
#[inline(always)]
pub fn get_seq(&self) -> u32 {
self.seq.load(Ordering::Relaxed)
}
#[inline(always)]
pub fn set_seq(&self, seq: u32) {
self.seq.store(seq, Ordering::Relaxed);
}
#[inline(always)]
fn update_thread_handle(&self) {
let _waker = self.get_waker_mut();
*_waker = ThinWaker::Blocking(thread::current());
}
#[inline(always)]
pub fn commit_waiting(&self) -> u8 {
if let Err(s) = self.try_change_state(WakerState::Init, WakerState::Waiting) {
s
} else {
WakerState::Waiting as u8
}
}
#[inline(always)]
pub fn try_change_state(&self, cur: WakerState, new_state: WakerState) -> Result<(), u8> {
self.state.compare_exchange(
cur as u8,
new_state as u8,
Ordering::SeqCst,
Ordering::Acquire,
)?;
Ok(())
}
#[inline(always)]
pub fn reset_init(&self) {
self.state.store(WakerState::Init as u8, Ordering::Relaxed);
}
#[inline(always)]
pub fn abandon(&self) -> Result<(), u8> {
match self.change_state_smaller_eq(WakerState::Waiting, WakerState::Closed) {
Ok(_) => Ok(()),
Err(state) => Err(state),
}
}
#[inline(always)]
pub fn close_wake(&self) -> bool {
if self.change_state_smaller_eq(WakerState::Waiting, WakerState::Closed).is_ok() {
self.get_waker().wake_by_ref();
return true;
}
false
}
#[inline(always)]
pub fn change_state_smaller_eq(
&self, condition: WakerState, target: WakerState,
) -> Result<u8, u8> {
debug_assert!((condition as u8) < (target as u8));
let mut state = condition as u8;
loop {
match self.state.compare_exchange_weak(
state,
target as u8,
Ordering::SeqCst,
Ordering::Acquire,
) {
Ok(_) => {
return Ok(state);
}
Err(s) => {
if s > condition as u8 {
return Err(s);
}
state = s;
}
}
}
}
#[inline(always)]
pub fn _get_state(&self, order: Ordering) -> u8 {
self.state.load(order)
}
#[inline(always)]
pub fn get_state(&self) -> u8 {
self.state.load(Ordering::SeqCst)
}
#[inline(always)]
pub fn get_state_relaxed(&self) -> u8 {
self.state.load(Ordering::Relaxed)
}
#[inline(always)]
pub fn wake(&self) -> WakeResult {
let mut state = self.get_state_relaxed();
loop {
if state >= WakerState::Woken as u8 {
return WakeResult::Skip;
} else if state == WakerState::Waiting as u8 {
self.state.store(WakerState::Woken as u8, Ordering::SeqCst);
self.get_waker().wake_by_ref();
return WakeResult::Woken;
} else {
match self.state.compare_exchange_weak(
WakerState::Init as u8,
WakerState::Woken as u8,
Ordering::SeqCst,
Ordering::Acquire,
) {
Ok(_) => {
self.get_waker().wake_by_ref();
return WakeResult::Next;
}
Err(s) => {
state = s;
}
}
}
}
}
#[inline(always)]
pub fn will_wake(&self, ctx: &mut Context) -> bool {
self.get_waker().will_wake(ctx)
}
}
impl<T> WakerInner<*const T> {
#[inline(always)]
fn get_payload(&self) -> *const T {
*self.get_payload_mut()
}
#[inline(always)]
pub fn wake_or_copy<F: FlavorImpl<Item = T>>(&self, flavor: &F) -> WakeResult {
let mut state = self.get_state_relaxed();
loop {
if state >= WakerState::Woken as u8 {
return WakeResult::Skip;
} else if state == WakerState::Waiting as u8 {
let p = self.get_payload();
if p.is_null() {
self.state.store(WakerState::Woken as u8, Ordering::SeqCst);
self.get_waker().wake_by_ref();
return WakeResult::Woken;
}
state = if let Some(true) = flavor.try_send_oneshot(p) {
WakerState::Done as u8
} else {
WakerState::Woken as u8
};
self.state.store(state, Ordering::SeqCst);
self.get_waker().wake_by_ref();
if state == WakerState::Done as u8 {
return WakeResult::Sent;
} else {
return WakeResult::Woken;
}
} else {
match self.state.compare_exchange_weak(
WakerState::Init as u8,
WakerState::Woken as u8,
Ordering::SeqCst,
Ordering::Acquire,
) {
Ok(_) => {
self.get_waker().wake_by_ref();
return WakeResult::Next;
}
Err(s) => {
state = s;
}
}
}
}
}
}
pub struct WakerCache<P: Copy>(ArcCell<WakerInner<P>>);
impl<P: Copy> WakerCache<P> {
#[inline(always)]
pub(crate) fn new() -> Self {
Self(ArcCell::new())
}
#[inline(always)]
pub fn new_blocking(&self, payload: P) -> ArcWaker<P> {
if let Some(inner) = self.0.pop() {
inner.update_thread_handle();
inner.reset(payload);
return ArcWaker::<P>::from_arc(inner);
}
ArcWaker::new_blocking(payload)
}
#[inline(always)]
pub(crate) fn push(&self, waker: ArcWaker<P>) {
debug_assert!(waker.get_state() >= WakerState::Woken as u8);
let a = waker.to_arc();
if Arc::weak_count(&a) == 0 && Arc::strong_count(&a) == 1 {
self.0.try_put(a);
}
}
#[allow(dead_code)]
#[inline(always)]
pub(crate) fn is_empty(&self) -> bool {
!self.0.exists()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_waker_size() {
use std::mem::size_of;
println!("wakertype {}", size_of::<ThinWaker>());
println!("waker inner {}", size_of::<WakerInner<()>>());
}
}