use core::future::Future;
use core::pin::Pin;
use core::sync::atomic::Ordering;
use core::task::{Context, Poll, Waker};
use crate::loom_exports::cell::UnsafeCell;
use crate::loom_exports::sync::atomic::AtomicUsize;
use crate::WakeSinkRef;
const INDEX: usize = 0b00001;
const UPDATE: usize = 0b00010;
const REGISTERED: usize = 0b00100;
const LOCKED: usize = 0b01000;
const NOTIFICATION: usize = 0b10000;
#[derive(Debug)]
pub struct DiatomicWaker {
state: AtomicUsize,
waker: [UnsafeCell<Option<Waker>>; 2],
}
impl DiatomicWaker {
#[cfg(not(all(test, diatomic_waker_loom)))]
pub const fn new() -> Self {
Self {
state: AtomicUsize::new(0),
waker: [UnsafeCell::new(None), UnsafeCell::new(None)],
}
}
#[cfg(all(test, diatomic_waker_loom))]
pub fn new() -> Self {
Self {
state: AtomicUsize::new(0),
waker: [UnsafeCell::new(None), UnsafeCell::new(None)],
}
}
pub fn sink_ref(&mut self) -> WakeSinkRef<'_> {
WakeSinkRef { inner: self }
}
pub fn notify(&self) {
let mut state = if let Ok(s) = try_lock(&self.state) {
s
} else {
return;
};
loop {
let idx = state & INDEX;
unsafe {
self.wake_by_ref(idx);
}
if let Err(s) = try_unlock(&self.state, state) {
state = s;
} else {
return;
}
}
}
pub unsafe fn register(&self, waker: &Waker) {
let state = self.state.load(Ordering::Acquire);
let mut idx = state & INDEX;
let recent_idx = if state & UPDATE == 0 {
idx
} else {
INDEX - idx
};
let is_up_to_date = self.will_wake(recent_idx, waker);
if is_up_to_date {
self.state.fetch_or(REGISTERED, Ordering::Acquire);
return;
}
if state & (UPDATE | REGISTERED) == (UPDATE | REGISTERED) {
let state = self
.state
.fetch_and(!(REGISTERED | NOTIFICATION), Ordering::Acquire);
idx = state & INDEX;
}
let redundant_idx = 1 - idx;
self.set_waker(redundant_idx, waker.clone());
self.state.fetch_or(UPDATE | REGISTERED, Ordering::AcqRel);
}
pub unsafe fn unregister(&self) {
self.state
.fetch_and(!(REGISTERED | NOTIFICATION), Ordering::Relaxed);
}
pub unsafe fn wait_until<P, T>(&self, predicate: P) -> WaitUntil<'_, P, T>
where
P: FnMut() -> Option<T>,
{
WaitUntil::new(self, predicate)
}
unsafe fn set_waker(&self, idx: usize, new: Waker) {
self.waker[idx].with_mut(|waker| (*waker) = Some(new));
}
unsafe fn wake_by_ref(&self, idx: usize) {
self.waker[idx].with(|waker| {
if let Some(waker) = &*waker {
waker.wake_by_ref();
}
});
}
unsafe fn will_wake(&self, idx: usize, other: &Waker) -> bool {
self.waker[idx].with(|waker| match &*waker {
Some(waker) => waker.will_wake(other),
None => false,
})
}
}
impl Default for DiatomicWaker {
fn default() -> Self {
Self::new()
}
}
unsafe impl Send for DiatomicWaker {}
unsafe impl Sync for DiatomicWaker {}
fn try_lock(state: &AtomicUsize) -> Result<usize, ()> {
let mut old_state = state.load(Ordering::Relaxed);
loop {
if old_state & (LOCKED | REGISTERED) == REGISTERED {
let update_bit = old_state & UPDATE;
let xor_mask = update_bit | (update_bit >> 1);
let xor_mask = xor_mask | LOCKED | REGISTERED;
let new_state = old_state ^ xor_mask;
match state.compare_exchange_weak(
old_state,
new_state,
Ordering::AcqRel,
Ordering::Relaxed,
) {
Ok(_) => return Ok(new_state),
Err(s) => old_state = s,
}
} else {
let registered_bit = old_state & REGISTERED;
let new_state = old_state | (registered_bit << 2);
match state.compare_exchange_weak(
old_state,
new_state,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => return Err(()),
Err(s) => old_state = s,
}
};
}
}
fn try_unlock(state: &AtomicUsize, mut old_state: usize) -> Result<(), usize> {
loop {
if old_state & NOTIFICATION == 0 {
let new_state = old_state & !LOCKED;
match state.compare_exchange_weak(
old_state,
new_state,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => return Ok(()),
Err(s) => old_state = s,
}
} else {
let update_bit = old_state & UPDATE;
let xor_mask = update_bit | (update_bit >> 1);
let xor_mask = xor_mask | NOTIFICATION | REGISTERED;
let new_state = old_state ^ xor_mask;
match state.compare_exchange_weak(
old_state,
new_state,
Ordering::AcqRel,
Ordering::Relaxed,
) {
Ok(_) => return Err(new_state),
Err(s) => old_state = s,
}
};
}
}
#[derive(Debug)]
pub struct WaitUntil<'a, P, T>
where
P: FnMut() -> Option<T>,
{
predicate: P,
wake: &'a DiatomicWaker,
}
impl<'a, P, T> WaitUntil<'a, P, T>
where
P: FnMut() -> Option<T>,
{
fn new(wake: &'a DiatomicWaker, predicate: P) -> Self {
Self { predicate, wake }
}
}
impl<P: FnMut() -> Option<T>, T> Unpin for WaitUntil<'_, P, T> {}
impl<'a, P, T> Future for WaitUntil<'a, P, T>
where
P: FnMut() -> Option<T>,
{
type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
unsafe {
if let Some(value) = (self.predicate)() {
return Poll::Ready(value);
}
self.wake.register(cx.waker());
if let Some(value) = (self.predicate)() {
self.wake.unregister();
return Poll::Ready(value);
}
}
Poll::Pending
}
}