1use core::cell::{RefCell, UnsafeCell};
5use core::fmt;
6use core::future::{poll_fn, Future};
7use core::ops::{Deref, DerefMut};
8use core::task::Poll;
9
10use crate::blocking_mutex::raw::RawMutex;
11use crate::blocking_mutex::Mutex as BlockingMutex;
12use crate::waitqueue::WakerRegistration;
13
14#[derive(PartialEq, Eq, Clone, Copy, Debug)]
16#[cfg_attr(feature = "defmt", derive(defmt::Format))]
17pub struct TryLockError;
18
19struct State {
20 readers: usize,
21 writer: bool,
22 waker: WakerRegistration,
23}
24
25pub struct RwLock<M, T>
42where
43 M: RawMutex,
44 T: ?Sized,
45{
46 state: BlockingMutex<M, RefCell<State>>,
47 inner: UnsafeCell<T>,
48}
49
50unsafe impl<M: RawMutex + Send, T: ?Sized + Send> Send for RwLock<M, T> {}
51unsafe impl<M: RawMutex + Sync, T: ?Sized + Send> Sync for RwLock<M, T> {}
52
53impl<M, T> RwLock<M, T>
55where
56 M: RawMutex,
57{
58 pub const fn new(value: T) -> Self {
60 Self {
61 inner: UnsafeCell::new(value),
62 state: BlockingMutex::new(RefCell::new(State {
63 readers: 0,
64 writer: false,
65 waker: WakerRegistration::new(),
66 })),
67 }
68 }
69}
70
71impl<M, T> RwLock<M, T>
72where
73 M: RawMutex,
74 T: ?Sized,
75{
76 pub fn read(&self) -> impl Future<Output = RwLockReadGuard<'_, M, T>> {
80 poll_fn(|cx| {
81 let ready = self.state.lock(|s| {
82 let mut s = s.borrow_mut();
83 if s.writer {
84 s.waker.register(cx.waker());
85 false
86 } else {
87 s.readers += 1;
88 true
89 }
90 });
91
92 if ready {
93 Poll::Ready(RwLockReadGuard { rwlock: self })
94 } else {
95 Poll::Pending
96 }
97 })
98 }
99
100 pub fn write(&self) -> impl Future<Output = RwLockWriteGuard<'_, M, T>> {
104 poll_fn(|cx| {
105 let ready = self.state.lock(|s| {
106 let mut s = s.borrow_mut();
107 if s.writer || s.readers > 0 {
108 s.waker.register(cx.waker());
109 false
110 } else {
111 s.writer = true;
112 true
113 }
114 });
115
116 if ready {
117 Poll::Ready(RwLockWriteGuard { rwlock: self })
118 } else {
119 Poll::Pending
120 }
121 })
122 }
123
124 pub fn try_read(&self) -> Result<RwLockReadGuard<'_, M, T>, TryLockError> {
128 self.state
129 .lock(|s| {
130 let mut s = s.borrow_mut();
131 if s.writer {
132 return Err(());
133 }
134 s.readers += 1;
135 Ok(())
136 })
137 .map_err(|_| TryLockError)?;
138
139 Ok(RwLockReadGuard { rwlock: self })
140 }
141
142 pub fn try_write(&self) -> Result<RwLockWriteGuard<'_, M, T>, TryLockError> {
146 self.state
147 .lock(|s| {
148 let mut s = s.borrow_mut();
149 if s.writer || s.readers > 0 {
150 return Err(());
151 }
152 s.writer = true;
153 Ok(())
154 })
155 .map_err(|_| TryLockError)?;
156
157 Ok(RwLockWriteGuard { rwlock: self })
158 }
159
160 pub fn into_inner(self) -> T
162 where
163 T: Sized,
164 {
165 self.inner.into_inner()
166 }
167
168 pub fn get_mut(&mut self) -> &mut T {
173 self.inner.get_mut()
174 }
175}
176
177impl<M: RawMutex, T> From<T> for RwLock<M, T> {
178 fn from(from: T) -> Self {
179 Self::new(from)
180 }
181}
182
183impl<M, T> Default for RwLock<M, T>
184where
185 M: RawMutex,
186 T: Default,
187{
188 fn default() -> Self {
189 Self::new(Default::default())
190 }
191}
192
193impl<M, T> fmt::Debug for RwLock<M, T>
194where
195 M: RawMutex,
196 T: ?Sized + fmt::Debug,
197{
198 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199 let mut d = f.debug_struct("RwLock");
200 match self.try_read() {
201 Ok(guard) => d.field("inner", &&*guard),
202 Err(TryLockError) => d.field("inner", &"Locked"),
203 }
204 .finish_non_exhaustive()
205 }
206}
207
208#[clippy::has_significant_drop]
215#[must_use = "if unused the RwLock will immediately unlock"]
216pub struct RwLockReadGuard<'a, R, T>
217where
218 R: RawMutex,
219 T: ?Sized,
220{
221 rwlock: &'a RwLock<R, T>,
222}
223
224impl<'a, M, T> Drop for RwLockReadGuard<'a, M, T>
225where
226 M: RawMutex,
227 T: ?Sized,
228{
229 fn drop(&mut self) {
230 self.rwlock.state.lock(|s| {
231 let mut s = unwrap!(s.try_borrow_mut());
232 s.readers -= 1;
233 if s.readers == 0 {
234 s.waker.wake();
235 }
236 })
237 }
238}
239
240impl<'a, M, T> Deref for RwLockReadGuard<'a, M, T>
241where
242 M: RawMutex,
243 T: ?Sized,
244{
245 type Target = T;
246 fn deref(&self) -> &Self::Target {
247 unsafe { &*(self.rwlock.inner.get() as *const T) }
250 }
251}
252
253impl<'a, M, T> fmt::Debug for RwLockReadGuard<'a, M, T>
254where
255 M: RawMutex,
256 T: ?Sized + fmt::Debug,
257{
258 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259 fmt::Debug::fmt(&**self, f)
260 }
261}
262
263impl<'a, M, T> fmt::Display for RwLockReadGuard<'a, M, T>
264where
265 M: RawMutex,
266 T: ?Sized + fmt::Display,
267{
268 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269 fmt::Display::fmt(&**self, f)
270 }
271}
272
273#[clippy::has_significant_drop]
280#[must_use = "if unused the RwLock will immediately unlock"]
281pub struct RwLockWriteGuard<'a, R, T>
282where
283 R: RawMutex,
284 T: ?Sized,
285{
286 rwlock: &'a RwLock<R, T>,
287}
288
289impl<'a, R, T> Drop for RwLockWriteGuard<'a, R, T>
290where
291 R: RawMutex,
292 T: ?Sized,
293{
294 fn drop(&mut self) {
295 self.rwlock.state.lock(|s| {
296 let mut s = unwrap!(s.try_borrow_mut());
297 s.writer = false;
298 s.waker.wake();
299 })
300 }
301}
302
303impl<'a, R, T> Deref for RwLockWriteGuard<'a, R, T>
304where
305 R: RawMutex,
306 T: ?Sized,
307{
308 type Target = T;
309 fn deref(&self) -> &Self::Target {
310 unsafe { &*(self.rwlock.inner.get() as *mut T) }
313 }
314}
315
316impl<'a, R, T> DerefMut for RwLockWriteGuard<'a, R, T>
317where
318 R: RawMutex,
319 T: ?Sized,
320{
321 fn deref_mut(&mut self) -> &mut Self::Target {
322 unsafe { &mut *(self.rwlock.inner.get()) }
325 }
326}
327
328impl<'a, R, T> fmt::Debug for RwLockWriteGuard<'a, R, T>
329where
330 R: RawMutex,
331 T: ?Sized + fmt::Debug,
332{
333 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
334 fmt::Debug::fmt(&**self, f)
335 }
336}
337
338impl<'a, R, T> fmt::Display for RwLockWriteGuard<'a, R, T>
339where
340 R: RawMutex,
341 T: ?Sized + fmt::Display,
342{
343 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344 fmt::Display::fmt(&**self, f)
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use crate::blocking_mutex::raw::NoopRawMutex;
351 use crate::rwlock::RwLock;
352
353 #[futures_test::test]
354 async fn read_guard_releases_lock_when_dropped() {
355 let rwlock: RwLock<NoopRawMutex, [i32; 2]> = RwLock::new([0, 1]);
356
357 {
358 let guard = rwlock.read().await;
359 assert_eq!(*guard, [0, 1]);
360 }
361
362 {
363 let guard = rwlock.read().await;
364 assert_eq!(*guard, [0, 1]);
365 }
366
367 assert_eq!(*rwlock.read().await, [0, 1]);
368 }
369
370 #[futures_test::test]
371 async fn write_guard_releases_lock_when_dropped() {
372 let rwlock: RwLock<NoopRawMutex, [i32; 2]> = RwLock::new([0, 1]);
373
374 {
375 let mut guard = rwlock.write().await;
376 assert_eq!(*guard, [0, 1]);
377 guard[1] = 2;
378 }
379
380 {
381 let guard = rwlock.read().await;
382 assert_eq!(*guard, [0, 2]);
383 }
384
385 assert_eq!(*rwlock.read().await, [0, 2]);
386 }
387}