more_sync/versioned_parker/
mod.rs

1use std::ops::{Deref, DerefMut};
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::{Arc, Condvar, Mutex, MutexGuard, WaitTimeoutResult};
4use std::time::Duration;
5
6/// A thread parking and locking primitive that provide version numbers.
7///
8/// Like an [`std::sync::Condvar`], `VersionedParker` provides a `wait`
9/// method and several `notify` methods. The `wait` method blocks the current
10/// thread, while the `notify` methods unblocks waiting threads. Each time
11/// `notify` is called, the parker version is increased. When a blocked thread
12/// wakes up, it can check the internal counter and learn how many times it has
13/// been notified. The version can be obtained by calling method
14/// [`VersionedParker::version()`].
15///
16/// `VersionedParker` holds a piece of data that can be modified during `notify`
17/// and `wait` operations. The data is versioned also versioned by the same
18/// parker version.
19///
20/// ```
21/// use more_sync::VersionedParker;
22///
23/// let versioned_parker = VersionedParker::new(0);
24/// let mut guard = versioned_parker.lock();
25///
26/// let parker_clone = versioned_parker.clone();
27/// std::thread::spawn(move || {
28///     parker_clone.notify_one_mutate(|i| *i = 16);
29///     assert_eq!(parker_clone.version(), 1);
30///     // Version is 1, try_notify_all() should fail.
31///     assert!(!parker_clone.try_notify_all(0));
32/// });
33///
34/// guard.wait();
35/// assert_eq!(guard.notified_count(), 1);
36/// assert_eq!(*guard, 16);
37/// ```
38#[derive(Default, Clone, Debug)]
39pub struct VersionedParker<T> {
40    inner: Arc<Inner<T>>,
41}
42
43#[derive(Default, Debug)]
44struct Inner<T> {
45    version: AtomicUsize,
46    data: Mutex<T>,
47    condvar: Condvar,
48}
49
50impl<T> Inner<T> {
51    fn version(&self) -> usize {
52        self.version.load(Ordering::Acquire)
53    }
54}
55
56impl<T> VersionedParker<T> {
57    /// Creates a new `VersionedParker`, with the initial version being `0`, and
58    /// the shared data being `data`.
59    pub fn new(data: T) -> Self {
60        Self {
61            inner: Arc::new(Inner {
62                version: AtomicUsize::new(0),
63                data: Mutex::new(data),
64                condvar: Condvar::new(),
65            }),
66        }
67    }
68
69    /// Locks the shared data and the version.
70    ///
71    /// A thread can then call [`VersionedGuard::wait()`] to wait for version
72    /// changes.
73    pub fn lock(&self) -> VersionedGuard<T> {
74        let guard = self.inner.data.lock().unwrap();
75        VersionedGuard {
76            parker: self.inner.as_ref(),
77            guard: Some(guard),
78            notified_count: 0,
79        }
80    }
81
82    fn do_notify(
83        &self,
84        expected_version: Option<usize>,
85        mutate: fn(&mut T),
86        notify: fn(&Condvar),
87    ) -> bool {
88        let mut guard = self.inner.data.lock().unwrap();
89        if expected_version
90            .map(|v| v == self.version())
91            .unwrap_or(true)
92        {
93            self.inner.version.fetch_add(1, Ordering::AcqRel);
94            mutate(guard.deref_mut());
95            notify(&self.inner.condvar);
96            return true;
97        }
98        false
99    }
100
101    /// Increases the version and notifies one blocked thread.
102    pub fn notify_one(&self) {
103        self.do_notify(None, |_| {}, Condvar::notify_one);
104    }
105
106    /// Increases the version, mutates the shared data and notifies one blocked
107    /// thread.
108    pub fn notify_one_mutate(&self, mutate: fn(&mut T)) {
109        self.do_notify(None, mutate, Condvar::notify_one);
110    }
111
112    /// Increases the version and notifies one blocked thread, if the current
113    /// version is `expected_version`.
114    ///
115    /// Returns `true` if the version matches.
116    pub fn try_notify_one(&self, expected_version: usize) -> bool {
117        self.do_notify(Some(expected_version), |_| {}, Condvar::notify_one)
118    }
119
120    /// Increases the version, modifies the shared data and notifies one blocked
121    /// thread, if the current version is `expected_version`.
122    ///
123    /// Returns `true` if the version matches.
124    pub fn try_notify_one_mutate(
125        &self,
126        expected_version: usize,
127        mutate: fn(&mut T),
128    ) -> bool {
129        self.do_notify(Some(expected_version), mutate, Condvar::notify_one)
130    }
131
132    /// Increases the version and notifies all blocked threads.
133    pub fn notify_all(&self) {
134        self.do_notify(None, |_| {}, Condvar::notify_all);
135    }
136
137    /// Increases the version, modifies the shared data and notifies all blocked
138    /// threads.
139    pub fn notify_all_mutate(&self, mutate: fn(&mut T)) {
140        self.do_notify(None, mutate, Condvar::notify_all);
141    }
142
143    /// Increases the version and notifies all blocked threads, if the current
144    /// version is `expected_version`.
145    ///
146    /// Returns `true` if the version matches.
147    pub fn try_notify_all(&self, expected_version: usize) -> bool {
148        self.do_notify(Some(expected_version), |_| {}, Condvar::notify_all)
149    }
150
151    /// Increases the version, modifies the shared data and notifies all blocked
152    /// threads, if the current version is `expected_version`.
153    ///
154    /// Returns `true` if the version matches.
155    pub fn try_notify_all_mutate(
156        &self,
157        expected_version: usize,
158        mutate: fn(&mut T),
159    ) -> bool {
160        self.do_notify(Some(expected_version), mutate, Condvar::notify_all)
161    }
162
163    /// Returns the current version.
164    pub fn version(&self) -> usize {
165        self.inner.version()
166    }
167}
168
169/// Mutex guard returned by [`VersionedParker::lock`].
170#[derive(Debug)]
171pub struct VersionedGuard<'a, T> {
172    parker: &'a Inner<T>,
173    guard: Option<MutexGuard<'a, T>>,
174    notified_count: usize,
175}
176
177impl<'a, T> VersionedGuard<'a, T> {
178    /// Returns the current version.
179    ///
180    /// The version will not change unless [`wait()`](`VersionedGuard::wait`) or
181    /// [`wait_timeout()`](`VersionedGuard::wait_timeout`) is called.
182    pub fn version(&self) -> usize {
183        self.parker.version()
184    }
185
186    /// Returns if we were notified during last period.
187    ///
188    /// If we never waited, `notified()` returns false.
189    pub fn notified(&self) -> bool {
190        self.notified_count != 0
191    }
192
193    /// Returns the number of times we were notified during last wait.
194    ///
195    /// If we never waited, `notification_count()` returns 0.
196    pub fn notified_count(&self) -> usize {
197        self.notified_count
198    }
199
200    /// Blocks the current thread until notified.
201    ///
202    /// `wait()` updates the version stored in this guard.
203    pub fn wait(&mut self) {
204        let guard = self.guard.take().unwrap();
205        let version = self.parker.version();
206
207        self.guard = Some(self.parker.condvar.wait(guard).unwrap());
208        self.notified_count = self.parker.version() - version;
209    }
210
211    /// Blocks the current thread until notified, for up to `timeout`.
212    ///
213    /// `wait_timeout()` updates the version stored in this guard.
214    pub fn wait_timeout(&mut self, timeout: Duration) -> WaitTimeoutResult {
215        let guard = self.guard.take().unwrap();
216        let version = self.parker.version();
217        let (guard_result, wait_result) =
218            self.parker.condvar.wait_timeout(guard, timeout).unwrap();
219
220        self.guard = Some(guard_result);
221        self.notified_count = self.parker.version() - version;
222
223        wait_result
224    }
225}
226
227impl<'a, T> Deref for VersionedGuard<'a, T> {
228    type Target = T;
229
230    fn deref(&self) -> &Self::Target {
231        self.guard.as_deref().unwrap()
232    }
233}
234
235impl<'a, T> DerefMut for VersionedGuard<'a, T> {
236    fn deref_mut(&mut self) -> &mut Self::Target {
237        self.guard.as_deref_mut().unwrap()
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_basics() {
247        let versioned_parker = VersionedParker::new(0);
248        assert_eq!(versioned_parker.version(), 0);
249
250        versioned_parker.notify_one();
251        assert_eq!(versioned_parker.version(), 1);
252
253        versioned_parker.notify_one_mutate(|i| *i = 32);
254        let mut guard = versioned_parker.lock();
255        assert_eq!(versioned_parker.version(), 2);
256        assert_eq!(*guard, 32);
257
258        let parker_clone = versioned_parker.clone();
259        std::thread::spawn(move || parker_clone.notify_one_mutate(|i| *i = 64));
260        guard.wait();
261
262        assert_eq!(guard.notified_count(), 1);
263        assert_eq!(*guard, 64);
264    }
265
266    #[test]
267    fn test_multiple_notify() {
268        let versioned_parker = VersionedParker::new(0);
269        let mut guard = versioned_parker.lock();
270
271        let parker_clone = versioned_parker.clone();
272        std::thread::spawn(move || {
273            parker_clone.notify_all();
274            parker_clone.notify_all_mutate(|i| *i = 128);
275            parker_clone.notify_one_mutate(|i| *i = 256);
276            parker_clone.notify_one_mutate(|i| *i = 512);
277        });
278
279        guard.wait();
280        let expected_value = match guard.notified_count() {
281            1 => 0,
282            2 => 128,
283            3 => 256,
284            4 => 512,
285            _ => panic!("notify count should not be larger than 3"),
286        };
287        assert_eq!(*guard, expected_value);
288    }
289
290    #[test]
291    fn test_try_notify() {
292        let versioned_parker = VersionedParker::new(0);
293        let mut guard = versioned_parker.lock();
294
295        let parker_clone = versioned_parker.clone();
296        std::thread::spawn(move || {
297            assert!(parker_clone.try_notify_one(0));
298            assert!(!parker_clone.try_notify_all(0));
299        });
300
301        guard.wait();
302        assert_eq!(guard.notified_count(), 1);
303        assert_eq!(*guard, 0);
304    }
305
306    #[test]
307    fn test_try_notify_mutate() {
308        let versioned_parker = VersionedParker::new(0);
309        let mut guard = versioned_parker.lock();
310
311        let parker_clone = versioned_parker.clone();
312        std::thread::spawn(move || {
313            assert!(parker_clone.try_notify_one_mutate(0, |i| *i = 1024));
314            assert!(!parker_clone.try_notify_all_mutate(0, |i| *i = 2048));
315        });
316
317        guard.wait();
318        assert_eq!(guard.notified_count(), 1);
319        assert_eq!(*guard, 1024);
320    }
321}