range_lock/
rangelock.rs

1// -*- coding: utf-8 -*-
2//
3// Copyright 2021-2025 Michael Büsch <m@bues.ch>
4//
5// Licensed under the Apache License version 2.0
6// or the MIT license, at your option.
7// SPDX-License-Identifier: Apache-2.0 OR MIT
8//
9
10use crate::{lockedranges::LockedRanges, util::get_bounds, vecparts::VecParts};
11use std::{
12    cell::UnsafeCell,
13    marker::PhantomData,
14    mem::size_of,
15    ops::{Deref, DerefMut, Range, RangeBounds},
16    slice,
17    sync::{LockResult, Mutex, PoisonError, TryLockError, TryLockResult},
18};
19
20/// General purpose multi-thread range lock for [std::vec::Vec].
21///
22/// # Example
23///
24/// ```
25/// use range_lock::VecRangeLock;
26/// use std::{sync::{Arc, Barrier}, thread};
27///
28/// let data = vec![10, 11, 12, 13];
29///
30/// let data_lock0 = Arc::new(VecRangeLock::new(data));
31/// let data_lock1 = Arc::clone(&data_lock0);
32/// let data_lock2 = Arc::clone(&data_lock0);
33///
34/// // Thread barrier, only for demonstration purposes.
35/// let barrier0 = Arc::new(Barrier::new(2));
36/// let barrier1 = Arc::clone(&barrier0);
37///
38/// thread::scope(|s| {
39///     s.spawn(move || {
40///         {
41///             let mut guard = data_lock0.try_lock(0..2).expect("T0: Failed to lock 0..2");
42///             guard[0] = 100; // Write to data[0]
43///         }
44///         barrier0.wait(); // Synchronize with second thread.
45///         {
46///             let guard = data_lock0.try_lock(2..4).expect("T0: Failed to lock 2..4");
47///             assert_eq!(guard[0], 200); // Read from data[2]
48///         }
49///     });
50///
51///     s.spawn(move || {
52///         {
53///             let mut guard = data_lock1.try_lock(2..4).expect("T1: Failed to lock 2..4");
54///             guard[0] = 200; // Write to data[2]
55///         }
56///         barrier1.wait(); // Synchronize with first thread.
57///         {
58///             let guard = data_lock1.try_lock(0..2).expect("T1: Failed to lock 0..2");
59///             assert_eq!(guard[0], 100); // Read from data[0]
60///         }
61///     });
62/// });
63///
64/// let data = Arc::try_unwrap(data_lock2).expect("Arc unwrap failed").into_inner();
65///
66/// assert_eq!(data, vec![100, 11, 200, 13]);
67/// ```
68#[derive(Debug)]
69pub struct VecRangeLock<T> {
70    /// Set of the currently locked ranges.
71    ranges: Mutex<LockedRanges>,
72    /// The underlying data.
73    data: UnsafeCell<VecParts<T>>,
74}
75
76// SAFETY:
77// It is safe to access VecRangeLock and the contained data (via VecRangeLockGuard)
78// from multiple threads simultaneously.
79// The lock ensures that access to the data is strictly serialized.
80// T must be Send-able to other threads.
81unsafe impl<T> Sync for VecRangeLock<T> where T: Send {}
82
83impl<'a, T> VecRangeLock<T> {
84    /// Construct a new [VecRangeLock].
85    ///
86    /// * `data`: The data [Vec] to protect.
87    pub fn new(data: Vec<T>) -> VecRangeLock<T> {
88        VecRangeLock {
89            ranges: Mutex::new(LockedRanges::new()),
90            data: UnsafeCell::new(data.into()),
91        }
92    }
93
94    /// Get the length (in number of elements) of the embedded [Vec].
95    #[inline]
96    pub fn data_len(&self) -> usize {
97        // SAFETY: The UnsafeCell content it always valid.
98        unsafe { (*self.data.get()).len() }
99    }
100
101    /// Unwrap this [VecRangeLock] into the contained data.
102    /// This method consumes self.
103    #[inline]
104    pub fn into_inner(self) -> Vec<T> {
105        debug_assert!(self.ranges.lock().unwrap().is_empty());
106        self.data.into_inner().into()
107    }
108
109    /// Try to lock the given data `range`.
110    ///
111    /// * On success: Returns a [VecRangeLockGuard] that can be used to access the locked region.
112    ///   Dereferencing [VecRangeLockGuard] yields a slice of the `data`.
113    /// * On failure: Returns [TryLockError::WouldBlock], if the range is contended.
114    ///   The locking attempt may be retried by the caller upon contention.
115    ///   Returns [TryLockError::Poisoned], if the lock is poisoned.
116    pub fn try_lock(
117        &'a self,
118        range: impl RangeBounds<usize>,
119    ) -> TryLockResult<VecRangeLockGuard<'a, T>> {
120        let data_len = self.data_len();
121        let (range_start, range_end) = get_bounds(&range, data_len);
122        if range_start >= data_len || range_end > data_len {
123            panic!("Range is out of bounds.");
124        }
125        if range_start > range_end {
126            panic!("Invalid range. Start is bigger than end.");
127        }
128        let range = range_start..range_end;
129
130        if range.is_empty() {
131            TryLockResult::Ok(VecRangeLockGuard::new(self, range))
132        } else if let LockResult::Ok(mut ranges) = self.ranges.lock() {
133            if ranges.insert(&range) {
134                TryLockResult::Ok(VecRangeLockGuard::new(self, range))
135            } else {
136                TryLockResult::Err(TryLockError::WouldBlock)
137            }
138        } else {
139            TryLockResult::Err(TryLockError::Poisoned(PoisonError::new(
140                VecRangeLockGuard::new(self, range),
141            )))
142        }
143    }
144
145    /// Unlock a range.
146    fn unlock(&self, range: &Range<usize>) {
147        if !range.is_empty() {
148            let mut ranges = self
149                .ranges
150                .lock()
151                .expect("VecRangeLock: Failed to take ranges mutex.");
152            ranges.remove(range);
153        }
154    }
155
156    /// Get an immutable slice to the specified range.
157    ///
158    /// # SAFETY
159    ///
160    /// See get_mut_slice().
161    #[inline]
162    unsafe fn get_slice(&self, range: &Range<usize>) -> &[T] {
163        let data = (*self.data.get()).ptr();
164        assert!(range.start <= isize::MAX as usize / size_of::<T>());
165        // SAFETY: The caller is responsible for passing a range that results in a valid slice
166        // and isize overflow has been checked here.
167        unsafe { slice::from_raw_parts(data.add(range.start) as _, range.end - range.start) }
168    }
169
170    /// Get a mutable slice to the specified range.
171    ///
172    /// # SAFETY
173    ///
174    /// The caller must ensure that:
175    /// * No overlapping slices must coexist on multiple threads.
176    /// * Immutable slices to overlapping ranges may only coexist on a single thread.
177    /// * Immutable and mutable slices must not coexist.
178    #[inline]
179    #[allow(clippy::mut_from_ref)]
180    unsafe fn get_mut_slice(&self, range: &Range<usize>) -> &mut [T] {
181        let data = (*self.data.get()).ptr();
182        assert!(range.start <= isize::MAX as usize / size_of::<T>());
183        // SAFETY: The caller is responsible for passing a range that results in a valid slice
184        // and isize overflow has been checked here.
185        unsafe { slice::from_raw_parts_mut(data.add(range.start) as _, range.end - range.start) }
186    }
187}
188
189/// Lock guard variable type for [VecRangeLock].
190///
191/// The [Deref] and [DerefMut] traits are implemented for this struct.
192/// See the documentation of [VecRangeLock] for usage examples of [VecRangeLockGuard].
193#[derive(Debug)]
194pub struct VecRangeLockGuard<'a, T> {
195    /// Reference to the underlying lock.
196    lock: &'a VecRangeLock<T>,
197    /// The locked range.
198    range: Range<usize>,
199
200    /// Suppresses Send and Sync autotraits for VecRangeLockGuard.
201    _p: PhantomData<*mut T>,
202}
203
204impl<'a, T> VecRangeLockGuard<'a, T> {
205    #[inline]
206    fn new(lock: &'a VecRangeLock<T>, range: Range<usize>) -> VecRangeLockGuard<'a, T> {
207        VecRangeLockGuard {
208            lock,
209            range,
210            _p: PhantomData,
211        }
212    }
213}
214
215impl<T> Drop for VecRangeLockGuard<'_, T> {
216    #[inline]
217    fn drop(&mut self) {
218        self.lock.unlock(&self.range);
219    }
220}
221
222impl<T> Deref for VecRangeLockGuard<'_, T> {
223    type Target = [T];
224
225    #[inline]
226    fn deref(&self) -> &Self::Target {
227        // SAFETY: See deref_mut().
228        unsafe { self.lock.get_slice(&self.range) }
229    }
230}
231
232impl<T> DerefMut for VecRangeLockGuard<'_, T> {
233    #[inline]
234    fn deref_mut(&mut self) -> &mut Self::Target {
235        // SAFETY:
236        // The lifetime of the slice is bounded by the lifetime of the guard.
237        // The lifetime of the guard is bounded by the lifetime of the range lock.
238        // The underlying data is owned by the range lock.
239        // Therefore the slice cannot outlive the data.
240        // The range lock ensures that no overlapping/conflicting guards
241        // can be constructed.
242        // The compiler ensures that the DerefMut result cannot be used,
243        // if there's also an immutable Deref result.
244        unsafe { self.lock.get_mut_slice(&self.range) }
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use std::cell::RefCell;
252    use std::sync::{Arc, Barrier};
253    use std::thread;
254
255    #[test]
256    fn test_base() {
257        {
258            // Range
259            let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
260            {
261                let mut g = a.try_lock(2..4).unwrap();
262                assert!(!a.ranges.lock().unwrap().is_empty());
263                assert_eq!(g[0..2], [3, 4]);
264                g[1] = 10;
265                assert_eq!(g[0..2], [3, 10]);
266            }
267            assert!(a.ranges.lock().unwrap().is_empty());
268        }
269        {
270            // RangeInclusive
271            let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
272            let g = a.try_lock(2..=4).unwrap();
273            assert_eq!(g[0..3], [3, 4, 5]);
274        }
275        {
276            // RangeTo
277            let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
278            let g = a.try_lock(..4).unwrap();
279            assert_eq!(g[0..4], [1, 2, 3, 4]);
280        }
281        {
282            // RangeToInclusive
283            let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
284            let g = a.try_lock(..=4).unwrap();
285            assert_eq!(g[0..5], [1, 2, 3, 4, 5]);
286        }
287        {
288            // RangeFrom
289            let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
290            let g = a.try_lock(2..).unwrap();
291            assert_eq!(g[0..4], [3, 4, 5, 6]);
292        }
293        {
294            // RangeFull
295            let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
296            let g = a.try_lock(..).unwrap();
297            assert_eq!(g[0..6], [1, 2, 3, 4, 5, 6]);
298        }
299    }
300
301    #[test]
302    fn test_empty_range() {
303        // Empty range doesn't cause conflicts.
304        let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
305        let g0 = a.try_lock(2..2).unwrap();
306        assert!(a.ranges.lock().unwrap().is_empty());
307        assert_eq!(g0[0..0], []);
308        let g1 = a.try_lock(2..2).unwrap();
309        assert!(a.ranges.lock().unwrap().is_empty());
310        assert_eq!(g1[0..0], []);
311    }
312
313    #[test]
314    #[should_panic(expected = "index out of bounds")]
315    fn test_base_oob_read() {
316        let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
317        let g = a.try_lock(2..4).unwrap();
318        let _ = g[2];
319    }
320
321    #[test]
322    #[should_panic(expected = "index out of bounds")]
323    fn test_base_oob_write() {
324        let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
325        let mut g = a.try_lock(2..4).unwrap();
326        g[2] = 10;
327    }
328
329    #[test]
330    #[should_panic(expected = "guard 1 panicked")]
331    fn test_overlap0() {
332        let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
333        let _g0 = a.try_lock(2..4).expect("guard 0 panicked");
334        let _g1 = a.try_lock(3..5).expect("guard 1 panicked");
335    }
336
337    #[test]
338    #[should_panic(expected = "guard 0 panicked")]
339    fn test_overlap1() {
340        let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
341        let _g1 = a.try_lock(3..5).expect("guard 1 panicked");
342        let _g0 = a.try_lock(2..4).expect("guard 0 panicked");
343    }
344
345    #[test]
346    fn test_thread_no_overlap() {
347        let a = Arc::new(VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]));
348        let b = Arc::clone(&a);
349        let c = Arc::clone(&a);
350        let ba0 = Arc::new(Barrier::new(2));
351        let ba1 = Arc::clone(&ba0);
352        let j0 = thread::spawn(move || {
353            {
354                let mut g = b.try_lock(2..4).unwrap();
355                assert!(!b.ranges.lock().unwrap().is_empty());
356                assert_eq!(g[0..2], [3, 4]);
357                g[1] = 10;
358                assert_eq!(g[0..2], [3, 10]);
359            }
360            ba0.wait();
361        });
362        let j1 = thread::spawn(move || {
363            {
364                let g = c.try_lock(4..6).unwrap();
365                assert!(!c.ranges.lock().unwrap().is_empty());
366                assert_eq!(g[0..2], [5, 6]);
367            }
368            ba1.wait();
369            let g = c.try_lock(3..5).unwrap();
370            assert_eq!(g[0..2], [10, 5]);
371        });
372        j1.join().expect("Thread 1 panicked.");
373        j0.join().expect("Thread 0 panicked.");
374        assert!(a.ranges.lock().unwrap().is_empty());
375    }
376
377    #[allow(dead_code)]
378    struct NoSyncStruct(RefCell<u32>); // No Sync auto-trait.
379
380    #[test]
381    fn test_nosync() {
382        let a = Arc::new(VecRangeLock::new(vec![
383            NoSyncStruct(RefCell::new(1)),
384            NoSyncStruct(RefCell::new(2)),
385            NoSyncStruct(RefCell::new(3)),
386            NoSyncStruct(RefCell::new(4)),
387        ]));
388        let b = Arc::clone(&a);
389        let c = Arc::clone(&a);
390        let ba0 = Arc::new(Barrier::new(2));
391        let ba1 = Arc::clone(&ba0);
392        let j0 = thread::spawn(move || {
393            let _g = b.try_lock(0..1).unwrap();
394            assert!(!b.ranges.lock().unwrap().is_empty());
395            ba0.wait();
396        });
397        let j1 = thread::spawn(move || {
398            let _g = c.try_lock(1..2).unwrap();
399            assert!(!c.ranges.lock().unwrap().is_empty());
400            ba1.wait();
401        });
402        j1.join().expect("Thread 1 panicked.");
403        j0.join().expect("Thread 0 panicked.");
404        assert!(a.ranges.lock().unwrap().is_empty());
405    }
406}
407
408// vim: ts=4 sw=4 expandtab