use std::cell::Cell;
use std::future::{Future, IntoFuture};
use std::mem;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
use futures_lite::{future, Stream};
use slab::Slab;
use crate::sync::{MutexGuard, ThreadSafety, __private::*};
pub struct Handler<T: Event, TS: ThreadSafety> {
state: TS::OnceLock<Box<TS::Mutex<State<T>>>>,
}
struct State<T: Event> {
listeners: Slab<Listener>,
directs: Vec<DirectListener<T>>,
head_and_tail: Option<(usize, usize)>,
waker: Option<Waker>,
instance: Option<T::Clonable>,
}
type DirectListener<T> =
Box<dyn FnMut(&mut <T as Event>::Unique<'_>) -> DirectFuture + Send + 'static>;
type DirectFuture = Pin<Box<dyn Future<Output = bool> + Send + 'static>>;
impl<T: Event, TS: ThreadSafety> Handler<T, TS> {
pub(crate) fn new() -> Self {
Self {
state: TS::OnceLock::new(),
}
}
pub(crate) async fn run_with(&self, event: &mut T::Unique<'_>) {
let state = match self.state.get() {
Some(state) => state,
None => return,
};
let mut state_lock = Some(state.lock().unwrap());
if self.run_direct_listeners(&mut state_lock, event).await {
return;
}
{
let state = state_lock.get_or_insert_with(|| state.lock().unwrap());
let head = match state.head_and_tail {
Some((head, _)) => head,
None => return,
};
state.instance = Some(T::downgrade(event));
if let Some(waker) = state.notify(head) {
waker.wake();
}
}
future::poll_fn(|cx| {
let mut state = state_lock.take().unwrap_or_else(|| state.lock().unwrap());
if state.head_and_tail.is_none() {
return Poll::Ready(());
}
if state.instance.is_none() {
return Poll::Ready(());
}
if let Some(waker) = &state.waker {
if waker.will_wake(cx.waker()) {
return Poll::Pending;
}
}
state.waker = Some(cx.waker().clone());
Poll::Pending
})
.await
}
async fn run_direct_listeners(
&self,
state: &mut Option<MutexGuard<'_, State<T>, TS>>,
event: &mut T::Unique<'_>,
) -> bool {
struct RestoreDirects<'a, T: Event, TS: ThreadSafety> {
state: &'a Handler<T, TS>,
directs: Vec<DirectListener<T>>,
}
impl<T: Event, TS: ThreadSafety> Drop for RestoreDirects<'_, T, TS> {
fn drop(&mut self) {
let mut directs = mem::take(&mut self.directs);
self.state
.state()
.lock()
.unwrap()
.directs
.append(&mut directs);
}
}
let state_ref = state.as_mut().unwrap();
if state_ref.directs.is_empty() {
return false;
}
let mut directs = RestoreDirects {
directs: mem::take(&mut state_ref.directs),
state: self,
};
*state = None;
for direct in &mut directs.directs {
if direct(event).await {
return true;
}
}
false
}
pub fn wait(&self) -> Waiter<'_, T, TS> {
Waiter::new(self)
}
pub fn wait_direct_async<
Fut: Future<Output = bool> + Send + 'static,
F: FnMut(&mut T::Unique<'_>) -> Fut + Send + 'static,
>(
&self,
mut f: F,
) {
let mut state = self.state().lock().unwrap();
state.directs.push(Box::new(move |u| Box::pin(f(u))))
}
pub fn wait_direct(&self, mut f: impl FnMut(&mut T::Unique<'_>) -> bool + Send + 'static) {
self.wait_direct_async(move |u| std::future::ready(f(u)))
}
fn state(&self) -> &TS::Mutex<State<T>> {
self.state
.get_or_init(|| Box::new(TS::Mutex::new(State::new())))
}
}
impl<T: Event, TS: ThreadSafety> Unpin for Handler<T, TS> {}
impl<'a, T: Event, TS: ThreadSafety> IntoFuture for &'a Handler<T, TS> {
type IntoFuture = Waiter<'a, T, TS>;
type Output = T::Clonable;
fn into_future(self) -> Self::IntoFuture {
self.wait()
}
}
pub struct Waiter<'a, T: Event, TS: ThreadSafety> {
handler: &'a Handler<T, TS>,
index: usize,
}
impl<T: Event, TS: ThreadSafety> Unpin for Waiter<'_, T, TS> {}
impl<'a, T: Event, TS: ThreadSafety> Waiter<'a, T, TS> {
pub(crate) fn new(handler: &'a Handler<T, TS>) -> Self {
let state = handler.state();
let index = state.lock().unwrap().insert();
Self { handler, index }
}
fn notify_next(&mut self, mut state: MutexGuard<'_, State<T>, TS>) {
if let Some(next) = state.listeners[self.index].next.get() {
if let Some(waker) = state.notify(next) {
waker.wake();
}
} else {
state.instance = None;
if let Some(waker) = state.waker.take() {
waker.wake();
}
}
}
pub async fn hold(&mut self) -> HoldGuard<'_, 'a, T, TS> {
let event = future::poll_fn(|cx| {
let mut state = self.handler.state().lock().unwrap();
if state.take_notification(self.index) {
let event = match state.instance.clone() {
Some(event) => event,
None => return Poll::Pending,
};
return Poll::Ready(event);
}
state.register_waker(self.index, cx.waker());
Poll::Pending
})
.await;
HoldGuard {
waiter: self,
event: Some(event),
}
}
}
impl<T: Event, TS: ThreadSafety> Future for Waiter<'_, T, TS> {
type Output = T::Clonable;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.poll_next(cx) {
Poll::Ready(Some(event)) => Poll::Ready(event),
Poll::Ready(None) => panic!("event handler was dropped"),
Poll::Pending => Poll::Pending,
}
}
}
impl<T: Event, TS: ThreadSafety> Stream for Waiter<'_, T, TS> {
type Item = T::Clonable;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut state = self.handler.state.get().unwrap().lock().unwrap();
if state.take_notification(self.index) {
let event = match state.instance.clone() {
Some(event) => event,
None => return Poll::Pending,
};
self.notify_next(state);
return Poll::Ready(Some(event));
}
state.register_waker(self.index, cx.waker());
Poll::Pending
}
fn size_hint(&self) -> (usize, Option<usize>) {
(usize::MAX, None)
}
}
impl<'a, T: Event, TS: ThreadSafety> Drop for Waiter<'a, T, TS> {
fn drop(&mut self) {
let mut state = self.handler.state().lock().unwrap();
let listener = state.remove(self.index);
if listener.notified.get() {
self.notify_next(state);
}
}
}
pub struct HoldGuard<'waiter, 'handler, T: Event, TS: ThreadSafety> {
waiter: &'waiter mut Waiter<'handler, T, TS>,
event: Option<T::Clonable>,
}
impl<T: Event, TS: ThreadSafety> Deref for HoldGuard<'_, '_, T, TS> {
type Target = T::Clonable;
fn deref(&self) -> &Self::Target {
self.event.as_ref().unwrap()
}
}
impl<T: Event, TS: ThreadSafety> DerefMut for HoldGuard<'_, '_, T, TS> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.event.as_mut().unwrap()
}
}
impl<T: Event, TS: ThreadSafety> HoldGuard<'_, '_, T, TS> {
pub fn into_inner(mut self) -> T::Clonable {
self.event.take().unwrap()
}
}
impl<T: Event, TS: ThreadSafety> Drop for HoldGuard<'_, '_, T, TS> {
fn drop(&mut self) {
self.waiter
.notify_next(self.waiter.handler.state().lock().unwrap());
}
}
impl<T: Event> State<T> {
fn new() -> Self {
Self {
listeners: Slab::new(),
directs: Vec::new(),
head_and_tail: None,
waker: None,
instance: None,
}
}
fn insert(&mut self) -> usize {
let listener = Listener {
next: Cell::new(None),
prev: Cell::new(self.head_and_tail.map(|(_, tail)| tail)),
waker: Cell::new(None),
notified: Cell::new(false),
};
let index = self.listeners.insert(listener);
match &mut self.head_and_tail {
Some((_head, tail)) => {
self.listeners[*tail].next.set(Some(index));
*tail = index;
}
None => {
self.head_and_tail = Some((index, index));
}
}
index
}
fn remove(&mut self, index: usize) -> Listener {
let listener = self.listeners.remove(index);
match &mut self.head_and_tail {
Some((head, tail)) => {
if *head == index && *tail == index {
self.head_and_tail = None;
} else if *head == index {
self.head_and_tail = Some((listener.next.get().unwrap(), *tail));
} else if *tail == index {
self.head_and_tail = Some((*head, listener.prev.get().unwrap()));
}
}
None => panic!("invalid listener list: head and tail are both None"),
}
if let Some(next) = listener.next.get() {
self.listeners[next].prev.set(listener.prev.get());
}
if let Some(prev) = listener.prev.get() {
self.listeners[prev].next.set(listener.next.get());
}
listener
}
fn take_notification(&mut self, index: usize) -> bool {
self.listeners[index].notified.replace(false)
}
fn register_waker(&mut self, index: usize, waker: &Waker) {
let listener = &mut self.listeners[index];
let current_waker = listener.waker.take();
match current_waker {
Some(current_waker) if current_waker.will_wake(waker) => {
listener.waker.replace(Some(current_waker));
}
_ => {
listener.waker.replace(Some(waker.clone()));
}
}
}
fn notify(&mut self, index: usize) -> Option<Waker> {
if self.listeners[index].notified.replace(true) {
return None;
}
self.listeners[index].waker.replace(None)
}
}
struct Listener {
next: Cell<Option<usize>>,
prev: Cell<Option<usize>>,
waker: Cell<Option<Waker>>,
notified: Cell<bool>,
}
pub trait Event {
type Clonable: Clone + 'static;
type Unique<'a>: 'a;
fn downgrade(unique: &mut Self::Unique<'_>) -> Self::Clonable;
}
impl<T: Clone + 'static> Event for T {
type Clonable = T;
type Unique<'a> = T;
fn downgrade(unique: &mut Self::Unique<'_>) -> Self::Clonable {
unique.clone()
}
}
struct CallOnDrop<F: FnMut()>(F);
impl<F: FnMut()> Drop for CallOnDrop<F> {
fn drop(&mut self) {
(self.0)();
}
}