authenticator_ctap2_2021/
statecallback.rs1use std::sync::{Arc, Condvar, Mutex};
6
7pub struct StateCallback<T> {
8 callback: Arc<Mutex<Option<Box<dyn Fn(T) + Send>>>>,
9 observer: Arc<Mutex<Option<Box<dyn Fn() + Send>>>>,
10 condition: Arc<(Mutex<bool>, Condvar)>,
11}
12
13impl<T> StateCallback<T> {
14 #[allow(clippy::mutex_atomic)]
16 pub fn new(cb: Box<dyn Fn(T) + Send>) -> Self {
17 Self {
18 callback: Arc::new(Mutex::new(Some(cb))),
19 observer: Arc::new(Mutex::new(None)),
20 condition: Arc::new((Mutex::new(true), Condvar::new())),
21 }
22 }
23
24 pub fn add_uncloneable_observer(&mut self, obs: Box<dyn Fn() + Send>) {
25 let mut opt = self.observer.lock().unwrap();
26 if opt.is_some() {
27 error!("Replacing an already-set observer.")
28 }
29 opt.replace(obs);
30 }
31
32 pub fn call(&self, rv: T) {
33 if let Some(cb) = self.callback.lock().unwrap().take() {
34 cb(rv);
35
36 if let Some(obs) = self.observer.lock().unwrap().take() {
37 obs();
38 }
39 }
40
41 let (lock, cvar) = &*self.condition;
42 let mut pending = lock.lock().unwrap();
43 *pending = false;
44 cvar.notify_all();
45 }
46
47 pub fn wait(&self) {
48 let (lock, cvar) = &*self.condition;
49 let _useless_guard = cvar
50 .wait_while(lock.lock().unwrap(), |pending| *pending)
51 .unwrap();
52 }
53}
54
55impl<T> Clone for StateCallback<T> {
56 fn clone(&self) -> Self {
57 Self {
58 callback: self.callback.clone(),
59 observer: Arc::new(Mutex::new(None)),
60 condition: self.condition.clone(),
61 }
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use super::StateCallback;
68 use std::sync::atomic::{AtomicUsize, Ordering};
69 use std::sync::{Arc, Barrier};
70 use std::thread;
71
72 #[test]
73 fn test_statecallback_is_single_use() {
74 let counter = Arc::new(AtomicUsize::new(0));
75 let counter_clone = counter.clone();
76 let sc = StateCallback::new(Box::new(move |_| {
77 counter_clone.fetch_add(1, Ordering::SeqCst);
78 }));
79
80 assert_eq!(counter.load(Ordering::SeqCst), 0);
81 for _ in 0..10 {
82 sc.call(());
83 assert_eq!(counter.load(Ordering::SeqCst), 1);
84 }
85
86 for _ in 0..10 {
87 sc.clone().call(());
88 assert_eq!(counter.load(Ordering::SeqCst), 1);
89 }
90 }
91
92 #[test]
93 fn test_statecallback_observer_is_single_use() {
94 let counter = Arc::new(AtomicUsize::new(0));
95 let counter_clone = counter.clone();
96 let mut sc = StateCallback::<()>::new(Box::new(move |_| {}));
97
98 sc.add_uncloneable_observer(Box::new(move || {
99 counter_clone.fetch_add(1, Ordering::SeqCst);
100 }));
101
102 assert_eq!(counter.load(Ordering::SeqCst), 0);
103 for _ in 0..10 {
104 sc.call(());
105 assert_eq!(counter.load(Ordering::SeqCst), 1);
106 }
107
108 for _ in 0..10 {
109 sc.clone().call(());
110 assert_eq!(counter.load(Ordering::SeqCst), 1);
111 }
112 }
113
114 #[test]
115 fn test_statecallback_observer_only_runs_for_completing_callback() {
116 let cb_counter = Arc::new(AtomicUsize::new(0));
117 let cb_counter_clone = cb_counter.clone();
118 let sc = StateCallback::new(Box::new(move |_| {
119 cb_counter_clone.fetch_add(1, Ordering::SeqCst);
120 }));
121
122 let obs_counter = Arc::new(AtomicUsize::new(0));
123
124 for _ in 0..10 {
125 let obs_counter_clone = obs_counter.clone();
126 let mut c = sc.clone();
127 c.add_uncloneable_observer(Box::new(move || {
128 obs_counter_clone.fetch_add(1, Ordering::SeqCst);
129 }));
130
131 c.call(());
132
133 assert_eq!(cb_counter.load(Ordering::SeqCst), 1);
134 assert_eq!(obs_counter.load(Ordering::SeqCst), 1);
135 }
136 }
137
138 #[test]
139 #[allow(clippy::redundant_clone)]
140 fn test_statecallback_observer_unclonable() {
141 let mut sc = StateCallback::<()>::new(Box::new(move |_| {}));
142 sc.add_uncloneable_observer(Box::new(move || {}));
143
144 assert!(sc.observer.lock().unwrap().is_some());
145 assert!(sc.clone().observer.lock().unwrap().is_none());
147 }
148
149 #[test]
150 fn test_statecallback_wait() {
151 let sc = StateCallback::<()>::new(Box::new(move |_| {}));
152 let barrier = Arc::new(Barrier::new(2));
153
154 {
155 let c = sc.clone();
156 let b = barrier.clone();
157 thread::spawn(move || {
158 b.wait();
159 c.call(());
160 });
161 }
162
163 barrier.wait();
164 sc.wait();
165 }
166}