1use std::{
22 cell::UnsafeCell,
23 ops::{Deref, DerefMut},
24 sync::atomic::{AtomicU32, Ordering},
25};
26
27use crate::backoff::Backoff;
28use crate::hint::unlikely;
29
30const WRITE_NUMBER: u32 = 1_u32 << 30;
31
32pub struct RwSpinlock<T> {
33 flag: AtomicU32,
34 value: UnsafeCell<T>,
35}
36
37unsafe impl<T: Send> Send for RwSpinlock<T> {}
38
39unsafe impl<T: Send> Sync for RwSpinlock<T> {}
40
41impl<T> RwSpinlock<T> {
42 #[inline(always)]
43 pub fn new(value: T) -> Self {
44 Self {
45 flag: AtomicU32::new(0),
46 value: UnsafeCell::new(value),
47 }
48 }
49
50 pub fn try_write(&self) -> Option<RwSpinlockGuard<T>> {
51 if self.flag.compare_exchange_weak(
52 0,
53 WRITE_NUMBER,
54 Ordering::Acquire,
55 Ordering::Relaxed,
56 ).is_ok() {
57 return Some(RwSpinlockGuard { parent: self });
58 }
59 None
60 }
61
62 pub fn try_read(&self) -> Option<RwSpinlockGuard<T>> {
63 let pre_value = self.flag.fetch_add(1, Ordering::Relaxed);
64 if pre_value < WRITE_NUMBER {
65 return Some(RwSpinlockGuard { parent: self });
66 }
67 None
68 }
69
70 pub fn write(&self) -> RwSpinlockGuard<T> {
71 let backoff = Backoff::default();
72 loop {
73 if self.flag.compare_exchange_weak(
77 0,
78 WRITE_NUMBER,
79 Ordering::Acquire,
80 Ordering::Relaxed,
81 ).is_ok() {
82 break;
83 }
84
85 while self.flag.load(Ordering::Relaxed) != 0 {
86 backoff.spin();
88 }
89 }
90 RwSpinlockGuard { parent: self }
91 }
92
93 pub fn read(&self) -> RwSpinlockGuard<T> {
94 let backoff = Backoff::default();
95 loop {
96 let pre_value = self.flag.fetch_add(1, Ordering::Relaxed);
97 if pre_value < WRITE_NUMBER {
98 break;
99 }
100
101 while self.flag.load(Ordering::Relaxed) != 0 {
102 backoff.spin();
104 }
105 }
106 RwSpinlockGuard { parent: self }
107 }
108}
109
110pub struct RwSpinlockGuard<'a, T> {
111 parent: &'a RwSpinlock<T>,
112}
113
114impl<T> Drop for RwSpinlockGuard<'_, T> {
115 #[inline(always)]
116 fn drop(&mut self) {
117 if unlikely(self.parent.flag.load(Ordering::Relaxed) >= WRITE_NUMBER) {
118 self.parent.flag.store(0, Ordering::Release);
119 } else {
120 self.parent.flag.fetch_sub(1, Ordering::Relaxed);
121 }
122 }
123}
124
125impl<T> Deref for RwSpinlockGuard<'_, T> {
126 type Target = T;
127
128 #[inline(always)]
129 fn deref(&self) -> &T {
130 unsafe { &*self.parent.value.get() }
131 }
132}
133
134impl<T> DerefMut for RwSpinlockGuard<'_, T> {
135 #[inline(always)]
136 fn deref_mut(&mut self) -> &mut T {
137 unsafe { &mut *self.parent.value.get() }
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use std::sync::Arc;
144
145 use super::*;
146
147 #[allow(unused_variables)]
148 #[test]
149 fn test_read_unlock() {
150 let m = RwSpinlock::<i32>::new(0);
151 {
152 let r1 = m.read();
153 {
154 let r2 = m.read();
155 let r2 = m.read();
156 assert!(m.try_write().is_none());
157 }
158 assert!(m.try_write().is_none());
159 }
160 assert!(m.try_write().is_some());
161 }
162
163 #[allow(unused_variables)]
164 #[test]
165 fn test_write_unlock() {
166 let m = RwSpinlock::<i32>::new(0);
167 {
168 let w1 = m.write();
169 assert!(m.try_read().is_none());
170 }
171 assert!(m.try_read().is_some());
172 }
173
174 #[test]
175 fn test_rw_arc() {
176 let arc = Arc::new(RwSpinlock::new(0));
177 let arc2 = arc.clone();
178 let (tx, rx) = std::sync::mpsc::channel();
179
180 std::thread::spawn(move || {
181 let mut lock = arc2.write();
182 for _ in 0..10 {
183 let tmp = *lock;
184 *lock = -1;
185 std::thread::yield_now();
186 *lock = tmp + 1;
187 }
188 tx.send(()).unwrap();
189 });
190
191 let mut children = Vec::new();
193 for _ in 0..5 {
194 let arc3 = arc.clone();
195 children.push(std::thread::spawn(move || {
196 let lock = arc3.read();
197 assert!(*lock >= 0);
198 }));
199 }
200
201 for r in children {
203 assert!(r.join().is_ok());
204 }
205
206 rx.recv().unwrap();
208 let lock = arc.read();
209 assert_eq!(*lock, 10);
210 }
211
212 #[test]
213 fn test_rw_access_in_unwind() {
214 let arc = Arc::new(RwSpinlock::new(1));
215 let arc2 = arc.clone();
216 let _ = std::thread::spawn(move || {
217 struct Unwinder {
218 i: Arc<RwSpinlock<isize>>,
219 }
220 impl Drop for Unwinder {
221 fn drop(&mut self) {
222 let mut lock = self.i.write();
223 *lock += 1;
224 }
225 }
226 let _u = Unwinder { i: arc2 };
227 panic!();
228 })
229 .join();
230 let lock = arc.read();
231 assert_eq!(*lock, 2);
232 }
233
234 #[test]
235 fn test_rwlock_unsized() {
236 let rw: &RwSpinlock<[i32;3]> = &RwSpinlock::new([1, 2, 3]);
237 {
238 let b = &mut *rw.write();
239 b[0] = 4;
240 b[2] = 5;
241 }
242 let comp: &[i32] = &[4, 2, 5];
243 assert_eq!(&*rw.read(), comp);
244 }
245
246 #[allow(clippy::assertions_on_constants)]
247 #[test]
248 fn test_rwlock_try_write() {
249 let lock = RwSpinlock::new(0isize);
250 let read_guard = lock.read();
251
252 let write_result = lock.try_write();
253 match write_result {
254 None => (),
255 Some(_) => assert!(
256 false,
257 "try_write should not succeed while read_guard is in scope"
258 ),
259 }
260
261 drop(read_guard);
262 }
263
264 #[test]
265 fn test_rw_try_read() {
266 let m = RwSpinlock::new(0);
267 std::mem::forget(m.write());
268 assert!(m.try_read().is_none());
269 }
270}