os_trait/
notifier_impls.rs

1use crate::{fugit::MicrosDurationU32, notifier::*, *};
2use core::{
3    marker::PhantomData,
4    sync::atomic::{AtomicBool, Ordering},
5};
6
7#[derive(Clone)]
8pub struct FakeNotifier;
9
10impl FakeNotifier {
11    pub fn new() -> (Self, Self) {
12        (Self {}, Self {})
13    }
14}
15
16impl Notifier for FakeNotifier {
17    fn notify(&self) -> bool {
18        true
19    }
20}
21
22impl NotifyWaiter for FakeNotifier {
23    fn wait(&self, _timeout: MicrosDurationU32) -> bool {
24        true
25    }
26}
27
28// ------------------------------------------------------------------
29
30/// This [`Notifier`] implementation is for unit test
31pub struct AtomicNotifier<OS> {
32    flag: Arc<AtomicBool>,
33    _os: PhantomData<OS>,
34}
35
36impl<OS: OsInterface> Clone for AtomicNotifier<OS> {
37    fn clone(&self) -> Self {
38        Self {
39            flag: Arc::clone(&self.flag),
40            _os: PhantomData,
41        }
42    }
43}
44
45impl<OS: OsInterface> AtomicNotifier<OS> {
46    pub fn new() -> (Self, AtomicNotifyWaiter<OS>) {
47        let s = Self {
48            flag: Arc::new(AtomicBool::new(false)),
49            _os: PhantomData,
50        };
51        let r = AtomicNotifyWaiter::<OS> {
52            flag: Arc::clone(&s.flag),
53            _os: PhantomData,
54        };
55        (s, r)
56    }
57}
58
59impl<OS: OsInterface> Notifier for AtomicNotifier<OS> {
60    fn notify(&self) -> bool {
61        self.flag.store(true, Ordering::Release);
62        true
63    }
64}
65
66pub struct AtomicNotifyWaiter<OS> {
67    flag: Arc<AtomicBool>,
68    _os: PhantomData<OS>,
69}
70
71impl<OS: OsInterface> NotifyWaiter for AtomicNotifyWaiter<OS> {
72    fn wait(&self, timeout: MicrosDurationU32) -> bool {
73        let mut t = OS::timeout().start_us(timeout.to_micros());
74        while !t.timeout() {
75            if self
76                .flag
77                .compare_exchange(true, false, Ordering::SeqCst, Ordering::Acquire)
78                .is_ok()
79            {
80                return true;
81            }
82            OS::yield_thread();
83        }
84        false
85    }
86}
87
88// ------------------------------------------------------------------
89
90#[cfg(feature = "std")]
91pub use std_impl::*;
92#[cfg(feature = "std")]
93mod std_impl {
94    use super::*;
95    use std::sync::{
96        Arc,
97        atomic::{AtomicBool, Ordering},
98    };
99    use std::time::Instant;
100
101    /// This implementation is only for unit testing.
102    #[derive(Clone)]
103    pub struct StdNotifier {
104        flag: Arc<AtomicBool>,
105    }
106
107    impl StdNotifier {
108        pub fn new() -> (Self, StdNotifyWaiter) {
109            let s = Self {
110                flag: Arc::new(AtomicBool::new(false)),
111            };
112            let r = StdNotifyWaiter {
113                flag: Arc::clone(&s.flag),
114            };
115            (s, r)
116        }
117    }
118
119    impl Notifier for StdNotifier {
120        fn notify(&self) -> bool {
121            self.flag.store(true, Ordering::Release);
122            true
123        }
124    }
125
126    /// This implementation is only for unit testing.
127    pub struct StdNotifyWaiter {
128        flag: Arc<AtomicBool>,
129    }
130
131    impl NotifyWaiter for StdNotifyWaiter {
132        fn wait(&self, timeout: MicrosDurationU32) -> bool {
133            let now = Instant::now();
134            while now.elapsed().as_micros() < timeout.to_micros().into() {
135                if self
136                    .flag
137                    .compare_exchange(true, false, Ordering::SeqCst, Ordering::Acquire)
138                    .is_ok()
139                {
140                    return true;
141                }
142                std::thread::yield_now();
143            }
144            false
145        }
146    }
147
148    #[cfg(test)]
149    mod tests {
150        use super::*;
151        use fugit::ExtU32;
152        use std::thread;
153
154        #[test]
155        fn notify() {
156            let (n, w) = StdNotifier::new();
157            assert!(!w.wait(1.millis()));
158            n.notify();
159            assert!(w.wait(1.millis()));
160
161            let mut handles = vec![];
162
163            let n2 = n.clone();
164
165            handles.push(thread::spawn(move || {
166                assert!(w.wait(2000.millis()));
167                assert!(w.wait(2000.millis()));
168
169                let mut i = 0;
170                assert_eq!(
171                    w.wait_with(StdOs::O, 100.millis(), 4, || {
172                        i += 1;
173                        None::<()>
174                    }),
175                    None
176                );
177                assert_eq!(i, 5);
178            }));
179
180            handles.push(thread::spawn(move || {
181                assert!(n.notify());
182            }));
183
184            handles.push(thread::spawn(move || {
185                assert!(n2.notify());
186            }));
187
188            for h in handles {
189                h.join().unwrap();
190            }
191        }
192    }
193}