1use core::cell::{RefCell, UnsafeCell};
5use core::fmt;
6use core::future::{Future, poll_fn};
7use core::ops::{Deref, DerefMut};
8use core::task::Poll;
9
10use crate::blocking_mutex::Mutex as BlockingMutex;
11use crate::blocking_mutex::raw::RawMutex;
12use crate::waitqueue::WakerRegistration;
13
14#[derive(PartialEq, Eq, Clone, Copy, Debug)]
16#[cfg_attr(feature = "defmt", derive(defmt::Format))]
17pub struct TryLockError;
18
19#[derive(Debug)]
20struct State {
21 readers: usize,
22 writer: bool,
23 waker: WakerRegistration,
24}
25
26pub struct RwLock<M, T>
41where
42 M: RawMutex,
43 T: ?Sized,
44{
45 state: BlockingMutex<M, RefCell<State>>,
46 inner: UnsafeCell<T>,
47}
48
49unsafe impl<M: RawMutex + Send, T: ?Sized + Send> Send for RwLock<M, T> {}
50unsafe impl<M: RawMutex + Sync, T: ?Sized + Send> Sync for RwLock<M, T> {}
51
52impl<M, T> RwLock<M, T>
54where
55 M: RawMutex,
56{
57 pub const fn new(value: T) -> Self {
59 Self {
60 inner: UnsafeCell::new(value),
61 state: BlockingMutex::new(RefCell::new(State {
62 readers: 0,
63 writer: false,
64 waker: WakerRegistration::new(),
65 })),
66 }
67 }
68}
69
70impl<M, T> RwLock<M, T>
71where
72 M: RawMutex,
73 T: ?Sized,
74{
75 pub fn read(&self) -> impl Future<Output = RwLockReadGuard<'_, M, T>> {
79 poll_fn(|cx| {
80 let ready = self.state.lock(|s| {
81 let mut s = s.borrow_mut();
82 if s.writer {
83 s.waker.register(cx.waker());
84 false
85 } else {
86 s.readers += 1;
87 true
88 }
89 });
90
91 if ready {
92 Poll::Ready(RwLockReadGuard { rwlock: self })
93 } else {
94 Poll::Pending
95 }
96 })
97 }
98
99 pub fn write(&self) -> impl Future<Output = RwLockWriteGuard<'_, M, T>> {
103 poll_fn(|cx| {
104 let ready = self.state.lock(|s| {
105 let mut s = s.borrow_mut();
106 if s.writer || s.readers > 0 {
107 s.waker.register(cx.waker());
108 false
109 } else {
110 s.writer = true;
111 true
112 }
113 });
114
115 if ready {
116 Poll::Ready(RwLockWriteGuard { rwlock: self })
117 } else {
118 Poll::Pending
119 }
120 })
121 }
122
123 pub fn try_read(&self) -> Result<RwLockReadGuard<'_, M, T>, TryLockError> {
127 self.state
128 .lock(|s| {
129 let mut s = s.borrow_mut();
130 if s.writer {
131 return Err(());
132 }
133 s.readers += 1;
134 Ok(())
135 })
136 .map_err(|_| TryLockError)?;
137
138 Ok(RwLockReadGuard { rwlock: self })
139 }
140
141 pub fn try_write(&self) -> Result<RwLockWriteGuard<'_, M, T>, TryLockError> {
145 self.state
146 .lock(|s| {
147 let mut s = s.borrow_mut();
148 if s.writer || s.readers > 0 {
149 return Err(());
150 }
151 s.writer = true;
152 Ok(())
153 })
154 .map_err(|_| TryLockError)?;
155
156 Ok(RwLockWriteGuard { rwlock: self })
157 }
158
159 pub fn into_inner(self) -> T
161 where
162 T: Sized,
163 {
164 self.inner.into_inner()
165 }
166
167 pub fn get_mut(&mut self) -> &mut T {
172 self.inner.get_mut()
173 }
174}
175
176impl<M: RawMutex, T> From<T> for RwLock<M, T> {
177 fn from(from: T) -> Self {
178 Self::new(from)
179 }
180}
181
182impl<M, T> Default for RwLock<M, T>
183where
184 M: RawMutex,
185 T: Default,
186{
187 fn default() -> Self {
188 Self::new(Default::default())
189 }
190}
191
192impl<M, T> fmt::Debug for RwLock<M, T>
193where
194 M: RawMutex,
195 T: ?Sized + fmt::Debug,
196{
197 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198 let mut d = f.debug_struct("RwLock");
199 match self.try_read() {
200 Ok(guard) => d.field("inner", &&*guard),
201 Err(TryLockError) => d.field("inner", &"Locked"),
202 }
203 .finish_non_exhaustive()
204 }
205}
206
207#[clippy::has_significant_drop]
214#[must_use = "if unused the RwLock will immediately unlock"]
215pub struct RwLockReadGuard<'a, R, T>
216where
217 R: RawMutex,
218 T: ?Sized,
219{
220 rwlock: &'a RwLock<R, T>,
221}
222
223impl<'a, M, T> Drop for RwLockReadGuard<'a, M, T>
224where
225 M: RawMutex,
226 T: ?Sized,
227{
228 fn drop(&mut self) {
229 self.rwlock.state.lock(|s| {
230 let mut s = unwrap!(s.try_borrow_mut());
231 s.readers -= 1;
232 if s.readers == 0 {
233 s.waker.wake();
234 }
235 })
236 }
237}
238
239impl<'a, M, T> Deref for RwLockReadGuard<'a, M, T>
240where
241 M: RawMutex,
242 T: ?Sized,
243{
244 type Target = T;
245 fn deref(&self) -> &Self::Target {
246 unsafe { &*(self.rwlock.inner.get() as *const T) }
249 }
250}
251
252impl<'a, M, T> fmt::Debug for RwLockReadGuard<'a, M, T>
253where
254 M: RawMutex,
255 T: ?Sized + fmt::Debug,
256{
257 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258 fmt::Debug::fmt(&**self, f)
259 }
260}
261
262impl<'a, M, T> fmt::Display for RwLockReadGuard<'a, M, T>
263where
264 M: RawMutex,
265 T: ?Sized + fmt::Display,
266{
267 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
268 fmt::Display::fmt(&**self, f)
269 }
270}
271
272#[clippy::has_significant_drop]
279#[must_use = "if unused the RwLock will immediately unlock"]
280pub struct RwLockWriteGuard<'a, R, T>
281where
282 R: RawMutex,
283 T: ?Sized,
284{
285 rwlock: &'a RwLock<R, T>,
286}
287
288impl<'a, R, T> Drop for RwLockWriteGuard<'a, R, T>
289where
290 R: RawMutex,
291 T: ?Sized,
292{
293 fn drop(&mut self) {
294 self.rwlock.state.lock(|s| {
295 let mut s = unwrap!(s.try_borrow_mut());
296 s.writer = false;
297 s.waker.wake();
298 })
299 }
300}
301
302impl<'a, R, T> Deref for RwLockWriteGuard<'a, R, T>
303where
304 R: RawMutex,
305 T: ?Sized,
306{
307 type Target = T;
308 fn deref(&self) -> &Self::Target {
309 unsafe { &*(self.rwlock.inner.get() as *mut T) }
312 }
313}
314
315impl<'a, R, T> DerefMut for RwLockWriteGuard<'a, R, T>
316where
317 R: RawMutex,
318 T: ?Sized,
319{
320 fn deref_mut(&mut self) -> &mut Self::Target {
321 unsafe { &mut *(self.rwlock.inner.get()) }
324 }
325}
326
327impl<'a, R, T> fmt::Debug for RwLockWriteGuard<'a, R, T>
328where
329 R: RawMutex,
330 T: ?Sized + fmt::Debug,
331{
332 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
333 fmt::Debug::fmt(&**self, f)
334 }
335}
336
337impl<'a, R, T> fmt::Display for RwLockWriteGuard<'a, R, T>
338where
339 R: RawMutex,
340 T: ?Sized + fmt::Display,
341{
342 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
343 fmt::Display::fmt(&**self, f)
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use crate::blocking_mutex::raw::NoopRawMutex;
350 use crate::rwlock::RwLock;
351
352 #[futures_test::test]
353 async fn read_guard_releases_lock_when_dropped() {
354 let rwlock: RwLock<NoopRawMutex, [i32; 2]> = RwLock::new([0, 1]);
355
356 {
357 let guard = rwlock.read().await;
358 assert_eq!(*guard, [0, 1]);
359 }
360
361 {
362 let guard = rwlock.read().await;
363 assert_eq!(*guard, [0, 1]);
364 }
365
366 assert_eq!(*rwlock.read().await, [0, 1]);
367 }
368
369 #[futures_test::test]
370 async fn write_guard_releases_lock_when_dropped() {
371 let rwlock: RwLock<NoopRawMutex, [i32; 2]> = RwLock::new([0, 1]);
372
373 {
374 let mut guard = rwlock.write().await;
375 assert_eq!(*guard, [0, 1]);
376 guard[1] = 2;
377 }
378
379 {
380 let guard = rwlock.read().await;
381 assert_eq!(*guard, [0, 2]);
382 }
383
384 assert_eq!(*rwlock.read().await, [0, 2]);
385 }
386}