os_trait/
notifier_impls.rs

1use crate::{notifier::*, *};
2use core::{
3    marker::PhantomData,
4    sync::atomic::{AtomicBool, Ordering},
5};
6
7#[derive(Clone)]
8pub struct FakeNotifier;
9
10impl FakeNotifier {
11    pub fn new() -> (Self, Self) {
12        (Self {}, Self {})
13    }
14}
15
16impl Notifier for FakeNotifier {
17    fn notify(&self) -> bool {
18        true
19    }
20}
21
22impl<OS: OsInterface> NotifyWaiter<OS> for FakeNotifier {
23    fn wait(&self, _timeout: &Duration<OS>) -> bool {
24        true
25    }
26}
27
28// ------------------------------------------------------------------
29
30/// This [`Notifier`] implementation is for unit test
31pub struct AtomicNotifier<OS> {
32    flag: Arc<AtomicBool>,
33    _os: PhantomData<OS>,
34}
35
36impl<OS: OsInterface> Clone for AtomicNotifier<OS> {
37    fn clone(&self) -> Self {
38        Self {
39            flag: Arc::clone(&self.flag),
40            _os: PhantomData,
41        }
42    }
43}
44
45impl<OS: OsInterface> AtomicNotifier<OS> {
46    pub fn new() -> (Self, AtomicNotifyWaiter<OS>) {
47        let s = Self {
48            flag: Arc::new(AtomicBool::new(false)),
49            _os: PhantomData,
50        };
51        let r = AtomicNotifyWaiter::<OS> {
52            flag: Arc::clone(&s.flag),
53            _os: PhantomData,
54        };
55        (s, r)
56    }
57}
58
59impl<OS: OsInterface> Notifier for AtomicNotifier<OS> {
60    fn notify(&self) -> bool {
61        self.flag.store(true, Ordering::Release);
62        true
63    }
64}
65
66pub struct AtomicNotifyWaiter<OS> {
67    flag: Arc<AtomicBool>,
68    _os: PhantomData<OS>,
69}
70
71impl<OS: OsInterface> NotifyWaiter<OS> for AtomicNotifyWaiter<OS> {
72    fn wait(&self, timeout: &Duration<OS>) -> bool {
73        let mut t = Timeout::<OS>::from_duration(timeout);
74        while !t.timeout() {
75            if self
76                .flag
77                .compare_exchange(true, false, Ordering::SeqCst, Ordering::Acquire)
78                .is_ok()
79            {
80                return true;
81            }
82            OS::yield_thread();
83        }
84        false
85    }
86}
87
88// ------------------------------------------------------------------
89
90#[cfg(feature = "std")]
91pub use std_impl::*;
92#[cfg(feature = "std")]
93mod std_impl {
94    use super::*;
95    use std::sync::{
96        Arc,
97        atomic::{AtomicBool, Ordering},
98    };
99
100    /// This implementation is only for unit testing.
101    #[derive(Clone)]
102    pub struct StdNotifier {
103        flag: Arc<AtomicBool>,
104    }
105
106    impl StdNotifier {
107        pub fn new() -> (Self, StdNotifyWaiter) {
108            let s = Self {
109                flag: Arc::new(AtomicBool::new(false)),
110            };
111            let r = StdNotifyWaiter {
112                flag: Arc::clone(&s.flag),
113            };
114            (s, r)
115        }
116    }
117
118    impl Notifier for StdNotifier {
119        fn notify(&self) -> bool {
120            self.flag.store(true, Ordering::Release);
121            true
122        }
123    }
124
125    /// This implementation is only for unit testing.
126    pub struct StdNotifyWaiter {
127        flag: Arc<AtomicBool>,
128    }
129
130    impl NotifyWaiter<StdOs> for StdNotifyWaiter {
131        fn wait(&self, timeout: &Duration<StdOs>) -> bool {
132            let mut t = Timeout::<StdOs>::from_duration(timeout);
133            while !t.timeout() {
134                if self
135                    .flag
136                    .compare_exchange(true, false, Ordering::SeqCst, Ordering::Acquire)
137                    .is_ok()
138                {
139                    return true;
140                }
141                std::thread::sleep(std::time::Duration::from_nanos(1));
142            }
143            false
144        }
145    }
146
147    #[cfg(test)]
148    mod tests {
149        use super::*;
150        use std::thread;
151        type OsDuration = Duration<StdOs>;
152
153        #[test]
154        fn notify() {
155            let (n, w) = StdNotifier::new();
156            assert!(!w.wait(&OsDuration::from_millis(1)));
157            n.notify();
158            assert!(w.wait(&OsDuration::from_millis(1)));
159
160            let mut handles = vec![];
161
162            let n2 = n.clone();
163
164            handles.push(thread::spawn(move || {
165                assert!(w.wait(&OsDuration::from_millis(2000)));
166                assert!(w.wait(&OsDuration::from_millis(2000)));
167
168                let mut i = 0;
169                assert_eq!(
170                    w.wait_with(&OsDuration::from_millis(100), 4, || {
171                        i += 1;
172                        None::<()>
173                    }),
174                    None
175                );
176                assert_eq!(i, 5);
177            }));
178
179            handles.push(thread::spawn(move || {
180                assert!(n.notify());
181            }));
182
183            handles.push(thread::spawn(move || {
184                std::thread::sleep(std::time::Duration::from_millis(10));
185                assert!(n2.notify());
186            }));
187
188            for h in handles {
189                h.join().unwrap();
190            }
191        }
192    }
193}