more_sync/versioned_parker/
mod.rs1use std::ops::{Deref, DerefMut};
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::{Arc, Condvar, Mutex, MutexGuard, WaitTimeoutResult};
4use std::time::Duration;
5
6#[derive(Default, Clone, Debug)]
39pub struct VersionedParker<T> {
40 inner: Arc<Inner<T>>,
41}
42
43#[derive(Default, Debug)]
44struct Inner<T> {
45 version: AtomicUsize,
46 data: Mutex<T>,
47 condvar: Condvar,
48}
49
50impl<T> Inner<T> {
51 fn version(&self) -> usize {
52 self.version.load(Ordering::Acquire)
53 }
54}
55
56impl<T> VersionedParker<T> {
57 pub fn new(data: T) -> Self {
60 Self {
61 inner: Arc::new(Inner {
62 version: AtomicUsize::new(0),
63 data: Mutex::new(data),
64 condvar: Condvar::new(),
65 }),
66 }
67 }
68
69 pub fn lock(&self) -> VersionedGuard<T> {
74 let guard = self.inner.data.lock().unwrap();
75 VersionedGuard {
76 parker: self.inner.as_ref(),
77 guard: Some(guard),
78 notified_count: 0,
79 }
80 }
81
82 fn do_notify(
83 &self,
84 expected_version: Option<usize>,
85 mutate: fn(&mut T),
86 notify: fn(&Condvar),
87 ) -> bool {
88 let mut guard = self.inner.data.lock().unwrap();
89 if expected_version
90 .map(|v| v == self.version())
91 .unwrap_or(true)
92 {
93 self.inner.version.fetch_add(1, Ordering::AcqRel);
94 mutate(guard.deref_mut());
95 notify(&self.inner.condvar);
96 return true;
97 }
98 false
99 }
100
101 pub fn notify_one(&self) {
103 self.do_notify(None, |_| {}, Condvar::notify_one);
104 }
105
106 pub fn notify_one_mutate(&self, mutate: fn(&mut T)) {
109 self.do_notify(None, mutate, Condvar::notify_one);
110 }
111
112 pub fn try_notify_one(&self, expected_version: usize) -> bool {
117 self.do_notify(Some(expected_version), |_| {}, Condvar::notify_one)
118 }
119
120 pub fn try_notify_one_mutate(
125 &self,
126 expected_version: usize,
127 mutate: fn(&mut T),
128 ) -> bool {
129 self.do_notify(Some(expected_version), mutate, Condvar::notify_one)
130 }
131
132 pub fn notify_all(&self) {
134 self.do_notify(None, |_| {}, Condvar::notify_all);
135 }
136
137 pub fn notify_all_mutate(&self, mutate: fn(&mut T)) {
140 self.do_notify(None, mutate, Condvar::notify_all);
141 }
142
143 pub fn try_notify_all(&self, expected_version: usize) -> bool {
148 self.do_notify(Some(expected_version), |_| {}, Condvar::notify_all)
149 }
150
151 pub fn try_notify_all_mutate(
156 &self,
157 expected_version: usize,
158 mutate: fn(&mut T),
159 ) -> bool {
160 self.do_notify(Some(expected_version), mutate, Condvar::notify_all)
161 }
162
163 pub fn version(&self) -> usize {
165 self.inner.version()
166 }
167}
168
169#[derive(Debug)]
171pub struct VersionedGuard<'a, T> {
172 parker: &'a Inner<T>,
173 guard: Option<MutexGuard<'a, T>>,
174 notified_count: usize,
175}
176
177impl<'a, T> VersionedGuard<'a, T> {
178 pub fn version(&self) -> usize {
183 self.parker.version()
184 }
185
186 pub fn notified(&self) -> bool {
190 self.notified_count != 0
191 }
192
193 pub fn notified_count(&self) -> usize {
197 self.notified_count
198 }
199
200 pub fn wait(&mut self) {
204 let guard = self.guard.take().unwrap();
205 let version = self.parker.version();
206
207 self.guard = Some(self.parker.condvar.wait(guard).unwrap());
208 self.notified_count = self.parker.version() - version;
209 }
210
211 pub fn wait_timeout(&mut self, timeout: Duration) -> WaitTimeoutResult {
215 let guard = self.guard.take().unwrap();
216 let version = self.parker.version();
217 let (guard_result, wait_result) =
218 self.parker.condvar.wait_timeout(guard, timeout).unwrap();
219
220 self.guard = Some(guard_result);
221 self.notified_count = self.parker.version() - version;
222
223 wait_result
224 }
225}
226
227impl<'a, T> Deref for VersionedGuard<'a, T> {
228 type Target = T;
229
230 fn deref(&self) -> &Self::Target {
231 self.guard.as_deref().unwrap()
232 }
233}
234
235impl<'a, T> DerefMut for VersionedGuard<'a, T> {
236 fn deref_mut(&mut self) -> &mut Self::Target {
237 self.guard.as_deref_mut().unwrap()
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_basics() {
247 let versioned_parker = VersionedParker::new(0);
248 assert_eq!(versioned_parker.version(), 0);
249
250 versioned_parker.notify_one();
251 assert_eq!(versioned_parker.version(), 1);
252
253 versioned_parker.notify_one_mutate(|i| *i = 32);
254 let mut guard = versioned_parker.lock();
255 assert_eq!(versioned_parker.version(), 2);
256 assert_eq!(*guard, 32);
257
258 let parker_clone = versioned_parker.clone();
259 std::thread::spawn(move || parker_clone.notify_one_mutate(|i| *i = 64));
260 guard.wait();
261
262 assert_eq!(guard.notified_count(), 1);
263 assert_eq!(*guard, 64);
264 }
265
266 #[test]
267 fn test_multiple_notify() {
268 let versioned_parker = VersionedParker::new(0);
269 let mut guard = versioned_parker.lock();
270
271 let parker_clone = versioned_parker.clone();
272 std::thread::spawn(move || {
273 parker_clone.notify_all();
274 parker_clone.notify_all_mutate(|i| *i = 128);
275 parker_clone.notify_one_mutate(|i| *i = 256);
276 parker_clone.notify_one_mutate(|i| *i = 512);
277 });
278
279 guard.wait();
280 let expected_value = match guard.notified_count() {
281 1 => 0,
282 2 => 128,
283 3 => 256,
284 4 => 512,
285 _ => panic!("notify count should not be larger than 3"),
286 };
287 assert_eq!(*guard, expected_value);
288 }
289
290 #[test]
291 fn test_try_notify() {
292 let versioned_parker = VersionedParker::new(0);
293 let mut guard = versioned_parker.lock();
294
295 let parker_clone = versioned_parker.clone();
296 std::thread::spawn(move || {
297 assert!(parker_clone.try_notify_one(0));
298 assert!(!parker_clone.try_notify_all(0));
299 });
300
301 guard.wait();
302 assert_eq!(guard.notified_count(), 1);
303 assert_eq!(*guard, 0);
304 }
305
306 #[test]
307 fn test_try_notify_mutate() {
308 let versioned_parker = VersionedParker::new(0);
309 let mut guard = versioned_parker.lock();
310
311 let parker_clone = versioned_parker.clone();
312 std::thread::spawn(move || {
313 assert!(parker_clone.try_notify_one_mutate(0, |i| *i = 1024));
314 assert!(!parker_clone.try_notify_all_mutate(0, |i| *i = 2048));
315 });
316
317 guard.wait();
318 assert_eq!(guard.notified_count(), 1);
319 assert_eq!(*guard, 1024);
320 }
321}