range_lock/
rangelock.rs

1// -*- coding: utf-8 -*-
2//
3// Copyright 2021-2023 Michael Büsch <m@bues.ch>
4//
5// Licensed under the Apache License version 2.0
6// or the MIT license, at your option.
7// SPDX-License-Identifier: Apache-2.0 OR MIT
8//
9
10use crate::{lockedranges::LockedRanges, util::get_bounds};
11use std::{
12    cell::UnsafeCell,
13    hint::unreachable_unchecked,
14    marker::PhantomData,
15    ops::{Deref, DerefMut, Range, RangeBounds},
16    rc::Rc,
17    sync::{LockResult, Mutex, PoisonError, TryLockError, TryLockResult},
18};
19
20/// General purpose multi-thread range lock for [std::vec::Vec].
21///
22/// # Example
23///
24/// ```
25/// use range_lock::VecRangeLock;
26/// use std::{sync::{Arc, Barrier}, thread};
27///
28/// let data = vec![10, 11, 12, 13];
29///
30/// let data_lock0 = Arc::new(VecRangeLock::new(data));
31/// let data_lock1 = Arc::clone(&data_lock0);
32/// let data_lock2 = Arc::clone(&data_lock0);
33///
34/// // Thread barrier, only for demonstration purposes.
35/// let barrier0 = Arc::new(Barrier::new(2));
36/// let barrier1 = Arc::clone(&barrier0);
37///
38/// thread::scope(|s| {
39///     s.spawn(move || {
40///         {
41///             let mut guard = data_lock0.try_lock(0..2).expect("T0: Failed to lock 0..2");
42///             guard[0] = 100; // Write to data[0]
43///         }
44///         barrier0.wait(); // Synchronize with second thread.
45///         {
46///             let guard = data_lock0.try_lock(2..4).expect("T0: Failed to lock 2..4");
47///             assert_eq!(guard[0], 200); // Read from data[2]
48///         }
49///     });
50///
51///     s.spawn(move || {
52///         {
53///             let mut guard = data_lock1.try_lock(2..4).expect("T1: Failed to lock 2..4");
54///             guard[0] = 200; // Write to data[2]
55///         }
56///         barrier1.wait(); // Synchronize with first thread.
57///         {
58///             let guard = data_lock1.try_lock(0..2).expect("T1: Failed to lock 0..2");
59///             assert_eq!(guard[0], 100); // Read from data[0]
60///         }
61///     });
62/// });
63///
64/// let data = Arc::try_unwrap(data_lock2).expect("Arc unwrap failed").into_inner();
65///
66/// assert_eq!(data, vec![100, 11, 200, 13]);
67/// ```
68#[derive(Debug)]
69pub struct VecRangeLock<T> {
70    /// Set of the currently locked ranges.
71    ranges: Mutex<LockedRanges>,
72    /// The underlying data.
73    data: UnsafeCell<Vec<T>>,
74}
75
76// SAFETY:
77// It is safe to access VecRangeLock and the contained data (via VecRangeLockGuard)
78// from multiple threads simultaneously.
79// The lock ensures that access to the data is strictly serialized.
80// T must be Send-able to other threads.
81unsafe impl<T> Sync for VecRangeLock<T> where T: Send {}
82
83impl<'a, T> VecRangeLock<T> {
84    /// Construct a new [VecRangeLock].
85    ///
86    /// * `data`: The data [Vec] to protect.
87    pub fn new(data: Vec<T>) -> VecRangeLock<T> {
88        VecRangeLock {
89            ranges: Mutex::new(LockedRanges::new()),
90            data: UnsafeCell::new(data),
91        }
92    }
93
94    /// Get the length (in number of elements) of the embedded [Vec].
95    #[inline]
96    pub fn data_len(&self) -> usize {
97        // SAFETY: Multithreaded access is safe. len cannot change.
98        unsafe { (*self.data.get()).len() }
99    }
100
101    /// Unwrap this [VecRangeLock] into the contained data.
102    /// This method consumes self.
103    #[inline]
104    pub fn into_inner(self) -> Vec<T> {
105        debug_assert!(self.ranges.lock().unwrap().is_empty());
106        self.data.into_inner()
107    }
108
109    /// Try to lock the given data `range`.
110    ///
111    /// * On success: Returns a [VecRangeLockGuard] that can be used to access the locked region.
112    ///               Dereferencing [VecRangeLockGuard] yields a slice of the `data`.
113    /// * On failure: Returns [TryLockError::WouldBlock], if the range is contended.
114    ///               The locking attempt may be retried by the caller upon contention.
115    ///               Returns [TryLockError::Poisoned], if the lock is poisoned.
116    pub fn try_lock(
117        &'a self,
118        range: impl RangeBounds<usize>,
119    ) -> TryLockResult<VecRangeLockGuard<'a, T>> {
120        let data_len = self.data_len();
121        let (range_start, range_end) = get_bounds(&range, data_len);
122        if range_start >= data_len || range_end > data_len {
123            panic!("Range is out of bounds.");
124        }
125        if range_start > range_end {
126            panic!("Invalid range. Start is bigger than end.");
127        }
128        let range = range_start..range_end;
129
130        if range.is_empty() {
131            TryLockResult::Ok(VecRangeLockGuard::new(self, range))
132        } else if let LockResult::Ok(mut ranges) = self.ranges.lock() {
133            if ranges.insert(&range) {
134                TryLockResult::Ok(VecRangeLockGuard::new(self, range))
135            } else {
136                TryLockResult::Err(TryLockError::WouldBlock)
137            }
138        } else {
139            TryLockResult::Err(TryLockError::Poisoned(PoisonError::new(
140                VecRangeLockGuard::new(self, range),
141            )))
142        }
143    }
144
145    /// Unlock a range.
146    fn unlock(&self, range: &Range<usize>) {
147        if !range.is_empty() {
148            let mut ranges = self
149                .ranges
150                .lock()
151                .expect("VecRangeLock: Failed to take ranges mutex.");
152            ranges.remove(range);
153        }
154    }
155
156    /// Get an immutable slice to the specified range.
157    ///
158    /// # SAFETY
159    ///
160    /// See get_mut_slice().
161    #[inline]
162    unsafe fn get_slice(&self, range: &Range<usize>) -> &[T] {
163        // SAFETY: We trust the slicing machinery of Vec to work correctly.
164        //         It must return the slice range that we requested.
165        //         Otherwise our non-overlap guarantees are gone.
166        &(*self.data.get())[range.clone()]
167    }
168
169    /// Get a mutable slice to the specified range.
170    ///
171    /// # SAFETY
172    ///
173    /// The caller must ensure that:
174    /// * No overlapping slices must coexist on multiple threads.
175    /// * Immutable slices to overlapping ranges may only coexist on a single thread.
176    /// * Immutable and mutable slices must not coexist.
177    #[inline]
178    #[allow(clippy::mut_from_ref)] // Slices won't overlap. See SAFETY.
179    unsafe fn get_mut_slice(&self, range: &Range<usize>) -> &mut [T] {
180        let cptr = self.get_slice(range) as *const [T];
181        let mut_slice = (cptr as *mut [T]).as_mut();
182        // SAFETY: The pointer is never null, because it has been casted from a slice.
183        mut_slice.unwrap_or_else(|| unreachable_unchecked())
184    }
185}
186
187/// Lock guard variable type for [VecRangeLock].
188///
189/// The [Deref] and [DerefMut] traits are implemented for this struct.
190/// See the documentation of [VecRangeLock] for usage examples of [VecRangeLockGuard].
191#[derive(Debug)]
192pub struct VecRangeLockGuard<'a, T> {
193    /// Reference to the underlying lock.
194    lock: &'a VecRangeLock<T>,
195    /// The locked range.
196    range: Range<usize>,
197
198    /// Suppresses Send and Sync autotraits for VecRangeLockGuard.
199    /// The &mut suppresses Sync and the Rc suppresses Send.
200    #[allow(clippy::redundant_allocation)]
201    _p: PhantomData<Rc<&'a mut T>>,
202}
203
204impl<'a, T> VecRangeLockGuard<'a, T> {
205    #[inline]
206    fn new(lock: &'a VecRangeLock<T>, range: Range<usize>) -> VecRangeLockGuard<'a, T> {
207        VecRangeLockGuard {
208            lock,
209            range,
210            _p: PhantomData,
211        }
212    }
213}
214
215impl<'a, T> Drop for VecRangeLockGuard<'a, T> {
216    #[inline]
217    fn drop(&mut self) {
218        self.lock.unlock(&self.range);
219    }
220}
221
222impl<'a, T> Deref for VecRangeLockGuard<'a, T> {
223    type Target = [T];
224
225    #[inline]
226    fn deref(&self) -> &Self::Target {
227        // SAFETY: See deref_mut().
228        unsafe { self.lock.get_slice(&self.range) }
229    }
230}
231
232impl<'a, T> DerefMut for VecRangeLockGuard<'a, T> {
233    #[inline]
234    fn deref_mut(&mut self) -> &mut Self::Target {
235        // SAFETY:
236        // The lifetime of the slice is bounded by the lifetime of the guard.
237        // The lifetime of the guard is bounded by the lifetime of the range lock.
238        // The underlying data is owned by the range lock.
239        // Therefore the slice cannot outlive the data.
240        // The range lock ensures that no overlapping/conflicting guards
241        // can be constructed.
242        // The compiler ensures that the DerefMut result cannot be used,
243        // if there's also an immutable Deref result.
244        unsafe { self.lock.get_mut_slice(&self.range) }
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use std::cell::RefCell;
252    use std::sync::{Arc, Barrier};
253    use std::thread;
254
255    #[test]
256    fn test_base() {
257        {
258            // Range
259            let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
260            {
261                let mut g = a.try_lock(2..4).unwrap();
262                assert!(!a.ranges.lock().unwrap().is_empty());
263                assert_eq!(g[0..2], [3, 4]);
264                g[1] = 10;
265                assert_eq!(g[0..2], [3, 10]);
266            }
267            assert!(a.ranges.lock().unwrap().is_empty());
268        }
269        {
270            // RangeInclusive
271            let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
272            let g = a.try_lock(2..=4).unwrap();
273            assert_eq!(g[0..3], [3, 4, 5]);
274        }
275        {
276            // RangeTo
277            let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
278            let g = a.try_lock(..4).unwrap();
279            assert_eq!(g[0..4], [1, 2, 3, 4]);
280        }
281        {
282            // RangeToInclusive
283            let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
284            let g = a.try_lock(..=4).unwrap();
285            assert_eq!(g[0..5], [1, 2, 3, 4, 5]);
286        }
287        {
288            // RangeFrom
289            let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
290            let g = a.try_lock(2..).unwrap();
291            assert_eq!(g[0..4], [3, 4, 5, 6]);
292        }
293        {
294            // RangeFull
295            let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
296            let g = a.try_lock(..).unwrap();
297            assert_eq!(g[0..6], [1, 2, 3, 4, 5, 6]);
298        }
299    }
300
301    #[test]
302    fn test_empty_range() {
303        // Empty range doesn't cause conflicts.
304        let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
305        let g0 = a.try_lock(2..2).unwrap();
306        assert!(a.ranges.lock().unwrap().is_empty());
307        assert_eq!(g0[0..0], []);
308        let g1 = a.try_lock(2..2).unwrap();
309        assert!(a.ranges.lock().unwrap().is_empty());
310        assert_eq!(g1[0..0], []);
311    }
312
313    #[test]
314    #[should_panic(expected = "index out of bounds")]
315    fn test_base_oob_read() {
316        let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
317        let g = a.try_lock(2..4).unwrap();
318        let _ = g[2];
319    }
320
321    #[test]
322    #[should_panic(expected = "index out of bounds")]
323    fn test_base_oob_write() {
324        let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
325        let mut g = a.try_lock(2..4).unwrap();
326        g[2] = 10;
327    }
328
329    #[test]
330    #[should_panic(expected = "guard 1 panicked")]
331    fn test_overlap0() {
332        let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
333        let _g0 = a.try_lock(2..4).expect("guard 0 panicked");
334        let _g1 = a.try_lock(3..5).expect("guard 1 panicked");
335    }
336
337    #[test]
338    #[should_panic(expected = "guard 0 panicked")]
339    fn test_overlap1() {
340        let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
341        let _g1 = a.try_lock(3..5).expect("guard 1 panicked");
342        let _g0 = a.try_lock(2..4).expect("guard 0 panicked");
343    }
344
345    #[test]
346    fn test_thread_no_overlap() {
347        let a = Arc::new(VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]));
348        let b = Arc::clone(&a);
349        let c = Arc::clone(&a);
350        let ba0 = Arc::new(Barrier::new(2));
351        let ba1 = Arc::clone(&ba0);
352        let j0 = thread::spawn(move || {
353            {
354                let mut g = b.try_lock(2..4).unwrap();
355                assert!(!b.ranges.lock().unwrap().is_empty());
356                assert_eq!(g[0..2], [3, 4]);
357                g[1] = 10;
358                assert_eq!(g[0..2], [3, 10]);
359            }
360            ba0.wait();
361        });
362        let j1 = thread::spawn(move || {
363            {
364                let g = c.try_lock(4..6).unwrap();
365                assert!(!c.ranges.lock().unwrap().is_empty());
366                assert_eq!(g[0..2], [5, 6]);
367            }
368            ba1.wait();
369            let g = c.try_lock(3..5).unwrap();
370            assert_eq!(g[0..2], [10, 5]);
371        });
372        j1.join().expect("Thread 1 panicked.");
373        j0.join().expect("Thread 0 panicked.");
374        assert!(a.ranges.lock().unwrap().is_empty());
375    }
376
377    struct NoSyncStruct(RefCell<u32>); // No Sync auto-trait.
378
379    #[test]
380    fn test_nosync() {
381        let a = Arc::new(VecRangeLock::new(vec![
382            NoSyncStruct(RefCell::new(1)),
383            NoSyncStruct(RefCell::new(2)),
384            NoSyncStruct(RefCell::new(3)),
385            NoSyncStruct(RefCell::new(4)),
386        ]));
387        let b = Arc::clone(&a);
388        let c = Arc::clone(&a);
389        let ba0 = Arc::new(Barrier::new(2));
390        let ba1 = Arc::clone(&ba0);
391        let j0 = thread::spawn(move || {
392            let _g = b.try_lock(0..1).unwrap();
393            assert!(!b.ranges.lock().unwrap().is_empty());
394            ba0.wait();
395        });
396        let j1 = thread::spawn(move || {
397            let _g = c.try_lock(1..2).unwrap();
398            assert!(!c.ranges.lock().unwrap().is_empty());
399            ba1.wait();
400        });
401        j1.join().expect("Thread 1 panicked.");
402        j0.join().expect("Thread 0 panicked.");
403        assert!(a.ranges.lock().unwrap().is_empty());
404    }
405}
406
407// vim: ts=4 sw=4 expandtab