peace_lock/
rwlock.rs

1#[cfg(feature = "owning_ref")]
2use owning_ref::StableAddress;
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5#[cfg(any(debug_assertions, feature = "check"))]
6use std::sync::atomic::{AtomicUsize, Ordering};
7use std::{
8    cell::UnsafeCell,
9    ops::{Deref, DerefMut},
10    panic::{RefUnwindSafe, UnwindSafe},
11};
12
13// Locking bits are copied from [parking_lot](https://github.com/Amanieu/parking_lot).
14// If the reader count is zero: a writer is currently holding an exclusive lock.
15// Otherwise: a writer is waiting for the remaining readers to exit the lock.
16#[cfg(any(debug_assertions, feature = "check"))]
17const WRITER_BIT: usize = 0b1000;
18// Base unit for counting readers.
19#[cfg(any(debug_assertions, feature = "check"))]
20const ONE_READER: usize = 0b10000;
21
22/// A read-write lock
23#[derive(Debug)]
24pub struct RwLock<T: ?Sized> {
25    #[cfg(any(debug_assertions, feature = "check"))]
26    state: AtomicUsize,
27    value: UnsafeCell<T>,
28}
29
30impl<T> RefUnwindSafe for RwLock<T> where T: ?Sized {}
31impl<T> UnwindSafe for RwLock<T> where T: ?Sized {}
32unsafe impl<T> Send for RwLock<T> where T: ?Sized + Send {}
33unsafe impl<T> Sync for RwLock<T> where T: ?Sized + Send + Sync {}
34
35impl<T> From<T> for RwLock<T> {
36    fn from(val: T) -> Self {
37        Self::new(val)
38    }
39}
40
41impl<T> Default for RwLock<T>
42where
43    T: ?Sized + Default,
44{
45    fn default() -> Self {
46        Self::new(T::default())
47    }
48}
49
50impl<T> RwLock<T> {
51    /// Create a new `RwLock`.
52    #[inline]
53    pub const fn new(val: T) -> Self {
54        Self {
55            value: UnsafeCell::new(val),
56            #[cfg(any(debug_assertions, feature = "check"))]
57            state: AtomicUsize::new(0),
58        }
59    }
60
61    /// Consume the `RwLock`, returning the inner value.
62    #[inline]
63    pub fn into_inner(self) -> T {
64        self.value.into_inner()
65    }
66}
67
68impl<T> RwLock<T>
69where
70    T: ?Sized,
71{
72    /// Get a mutable reference of the inner value T. This is safe because we
73    /// have the mutable reference of the lock.
74    #[inline]
75    pub fn get_mut(&mut self) -> &mut T {
76        self.value.get_mut()
77    }
78
79    /// Try write lock the `RwLock`, returns the write guard. Returns None if the
80    /// `RwLock` is write locked.
81    #[inline]
82    pub fn try_write<'a>(&'a self) -> Option<RwLockWriteGuard<'a, T>> {
83        self.lock_exclusive()
84            .then(|| RwLockWriteGuard { lock: self })
85    }
86
87    /// Write lock the `RwLock`, returns the write guard.
88    ///
89    /// # Panics
90    ///
91    /// If the `RwLock` is already write locked, this will panic if the `check`
92    /// feature is turned on.
93    #[inline]
94    pub fn write<'a>(&'a self) -> RwLockWriteGuard<'a, T> {
95        if !self.lock_exclusive() {
96            #[cfg(any(debug_assertions, feature = "check"))]
97            panic!("The lock is already write locked")
98        }
99
100        RwLockWriteGuard { lock: self }
101    }
102
103    /// Try read lock the `RwLock`, returns the read guard. Returns None if the
104    /// `RwLock` is write locked.
105    #[inline]
106    pub fn try_read<'a>(&'a self) -> Option<RwLockReadGuard<'a, T>> {
107        self.lock_shared().then(|| RwLockReadGuard { lock: self })
108    }
109
110    /// Read lock the `RwLock`, returns the read guard.
111    ///
112    /// # Panics
113    ///
114    /// If the `RwLock` is already write locked, this will panic if the check feature
115    /// is turned on.
116    #[inline]
117    pub fn read<'a>(&'a self) -> RwLockReadGuard<'a, T> {
118        if !self.lock_shared() {
119            #[cfg(any(debug_assertions, feature = "check"))]
120            panic!("The lock is already write locked")
121        }
122
123        RwLockReadGuard { lock: self }
124    }
125
126    #[inline]
127    fn lock_exclusive(&self) -> bool {
128        #[cfg(any(debug_assertions, feature = "check"))]
129        {
130            self.state
131                .compare_exchange(0, WRITER_BIT, Ordering::Acquire, Ordering::Relaxed)
132                .is_ok()
133        }
134
135        #[cfg(not(any(debug_assertions, feature = "check")))]
136        true
137    }
138
139    #[inline]
140    fn unlock_exclusive(&self) -> bool {
141        #[cfg(any(debug_assertions, feature = "check"))]
142        {
143            self.state
144                .compare_exchange(WRITER_BIT, 0, Ordering::Acquire, Ordering::Relaxed)
145                .is_ok()
146        }
147
148        #[cfg(not(any(debug_assertions, feature = "check")))]
149        true
150    }
151
152    #[inline]
153    fn lock_shared(&self) -> bool {
154        #[cfg(any(debug_assertions, feature = "check"))]
155        loop {
156            let state = self.state.load(Ordering::Relaxed);
157            if state & WRITER_BIT != 0 {
158                // is write locked
159                return false;
160            }
161
162            if self
163                .state
164                .compare_exchange(
165                    state,
166                    state.checked_add(ONE_READER).expect("too many readers"),
167                    Ordering::Acquire,
168                    Ordering::Relaxed,
169                )
170                .is_ok()
171            {
172                break;
173            }
174        }
175
176        true
177    }
178
179    #[inline]
180    fn unlock_shared(&self) {
181        #[cfg(any(debug_assertions, feature = "check"))]
182        self.state.fetch_sub(ONE_READER, Ordering::Release);
183    }
184}
185
186pub struct RwLockWriteGuard<'a, T>
187where
188    T: ?Sized,
189{
190    lock: &'a RwLock<T>,
191}
192
193impl<'a, T> Deref for RwLockWriteGuard<'a, T>
194where
195    T: ?Sized,
196{
197    type Target = T;
198
199    #[inline]
200    fn deref(&self) -> &T {
201        unsafe { &*self.lock.value.get() }
202    }
203}
204
205impl<'a, T> DerefMut for RwLockWriteGuard<'a, T>
206where
207    T: ?Sized,
208{
209    #[inline]
210    fn deref_mut(&mut self) -> &mut T {
211        unsafe { &mut *self.lock.value.get() }
212    }
213}
214
215impl<'a, T> Drop for RwLockWriteGuard<'a, T>
216where
217    T: ?Sized,
218{
219    #[inline]
220    fn drop(&mut self) {
221        self.lock.unlock_exclusive();
222    }
223}
224
225pub struct RwLockReadGuard<'a, T>
226where
227    T: ?Sized,
228{
229    lock: &'a RwLock<T>,
230}
231
232impl<'a, T> Deref for RwLockReadGuard<'a, T>
233where
234    T: ?Sized,
235{
236    type Target = T;
237
238    #[inline]
239    fn deref(&self) -> &T {
240        unsafe { &*self.lock.value.get() }
241    }
242}
243
244impl<'a, T> Drop for RwLockReadGuard<'a, T>
245where
246    T: ?Sized,
247{
248    #[inline]
249    fn drop(&mut self) {
250        self.lock.unlock_shared();
251    }
252}
253
254#[cfg(feature = "owning_ref")]
255unsafe impl<'a, T: 'a> StableAddress for RwLockReadGuard<'a, T> where T: ?Sized {}
256#[cfg(feature = "owning_ref")]
257unsafe impl<'a, T: 'a> StableAddress for RwLockWriteGuard<'a, T> where T: ?Sized {}
258
259#[cfg(feature = "serde")]
260impl<T> Serialize for RwLock<T>
261where
262    T: Serialize + ?Sized,
263{
264    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
265    where
266        S: Serializer,
267    {
268        self.read().serialize(serializer)
269    }
270}
271
272#[cfg(feature = "serde")]
273impl<'de, T> Deserialize<'de> for RwLock<T>
274where
275    T: Deserialize<'de> + ?Sized,
276{
277    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
278    where
279        D: Deserializer<'de>,
280    {
281        Deserialize::deserialize(deserializer).map(RwLock::new)
282    }
283}