use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread::Thread;
#[cfg(feature = "async")]
use tokio::sync::Notify;
pub trait EventLoopWaker: Send + Sync {
fn wake(&self);
fn is_valid(&self) -> bool;
fn clone_box(&self) -> Box<dyn EventLoopWaker>;
}
impl Clone for Box<dyn EventLoopWaker> {
fn clone(&self) -> Self {
self.clone_box()
}
}
#[derive(Debug, Clone)]
pub struct ThreadWaker {
thread: Thread,
valid: Arc<AtomicBool>,
}
impl ThreadWaker {
pub fn current() -> Self {
Self {
thread: std::thread::current(),
valid: Arc::new(AtomicBool::new(true)),
}
}
pub fn new(thread: Thread) -> Self {
Self {
thread,
valid: Arc::new(AtomicBool::new(true)),
}
}
pub fn invalidate(&self) {
self.valid.store(false, Ordering::SeqCst);
}
}
impl EventLoopWaker for ThreadWaker {
fn wake(&self) {
if self.is_valid() {
self.thread.unpark();
}
}
fn is_valid(&self) -> bool {
self.valid.load(Ordering::SeqCst)
}
fn clone_box(&self) -> Box<dyn EventLoopWaker> {
Box::new(self.clone())
}
}
pub struct CallbackWaker<F>
where
F: Fn() + Send + Sync + Clone + 'static,
{
callback: F,
valid: Arc<AtomicBool>,
}
impl<F> CallbackWaker<F>
where
F: Fn() + Send + Sync + Clone + 'static,
{
pub fn new(callback: F) -> Self {
Self {
callback,
valid: Arc::new(AtomicBool::new(true)),
}
}
pub fn invalidate(&self) {
self.valid.store(false, Ordering::SeqCst);
}
}
impl<F> Clone for CallbackWaker<F>
where
F: Fn() + Send + Sync + Clone + 'static,
{
fn clone(&self) -> Self {
Self {
callback: self.callback.clone(),
valid: Arc::clone(&self.valid),
}
}
}
impl<F> EventLoopWaker for CallbackWaker<F>
where
F: Fn() + Send + Sync + Clone + 'static,
{
fn wake(&self) {
if self.is_valid() {
(self.callback)();
}
}
fn is_valid(&self) -> bool {
self.valid.load(Ordering::SeqCst)
}
fn clone_box(&self) -> Box<dyn EventLoopWaker> {
Box::new(self.clone())
}
}
#[cfg(feature = "async")]
#[derive(Debug, Clone)]
pub struct TokioWaker {
notify: Arc<Notify>,
valid: Arc<AtomicBool>,
}
#[cfg(feature = "async")]
impl TokioWaker {
pub fn new() -> Self {
Self {
notify: Arc::new(Notify::new()),
valid: Arc::new(AtomicBool::new(true)),
}
}
pub fn notified(&self) -> tokio::sync::futures::Notified<'_> {
self.notify.notified()
}
pub fn invalidate(&self) {
self.valid.store(false, Ordering::SeqCst);
}
}
#[cfg(feature = "async")]
impl Default for TokioWaker {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "async")]
impl EventLoopWaker for TokioWaker {
fn wake(&self) {
if self.is_valid() {
self.notify.notify_one();
}
}
fn is_valid(&self) -> bool {
self.valid.load(Ordering::SeqCst)
}
fn clone_box(&self) -> Box<dyn EventLoopWaker> {
Box::new(self.clone())
}
}
#[derive(Clone, Default)]
pub struct BroadcastWaker {
wakers: Vec<Box<dyn EventLoopWaker>>,
}
impl BroadcastWaker {
pub fn new() -> Self {
Self { wakers: Vec::new() }
}
pub fn add(&mut self, waker: Box<dyn EventLoopWaker>) {
self.wakers.push(waker);
}
pub fn cleanup(&mut self) {
self.wakers.retain(|w| w.is_valid());
}
pub fn len(&self) -> usize {
self.wakers.len()
}
pub fn is_empty(&self) -> bool {
self.wakers.is_empty()
}
}
impl EventLoopWaker for BroadcastWaker {
fn wake(&self) {
for waker in &self.wakers {
if waker.is_valid() {
waker.wake();
}
}
}
fn is_valid(&self) -> bool {
self.wakers.iter().any(|w| w.is_valid())
}
fn clone_box(&self) -> Box<dyn EventLoopWaker> {
Box::new(self.clone())
}
}
pub trait WakeableChannel {
fn set_waker(&mut self, waker: Box<dyn EventLoopWaker>);
fn clear_waker(&mut self);
fn waker(&self) -> Option<&dyn EventLoopWaker>;
}
pub struct WakeableWrapper<C> {
inner: C,
waker: Option<Box<dyn EventLoopWaker>>,
}
impl<C> WakeableWrapper<C> {
pub fn new(channel: C) -> Self {
Self {
inner: channel,
waker: None,
}
}
pub fn inner(&self) -> &C {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut C {
&mut self.inner
}
pub fn into_inner(self) -> C {
self.inner
}
pub fn wake(&self) {
if let Some(ref waker) = self.waker {
if waker.is_valid() {
waker.wake();
}
}
}
}
impl<C> WakeableChannel for WakeableWrapper<C> {
fn set_waker(&mut self, waker: Box<dyn EventLoopWaker>) {
self.waker = Some(waker);
}
fn clear_waker(&mut self) {
self.waker = None;
}
fn waker(&self) -> Option<&dyn EventLoopWaker> {
self.waker.as_deref()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicUsize;
use std::time::Duration;
#[test]
fn test_thread_waker() {
let waker = ThreadWaker::current();
assert!(waker.is_valid());
waker.wake();
waker.invalidate();
assert!(!waker.is_valid());
}
#[test]
fn test_callback_waker() {
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let waker = CallbackWaker::new(move || {
counter_clone.fetch_add(1, Ordering::SeqCst);
});
assert!(waker.is_valid());
waker.wake();
assert_eq!(counter.load(Ordering::SeqCst), 1);
waker.wake();
assert_eq!(counter.load(Ordering::SeqCst), 2);
waker.invalidate();
waker.wake();
assert_eq!(counter.load(Ordering::SeqCst), 2); }
#[test]
fn test_broadcast_waker() {
let counter1 = Arc::new(AtomicUsize::new(0));
let counter2 = Arc::new(AtomicUsize::new(0));
let c1 = Arc::clone(&counter1);
let c2 = Arc::clone(&counter2);
let mut broadcast = BroadcastWaker::new();
broadcast.add(Box::new(CallbackWaker::new(move || {
c1.fetch_add(1, Ordering::SeqCst);
})));
broadcast.add(Box::new(CallbackWaker::new(move || {
c2.fetch_add(1, Ordering::SeqCst);
})));
assert_eq!(broadcast.len(), 2);
assert!(broadcast.is_valid());
broadcast.wake();
assert_eq!(counter1.load(Ordering::SeqCst), 1);
assert_eq!(counter2.load(Ordering::SeqCst), 1);
}
#[cfg(feature = "async")]
#[tokio::test]
async fn test_tokio_waker() {
let waker = TokioWaker::new();
assert!(waker.is_valid());
let waker_clone = waker.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
waker_clone.wake();
});
tokio::time::timeout(Duration::from_millis(100), waker.notified())
.await
.expect("Should be notified");
}
#[test]
fn test_wakeable_wrapper() {
struct DummyChannel;
let mut wrapper = WakeableWrapper::new(DummyChannel);
assert!(wrapper.waker().is_none());
let counter = Arc::new(AtomicUsize::new(0));
let c = Arc::clone(&counter);
wrapper.set_waker(Box::new(CallbackWaker::new(move || {
c.fetch_add(1, Ordering::SeqCst);
})));
assert!(wrapper.waker().is_some());
wrapper.wake();
assert_eq!(counter.load(Ordering::SeqCst), 1);
wrapper.clear_waker();
assert!(wrapper.waker().is_none());
}
}