1use std::sync::atomic::{AtomicBool, Ordering};
22use std::sync::Arc;
23use std::thread::Thread;
24
25#[cfg(feature = "async")]
26use tokio::sync::Notify;
27
28pub trait EventLoopWaker: Send + Sync {
33 fn wake(&self);
38
39 fn is_valid(&self) -> bool;
44
45 fn clone_box(&self) -> Box<dyn EventLoopWaker>;
47}
48
49impl Clone for Box<dyn EventLoopWaker> {
50 fn clone(&self) -> Self {
51 self.clone_box()
52 }
53}
54
55#[derive(Debug, Clone)]
60pub struct ThreadWaker {
61 thread: Thread,
62 valid: Arc<AtomicBool>,
63}
64
65impl ThreadWaker {
66 pub fn current() -> Self {
68 Self {
69 thread: std::thread::current(),
70 valid: Arc::new(AtomicBool::new(true)),
71 }
72 }
73
74 pub fn new(thread: Thread) -> Self {
76 Self {
77 thread,
78 valid: Arc::new(AtomicBool::new(true)),
79 }
80 }
81
82 pub fn invalidate(&self) {
86 self.valid.store(false, Ordering::SeqCst);
87 }
88}
89
90impl EventLoopWaker for ThreadWaker {
91 fn wake(&self) {
92 if self.is_valid() {
93 self.thread.unpark();
94 }
95 }
96
97 fn is_valid(&self) -> bool {
98 self.valid.load(Ordering::SeqCst)
99 }
100
101 fn clone_box(&self) -> Box<dyn EventLoopWaker> {
102 Box::new(self.clone())
103 }
104}
105
106pub struct CallbackWaker<F>
110where
111 F: Fn() + Send + Sync + Clone + 'static,
112{
113 callback: F,
114 valid: Arc<AtomicBool>,
115}
116
117impl<F> CallbackWaker<F>
118where
119 F: Fn() + Send + Sync + Clone + 'static,
120{
121 pub fn new(callback: F) -> Self {
123 Self {
124 callback,
125 valid: Arc::new(AtomicBool::new(true)),
126 }
127 }
128
129 pub fn invalidate(&self) {
131 self.valid.store(false, Ordering::SeqCst);
132 }
133}
134
135impl<F> Clone for CallbackWaker<F>
136where
137 F: Fn() + Send + Sync + Clone + 'static,
138{
139 fn clone(&self) -> Self {
140 Self {
141 callback: self.callback.clone(),
142 valid: Arc::clone(&self.valid),
143 }
144 }
145}
146
147impl<F> EventLoopWaker for CallbackWaker<F>
148where
149 F: Fn() + Send + Sync + Clone + 'static,
150{
151 fn wake(&self) {
152 if self.is_valid() {
153 (self.callback)();
154 }
155 }
156
157 fn is_valid(&self) -> bool {
158 self.valid.load(Ordering::SeqCst)
159 }
160
161 fn clone_box(&self) -> Box<dyn EventLoopWaker> {
162 Box::new(self.clone())
163 }
164}
165
166#[cfg(feature = "async")]
170#[derive(Debug, Clone)]
171pub struct TokioWaker {
172 notify: Arc<Notify>,
173 valid: Arc<AtomicBool>,
174}
175
176#[cfg(feature = "async")]
177impl TokioWaker {
178 pub fn new() -> Self {
180 Self {
181 notify: Arc::new(Notify::new()),
182 valid: Arc::new(AtomicBool::new(true)),
183 }
184 }
185
186 pub fn notified(&self) -> tokio::sync::futures::Notified<'_> {
193 self.notify.notified()
194 }
195
196 pub fn invalidate(&self) {
198 self.valid.store(false, Ordering::SeqCst);
199 }
200}
201
202#[cfg(feature = "async")]
203impl Default for TokioWaker {
204 fn default() -> Self {
205 Self::new()
206 }
207}
208
209#[cfg(feature = "async")]
210impl EventLoopWaker for TokioWaker {
211 fn wake(&self) {
212 if self.is_valid() {
213 self.notify.notify_one();
214 }
215 }
216
217 fn is_valid(&self) -> bool {
218 self.valid.load(Ordering::SeqCst)
219 }
220
221 fn clone_box(&self) -> Box<dyn EventLoopWaker> {
222 Box::new(self.clone())
223 }
224}
225
226#[derive(Clone, Default)]
230pub struct BroadcastWaker {
231 wakers: Vec<Box<dyn EventLoopWaker>>,
232}
233
234impl BroadcastWaker {
235 pub fn new() -> Self {
237 Self { wakers: Vec::new() }
238 }
239
240 pub fn add(&mut self, waker: Box<dyn EventLoopWaker>) {
242 self.wakers.push(waker);
243 }
244
245 pub fn cleanup(&mut self) {
247 self.wakers.retain(|w| w.is_valid());
248 }
249
250 pub fn len(&self) -> usize {
252 self.wakers.len()
253 }
254
255 pub fn is_empty(&self) -> bool {
257 self.wakers.is_empty()
258 }
259}
260
261impl EventLoopWaker for BroadcastWaker {
262 fn wake(&self) {
263 for waker in &self.wakers {
264 if waker.is_valid() {
265 waker.wake();
266 }
267 }
268 }
269
270 fn is_valid(&self) -> bool {
271 self.wakers.iter().any(|w| w.is_valid())
272 }
273
274 fn clone_box(&self) -> Box<dyn EventLoopWaker> {
275 Box::new(self.clone())
276 }
277}
278
279pub trait WakeableChannel {
281 fn set_waker(&mut self, waker: Box<dyn EventLoopWaker>);
286
287 fn clear_waker(&mut self);
289
290 fn waker(&self) -> Option<&dyn EventLoopWaker>;
292}
293
294pub struct WakeableWrapper<C> {
296 inner: C,
297 waker: Option<Box<dyn EventLoopWaker>>,
298}
299
300impl<C> WakeableWrapper<C> {
301 pub fn new(channel: C) -> Self {
303 Self {
304 inner: channel,
305 waker: None,
306 }
307 }
308
309 pub fn inner(&self) -> &C {
311 &self.inner
312 }
313
314 pub fn inner_mut(&mut self) -> &mut C {
316 &mut self.inner
317 }
318
319 pub fn into_inner(self) -> C {
321 self.inner
322 }
323
324 pub fn wake(&self) {
326 if let Some(ref waker) = self.waker {
327 if waker.is_valid() {
328 waker.wake();
329 }
330 }
331 }
332}
333
334impl<C> WakeableChannel for WakeableWrapper<C> {
335 fn set_waker(&mut self, waker: Box<dyn EventLoopWaker>) {
336 self.waker = Some(waker);
337 }
338
339 fn clear_waker(&mut self) {
340 self.waker = None;
341 }
342
343 fn waker(&self) -> Option<&dyn EventLoopWaker> {
344 self.waker.as_deref()
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use std::sync::atomic::AtomicUsize;
352 use std::time::Duration;
353
354 #[test]
355 fn test_thread_waker() {
356 let waker = ThreadWaker::current();
357 assert!(waker.is_valid());
358
359 waker.wake();
360 waker.invalidate();
363 assert!(!waker.is_valid());
364 }
365
366 #[test]
367 fn test_callback_waker() {
368 let counter = Arc::new(AtomicUsize::new(0));
369 let counter_clone = Arc::clone(&counter);
370
371 let waker = CallbackWaker::new(move || {
372 counter_clone.fetch_add(1, Ordering::SeqCst);
373 });
374
375 assert!(waker.is_valid());
376 waker.wake();
377 assert_eq!(counter.load(Ordering::SeqCst), 1);
378
379 waker.wake();
380 assert_eq!(counter.load(Ordering::SeqCst), 2);
381
382 waker.invalidate();
383 waker.wake();
384 assert_eq!(counter.load(Ordering::SeqCst), 2); }
386
387 #[test]
388 fn test_broadcast_waker() {
389 let counter1 = Arc::new(AtomicUsize::new(0));
390 let counter2 = Arc::new(AtomicUsize::new(0));
391
392 let c1 = Arc::clone(&counter1);
393 let c2 = Arc::clone(&counter2);
394
395 let mut broadcast = BroadcastWaker::new();
396 broadcast.add(Box::new(CallbackWaker::new(move || {
397 c1.fetch_add(1, Ordering::SeqCst);
398 })));
399 broadcast.add(Box::new(CallbackWaker::new(move || {
400 c2.fetch_add(1, Ordering::SeqCst);
401 })));
402
403 assert_eq!(broadcast.len(), 2);
404 assert!(broadcast.is_valid());
405
406 broadcast.wake();
407 assert_eq!(counter1.load(Ordering::SeqCst), 1);
408 assert_eq!(counter2.load(Ordering::SeqCst), 1);
409 }
410
411 #[cfg(feature = "async")]
412 #[tokio::test]
413 async fn test_tokio_waker() {
414 let waker = TokioWaker::new();
415 assert!(waker.is_valid());
416
417 let waker_clone = waker.clone();
418 tokio::spawn(async move {
419 tokio::time::sleep(Duration::from_millis(10)).await;
420 waker_clone.wake();
421 });
422
423 tokio::time::timeout(Duration::from_millis(100), waker.notified())
424 .await
425 .expect("Should be notified");
426 }
427
428 #[test]
429 fn test_wakeable_wrapper() {
430 struct DummyChannel;
431
432 let mut wrapper = WakeableWrapper::new(DummyChannel);
433 assert!(wrapper.waker().is_none());
434
435 let counter = Arc::new(AtomicUsize::new(0));
436 let c = Arc::clone(&counter);
437 wrapper.set_waker(Box::new(CallbackWaker::new(move || {
438 c.fetch_add(1, Ordering::SeqCst);
439 })));
440
441 assert!(wrapper.waker().is_some());
442 wrapper.wake();
443 assert_eq!(counter.load(Ordering::SeqCst), 1);
444
445 wrapper.clear_waker();
446 assert!(wrapper.waker().is_none());
447 }
448}