utils_atomics/
notify.rs

1use crate::{
2    locks::{lock, Lock},
3    FillQueue,
4};
5use alloc::sync::{Arc, Weak};
6
7/// Creates a new notifier and a listener to it.
8pub fn notify() -> (Notify, Listener) {
9    let inner = Arc::new(Inner {
10        wakers: FillQueue::new(),
11    });
12
13    let listener = Listener {
14        inner: Arc::downgrade(&inner),
15    };
16    return (Notify { inner }, listener);
17}
18
19#[derive(Debug)]
20struct Inner {
21    wakers: FillQueue<Lock>,
22}
23
24/// Synchronous notifier. This structure can be used not block threads until desired,
25/// at which point all waiting threads can be awaken with [`notify_all`](Notify::notify_all).
26///
27/// This structure drops loudly by default (a.k.a it will awake blocked threads when dropped),
28/// but can be droped silently via [`silent_drop`](Notify::silent_drop)
29#[derive(Debug, Clone)]
30pub struct Notify {
31    inner: Arc<Inner>,
32}
33
34#[derive(Debug, Clone)]
35pub struct Listener {
36    inner: Weak<Inner>,
37}
38
39impl Notify {
40    pub unsafe fn into_raw(self) -> *const () {
41        Arc::into_raw(self.inner).cast()
42    }
43
44    pub unsafe fn from_raw(ptr: *const ()) -> Self {
45        Self {
46            inner: Arc::from_raw(ptr.cast()),
47        }
48    }
49
50    #[inline]
51    pub fn listeners(&self) -> usize {
52        return Arc::weak_count(&self.inner);
53    }
54
55    #[inline]
56    pub fn notify_all(&self) {
57        self.inner.wakers.chop().for_each(Lock::wake)
58    }
59
60    #[inline]
61    pub fn listen(&self) -> Listener {
62        return Listener {
63            inner: Arc::downgrade(&self.inner),
64        };
65    }
66
67    /// Drops the notifier without awaking blocked threads.
68    /// This method may leak memory.
69    #[inline]
70    pub fn silent_drop(self) {
71        if let Ok(mut inner) = Arc::try_unwrap(self.inner) {
72            inner.wakers.chop_mut().for_each(Lock::silent_drop);
73        }
74    }
75}
76
77impl Listener {
78    pub unsafe fn into_raw(self) -> *const () {
79        Weak::into_raw(self.inner).cast()
80    }
81
82    pub unsafe fn from_raw(ptr: *const ()) -> Self {
83        Self {
84            inner: Weak::from_raw(ptr.cast()),
85        }
86    }
87
88    #[inline]
89    pub fn listeners(&self) -> usize {
90        return Weak::weak_count(&self.inner);
91    }
92
93    #[inline]
94    pub fn recv(&self) {
95        let _: bool = self.try_recv();
96    }
97
98    #[inline]
99    pub fn try_recv(&self) -> bool {
100        if let Some(inner) = self.inner.upgrade() {
101            let (lock, sub) = lock();
102            inner.wakers.push(lock);
103            sub.wait();
104            return true;
105        }
106        return false;
107    }
108}
109
110cfg_if::cfg_if! {
111    if #[cfg(feature = "futures")] {
112        use futures::{FutureExt, Stream};
113        use crate::flag::mpsc::{AsyncFlag, AsyncSubscribe, async_flag};
114        use core::task::Poll;
115        use futures::stream::FusedStream;
116
117        /// Creates a new async notifier and a listener to it.
118        pub fn async_notify() -> (AsyncNotify, AsyncListener) {
119            let inner = Arc::new(AsyncInner {
120                wakers: FillQueue::new(),
121            });
122
123            let listener = AsyncListener {
124                inner: Some(Arc::downgrade(&inner)),
125                sub: None
126            };
127
128            return (AsyncNotify { inner }, listener);
129        }
130
131        #[derive(Debug)]
132        struct AsyncInner {
133            wakers: FillQueue<AsyncFlag>,
134        }
135
136        /// Synchronous notifier. This structure can be used not block tasks until desired,
137        /// at which point all waiting tasks can be awaken with [`notify_all`](AsyncNotify::notify_all).
138        ///
139        /// This structure drops loudly by default (a.k.a it will awake blocked tasks when dropped),
140        /// but can be droped silently via [`silent_drop`](AsyncNotify::silent_drop)
141        #[derive(Debug, Clone)]
142        pub struct AsyncNotify {
143            inner: Arc<AsyncInner>,
144        }
145
146        #[derive(Debug)]
147        pub struct AsyncListener {
148            inner: Option<Weak<AsyncInner>>,
149            sub: Option<AsyncSubscribe>
150        }
151
152        impl AsyncNotify {
153            pub unsafe fn into_raw(self) -> *const () {
154                Arc::into_raw(self.inner).cast()
155            }
156
157            pub unsafe fn from_raw(ptr: *const ()) -> Self {
158                Self {
159                    inner: Arc::from_raw(ptr.cast()),
160                }
161            }
162
163            #[inline]
164            pub fn listeners(&self) -> usize {
165                return Arc::weak_count(&self.inner);
166            }
167
168            #[inline]
169            pub fn notify_all(&self) {
170                self.inner.wakers.chop().for_each(AsyncFlag::mark)
171            }
172
173            #[inline]
174            pub fn listen(&self) -> AsyncListener {
175                return AsyncListener {
176                    inner: Some(Arc::downgrade(&self.inner)),
177                    sub: None
178                };
179            }
180
181            /// Drops the notifier without awaking blocked tasks.
182            /// This method may leak memory.
183            #[inline]
184            pub fn silent_drop (self) {
185                if let Ok(mut inner) = Arc::try_unwrap(self.inner) {
186                    inner.wakers.chop_mut().for_each(AsyncFlag::silent_drop);
187                }
188            }
189        }
190
191        impl AsyncListener {
192            #[inline]
193            pub fn listeners(&self) -> usize {
194                return match self.inner {
195                    Some(ref inner) => Weak::weak_count(inner),
196                    None => 0
197                }
198            }
199        }
200
201        impl Stream for AsyncListener {
202            type Item = ();
203
204            fn poll_next(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Option<Self::Item>> {
205                if let Some(ref mut sub) = self.sub {
206                    return match sub.poll_unpin(cx) {
207                        Poll::Ready(_) => {
208                            self.sub = None;
209                            Poll::Ready(Some(()))
210                        },
211                        Poll::Pending => Poll::Pending
212                    }
213                } else if let Some(inner) = self.inner.as_ref().and_then(Weak::upgrade) {
214                    let (flag, sub) = async_flag();
215                    inner.wakers.push(flag);
216                    self.sub = Some(sub);
217                    return self.poll_next(cx)
218                }
219
220                self.inner = None;
221                return core::task::Poll::Ready(None)
222            }
223
224            #[inline]
225            fn size_hint(&self) -> (usize, Option<usize>) {
226                match (&self.inner, &self.sub) {
227                    (None, None) => (0, Some(0)),
228                    (Some(inner), None) if inner.upgrade().is_none() => (0, Some(0)),
229                    (None, Some(_)) => (1, Some(1)),
230                    (Some(inner), Some(_)) if inner.upgrade().is_none() => (1, Some(1)),
231                    (Some(_), Some(_)) => (1, None),
232                    _ => (0, None)
233                }
234            }
235        }
236
237        impl FusedStream for AsyncListener {
238            #[inline]
239            fn is_terminated(&self) -> bool {
240                match (&self.inner, &self.sub) {
241                    (_, Some(_)) => false,
242                    (None, None) => true,
243                    (Some(inner), None) => inner.upgrade().is_none(),
244                }
245            }
246        }
247
248        impl Clone for AsyncListener {
249            #[inline]
250            fn clone(&self) -> Self {
251                return Self {
252                    inner: self.inner.clone(),
253                    sub: None
254                }
255            }
256        }
257    }
258}
259
260// Thanks ChatGPT!
261#[cfg(all(feature = "std", test))]
262mod tests {
263    use super::notify;
264    use std::{
265        thread::{self},
266        time::Duration,
267    };
268
269    #[test]
270    fn test_basic_functionality() {
271        let (notify, listener) = notify();
272        assert_eq!(notify.listeners(), 1);
273
274        let listener2 = notify.listen();
275        assert_eq!(notify.listeners(), 2);
276
277        let handle = thread::spawn(move || {
278            listener2.recv();
279        });
280
281        thread::sleep(Duration::from_millis(100));
282        notify.notify_all();
283        handle.join().unwrap();
284
285        assert_eq!(notify.listeners(), 1);
286        drop(listener);
287    }
288
289    #[test]
290    fn test_multi_threaded() {
291        use std::sync::{Arc, Barrier};
292        use std::thread::JoinHandle;
293
294        let (notify, listener) = notify();
295        let barrier = Arc::new(Barrier::new(11));
296        let mut handles = vec![];
297
298        for _ in 0..10 {
299            let barrier_clone = Arc::clone(&barrier);
300            let listener_clone = listener.clone();
301            handles.push(thread::spawn(move || {
302                barrier_clone.wait();
303                listener_clone.recv();
304            }));
305        }
306
307        barrier.wait();
308        thread::sleep(Duration::from_millis(100));
309        notify.notify_all();
310
311        handles
312            .into_iter()
313            .map(JoinHandle::join)
314            .for_each(Result::unwrap);
315
316        assert_eq!(listener.listeners(), 1);
317    }
318}
319
320#[cfg(all(feature = "futures", test))]
321mod async_tests {
322    use crate::notify::async_notify;
323    use core::time::Duration;
324    use futures::stream::StreamExt;
325
326    #[tokio::test]
327    async fn test_basic_functionality_async_tokio() {
328        let (notify, listener) = async_notify();
329        assert_eq!(notify.listeners(), 1);
330
331        let mut listener2 = notify.listen();
332        let handle = tokio::spawn(async move {
333            assert_eq!(listener2.next().await, Some(()));
334        });
335
336        tokio::time::sleep(Duration::from_millis(100)).await;
337        notify.notify_all();
338
339        drop(listener);
340        handle.await.unwrap();
341        assert_eq!(notify.listeners(), 0);
342    }
343
344    #[tokio::test]
345    async fn test_multi_task_async_tokio() {
346        let (notify, listener) = async_notify();
347        let mut handles = vec![];
348
349        for _ in 0..10 {
350            let mut listener_clone = listener.clone();
351            let handle = tokio::spawn(async move {
352                assert_eq!(listener_clone.next().await, Some(()));
353            });
354
355            handles.push(handle);
356        }
357
358        drop(listener);
359        tokio::time::sleep(Duration::from_millis(100)).await;
360        notify.notify_all();
361
362        let _ = futures::future::try_join_all(handles).await.unwrap();
363        assert_eq!(notify.listeners(), 0);
364    }
365}