os_trait/
notifier_impls.rs

1use crate::{Duration, Timeout, notifier::*, prelude::*};
2use core::{
3    marker::PhantomData,
4    sync::atomic::{AtomicBool, Ordering},
5};
6
7#[derive(Clone)]
8pub struct FakeNotifier;
9
10impl FakeNotifier {
11    pub fn new() -> (Self, Self) {
12        (Self {}, Self {})
13    }
14}
15
16impl NotifierInterface for FakeNotifier {
17    fn notify(&self) -> bool {
18        true
19    }
20}
21
22impl<OS: OsInterface> NotifyWaiterInterface<OS> for FakeNotifier {
23    fn wait(&self, _timeout: &Duration<OS>) -> bool {
24        true
25    }
26}
27
28// ------------------------------------------------------------------
29
30/// This [`NotifierInterface`] implementation is for unit test
31pub struct AtomicNotifier<OS> {
32    flag: Arc<AtomicBool>,
33    _os: PhantomData<OS>,
34}
35
36impl<OS: OsInterface> Clone for AtomicNotifier<OS> {
37    fn clone(&self) -> Self {
38        Self {
39            flag: Arc::clone(&self.flag),
40            _os: PhantomData,
41        }
42    }
43}
44
45impl<OS: OsInterface> AtomicNotifier<OS> {
46    pub fn new() -> (Self, AtomicNotifyWaiter<OS>) {
47        let s = Self {
48            flag: Arc::new(AtomicBool::new(false)),
49            _os: PhantomData,
50        };
51        let r = AtomicNotifyWaiter::<OS> {
52            flag: Arc::clone(&s.flag),
53            _os: PhantomData,
54        };
55        (s, r)
56    }
57}
58
59impl<OS: OsInterface> NotifierInterface for AtomicNotifier<OS> {
60    fn notify(&self) -> bool {
61        self.flag.store(true, Ordering::Release);
62        true
63    }
64}
65
66pub struct AtomicNotifyWaiter<OS> {
67    flag: Arc<AtomicBool>,
68    _os: PhantomData<OS>,
69}
70
71impl<OS: OsInterface> NotifyWaiterInterface<OS> for AtomicNotifyWaiter<OS> {
72    fn wait(&self, timeout: &Duration<OS>) -> bool {
73        let mut t = Timeout::<OS>::from(timeout);
74        while !t.timeout() {
75            if self
76                .flag
77                .compare_exchange(true, false, Ordering::SeqCst, Ordering::Acquire)
78                .is_ok()
79            {
80                return true;
81            }
82            OS::yield_thread();
83        }
84        false
85    }
86}
87
88// ------------------------------------------------------------------
89
90#[cfg(feature = "std")]
91pub use std_impl::*;
92#[cfg(feature = "std")]
93mod std_impl {
94    use super::*;
95    use crate::os_impls::*;
96    use std::sync::{
97        Arc,
98        atomic::{AtomicBool, Ordering},
99    };
100
101    /// This implementation is only for unit testing.
102    #[derive(Clone)]
103    pub struct StdNotifier {
104        flag: Arc<AtomicBool>,
105    }
106
107    impl StdNotifier {
108        pub fn new() -> (Self, StdNotifyWaiter) {
109            let s = Self {
110                flag: Arc::new(AtomicBool::new(false)),
111            };
112            let r = StdNotifyWaiter {
113                flag: Arc::clone(&s.flag),
114            };
115            (s, r)
116        }
117    }
118
119    impl NotifierInterface for StdNotifier {
120        fn notify(&self) -> bool {
121            self.flag.store(true, Ordering::Release);
122            true
123        }
124    }
125
126    /// This implementation is only for unit testing.
127    pub struct StdNotifyWaiter {
128        flag: Arc<AtomicBool>,
129    }
130
131    impl NotifyWaiterInterface<StdOs> for StdNotifyWaiter {
132        fn wait(&self, timeout: &Duration<StdOs>) -> bool {
133            let mut t = Timeout::<StdOs>::from(timeout);
134            while !t.timeout() {
135                if self
136                    .flag
137                    .compare_exchange(true, false, Ordering::SeqCst, Ordering::Acquire)
138                    .is_ok()
139                {
140                    return true;
141                }
142                std::thread::sleep(std::time::Duration::from_nanos(1));
143            }
144            false
145        }
146    }
147
148    #[cfg(test)]
149    mod tests {
150        use super::*;
151        use std::thread;
152        type OsDuration = Duration<StdOs>;
153
154        #[test]
155        fn notify() {
156            let (n, w) = StdNotifier::new();
157            assert!(!w.wait(&OsDuration::millis(1)));
158            n.notify();
159            assert!(w.wait(&OsDuration::millis(1)));
160
161            let mut handles = vec![];
162
163            let n2 = n.clone();
164
165            handles.push(thread::spawn(move || {
166                assert!(w.wait(&OsDuration::millis(2000)));
167                assert!(w.wait(&OsDuration::millis(2000)));
168
169                let mut i = 0;
170                assert_eq!(
171                    w.wait_with(&OsDuration::millis(100), 4, || {
172                        i += 1;
173                        None::<()>
174                    }),
175                    None
176                );
177                assert_eq!(i, 5);
178            }));
179
180            handles.push(thread::spawn(move || {
181                assert!(n.notify());
182            }));
183
184            handles.push(thread::spawn(move || {
185                std::thread::sleep(std::time::Duration::from_millis(10));
186                assert!(n2.notify());
187            }));
188
189            for h in handles {
190                h.join().unwrap();
191            }
192        }
193    }
194}