1use crate::{lockedranges::LockedRanges, util::get_bounds};
11use std::{
12 cell::UnsafeCell,
13 hint::unreachable_unchecked,
14 marker::PhantomData,
15 ops::{Deref, DerefMut, Range, RangeBounds},
16 rc::Rc,
17 sync::{LockResult, Mutex, PoisonError, TryLockError, TryLockResult},
18};
19
20#[derive(Debug)]
69pub struct VecRangeLock<T> {
70 ranges: Mutex<LockedRanges>,
72 data: UnsafeCell<Vec<T>>,
74}
75
76unsafe impl<T> Sync for VecRangeLock<T> where T: Send {}
82
83impl<'a, T> VecRangeLock<T> {
84 pub fn new(data: Vec<T>) -> VecRangeLock<T> {
88 VecRangeLock {
89 ranges: Mutex::new(LockedRanges::new()),
90 data: UnsafeCell::new(data),
91 }
92 }
93
94 #[inline]
96 pub fn data_len(&self) -> usize {
97 unsafe { (*self.data.get()).len() }
99 }
100
101 #[inline]
104 pub fn into_inner(self) -> Vec<T> {
105 debug_assert!(self.ranges.lock().unwrap().is_empty());
106 self.data.into_inner()
107 }
108
109 pub fn try_lock(
117 &'a self,
118 range: impl RangeBounds<usize>,
119 ) -> TryLockResult<VecRangeLockGuard<'a, T>> {
120 let data_len = self.data_len();
121 let (range_start, range_end) = get_bounds(&range, data_len);
122 if range_start >= data_len || range_end > data_len {
123 panic!("Range is out of bounds.");
124 }
125 if range_start > range_end {
126 panic!("Invalid range. Start is bigger than end.");
127 }
128 let range = range_start..range_end;
129
130 if range.is_empty() {
131 TryLockResult::Ok(VecRangeLockGuard::new(self, range))
132 } else if let LockResult::Ok(mut ranges) = self.ranges.lock() {
133 if ranges.insert(&range) {
134 TryLockResult::Ok(VecRangeLockGuard::new(self, range))
135 } else {
136 TryLockResult::Err(TryLockError::WouldBlock)
137 }
138 } else {
139 TryLockResult::Err(TryLockError::Poisoned(PoisonError::new(
140 VecRangeLockGuard::new(self, range),
141 )))
142 }
143 }
144
145 fn unlock(&self, range: &Range<usize>) {
147 if !range.is_empty() {
148 let mut ranges = self
149 .ranges
150 .lock()
151 .expect("VecRangeLock: Failed to take ranges mutex.");
152 ranges.remove(range);
153 }
154 }
155
156 #[inline]
162 unsafe fn get_slice(&self, range: &Range<usize>) -> &[T] {
163 &(*self.data.get())[range.clone()]
167 }
168
169 #[inline]
178 #[allow(clippy::mut_from_ref)] unsafe fn get_mut_slice(&self, range: &Range<usize>) -> &mut [T] {
180 let cptr = self.get_slice(range) as *const [T];
181 let mut_slice = (cptr as *mut [T]).as_mut();
182 mut_slice.unwrap_or_else(|| unreachable_unchecked())
184 }
185}
186
187#[derive(Debug)]
192pub struct VecRangeLockGuard<'a, T> {
193 lock: &'a VecRangeLock<T>,
195 range: Range<usize>,
197
198 #[allow(clippy::redundant_allocation)]
201 _p: PhantomData<Rc<&'a mut T>>,
202}
203
204impl<'a, T> VecRangeLockGuard<'a, T> {
205 #[inline]
206 fn new(lock: &'a VecRangeLock<T>, range: Range<usize>) -> VecRangeLockGuard<'a, T> {
207 VecRangeLockGuard {
208 lock,
209 range,
210 _p: PhantomData,
211 }
212 }
213}
214
215impl<'a, T> Drop for VecRangeLockGuard<'a, T> {
216 #[inline]
217 fn drop(&mut self) {
218 self.lock.unlock(&self.range);
219 }
220}
221
222impl<'a, T> Deref for VecRangeLockGuard<'a, T> {
223 type Target = [T];
224
225 #[inline]
226 fn deref(&self) -> &Self::Target {
227 unsafe { self.lock.get_slice(&self.range) }
229 }
230}
231
232impl<'a, T> DerefMut for VecRangeLockGuard<'a, T> {
233 #[inline]
234 fn deref_mut(&mut self) -> &mut Self::Target {
235 unsafe { self.lock.get_mut_slice(&self.range) }
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use std::cell::RefCell;
252 use std::sync::{Arc, Barrier};
253 use std::thread;
254
255 #[test]
256 fn test_base() {
257 {
258 let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
260 {
261 let mut g = a.try_lock(2..4).unwrap();
262 assert!(!a.ranges.lock().unwrap().is_empty());
263 assert_eq!(g[0..2], [3, 4]);
264 g[1] = 10;
265 assert_eq!(g[0..2], [3, 10]);
266 }
267 assert!(a.ranges.lock().unwrap().is_empty());
268 }
269 {
270 let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
272 let g = a.try_lock(2..=4).unwrap();
273 assert_eq!(g[0..3], [3, 4, 5]);
274 }
275 {
276 let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
278 let g = a.try_lock(..4).unwrap();
279 assert_eq!(g[0..4], [1, 2, 3, 4]);
280 }
281 {
282 let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
284 let g = a.try_lock(..=4).unwrap();
285 assert_eq!(g[0..5], [1, 2, 3, 4, 5]);
286 }
287 {
288 let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
290 let g = a.try_lock(2..).unwrap();
291 assert_eq!(g[0..4], [3, 4, 5, 6]);
292 }
293 {
294 let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
296 let g = a.try_lock(..).unwrap();
297 assert_eq!(g[0..6], [1, 2, 3, 4, 5, 6]);
298 }
299 }
300
301 #[test]
302 fn test_empty_range() {
303 let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
305 let g0 = a.try_lock(2..2).unwrap();
306 assert!(a.ranges.lock().unwrap().is_empty());
307 assert_eq!(g0[0..0], []);
308 let g1 = a.try_lock(2..2).unwrap();
309 assert!(a.ranges.lock().unwrap().is_empty());
310 assert_eq!(g1[0..0], []);
311 }
312
313 #[test]
314 #[should_panic(expected = "index out of bounds")]
315 fn test_base_oob_read() {
316 let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
317 let g = a.try_lock(2..4).unwrap();
318 let _ = g[2];
319 }
320
321 #[test]
322 #[should_panic(expected = "index out of bounds")]
323 fn test_base_oob_write() {
324 let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
325 let mut g = a.try_lock(2..4).unwrap();
326 g[2] = 10;
327 }
328
329 #[test]
330 #[should_panic(expected = "guard 1 panicked")]
331 fn test_overlap0() {
332 let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
333 let _g0 = a.try_lock(2..4).expect("guard 0 panicked");
334 let _g1 = a.try_lock(3..5).expect("guard 1 panicked");
335 }
336
337 #[test]
338 #[should_panic(expected = "guard 0 panicked")]
339 fn test_overlap1() {
340 let a = VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
341 let _g1 = a.try_lock(3..5).expect("guard 1 panicked");
342 let _g0 = a.try_lock(2..4).expect("guard 0 panicked");
343 }
344
345 #[test]
346 fn test_thread_no_overlap() {
347 let a = Arc::new(VecRangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]));
348 let b = Arc::clone(&a);
349 let c = Arc::clone(&a);
350 let ba0 = Arc::new(Barrier::new(2));
351 let ba1 = Arc::clone(&ba0);
352 let j0 = thread::spawn(move || {
353 {
354 let mut g = b.try_lock(2..4).unwrap();
355 assert!(!b.ranges.lock().unwrap().is_empty());
356 assert_eq!(g[0..2], [3, 4]);
357 g[1] = 10;
358 assert_eq!(g[0..2], [3, 10]);
359 }
360 ba0.wait();
361 });
362 let j1 = thread::spawn(move || {
363 {
364 let g = c.try_lock(4..6).unwrap();
365 assert!(!c.ranges.lock().unwrap().is_empty());
366 assert_eq!(g[0..2], [5, 6]);
367 }
368 ba1.wait();
369 let g = c.try_lock(3..5).unwrap();
370 assert_eq!(g[0..2], [10, 5]);
371 });
372 j1.join().expect("Thread 1 panicked.");
373 j0.join().expect("Thread 0 panicked.");
374 assert!(a.ranges.lock().unwrap().is_empty());
375 }
376
377 struct NoSyncStruct(RefCell<u32>); #[test]
380 fn test_nosync() {
381 let a = Arc::new(VecRangeLock::new(vec![
382 NoSyncStruct(RefCell::new(1)),
383 NoSyncStruct(RefCell::new(2)),
384 NoSyncStruct(RefCell::new(3)),
385 NoSyncStruct(RefCell::new(4)),
386 ]));
387 let b = Arc::clone(&a);
388 let c = Arc::clone(&a);
389 let ba0 = Arc::new(Barrier::new(2));
390 let ba1 = Arc::clone(&ba0);
391 let j0 = thread::spawn(move || {
392 let _g = b.try_lock(0..1).unwrap();
393 assert!(!b.ranges.lock().unwrap().is_empty());
394 ba0.wait();
395 });
396 let j1 = thread::spawn(move || {
397 let _g = c.try_lock(1..2).unwrap();
398 assert!(!c.ranges.lock().unwrap().is_empty());
399 ba1.wait();
400 });
401 j1.join().expect("Thread 1 panicked.");
402 j0.join().expect("Thread 0 panicked.");
403 assert!(a.ranges.lock().unwrap().is_empty());
404 }
405}
406
407