use crate::util::{
get_bounds,
overlaps_any,
};
use std::{
cell::UnsafeCell,
collections::HashSet,
ops::{
Deref,
DerefMut,
Range,
RangeBounds,
},
sync::{
LockResult,
Mutex,
PoisonError,
TryLockError,
TryLockResult,
}
};
#[derive(Debug)]
pub struct RangeLock<T> {
ranges: Mutex<HashSet<Range<usize>>>,
data: UnsafeCell<Vec<T>>,
}
unsafe impl<T> Sync for RangeLock<T> {}
impl<'a, T> RangeLock<T> {
pub fn new(data: Vec<T>) -> RangeLock<T> {
RangeLock {
ranges: Mutex::new(HashSet::new()),
data: UnsafeCell::new(data),
}
}
#[inline]
fn data_len(&self) -> usize {
unsafe { (*self.data.get()).len() }
}
pub fn into_inner(self) -> Vec<T> {
debug_assert!(self.ranges.lock().unwrap().is_empty());
self.data.into_inner()
}
pub fn try_lock(&'a self, range: impl RangeBounds<usize>) -> TryLockResult<RangeLockGuard<'a, T>> {
let data_len = self.data_len();
let (range_start, range_end) = get_bounds(&range, data_len);
if range_start >= data_len || range_end > data_len {
panic!("Range is out of bounds.");
}
if range_start > range_end {
panic!("Invalid range. Start is bigger than end.");
}
let range = range_start..range_end;
if range_start < range_end {
if let LockResult::Ok(mut ranges) = self.ranges.lock() {
if overlaps_any(&*ranges, &range) {
TryLockResult::Err(TryLockError::WouldBlock)
} else {
ranges.insert(range.clone());
TryLockResult::Ok(RangeLockGuard::new(self, range))
}
} else {
TryLockResult::Err(TryLockError::Poisoned(
PoisonError::new(RangeLockGuard::new(self, range))))
}
} else {
TryLockResult::Ok(RangeLockGuard::new(self, range))
}
}
fn unlock(&self, range: &Range<usize>) {
let mut ranges = self.ranges.lock()
.expect("RangeLock: Failed to take ranges mutex.");
ranges.remove(range);
}
#[inline]
unsafe fn get_slice(&self, range: &Range<usize>) -> &[T] {
&(*self.data.get())[range.clone()]
}
#[inline]
unsafe fn get_mut_slice(&self, range: &Range<usize>) -> &mut [T] {
let cptr = self.get_slice(range) as *const [T];
let mut_slice = (cptr as *mut [T]).as_mut();
mut_slice.unwrap()
}
}
#[derive(Debug)]
pub struct RangeLockGuard<'a, T> {
lock: &'a RangeLock<T>,
range: Range<usize>,
}
impl<'a, T> RangeLockGuard<'a, T> {
fn new(lock: &'a RangeLock<T>,
range: Range<usize>) -> RangeLockGuard<'a, T> {
RangeLockGuard {
lock,
range,
}
}
}
impl<'a, T> Drop for RangeLockGuard<'a, T> {
#[inline]
fn drop(&mut self) {
self.lock.unlock(&self.range);
}
}
impl<'a, T> Deref for RangeLockGuard<'a, T> {
type Target = [T];
#[inline]
fn deref(&self) -> &Self::Target {
unsafe { self.lock.get_slice(&self.range) }
}
}
impl<'a, T> DerefMut for RangeLockGuard<'a, T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.lock.get_mut_slice(&self.range) }
}
}
#[cfg(test)]
mod tests {
use std::cell::RefCell;
use std::sync::{Arc, Barrier};
use std::thread;
use super::*;
#[test]
fn test_base() {
{
let a = RangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
{
let mut g = a.try_lock(2..4).unwrap();
assert!(!a.ranges.lock().unwrap().is_empty());
assert_eq!(g[0..2], [3, 4]);
g[1] = 10;
assert_eq!(g[0..2], [3, 10]);
}
assert!(a.ranges.lock().unwrap().is_empty());
}
{
let a = RangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
let g = a.try_lock(2..=4).unwrap();
assert_eq!(g[0..3], [3, 4, 5]);
}
{
let a = RangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
let g = a.try_lock(..4).unwrap();
assert_eq!(g[0..4], [1, 2, 3, 4]);
}
{
let a = RangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
let g = a.try_lock(..=4).unwrap();
assert_eq!(g[0..5], [1, 2, 3, 4, 5]);
}
{
let a = RangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
let g = a.try_lock(2..).unwrap();
assert_eq!(g[0..4], [3, 4, 5, 6]);
}
{
let a = RangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
let g = a.try_lock(..).unwrap();
assert_eq!(g[0..6], [1, 2, 3, 4, 5, 6]);
}
}
#[test]
fn test_empty_range() {
let a = RangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
let g0 = a.try_lock(2..2).unwrap();
assert!(a.ranges.lock().unwrap().is_empty());
assert_eq!(g0[0..0], []);
let g1 = a.try_lock(2..2).unwrap();
assert!(a.ranges.lock().unwrap().is_empty());
assert_eq!(g1[0..0], []);
}
#[test]
#[should_panic(expected="index out of bounds")]
fn test_base_oob_read() {
let a = RangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
let g = a.try_lock(2..4).unwrap();
let _ = g[2];
}
#[test]
#[should_panic(expected="index out of bounds")]
fn test_base_oob_write() {
let a = RangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
let mut g = a.try_lock(2..4).unwrap();
g[2] = 10;
}
#[test]
#[should_panic(expected="guard 1 panicked")]
fn test_overlap0() {
let a = RangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
let _g0 = a.try_lock(2..4).expect("guard 0 panicked");
let _g1 = a.try_lock(3..5).expect("guard 1 panicked");
}
#[test]
#[should_panic(expected="guard 0 panicked")]
fn test_overlap1() {
let a = RangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]);
let _g1 = a.try_lock(3..5).expect("guard 1 panicked");
let _g0 = a.try_lock(2..4).expect("guard 0 panicked");
}
#[test]
fn test_thread_no_overlap() {
let a = Arc::new(RangeLock::new(vec![1_i32, 2, 3, 4, 5, 6]));
let b = Arc::clone(&a);
let c = Arc::clone(&a);
let ba0 = Arc::new(Barrier::new(2));
let ba1 = Arc::clone(&ba0);
let j0 = thread::spawn(move || {
{
let mut g = b.try_lock(2..4).unwrap();
assert!(!b.ranges.lock().unwrap().is_empty());
assert_eq!(g[0..2], [3, 4]);
g[1] = 10;
assert_eq!(g[0..2], [3, 10]);
}
ba0.wait();
});
let j1 = thread::spawn(move || {
{
let g = c.try_lock(4..6).unwrap();
assert!(!c.ranges.lock().unwrap().is_empty());
assert_eq!(g[0..2], [5, 6]);
}
ba1.wait();
let g = c.try_lock(3..5).unwrap();
assert_eq!(g[0..2], [10, 5]);
});
j1.join().expect("Thread 1 panicked.");
j0.join().expect("Thread 0 panicked.");
assert!(a.ranges.lock().unwrap().is_empty());
}
struct NoSyncStruct(RefCell<u32>);
#[test]
fn test_nosync() {
let a = Arc::new(RangeLock::new(vec![
NoSyncStruct(RefCell::new(1)),
NoSyncStruct(RefCell::new(2)),
NoSyncStruct(RefCell::new(3)),
NoSyncStruct(RefCell::new(4)),
]));
let b = Arc::clone(&a);
let c = Arc::clone(&a);
let ba0 = Arc::new(Barrier::new(2));
let ba1 = Arc::clone(&ba0);
let j0 = thread::spawn(move || {
let _g = b.try_lock(0..1).unwrap();
assert!(!b.ranges.lock().unwrap().is_empty());
ba0.wait();
});
let j1 = thread::spawn(move || {
let _g = c.try_lock(1..2).unwrap();
assert!(!c.ranges.lock().unwrap().is_empty());
ba1.wait();
});
j1.join().expect("Thread 1 panicked.");
j0.join().expect("Thread 0 panicked.");
assert!(a.ranges.lock().unwrap().is_empty());
}
}