Skip to main content

authenticator_ctap2_2021/
statecallback.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4
5use std::sync::{Arc, Condvar, Mutex};
6
7pub struct StateCallback<T> {
8    callback: Arc<Mutex<Option<Box<dyn Fn(T) + Send>>>>,
9    observer: Arc<Mutex<Option<Box<dyn Fn() + Send>>>>,
10    condition: Arc<(Mutex<bool>, Condvar)>,
11}
12
13impl<T> StateCallback<T> {
14    // This is used for the Condvar, which requires this kind of construction
15    #[allow(clippy::mutex_atomic)]
16    pub fn new(cb: Box<dyn Fn(T) + Send>) -> Self {
17        Self {
18            callback: Arc::new(Mutex::new(Some(cb))),
19            observer: Arc::new(Mutex::new(None)),
20            condition: Arc::new((Mutex::new(true), Condvar::new())),
21        }
22    }
23
24    pub fn add_uncloneable_observer(&mut self, obs: Box<dyn Fn() + Send>) {
25        let mut opt = self.observer.lock().unwrap();
26        if opt.is_some() {
27            error!("Replacing an already-set observer.")
28        }
29        opt.replace(obs);
30    }
31
32    pub fn call(&self, rv: T) {
33        if let Some(cb) = self.callback.lock().unwrap().take() {
34            cb(rv);
35
36            if let Some(obs) = self.observer.lock().unwrap().take() {
37                obs();
38            }
39        }
40
41        let (lock, cvar) = &*self.condition;
42        let mut pending = lock.lock().unwrap();
43        *pending = false;
44        cvar.notify_all();
45    }
46
47    pub fn wait(&self) {
48        let (lock, cvar) = &*self.condition;
49        let _useless_guard = cvar
50            .wait_while(lock.lock().unwrap(), |pending| *pending)
51            .unwrap();
52    }
53}
54
55impl<T> Clone for StateCallback<T> {
56    fn clone(&self) -> Self {
57        Self {
58            callback: self.callback.clone(),
59            observer: Arc::new(Mutex::new(None)),
60            condition: self.condition.clone(),
61        }
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use super::StateCallback;
68    use std::sync::atomic::{AtomicUsize, Ordering};
69    use std::sync::{Arc, Barrier};
70    use std::thread;
71
72    #[test]
73    fn test_statecallback_is_single_use() {
74        let counter = Arc::new(AtomicUsize::new(0));
75        let counter_clone = counter.clone();
76        let sc = StateCallback::new(Box::new(move |_| {
77            counter_clone.fetch_add(1, Ordering::SeqCst);
78        }));
79
80        assert_eq!(counter.load(Ordering::SeqCst), 0);
81        for _ in 0..10 {
82            sc.call(());
83            assert_eq!(counter.load(Ordering::SeqCst), 1);
84        }
85
86        for _ in 0..10 {
87            sc.clone().call(());
88            assert_eq!(counter.load(Ordering::SeqCst), 1);
89        }
90    }
91
92    #[test]
93    fn test_statecallback_observer_is_single_use() {
94        let counter = Arc::new(AtomicUsize::new(0));
95        let counter_clone = counter.clone();
96        let mut sc = StateCallback::<()>::new(Box::new(move |_| {}));
97
98        sc.add_uncloneable_observer(Box::new(move || {
99            counter_clone.fetch_add(1, Ordering::SeqCst);
100        }));
101
102        assert_eq!(counter.load(Ordering::SeqCst), 0);
103        for _ in 0..10 {
104            sc.call(());
105            assert_eq!(counter.load(Ordering::SeqCst), 1);
106        }
107
108        for _ in 0..10 {
109            sc.clone().call(());
110            assert_eq!(counter.load(Ordering::SeqCst), 1);
111        }
112    }
113
114    #[test]
115    fn test_statecallback_observer_only_runs_for_completing_callback() {
116        let cb_counter = Arc::new(AtomicUsize::new(0));
117        let cb_counter_clone = cb_counter.clone();
118        let sc = StateCallback::new(Box::new(move |_| {
119            cb_counter_clone.fetch_add(1, Ordering::SeqCst);
120        }));
121
122        let obs_counter = Arc::new(AtomicUsize::new(0));
123
124        for _ in 0..10 {
125            let obs_counter_clone = obs_counter.clone();
126            let mut c = sc.clone();
127            c.add_uncloneable_observer(Box::new(move || {
128                obs_counter_clone.fetch_add(1, Ordering::SeqCst);
129            }));
130
131            c.call(());
132
133            assert_eq!(cb_counter.load(Ordering::SeqCst), 1);
134            assert_eq!(obs_counter.load(Ordering::SeqCst), 1);
135        }
136    }
137
138    #[test]
139    #[allow(clippy::redundant_clone)]
140    fn test_statecallback_observer_unclonable() {
141        let mut sc = StateCallback::<()>::new(Box::new(move |_| {}));
142        sc.add_uncloneable_observer(Box::new(move || {}));
143
144        assert!(sc.observer.lock().unwrap().is_some());
145        // This is deliberate, to force an extra clone
146        assert!(sc.clone().observer.lock().unwrap().is_none());
147    }
148
149    #[test]
150    fn test_statecallback_wait() {
151        let sc = StateCallback::<()>::new(Box::new(move |_| {}));
152        let barrier = Arc::new(Barrier::new(2));
153
154        {
155            let c = sc.clone();
156            let b = barrier.clone();
157            thread::spawn(move || {
158                b.wait();
159                c.call(());
160            });
161        }
162
163        barrier.wait();
164        sc.wait();
165    }
166}