glar_base/range_rwlock.rs
1use std::cell::UnsafeCell;
2use std::sync::Mutex;
3
4/// # RangeLockResult
5/// The result of a range locking attempt
6pub enum RangeLockResult<Guard> {
7 Ok(Guard),
8 RangeConflict,
9 BadRange,
10 OtherError,
11}
12
13impl<Guard> RangeLockResult<Guard> {
14 pub fn unwrap(self) -> Guard {
15 match self {
16 RangeLockResult::Ok(guard) => guard,
17 RangeLockResult::RangeConflict => panic!("RangeConflict Error"),
18 RangeLockResult::BadRange => panic!("BadRange"),
19 RangeLockResult::OtherError => panic!("OtherError"),
20 }
21 }
22}
23
24/// make sure the RangeLock can be shared between threads
25unsafe impl<'a, T> Sync for RangeLock<'a, T> {}
26
27pub struct RangeLockWriteGuard<'lock, 'a, T: 'lock> {
28 rlock: &'lock RangeLock<'a, T>,
29 idx: usize,
30 start: usize,
31 end: usize,
32}
33
34impl<'lock, 'a, T> RangeLockWriteGuard<'lock, 'a, T>
35where
36 T: 'lock,
37{
38 pub fn get(&self) -> &mut [T] {
39 &mut self.rlock.data_mut()[self.start..self.end]
40 }
41
42 pub fn change_kc(&self, kc: usize) {
43 self.rlock.change_kc(kc);
44 }
45}
46
47impl<'a, 'lock, T> Drop for RangeLockWriteGuard<'a, 'lock, T>
48where
49 T: 'lock,
50{
51 fn drop(&mut self) {
52 self.rlock.remove_write(self.idx);
53 }
54}
55
56pub struct RangeLockReadGuard<'lock, 'a, T: 'lock> {
57 rlock: &'lock RangeLock<'a, T>,
58}
59
60impl<'lock, 'a, T> RangeLockReadGuard<'lock, 'a, T>
61where
62 T: 'lock,
63{
64 pub fn get(&self) -> &[T] {
65 self.rlock.data()
66 }
67}
68
69impl<'a, 'lock, T> Drop for RangeLockReadGuard<'a, 'lock, T>
70where
71 T: 'lock,
72{
73 fn drop(&mut self) {
74 self.rlock.remove_read();
75 }
76}
77
78// variation of rwlock where
79// write access has also idx features to subslices with len n see struct field
80// read access is to the entire slice
81// this is has the least complexity and enough # of features to fit my purpose
82
83/// # RangeLock
84/// Allows multiple immutable and mutable borrows based on access ranges.
85pub struct RangeLock<'a, T> {
86 n: usize,
87 mc_chunk_len: usize,
88 ranges: Mutex<(Vec<bool>, usize, usize)>,
89 data: UnsafeCell<&'a mut [T]>,
90}
91
92impl<'a, T> RangeLock<'a, T> {
93 pub fn from(data: &'a mut [T], n: usize, mc: usize, kc: usize, mr: usize) -> Self {
94 let pool_size = mc * kc;
95 assert!(pool_size <= data.len(), "pool_size: {}, data.len(): {}", pool_size, data.len());
96 let mc_chunk_len = ((mc + n * mr - 1) / (n * mr)) * mr;
97 let ranges = Mutex::new((vec![false; n], 0, kc));
98 RangeLock { n, mc_chunk_len, ranges, data: UnsafeCell::new(data) }
99 }
100
101 pub fn get_mc(&self) -> usize {
102 self.mc_chunk_len
103 }
104
105 pub fn change_kc(&self, kc: usize) {
106 let mut x = self.ranges.lock().unwrap();
107 assert!(x.1 == 0, "read mode is on, cannot change kc: {}", x.1);
108 x.2 = kc;
109 }
110
111 pub fn len(&self) -> usize {
112 unsafe { (*self.data.get()).len() }
113 }
114
115 /// get a reference to the data
116 fn data(&self) -> &[T] {
117 unsafe { *self.data.get() }
118 }
119
120 /// get a mutable reference to the data
121 fn data_mut(&self) -> &mut [T] {
122 unsafe { *self.data.get() }
123 }
124
125 pub fn read(&self) -> RangeLockResult<RangeLockReadGuard<'a, '_, T>> {
126 let mut x = self.ranges.lock().unwrap();
127 let ranges = &mut x.0;
128 // if ranges is not empty, then there is a conflict
129 if ranges.iter().any(|&x| x) {
130 return RangeLockResult::RangeConflict;
131 }
132 let read_mode = &mut x.1;
133 // println!("reading read_mode: {}", *read_mode);
134 *read_mode += 1;
135 RangeLockResult::Ok(RangeLockReadGuard { rlock: &self })
136 }
137
138 pub fn write(&self, idx: usize, kc: usize) -> RangeLockResult<RangeLockWriteGuard<'a, '_, T>> {
139 if idx > self.n {
140 return RangeLockResult::BadRange;
141 }
142 let mut x = self.ranges.lock().unwrap();
143 let is_occupied = &x.0;
144 let read_mode = &x.1;
145 let chunk_len = self.mc_chunk_len * kc;
146
147 // TODO: add check for kc_len stays the same from write->write
148 // it is fine to use contains since len of ranges is small ( ~ num threads / ic_par or jc_par)
149 // on average 2-4
150 // conflict if the idx is already in ranges or read_mode is on
151 if is_occupied[idx] || *read_mode > 0 {
152 return RangeLockResult::RangeConflict;
153 }
154 let is_occupied = &mut x.0;
155 is_occupied[idx] = true;
156
157 let (start, end) = (idx * chunk_len, ((idx + 1) * chunk_len).min(self.len()));
158 RangeLockResult::Ok(RangeLockWriteGuard { rlock: &self, idx, start, end })
159 }
160
161 fn remove_write(&self, idx: usize) {
162 let mut x = self.ranges.lock().unwrap();
163 let ranges = &mut x.0;
164 ranges[idx] = false;
165 }
166
167 fn remove_read(&self) {
168 let mut x = self.ranges.lock().unwrap();
169 let read_mode = &mut x.1;
170 *read_mode -= 1;
171 }
172}
173
174// mod test {
175// use super::*;
176// // #[test]
177// fn range_lock_read_test() {
178// let mut data: Vec<usize> = vec![0, 1, 2, 3, 4, 5];
179// let lock = RangeLock::from(&mut data, 2);
180// assert!(lock.data()[0..3] == [0, 1, 2])
181// }
182
183// // #[test]
184// fn range_lock_write_test() {
185// let mut data: Vec<usize> = vec![0usize; 12];
186// let lock = RangeLock::from(&mut data, 4);
187// {
188// let guard = lock.write(0).unwrap();
189// let guard_ref = guard.get();
190
191// guard_ref[0] = 2;
192// guard_ref[1] = 1;
193// guard_ref[2] = 0;
194// }
195// assert!(data[0..3] == [2, 1, 0])
196// }
197
198// // #[test]
199// fn range_lock_write_test_mt() {
200// use glar_dev::random_matrix_uniform;
201// // create vec of random integers
202// let n_thread = 4;
203// let chunk_num = 4;
204// let vec_len = 12;
205// let chunk_len = vec_len / chunk_num;
206// let mut data_0 = vec![0i32; vec_len];
207// random_matrix_uniform(vec_len, 1, &mut data_0, vec_len);
208
209// // lock for the data_0
210// let lock0 = RangeLock::from(&mut data_0, n_thread);
211// let lock0_r = &lock0;
212
213// let mut data = vec![0i32; vec_len];
214// let lock = RangeLock::from(&mut data, n_thread);
215// let lock_r = &lock;
216// use std::thread;
217
218// thread::scope(|s| {
219// for i in 0..n_thread {
220// s.spawn(move || {
221// let guard = lock_r.write(i).unwrap();
222// let data_slice = guard.get();
223// let guard0 = lock0_r.read().unwrap();
224// let data_slice0 = guard0.get();
225// let offset = i * chunk_len;
226// for j in 0..chunk_len {
227// data_slice[j] = data_slice0[j + offset];
228// std::thread::sleep(std::time::Duration::from_secs(1));
229// }
230// });
231// }
232// });
233
234// let n_thread_r = 13;
235// thread::scope(|s| {
236// for _ in 0..n_thread_r {
237// s.spawn(move || {
238// let guard = lock_r.read().unwrap();
239// let data_slice = guard.get();
240// let guard0 = lock0_r.read().unwrap();
241// let data_slice0 = guard0.get();
242
243// assert!(data_slice[0..12] == data_slice0[0..12]);
244// });
245// }
246// });
247
248// }
249// }