#![deny(unsafe_code)]
#![deny(rust_2018_idioms)]
use std::collections::HashMap;
use std::mem;
use std::pin::Pin;
use std::sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc, Condvar, Mutex,
};
use std::task::{Context, Poll, Waker};
use std::time::Duration;
pub fn trigger() -> (Trigger, Listener) {
let inner = Arc::new(Inner {
complete: AtomicBool::new(false),
tasks: Mutex::new(HashMap::new()),
condvar: Condvar::new(),
next_listener_id: AtomicUsize::new(1),
});
let trigger = Trigger {
inner: inner.clone(),
};
let listener = Listener { inner, id: 0 };
(trigger, listener)
}
#[derive(Clone, Debug)]
pub struct Trigger {
inner: Arc<Inner>,
}
#[derive(Debug)]
pub struct Listener {
inner: Arc<Inner>,
id: usize,
}
impl Clone for Listener {
fn clone(&self) -> Self {
Listener {
inner: self.inner.clone(),
id: self.inner.next_listener_id.fetch_add(1, Ordering::SeqCst),
}
}
}
impl Drop for Listener {
fn drop(&mut self) {
self.inner
.tasks
.lock()
.expect("Some Trigger/Listener has panicked")
.remove(&self.id);
}
}
#[derive(Debug)]
struct Inner {
complete: AtomicBool,
tasks: Mutex<HashMap<usize, Waker>>,
condvar: Condvar,
next_listener_id: AtomicUsize,
}
impl Unpin for Trigger {}
impl Unpin for Listener {}
impl Trigger {
pub fn trigger(&self) {
if self.inner.complete.swap(true, Ordering::SeqCst) {
return;
}
let mut tasks_guard = self
.inner
.tasks
.lock()
.expect("Some Trigger/Listener has panicked");
let tasks = mem::take(&mut *tasks_guard);
mem::drop(tasks_guard);
for (_listener_id, task) in tasks {
task.wake();
}
self.inner.condvar.notify_all();
}
pub fn is_triggered(&self) -> bool {
self.inner.complete.load(Ordering::SeqCst)
}
}
impl std::future::Future for Listener {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.inner.complete.load(Ordering::SeqCst) {
return Poll::Ready(());
}
let mut task_guard = self
.inner
.tasks
.lock()
.expect("Some Trigger/Listener has panicked");
if self.inner.complete.load(Ordering::SeqCst) {
Poll::Ready(())
} else {
task_guard.insert(self.id, cx.waker().clone());
Poll::Pending
}
}
}
impl Listener {
pub fn wait(&self) {
if self.inner.complete.load(Ordering::SeqCst) {
return;
}
let task_guard = self
.inner
.tasks
.lock()
.expect("Some Trigger/Listener has panicked");
let _guard = self
.inner
.condvar
.wait_while(task_guard, |_| !self.inner.complete.load(Ordering::SeqCst))
.expect("Some Trigger/Listener has panicked");
}
pub fn wait_timeout(&self, duration: Duration) -> bool {
if self.inner.complete.load(Ordering::SeqCst) {
return true;
}
let task_guard = self
.inner
.tasks
.lock()
.expect("Some Trigger/Listener has panicked");
let _ = self
.inner
.condvar
.wait_timeout_while(task_guard, duration, |_| {
!self.inner.complete.load(Ordering::SeqCst)
})
.expect("Some Trigger/Listener has panicked");
self.inner.complete.load(Ordering::SeqCst)
}
pub fn is_triggered(&self) -> bool {
self.inner.complete.load(Ordering::SeqCst)
}
}
#[allow(unsafe_code)]
#[cfg(test)]
mod tests {
use super::*;
use std::future::Future;
use std::sync::atomic::AtomicU8;
use std::task::{RawWaker, RawWakerVTable};
#[test]
fn polling_listener_keeps_only_last_waker() {
let (_trigger, mut listener) = trigger();
let (waker1, waker_handle1) = create_waker();
{
let mut context = Context::from_waker(&waker1);
let listener = Pin::new(&mut listener);
assert_eq!(listener.poll(&mut context), Poll::Pending);
}
assert!(waker_handle1.data.load(Ordering::SeqCst) & CLONED != 0);
assert!(waker_handle1.data.load(Ordering::SeqCst) & DROPPED == 0);
let (waker2, waker_handle2) = create_waker();
{
let mut context = Context::from_waker(&waker2);
let listener = Pin::new(&mut listener);
assert_eq!(listener.poll(&mut context), Poll::Pending);
}
assert!(waker_handle2.data.load(Ordering::SeqCst) & CLONED != 0);
assert!(waker_handle2.data.load(Ordering::SeqCst) & DROPPED == 0);
assert!(waker_handle1.data.load(Ordering::SeqCst) & DROPPED != 0);
}
const CLONED: u8 = 0b0001;
const WOKE: u8 = 0b0010;
const DROPPED: u8 = 0b0100;
fn create_waker() -> (Waker, Arc<WakerHandle>) {
let waker_handle = Arc::new(WakerHandle {
data: AtomicU8::new(0),
});
let data = Arc::into_raw(waker_handle.clone()) as *const _;
let raw_waker = RawWaker::new(data, &VTABLE);
(unsafe { Waker::from_raw(raw_waker) }, waker_handle)
}
struct WakerHandle {
data: AtomicU8,
}
impl Drop for WakerHandle {
fn drop(&mut self) {
println!("WakerHandle dropped");
}
}
const VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);
unsafe fn clone(data: *const ()) -> RawWaker {
let waker_handle = &*(data as *const WakerHandle);
waker_handle.data.fetch_or(CLONED, Ordering::SeqCst);
Arc::increment_strong_count(waker_handle);
RawWaker::new(data, &VTABLE)
}
unsafe fn wake(data: *const ()) {
let waker_handle = &*(data as *const WakerHandle);
waker_handle.data.fetch_or(WOKE, Ordering::SeqCst);
}
unsafe fn wake_by_ref(_data: *const ()) {
todo!();
}
unsafe fn drop(data: *const ()) {
let waker_handle = &*(data as *const WakerHandle);
waker_handle.data.fetch_or(DROPPED, Ordering::SeqCst);
Arc::decrement_strong_count(waker_handle);
}
}