1#[cfg(feature = "owning_ref")]
2use owning_ref::StableAddress;
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5#[cfg(any(debug_assertions, feature = "check"))]
6use std::sync::atomic::{AtomicUsize, Ordering};
7use std::{
8 cell::UnsafeCell,
9 ops::{Deref, DerefMut},
10 panic::{RefUnwindSafe, UnwindSafe},
11};
12
13#[cfg(any(debug_assertions, feature = "check"))]
17const WRITER_BIT: usize = 0b1000;
18#[cfg(any(debug_assertions, feature = "check"))]
20const ONE_READER: usize = 0b10000;
21
22#[derive(Debug)]
24pub struct RwLock<T: ?Sized> {
25 #[cfg(any(debug_assertions, feature = "check"))]
26 state: AtomicUsize,
27 value: UnsafeCell<T>,
28}
29
30impl<T> RefUnwindSafe for RwLock<T> where T: ?Sized {}
31impl<T> UnwindSafe for RwLock<T> where T: ?Sized {}
32unsafe impl<T> Send for RwLock<T> where T: ?Sized + Send {}
33unsafe impl<T> Sync for RwLock<T> where T: ?Sized + Send + Sync {}
34
35impl<T> From<T> for RwLock<T> {
36 fn from(val: T) -> Self {
37 Self::new(val)
38 }
39}
40
41impl<T> Default for RwLock<T>
42where
43 T: ?Sized + Default,
44{
45 fn default() -> Self {
46 Self::new(T::default())
47 }
48}
49
50impl<T> RwLock<T> {
51 #[inline]
53 pub const fn new(val: T) -> Self {
54 Self {
55 value: UnsafeCell::new(val),
56 #[cfg(any(debug_assertions, feature = "check"))]
57 state: AtomicUsize::new(0),
58 }
59 }
60
61 #[inline]
63 pub fn into_inner(self) -> T {
64 self.value.into_inner()
65 }
66}
67
68impl<T> RwLock<T>
69where
70 T: ?Sized,
71{
72 #[inline]
75 pub fn get_mut(&mut self) -> &mut T {
76 self.value.get_mut()
77 }
78
79 #[inline]
82 pub fn try_write<'a>(&'a self) -> Option<RwLockWriteGuard<'a, T>> {
83 self.lock_exclusive()
84 .then(|| RwLockWriteGuard { lock: self })
85 }
86
87 #[inline]
94 pub fn write<'a>(&'a self) -> RwLockWriteGuard<'a, T> {
95 if !self.lock_exclusive() {
96 #[cfg(any(debug_assertions, feature = "check"))]
97 panic!("The lock is already write locked")
98 }
99
100 RwLockWriteGuard { lock: self }
101 }
102
103 #[inline]
106 pub fn try_read<'a>(&'a self) -> Option<RwLockReadGuard<'a, T>> {
107 self.lock_shared().then(|| RwLockReadGuard { lock: self })
108 }
109
110 #[inline]
117 pub fn read<'a>(&'a self) -> RwLockReadGuard<'a, T> {
118 if !self.lock_shared() {
119 #[cfg(any(debug_assertions, feature = "check"))]
120 panic!("The lock is already write locked")
121 }
122
123 RwLockReadGuard { lock: self }
124 }
125
126 #[inline]
127 fn lock_exclusive(&self) -> bool {
128 #[cfg(any(debug_assertions, feature = "check"))]
129 {
130 self.state
131 .compare_exchange(0, WRITER_BIT, Ordering::Acquire, Ordering::Relaxed)
132 .is_ok()
133 }
134
135 #[cfg(not(any(debug_assertions, feature = "check")))]
136 true
137 }
138
139 #[inline]
140 fn unlock_exclusive(&self) -> bool {
141 #[cfg(any(debug_assertions, feature = "check"))]
142 {
143 self.state
144 .compare_exchange(WRITER_BIT, 0, Ordering::Acquire, Ordering::Relaxed)
145 .is_ok()
146 }
147
148 #[cfg(not(any(debug_assertions, feature = "check")))]
149 true
150 }
151
152 #[inline]
153 fn lock_shared(&self) -> bool {
154 #[cfg(any(debug_assertions, feature = "check"))]
155 loop {
156 let state = self.state.load(Ordering::Relaxed);
157 if state & WRITER_BIT != 0 {
158 return false;
160 }
161
162 if self
163 .state
164 .compare_exchange(
165 state,
166 state.checked_add(ONE_READER).expect("too many readers"),
167 Ordering::Acquire,
168 Ordering::Relaxed,
169 )
170 .is_ok()
171 {
172 break;
173 }
174 }
175
176 true
177 }
178
179 #[inline]
180 fn unlock_shared(&self) {
181 #[cfg(any(debug_assertions, feature = "check"))]
182 self.state.fetch_sub(ONE_READER, Ordering::Release);
183 }
184}
185
186pub struct RwLockWriteGuard<'a, T>
187where
188 T: ?Sized,
189{
190 lock: &'a RwLock<T>,
191}
192
193impl<'a, T> Deref for RwLockWriteGuard<'a, T>
194where
195 T: ?Sized,
196{
197 type Target = T;
198
199 #[inline]
200 fn deref(&self) -> &T {
201 unsafe { &*self.lock.value.get() }
202 }
203}
204
205impl<'a, T> DerefMut for RwLockWriteGuard<'a, T>
206where
207 T: ?Sized,
208{
209 #[inline]
210 fn deref_mut(&mut self) -> &mut T {
211 unsafe { &mut *self.lock.value.get() }
212 }
213}
214
215impl<'a, T> Drop for RwLockWriteGuard<'a, T>
216where
217 T: ?Sized,
218{
219 #[inline]
220 fn drop(&mut self) {
221 self.lock.unlock_exclusive();
222 }
223}
224
225pub struct RwLockReadGuard<'a, T>
226where
227 T: ?Sized,
228{
229 lock: &'a RwLock<T>,
230}
231
232impl<'a, T> Deref for RwLockReadGuard<'a, T>
233where
234 T: ?Sized,
235{
236 type Target = T;
237
238 #[inline]
239 fn deref(&self) -> &T {
240 unsafe { &*self.lock.value.get() }
241 }
242}
243
244impl<'a, T> Drop for RwLockReadGuard<'a, T>
245where
246 T: ?Sized,
247{
248 #[inline]
249 fn drop(&mut self) {
250 self.lock.unlock_shared();
251 }
252}
253
254#[cfg(feature = "owning_ref")]
255unsafe impl<'a, T: 'a> StableAddress for RwLockReadGuard<'a, T> where T: ?Sized {}
256#[cfg(feature = "owning_ref")]
257unsafe impl<'a, T: 'a> StableAddress for RwLockWriteGuard<'a, T> where T: ?Sized {}
258
259#[cfg(feature = "serde")]
260impl<T> Serialize for RwLock<T>
261where
262 T: Serialize + ?Sized,
263{
264 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
265 where
266 S: Serializer,
267 {
268 self.read().serialize(serializer)
269 }
270}
271
272#[cfg(feature = "serde")]
273impl<'de, T> Deserialize<'de> for RwLock<T>
274where
275 T: Deserialize<'de> + ?Sized,
276{
277 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
278 where
279 D: Deserializer<'de>,
280 {
281 Deserialize::deserialize(deserializer).map(RwLock::new)
282 }
283}