os_trait/
notifier_impls.rs

1use crate::{fugit::MicrosDurationU32, notifier::*, *};
2use core::{
3    marker::PhantomData,
4    sync::atomic::{AtomicBool, Ordering},
5};
6
7#[derive(Default, Clone)]
8pub struct FakeNotifier;
9
10impl NotifyBuilder for FakeNotifier {
11    fn build() -> (impl Notifier, impl NotifyWaiter) {
12        (Self {}, Self {})
13    }
14}
15
16impl Notifier for FakeNotifier {
17    fn notify(&self) -> bool {
18        true
19    }
20}
21
22impl NotifyWaiter for FakeNotifier {
23    fn wait(&self, _timeout: MicrosDurationU32) -> bool {
24        true
25    }
26}
27
28// ------------------------------------------------------------------
29
30/// This [`Notifier`] implementation is for unit test
31pub struct AtomicNotifier<OS> {
32    flag: Arc<AtomicBool>,
33    _os: PhantomData<OS>,
34}
35
36impl<OS: OsInterface> Clone for AtomicNotifier<OS> {
37    fn clone(&self) -> Self {
38        Self {
39            flag: Arc::clone(&self.flag),
40            _os: PhantomData,
41        }
42    }
43}
44
45impl<OS: OsInterface> AtomicNotifier<OS> {
46    pub fn new() -> (Self, impl NotifyWaiter) {
47        let s = Self {
48            flag: Arc::new(AtomicBool::new(false)),
49            _os: PhantomData,
50        };
51        let r = AtomicNotifyReceiver::<OS> {
52            flag: Arc::clone(&s.flag),
53            _os: PhantomData,
54        };
55        (s, r)
56    }
57}
58
59impl<OS: OsInterface> NotifyBuilder for AtomicNotifier<OS> {
60    fn build() -> (impl Notifier, impl NotifyWaiter) {
61        Self::new()
62    }
63}
64
65impl<OS: OsInterface> Notifier for AtomicNotifier<OS> {
66    fn notify(&self) -> bool {
67        self.flag.store(true, Ordering::Release);
68        true
69    }
70}
71
72pub struct AtomicNotifyReceiver<OS> {
73    flag: Arc<AtomicBool>,
74    _os: PhantomData<OS>,
75}
76
77impl<OS: OsInterface> NotifyWaiter for AtomicNotifyReceiver<OS> {
78    fn wait(&self, timeout: MicrosDurationU32) -> bool {
79        let mut t = OS::Timeout::start_us(timeout.to_micros());
80        while !t.timeout() {
81            if self
82                .flag
83                .compare_exchange(true, false, Ordering::SeqCst, Ordering::Acquire)
84                .is_ok()
85            {
86                return true;
87            }
88            OS::yield_thread();
89        }
90        false
91    }
92}
93
94// ------------------------------------------------------------------
95
96#[cfg(feature = "std")]
97pub use std_impl::*;
98#[cfg(feature = "std")]
99mod std_impl {
100    use super::*;
101    use std::sync::{
102        Arc,
103        atomic::{AtomicBool, Ordering},
104    };
105    use std::time::Instant;
106
107    /// This implementation is only for unit testing.
108    #[derive(Clone)]
109    pub struct StdNotifier {
110        flag: Arc<AtomicBool>,
111    }
112
113    impl StdNotifier {
114        pub fn new() -> (Self, StdNotifyWaiter) {
115            let s = Self {
116                flag: Arc::new(AtomicBool::new(false)),
117            };
118            let r = StdNotifyWaiter {
119                flag: Arc::clone(&s.flag),
120            };
121            (s, r)
122        }
123    }
124
125    impl NotifyBuilder for StdNotifier {
126        fn build() -> (impl Notifier, impl NotifyWaiter) {
127            Self::new()
128        }
129    }
130
131    impl Notifier for StdNotifier {
132        fn notify(&self) -> bool {
133            self.flag.store(true, Ordering::Release);
134            true
135        }
136    }
137
138    /// This implementation is only for unit testing.
139    pub struct StdNotifyWaiter {
140        flag: Arc<AtomicBool>,
141    }
142
143    impl NotifyWaiter for StdNotifyWaiter {
144        fn wait(&self, timeout: MicrosDurationU32) -> bool {
145            let now = Instant::now();
146            while now.elapsed().as_micros() < timeout.to_micros().into() {
147                if self
148                    .flag
149                    .compare_exchange(true, false, Ordering::SeqCst, Ordering::Acquire)
150                    .is_ok()
151                {
152                    return true;
153                }
154                std::thread::yield_now();
155            }
156            false
157        }
158    }
159
160    #[cfg(test)]
161    mod tests {
162        use super::*;
163        use fugit::ExtU32;
164        use std::thread;
165
166        #[test]
167        fn notify() {
168            let (n, w) = StdNotifier::new();
169            assert!(!w.wait(1.millis()));
170            n.notify();
171            assert!(w.wait(1.millis()));
172
173            let mut handles = vec![];
174
175            let n2 = n.clone();
176
177            handles.push(thread::spawn(move || {
178                assert!(w.wait(1000.millis()));
179                assert!(w.wait(1000.millis()));
180            }));
181
182            handles.push(thread::spawn(move || {
183                assert!(n.notify());
184            }));
185
186            handles.push(thread::spawn(move || {
187                assert!(n2.notify());
188            }));
189
190            for h in handles {
191                h.join().unwrap();
192            }
193        }
194    }
195}