Skip to main content

embassy_sync/
rwlock.rs

1//! Async read-write lock.
2//!
3//! This module provides a read-write lock that can be used to synchronize data between asynchronous tasks.
4use core::cell::{RefCell, UnsafeCell};
5use core::fmt;
6use core::future::{Future, poll_fn};
7use core::ops::{Deref, DerefMut};
8use core::task::Poll;
9
10use crate::blocking_mutex::Mutex as BlockingMutex;
11use crate::blocking_mutex::raw::RawMutex;
12use crate::waitqueue::WakerRegistration;
13
14/// Error returned by [`RwLock::try_read`] and [`RwLock::try_write`] when the lock is already held.
15#[derive(PartialEq, Eq, Clone, Copy, Debug)]
16#[cfg_attr(feature = "defmt", derive(defmt::Format))]
17pub struct TryLockError;
18
19#[derive(Debug)]
20struct State {
21    readers: usize,
22    writer: bool,
23    waker: WakerRegistration,
24}
25
26/// Async read-write lock.
27///
28/// The read-write lock is generic over the raw mutex implementation `M` and the data `T` it protects.
29/// The raw read-write lock is used to guard access to the internal state. It
30/// is held for very short periods only, while locking and unlocking. It is *not* held
31/// for the entire time the async RwLock is locked.
32///
33/// Which implementation you select depends on the context in which you're using the read-write lock.
34///
35/// Use [`CriticalSectionRawMutex`](crate::blocking_mutex::raw::CriticalSectionRawMutex) when data can be shared between threads and interrupts.
36///
37/// Use [`NoopRawMutex`](crate::blocking_mutex::raw::NoopRawMutex) when data is only shared between tasks running on the same executor.
38///
39/// Use [`ThreadModeRawMutex`](crate::blocking_mutex::raw::ThreadModeRawMutex) when data is shared between tasks running on the same executor but you want a singleton.
40pub struct RwLock<M, T>
41where
42    M: RawMutex,
43    T: ?Sized,
44{
45    state: BlockingMutex<M, RefCell<State>>,
46    inner: UnsafeCell<T>,
47}
48
49unsafe impl<M: RawMutex + Send, T: ?Sized + Send> Send for RwLock<M, T> {}
50unsafe impl<M: RawMutex + Sync, T: ?Sized + Send> Sync for RwLock<M, T> {}
51
52/// Async read-write lock.
53impl<M, T> RwLock<M, T>
54where
55    M: RawMutex,
56{
57    /// Create a new read-write lock with the given value.
58    pub const fn new(value: T) -> Self {
59        Self {
60            inner: UnsafeCell::new(value),
61            state: BlockingMutex::new(RefCell::new(State {
62                readers: 0,
63                writer: false,
64                waker: WakerRegistration::new(),
65            })),
66        }
67    }
68}
69
70impl<M, T> RwLock<M, T>
71where
72    M: RawMutex,
73    T: ?Sized,
74{
75    /// Lock the read-write lock for reading.
76    ///
77    /// This will wait for the lock to be available if it's already locked for writing.
78    pub fn read(&self) -> impl Future<Output = RwLockReadGuard<'_, M, T>> {
79        poll_fn(|cx| {
80            let ready = self.state.lock(|s| {
81                let mut s = s.borrow_mut();
82                if s.writer {
83                    s.waker.register(cx.waker());
84                    false
85                } else {
86                    s.readers += 1;
87                    true
88                }
89            });
90
91            if ready {
92                Poll::Ready(RwLockReadGuard { rwlock: self })
93            } else {
94                Poll::Pending
95            }
96        })
97    }
98
99    /// Lock the read-write lock for writing.
100    ///
101    /// This will wait for the lock to be available if it's already locked for reading or writing.
102    pub fn write(&self) -> impl Future<Output = RwLockWriteGuard<'_, M, T>> {
103        poll_fn(|cx| {
104            let ready = self.state.lock(|s| {
105                let mut s = s.borrow_mut();
106                if s.writer || s.readers > 0 {
107                    s.waker.register(cx.waker());
108                    false
109                } else {
110                    s.writer = true;
111                    true
112                }
113            });
114
115            if ready {
116                Poll::Ready(RwLockWriteGuard { rwlock: self })
117            } else {
118                Poll::Pending
119            }
120        })
121    }
122
123    /// Attempt to immediately lock the rwlock.
124    ///
125    /// If the rwlock is already locked, this will return an error instead of waiting.
126    pub fn try_read(&self) -> Result<RwLockReadGuard<'_, M, T>, TryLockError> {
127        self.state
128            .lock(|s| {
129                let mut s = s.borrow_mut();
130                if s.writer {
131                    return Err(());
132                }
133                s.readers += 1;
134                Ok(())
135            })
136            .map_err(|_| TryLockError)?;
137
138        Ok(RwLockReadGuard { rwlock: self })
139    }
140
141    /// Attempt to immediately lock the rwlock.
142    ///
143    /// If the rwlock is already locked, this will return an error instead of waiting.
144    pub fn try_write(&self) -> Result<RwLockWriteGuard<'_, M, T>, TryLockError> {
145        self.state
146            .lock(|s| {
147                let mut s = s.borrow_mut();
148                if s.writer || s.readers > 0 {
149                    return Err(());
150                }
151                s.writer = true;
152                Ok(())
153            })
154            .map_err(|_| TryLockError)?;
155
156        Ok(RwLockWriteGuard { rwlock: self })
157    }
158
159    /// Consumes this read-write lock, returning the underlying data.
160    pub fn into_inner(self) -> T
161    where
162        T: Sized,
163    {
164        self.inner.into_inner()
165    }
166
167    /// Returns a mutable reference to the underlying data.
168    ///
169    /// Since this call borrows the RwLock mutably, no actual locking needs to
170    /// take place -- the mutable borrow statically guarantees no locks exist.
171    pub fn get_mut(&mut self) -> &mut T {
172        self.inner.get_mut()
173    }
174}
175
176impl<M: RawMutex, T> From<T> for RwLock<M, T> {
177    fn from(from: T) -> Self {
178        Self::new(from)
179    }
180}
181
182impl<M, T> Default for RwLock<M, T>
183where
184    M: RawMutex,
185    T: Default,
186{
187    fn default() -> Self {
188        Self::new(Default::default())
189    }
190}
191
192impl<M, T> fmt::Debug for RwLock<M, T>
193where
194    M: RawMutex,
195    T: ?Sized + fmt::Debug,
196{
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        let mut d = f.debug_struct("RwLock");
199        match self.try_read() {
200            Ok(guard) => d.field("inner", &&*guard),
201            Err(TryLockError) => d.field("inner", &"Locked"),
202        }
203        .finish_non_exhaustive()
204    }
205}
206
207/// Async read lock guard.
208///
209/// Owning an instance of this type indicates having
210/// successfully locked the read-write lock for reading, and grants access to the contents.
211///
212/// Dropping it unlocks the read-write lock.
213#[clippy::has_significant_drop]
214#[must_use = "if unused the RwLock will immediately unlock"]
215pub struct RwLockReadGuard<'a, R, T>
216where
217    R: RawMutex,
218    T: ?Sized,
219{
220    rwlock: &'a RwLock<R, T>,
221}
222
223impl<'a, M, T> Drop for RwLockReadGuard<'a, M, T>
224where
225    M: RawMutex,
226    T: ?Sized,
227{
228    fn drop(&mut self) {
229        self.rwlock.state.lock(|s| {
230            let mut s = unwrap!(s.try_borrow_mut());
231            s.readers -= 1;
232            if s.readers == 0 {
233                s.waker.wake();
234            }
235        })
236    }
237}
238
239impl<'a, M, T> Deref for RwLockReadGuard<'a, M, T>
240where
241    M: RawMutex,
242    T: ?Sized,
243{
244    type Target = T;
245    fn deref(&self) -> &Self::Target {
246        // Safety: the RwLockReadGuard represents shared access to the contents
247        // of the read-write lock, so it's OK to get it.
248        unsafe { &*(self.rwlock.inner.get() as *const T) }
249    }
250}
251
252impl<'a, M, T> fmt::Debug for RwLockReadGuard<'a, M, T>
253where
254    M: RawMutex,
255    T: ?Sized + fmt::Debug,
256{
257    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258        fmt::Debug::fmt(&**self, f)
259    }
260}
261
262impl<'a, M, T> fmt::Display for RwLockReadGuard<'a, M, T>
263where
264    M: RawMutex,
265    T: ?Sized + fmt::Display,
266{
267    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
268        fmt::Display::fmt(&**self, f)
269    }
270}
271
272/// Async write lock guard.
273///
274/// Owning an instance of this type indicates having
275/// successfully locked the read-write lock for writing, and grants access to the contents.
276///
277/// Dropping it unlocks the read-write lock.
278#[clippy::has_significant_drop]
279#[must_use = "if unused the RwLock will immediately unlock"]
280pub struct RwLockWriteGuard<'a, R, T>
281where
282    R: RawMutex,
283    T: ?Sized,
284{
285    rwlock: &'a RwLock<R, T>,
286}
287
288impl<'a, R, T> Drop for RwLockWriteGuard<'a, R, T>
289where
290    R: RawMutex,
291    T: ?Sized,
292{
293    fn drop(&mut self) {
294        self.rwlock.state.lock(|s| {
295            let mut s = unwrap!(s.try_borrow_mut());
296            s.writer = false;
297            s.waker.wake();
298        })
299    }
300}
301
302impl<'a, R, T> Deref for RwLockWriteGuard<'a, R, T>
303where
304    R: RawMutex,
305    T: ?Sized,
306{
307    type Target = T;
308    fn deref(&self) -> &Self::Target {
309        // Safety: the RwLockWriteGuard represents exclusive access to the contents
310        // of the read-write lock, so it's OK to get it.
311        unsafe { &*(self.rwlock.inner.get() as *mut T) }
312    }
313}
314
315impl<'a, R, T> DerefMut for RwLockWriteGuard<'a, R, T>
316where
317    R: RawMutex,
318    T: ?Sized,
319{
320    fn deref_mut(&mut self) -> &mut Self::Target {
321        // Safety: the RwLockWriteGuard represents exclusive access to the contents
322        // of the read-write lock, so it's OK to get it.
323        unsafe { &mut *(self.rwlock.inner.get()) }
324    }
325}
326
327impl<'a, R, T> fmt::Debug for RwLockWriteGuard<'a, R, T>
328where
329    R: RawMutex,
330    T: ?Sized + fmt::Debug,
331{
332    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
333        fmt::Debug::fmt(&**self, f)
334    }
335}
336
337impl<'a, R, T> fmt::Display for RwLockWriteGuard<'a, R, T>
338where
339    R: RawMutex,
340    T: ?Sized + fmt::Display,
341{
342    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
343        fmt::Display::fmt(&**self, f)
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use crate::blocking_mutex::raw::NoopRawMutex;
350    use crate::rwlock::RwLock;
351
352    #[futures_test::test]
353    async fn read_guard_releases_lock_when_dropped() {
354        let rwlock: RwLock<NoopRawMutex, [i32; 2]> = RwLock::new([0, 1]);
355
356        {
357            let guard = rwlock.read().await;
358            assert_eq!(*guard, [0, 1]);
359        }
360
361        {
362            let guard = rwlock.read().await;
363            assert_eq!(*guard, [0, 1]);
364        }
365
366        assert_eq!(*rwlock.read().await, [0, 1]);
367    }
368
369    #[futures_test::test]
370    async fn write_guard_releases_lock_when_dropped() {
371        let rwlock: RwLock<NoopRawMutex, [i32; 2]> = RwLock::new([0, 1]);
372
373        {
374            let mut guard = rwlock.write().await;
375            assert_eq!(*guard, [0, 1]);
376            guard[1] = 2;
377        }
378
379        {
380            let guard = rwlock.read().await;
381            assert_eq!(*guard, [0, 2]);
382        }
383
384        assert_eq!(*rwlock.read().await, [0, 2]);
385    }
386}