os_trait/
notifier_impls.rs

1use crate::{Duration, Timeout, notifier::*, prelude::*};
2use core::marker::PhantomData;
3use portable_atomic::{AtomicBool, Ordering};
4
5#[derive(Clone)]
6pub struct FakeNotifier;
7
8impl FakeNotifier {
9    pub fn new() -> (Self, Self) {
10        (Self {}, Self {})
11    }
12}
13
14impl NotifierInterface for FakeNotifier {
15    fn notify(&self) -> bool {
16        true
17    }
18}
19
20impl<OS: OsInterface> NotifyWaiterInterface<OS> for FakeNotifier {
21    fn wait(&self, _timeout: &Duration<OS>) -> bool {
22        true
23    }
24}
25
26// ------------------------------------------------------------------
27
28/// This [`NotifierInterface`] implementation is for unit test
29pub struct AtomicNotifier<OS> {
30    flag: Arc<AtomicBool>,
31    _os: PhantomData<OS>,
32}
33
34impl<OS: OsInterface> Clone for AtomicNotifier<OS> {
35    fn clone(&self) -> Self {
36        Self {
37            flag: Arc::clone(&self.flag),
38            _os: PhantomData,
39        }
40    }
41}
42
43impl<OS: OsInterface> AtomicNotifier<OS> {
44    pub fn new() -> (Self, AtomicNotifyWaiter<OS>) {
45        let s = Self {
46            flag: Arc::new(AtomicBool::new(false)),
47            _os: PhantomData,
48        };
49        let r = AtomicNotifyWaiter::<OS> {
50            flag: Arc::clone(&s.flag),
51            _os: PhantomData,
52        };
53        (s, r)
54    }
55}
56
57impl<OS: OsInterface> NotifierInterface for AtomicNotifier<OS> {
58    fn notify(&self) -> bool {
59        self.flag.store(true, Ordering::Release);
60        true
61    }
62}
63
64pub struct AtomicNotifyWaiter<OS> {
65    flag: Arc<AtomicBool>,
66    _os: PhantomData<OS>,
67}
68
69impl<OS: OsInterface> NotifyWaiterInterface<OS> for AtomicNotifyWaiter<OS> {
70    fn wait(&self, timeout: &Duration<OS>) -> bool {
71        let mut t = Timeout::<OS>::from(timeout);
72        loop {
73            if self.flag.swap(false, Ordering::AcqRel) {
74                return true;
75            } else if t.timeout() {
76                return false;
77            }
78            OS::yield_thread();
79        }
80    }
81}
82
83// ------------------------------------------------------------------
84
85#[cfg(feature = "std")]
86pub use std_impl::*;
87#[cfg(feature = "std")]
88mod std_impl {
89    use super::*;
90    use crate::os_impls::*;
91    use std::sync::{
92        Arc,
93        atomic::{AtomicBool, Ordering},
94    };
95
96    /// This implementation is only for unit testing.
97    #[derive(Clone)]
98    pub struct StdNotifier {
99        flag: Arc<AtomicBool>,
100    }
101
102    impl StdNotifier {
103        pub fn new() -> (Self, StdNotifyWaiter) {
104            let s = Self {
105                flag: Arc::new(AtomicBool::new(false)),
106            };
107            let r = StdNotifyWaiter {
108                flag: Arc::clone(&s.flag),
109            };
110            (s, r)
111        }
112    }
113
114    impl NotifierInterface for StdNotifier {
115        fn notify(&self) -> bool {
116            self.flag.store(true, Ordering::Release);
117            true
118        }
119    }
120
121    /// This implementation is only for unit testing.
122    pub struct StdNotifyWaiter {
123        flag: Arc<AtomicBool>,
124    }
125
126    impl NotifyWaiterInterface<StdOs> for StdNotifyWaiter {
127        fn wait(&self, timeout: &Duration<StdOs>) -> bool {
128            let mut t = Timeout::<StdOs>::from(timeout);
129            while !t.timeout() {
130                if self
131                    .flag
132                    .compare_exchange(true, false, Ordering::SeqCst, Ordering::Acquire)
133                    .is_ok()
134                {
135                    return true;
136                }
137                std::thread::sleep(std::time::Duration::from_nanos(1));
138            }
139            false
140        }
141    }
142
143    #[cfg(test)]
144    mod tests {
145        use super::*;
146        use std::thread;
147        type OsDuration = Duration<StdOs>;
148
149        #[test]
150        fn notify() {
151            let (n, w) = StdNotifier::new();
152            assert!(!w.wait(&OsDuration::millis(1)));
153            n.notify();
154            assert!(w.wait(&OsDuration::millis(1)));
155
156            let mut handles = vec![];
157
158            let n2 = n.clone();
159
160            handles.push(thread::spawn(move || {
161                assert!(w.wait(&OsDuration::millis(2000)));
162                assert!(w.wait(&OsDuration::millis(2000)));
163
164                let mut i = 0;
165                assert_eq!(
166                    w.wait_with(&OsDuration::millis(100), 4, || {
167                        i += 1;
168                        None::<()>
169                    }),
170                    None
171                );
172                assert_eq!(i, 5);
173            }));
174
175            handles.push(thread::spawn(move || {
176                assert!(n.notify());
177            }));
178
179            handles.push(thread::spawn(move || {
180                std::thread::sleep(std::time::Duration::from_millis(10));
181                assert!(n2.notify());
182            }));
183
184            for h in handles {
185                h.join().unwrap();
186            }
187        }
188    }
189}