use std::{
cell::{Cell, UnsafeCell},
fmt,
future::Future,
marker::PhantomPinned,
mem,
ops::{Deref, DerefMut},
panic::{RefUnwindSafe, UnwindSafe},
pin::Pin,
ptr::{self, NonNull},
sync::atomic::{AtomicUsize, Ordering},
task::{Context, Poll, Waker},
};
use crate::sync::{
atomic::full_fence,
mutex::{Mutex, MutexFamily, MutexGuardType, StdMutex},
};
pub struct Event<M: MutexFamily = StdMutex> {
notified: AtomicUsize,
list: M::Mutex<List>,
}
unsafe impl<M: MutexFamily> Send for Event<M> {}
unsafe impl<M: MutexFamily> Sync for Event<M> {}
impl<M: MutexFamily> UnwindSafe for Event<M> {}
impl<M: MutexFamily> RefUnwindSafe for Event<M> {}
impl<M: MutexFamily> Event<M> {
#[inline]
pub fn new() -> Self {
Self {
notified: AtomicUsize::new(usize::MAX),
list: M::new(List::new()),
}
}
#[inline]
fn lock(&self) -> ListGuard<'_, M> {
ListGuard {
event: self,
guard: self.list.lock(),
}
}
#[cold]
pub fn listener(&self) -> EventListener<'_, M> {
EventListener {
event: self,
state: ListenerState::Init,
entry: UnsafeCell::new(Entry::new()),
}
}
#[inline]
pub fn notify(&self, n: usize) {
full_fence();
self.notify_relaxed(n)
}
#[inline]
pub fn notify_relaxed(&self, n: usize) {
if self.notified.load(Ordering::Acquire) < n {
self.lock().notify(n);
}
}
#[inline]
pub fn notify_additional(&self, n: usize) {
full_fence();
self.notify_additional_relaxed(n);
}
#[inline]
pub fn notify_additional_relaxed(&self, n: usize) {
if self.notified.load(Ordering::Acquire) < usize::MAX {
self.lock().notify_additional(n);
}
}
}
impl<M: MutexFamily> fmt::Debug for Event<M> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("Event { .. }")
}
}
impl<M: MutexFamily> Default for Event<M> {
fn default() -> Self {
Event::new()
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum ListenerState {
Init,
Listening,
Done,
}
pub struct EventListener<'a, M: MutexFamily> {
event: &'a Event<M>,
state: ListenerState,
entry: UnsafeCell<Entry>,
}
unsafe impl<M: MutexFamily> Send for EventListener<'_, M> {}
unsafe impl<M: MutexFamily> Sync for EventListener<'_, M> {}
impl<M: MutexFamily> UnwindSafe for EventListener<'_, M> {}
impl<M: MutexFamily> RefUnwindSafe for EventListener<'_, M> {}
impl<M: MutexFamily> EventListener<'_, M> {
#[inline]
fn project(self: Pin<&mut Self>) -> (&Event<M>, &mut ListenerState, Pin<&UnsafeCell<Entry>>) {
fn is_unpin<T: Unpin>() {}
unsafe {
is_unpin::<&Event<M>>();
is_unpin::<ListenerState>();
let this = self.get_unchecked_mut();
(this.event, &mut this.state, Pin::new_unchecked(&this.entry))
}
}
pub fn listen(self: Pin<&mut Self>) {
let (event, state, entry) = self.project();
if *state == ListenerState::Init {
event.lock().insert(entry, None);
*state = ListenerState::Listening;
}
full_fence();
}
#[inline]
pub fn listens_to(&self, event: &Event<M>) -> bool {
ptr::eq(self.event, event)
}
#[inline]
pub fn same_event(&self, other: &EventListener<'_, M>) -> bool {
ptr::eq(self.event, other.event)
}
}
impl<M: MutexFamily> Future for EventListener<'_, M> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let (inner, listener_state, entry) = self.project();
if *listener_state == ListenerState::Done {
return Poll::Ready(());
}
let mut list = inner.lock();
if *listener_state == ListenerState::Init {
list.insert(entry, Some(cx.waker().clone()));
*listener_state = ListenerState::Listening;
full_fence();
return Poll::Pending;
}
let entry = unsafe { &mut *entry.get_ref().get() };
let state = &mut entry.state;
match state.replace(State::Notified(false)) {
State::Notified(_) => {
let entry = unsafe { NonNull::new_unchecked(entry as *mut _) };
list.remove(entry);
*listener_state = ListenerState::Done;
drop(list);
return Poll::Ready(());
}
State::Created => {
state.set(State::Polling(cx.waker().clone()));
}
State::Polling(w) => {
if w.will_wake(cx.waker()) {
state.set(State::Polling(w));
} else {
state.set(State::Polling(cx.waker().clone()));
}
}
}
Poll::Pending
}
}
impl<M: MutexFamily> Drop for EventListener<'_, M> {
fn drop(&mut self) {
if self.state == ListenerState::Listening {
let mut list = self.event.lock();
let entry = unsafe { NonNull::new_unchecked(self.entry.get()) };
if let State::Notified(additional) = list.remove(entry) {
if additional {
list.notify_additional(1);
} else {
list.notify(1);
}
}
self.state = ListenerState::Done;
}
}
}
struct ListGuard<'a, M: MutexFamily> {
event: &'a Event<M>,
guard: MutexGuardType<'a, M, List>,
}
impl<M: MutexFamily> Drop for ListGuard<'_, M> {
#[inline]
fn drop(&mut self) {
let list = &mut **self;
let notified = if list.notified < list.len {
list.notified
} else {
usize::MAX
};
self.event.notified.store(notified, Ordering::Release);
}
}
impl<M: MutexFamily> Deref for ListGuard<'_, M> {
type Target = List;
#[inline]
fn deref(&self) -> &List {
&self.guard
}
}
impl<M: MutexFamily> DerefMut for ListGuard<'_, M> {
#[inline]
fn deref_mut(&mut self) -> &mut List {
&mut self.guard
}
}
enum State {
Created,
Notified(bool),
Polling(Waker),
}
impl State {
#[inline]
fn is_notified(&self) -> bool {
match self {
State::Notified(_) => true,
State::Created | State::Polling(_) => false,
}
}
}
struct Entry {
state: Cell<State>,
prev: Cell<Option<NonNull<Entry>>>,
next: Cell<Option<NonNull<Entry>>>,
_pinned: PhantomPinned,
}
impl Entry {
#[inline]
const fn new() -> Self {
Self {
state: Cell::new(State::Created),
prev: Cell::new(None),
next: Cell::new(None),
_pinned: PhantomPinned,
}
}
}
struct List {
head: Option<NonNull<Entry>>,
tail: Option<NonNull<Entry>>,
start: Option<NonNull<Entry>>,
len: usize,
notified: usize,
}
unsafe impl Send for List {}
impl List {
#[inline]
const fn new() -> Self {
Self {
head: None,
tail: None,
start: None,
len: 0,
notified: 0,
}
}
fn insert(&mut self, entry: Pin<&UnsafeCell<Entry>>, waker: Option<Waker>) {
unsafe {
let state = waker.map(State::Polling).unwrap_or(State::Created);
let entry = &mut *entry.get_ref().get();
*entry = Entry {
state: Cell::new(state),
prev: Cell::new(self.tail),
next: Cell::new(None),
_pinned: PhantomPinned,
};
let entry = NonNull::new_unchecked(entry as *mut _);
match mem::replace(&mut self.tail, Some(entry)) {
None => self.head = Some(entry),
Some(t) => t.as_ref().next.set(Some(entry)),
}
if self.start.is_none() {
self.start = self.tail;
}
self.len += 1;
}
}
fn remove(&mut self, entry: NonNull<Entry>) -> State {
unsafe {
let prev = entry.as_ref().prev.get();
let next = entry.as_ref().next.get();
match prev {
None => self.head = next,
Some(p) => p.as_ref().next.set(next),
}
match next {
None => self.tail = prev,
Some(n) => n.as_ref().prev.set(prev),
}
if self.start == Some(entry) {
self.start = next;
}
let state = entry.as_ref().state.replace(State::Created);
if state.is_notified() {
self.notified -= 1;
}
self.len -= 1;
state
}
}
#[cold]
fn notify(&mut self, mut n: usize) {
if n <= self.notified {
return;
}
n -= self.notified;
while n > 0 {
n -= 1;
match self.start {
None => break,
Some(e) => {
let e = unsafe { e.as_ref() };
self.start = e.next.get();
match e.state.replace(State::Notified(false)) {
State::Notified(_) => {}
State::Created => {}
State::Polling(w) => w.wake(),
}
self.notified += 1;
}
}
}
}
#[cold]
fn notify_additional(&mut self, mut n: usize) {
while n > 0 {
n -= 1;
match self.start {
None => break,
Some(e) => {
let e = unsafe { e.as_ref() };
self.start = e.next.get();
match e.state.replace(State::Notified(true)) {
State::Notified(_) => {}
State::Created => {}
State::Polling(w) => w.wake(),
}
self.notified += 1;
}
}
}
}
}
#[cfg(test)]
mod tests {
use crate::sync::notify::Event;
fn is_send<T: Send>() {}
fn is_sync<T: Sync>() {}
#[test]
fn event_send_sync() {
is_send::<Event>();
is_sync::<Event>();
}
}