glar_base/
range_rwlock.rs

1use std::cell::UnsafeCell;
2use std::sync::Mutex;
3
4/// # RangeLockResult
5/// The result of a range locking attempt
6pub enum RangeLockResult<Guard> {
7    Ok(Guard),
8    RangeConflict,
9    BadRange,
10    OtherError,
11}
12
13impl<Guard> RangeLockResult<Guard> {
14    pub fn unwrap(self) -> Guard {
15        match self {
16            RangeLockResult::Ok(guard) => guard,
17            RangeLockResult::RangeConflict => panic!("RangeConflict Error"),
18            RangeLockResult::BadRange => panic!("BadRange"),
19            RangeLockResult::OtherError => panic!("OtherError"),
20        }
21    }
22}
23
24/// make sure the RangeLock can be shared between threads
25unsafe impl<'a, T> Sync for RangeLock<'a, T> {}
26
27pub struct RangeLockWriteGuard<'lock, 'a, T: 'lock> {
28    rlock: &'lock RangeLock<'a, T>,
29    idx: usize,
30    start: usize,
31    end: usize,
32}
33
34impl<'lock, 'a, T> RangeLockWriteGuard<'lock, 'a, T>
35where
36    T: 'lock,
37{
38    pub fn get(&self) -> &mut [T] {
39        &mut self.rlock.data_mut()[self.start..self.end]
40    }
41
42    pub fn change_kc(&self, kc: usize) {
43        self.rlock.change_kc(kc);
44    }
45}
46
47impl<'a, 'lock, T> Drop for RangeLockWriteGuard<'a, 'lock, T>
48where
49    T: 'lock,
50{
51    fn drop(&mut self) {
52        self.rlock.remove_write(self.idx);
53    }
54}
55
56pub struct RangeLockReadGuard<'lock, 'a, T: 'lock> {
57    rlock: &'lock RangeLock<'a, T>,
58}
59
60impl<'lock, 'a, T> RangeLockReadGuard<'lock, 'a, T>
61where
62    T: 'lock,
63{
64    pub fn get(&self) -> &[T] {
65        self.rlock.data()
66    }
67}
68
69impl<'a, 'lock, T> Drop for RangeLockReadGuard<'a, 'lock, T>
70where
71    T: 'lock,
72{
73    fn drop(&mut self) {
74        self.rlock.remove_read();
75    }
76}
77
78// variation of rwlock where
79// write access has also idx features to subslices with len n see struct field
80// read access is to the entire slice
81// this is has the least complexity and enough # of features to fit my purpose
82
83/// # RangeLock
84/// Allows multiple immutable and mutable borrows based on access ranges.
85pub struct RangeLock<'a, T> {
86    n: usize,
87    mc_chunk_len: usize,
88    ranges: Mutex<(Vec<bool>, usize, usize)>,
89    data: UnsafeCell<&'a mut [T]>,
90}
91
92impl<'a, T> RangeLock<'a, T> {
93    pub fn from(data: &'a mut [T], n: usize, mc: usize, kc: usize, mr: usize) -> Self {
94        let pool_size = mc * kc;
95        assert!(pool_size <= data.len(), "pool_size: {}, data.len(): {}", pool_size, data.len());
96        let mc_chunk_len = ((mc + n * mr - 1) / (n * mr)) * mr;
97        let ranges = Mutex::new((vec![false; n], 0, kc));
98        RangeLock { n, mc_chunk_len, ranges, data: UnsafeCell::new(data) }
99    }
100
101    pub fn get_mc(&self) -> usize {
102        self.mc_chunk_len
103    }
104
105    pub fn change_kc(&self, kc: usize) {
106        let mut x = self.ranges.lock().unwrap();
107        assert!(x.1 == 0, "read mode is on, cannot change kc: {}", x.1);
108        x.2 = kc;
109    }
110
111    pub fn len(&self) -> usize {
112        unsafe { (*self.data.get()).len() }
113    }
114
115    /// get a reference to the data
116    fn data(&self) -> &[T] {
117        unsafe { *self.data.get() }
118    }
119
120    /// get a mutable reference to the data
121    fn data_mut(&self) -> &mut [T] {
122        unsafe { *self.data.get() }
123    }
124
125    pub fn read(&self) -> RangeLockResult<RangeLockReadGuard<'a, '_, T>> {
126        let mut x = self.ranges.lock().unwrap();
127        let ranges = &mut x.0;
128        // if ranges is not empty, then there is a conflict
129        if ranges.iter().any(|&x| x) {
130            return RangeLockResult::RangeConflict;
131        }
132        let read_mode = &mut x.1;
133        // println!("reading read_mode: {}", *read_mode);
134        *read_mode += 1;
135        RangeLockResult::Ok(RangeLockReadGuard { rlock: &self })
136    }
137
138    pub fn write(&self, idx: usize, kc: usize) -> RangeLockResult<RangeLockWriteGuard<'a, '_, T>> {
139        if idx > self.n {
140            return RangeLockResult::BadRange;
141        }
142        let mut x = self.ranges.lock().unwrap();
143        let is_occupied = &x.0;
144        let read_mode = &x.1;
145        let chunk_len = self.mc_chunk_len * kc;
146
147        // TODO: add check for kc_len stays the same from write->write
148        // it is fine to use contains since len of ranges is small ( ~ num threads / ic_par or jc_par)
149        // on average 2-4
150        // conflict if the idx is already in ranges or read_mode is on
151        if is_occupied[idx] || *read_mode > 0 {
152            return RangeLockResult::RangeConflict;
153        }
154        let is_occupied = &mut x.0;
155        is_occupied[idx] = true;
156
157        let (start, end) = (idx * chunk_len, ((idx + 1) * chunk_len).min(self.len()));
158        RangeLockResult::Ok(RangeLockWriteGuard { rlock: &self, idx, start, end })
159    }
160
161    fn remove_write(&self, idx: usize) {
162        let mut x = self.ranges.lock().unwrap();
163        let ranges = &mut x.0;
164        ranges[idx] = false;
165    }
166
167    fn remove_read(&self) {
168        let mut x = self.ranges.lock().unwrap();
169        let read_mode = &mut x.1;
170        *read_mode -= 1;
171    }
172}
173
174// mod test {
175//     use super::*;
176//     // #[test]
177//     fn range_lock_read_test() {
178//         let mut data: Vec<usize> = vec![0, 1, 2, 3, 4, 5];
179//         let lock = RangeLock::from(&mut data, 2);
180//         assert!(lock.data()[0..3] == [0, 1, 2])
181//     }
182
183//     // #[test]
184//     fn range_lock_write_test() {
185//         let mut data: Vec<usize> = vec![0usize; 12];
186//         let lock = RangeLock::from(&mut data, 4);
187//         {
188//             let guard = lock.write(0).unwrap();
189//             let guard_ref = guard.get();
190
191//             guard_ref[0] = 2;
192//             guard_ref[1] = 1;
193//             guard_ref[2] = 0;
194//         }
195//         assert!(data[0..3] == [2, 1, 0])
196//     }
197
198//     // #[test]
199//     fn range_lock_write_test_mt() {
200//         use glar_dev::random_matrix_uniform;
201//         // create vec of random integers
202//         let n_thread = 4;
203//         let chunk_num = 4;
204//         let vec_len = 12;
205//         let chunk_len = vec_len / chunk_num;
206//         let mut data_0 = vec![0i32; vec_len];
207//         random_matrix_uniform(vec_len, 1, &mut data_0, vec_len);
208
209//         // lock for the data_0
210//         let lock0 = RangeLock::from(&mut data_0, n_thread);
211//         let lock0_r = &lock0;
212
213//         let mut data = vec![0i32; vec_len];
214//         let lock = RangeLock::from(&mut data, n_thread);
215//         let lock_r = &lock;
216//         use std::thread;
217
218//         thread::scope(|s| {
219//             for i in 0..n_thread {
220//                 s.spawn(move || {
221//                     let guard = lock_r.write(i).unwrap();
222//                     let data_slice = guard.get();
223//                     let guard0 = lock0_r.read().unwrap();
224//                     let data_slice0 = guard0.get();
225//                     let offset = i * chunk_len;
226//                     for j in 0..chunk_len {
227//                         data_slice[j] = data_slice0[j + offset];
228//                         std::thread::sleep(std::time::Duration::from_secs(1));
229//                     }
230//                 });
231//             }
232//         });
233
234//         let n_thread_r = 13;
235//         thread::scope(|s| {
236//             for _ in 0..n_thread_r {
237//                 s.spawn(move || {
238//                     let guard = lock_r.read().unwrap();
239//                     let data_slice = guard.get();
240//                     let guard0 = lock0_r.read().unwrap();
241//                     let data_slice0 = guard0.get();
242
243//                     assert!(data_slice[0..12] == data_slice0[0..12]);
244//                 });
245//             }
246//         });
247
248//     }
249// }