use std::cell::UnsafeCell;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Waker};
use crossbeam_utils::Backoff;
use slab::Slab;
#[allow(clippy::identity_op)]
const LOCKED: usize = 1 << 0;
const NOTIFIED: usize = 1 << 1;
const NOTIFIABLE: usize = 1 << 2;
struct Inner {
entries: Slab<Option<Waker>>,
notifiable: usize,
}
pub struct WakerSet {
flag: AtomicUsize,
inner: UnsafeCell<Inner>,
}
impl WakerSet {
#[inline]
pub fn new() -> WakerSet {
WakerSet {
flag: AtomicUsize::new(0),
inner: UnsafeCell::new(Inner {
entries: Slab::new(),
notifiable: 0,
}),
}
}
#[cold]
pub fn insert(&self, cx: &Context<'_>) -> usize {
let w = cx.waker().clone();
let mut inner = self.lock();
let key = inner.entries.insert(Some(w));
inner.notifiable += 1;
key
}
#[cfg(feature = "unstable")]
pub fn remove_if_notified(&self, key: usize, cx: &Context<'_>) -> bool {
let mut inner = self.lock();
match &mut inner.entries[key] {
None => {
inner.entries.remove(key);
true
}
Some(w) => {
if !w.will_wake(cx.waker()) {
*w = cx.waker().clone();
}
false
}
}
}
#[cold]
pub fn cancel(&self, key: usize) -> bool {
let mut inner = self.lock();
match inner.entries.remove(key) {
Some(_) => inner.notifiable -= 1,
None => {
for (_, opt_waker) in inner.entries.iter_mut() {
if let Some(w) = opt_waker.take() {
w.wake();
inner.notifiable -= 1;
return true;
}
}
}
}
false
}
#[inline]
#[cfg(feature = "unstable")]
pub fn notify_one(&self) -> bool {
if self.flag.load(Ordering::SeqCst) & NOTIFIABLE != 0 {
self.notify(Notify::One)
} else {
false
}
}
#[inline]
pub fn notify_all(&self) -> bool {
if self.flag.load(Ordering::SeqCst) & NOTIFIABLE != 0 {
self.notify(Notify::All)
} else {
false
}
}
#[cold]
fn notify(&self, n: Notify) -> bool {
let mut inner = &mut *self.lock();
let mut notified = false;
for (_, opt_waker) in inner.entries.iter_mut() {
if let Some(w) = opt_waker.take() {
w.wake();
inner.notifiable -= 1;
notified = true;
if n == Notify::One {
break;
}
}
if n == Notify::Any {
break;
}
}
notified
}
fn lock(&self) -> Lock<'_> {
let backoff = Backoff::new();
while self.flag.fetch_or(LOCKED, Ordering::Acquire) & LOCKED != 0 {
backoff.snooze();
}
Lock { waker_set: self }
}
}
struct Lock<'a> {
waker_set: &'a WakerSet,
}
impl Drop for Lock<'_> {
#[inline]
fn drop(&mut self) {
let mut flag = 0;
if self.entries.len() - self.notifiable > 0 {
flag |= NOTIFIED;
}
if self.notifiable > 0 {
flag |= NOTIFIABLE;
}
self.waker_set.flag.store(flag, Ordering::SeqCst);
}
}
impl Deref for Lock<'_> {
type Target = Inner;
#[inline]
fn deref(&self) -> &Inner {
unsafe { &*self.waker_set.inner.get() }
}
}
impl DerefMut for Lock<'_> {
#[inline]
fn deref_mut(&mut self) -> &mut Inner {
unsafe { &mut *self.waker_set.inner.get() }
}
}
#[derive(Clone, Copy, Eq, PartialEq)]
enum Notify {
Any,
One,
All,
}