1use core::cell::{RefCell, UnsafeCell};
5use core::future::{poll_fn, Future};
6use core::ops::{Deref, DerefMut};
7use core::task::Poll;
8use core::{fmt, mem};
9
10use crate::blocking_mutex::raw::RawMutex;
11use crate::blocking_mutex::Mutex as BlockingMutex;
12use crate::waitqueue::WakerRegistration;
13
14#[derive(PartialEq, Eq, Clone, Copy, Debug)]
16#[cfg_attr(feature = "defmt", derive(defmt::Format))]
17pub struct TryLockError;
18
19struct State {
20 locked: bool,
21 waker: WakerRegistration,
22}
23
24pub struct Mutex<M, T>
40where
41 M: RawMutex,
42 T: ?Sized,
43{
44 state: BlockingMutex<M, RefCell<State>>,
45 inner: UnsafeCell<T>,
46}
47
48unsafe impl<M: RawMutex + Send, T: ?Sized + Send> Send for Mutex<M, T> {}
49unsafe impl<M: RawMutex + Sync, T: ?Sized + Send> Sync for Mutex<M, T> {}
50
51impl<M, T> Mutex<M, T>
53where
54 M: RawMutex,
55{
56 pub const fn new(value: T) -> Self {
58 Self {
59 inner: UnsafeCell::new(value),
60 state: BlockingMutex::new(RefCell::new(State {
61 locked: false,
62 waker: WakerRegistration::new(),
63 })),
64 }
65 }
66}
67
68impl<M, T> Mutex<M, T>
69where
70 M: RawMutex,
71 T: ?Sized,
72{
73 pub fn lock(&self) -> impl Future<Output = MutexGuard<'_, M, T>> {
77 poll_fn(|cx| {
78 let ready = self.state.lock(|s| {
79 let mut s = s.borrow_mut();
80 if s.locked {
81 s.waker.register(cx.waker());
82 false
83 } else {
84 s.locked = true;
85 true
86 }
87 });
88
89 if ready {
90 Poll::Ready(MutexGuard { mutex: self })
91 } else {
92 Poll::Pending
93 }
94 })
95 }
96
97 pub fn try_lock(&self) -> Result<MutexGuard<'_, M, T>, TryLockError> {
101 self.state.lock(|s| {
102 let mut s = s.borrow_mut();
103 if s.locked {
104 Err(TryLockError)
105 } else {
106 s.locked = true;
107 Ok(())
108 }
109 })?;
110
111 Ok(MutexGuard { mutex: self })
112 }
113
114 pub fn into_inner(self) -> T
116 where
117 T: Sized,
118 {
119 self.inner.into_inner()
120 }
121
122 pub fn get_mut(&mut self) -> &mut T {
127 self.inner.get_mut()
128 }
129}
130
131impl<M: RawMutex, T> From<T> for Mutex<M, T> {
132 fn from(from: T) -> Self {
133 Self::new(from)
134 }
135}
136
137impl<M, T> Default for Mutex<M, T>
138where
139 M: RawMutex,
140 T: Default,
141{
142 fn default() -> Self {
143 Self::new(Default::default())
144 }
145}
146
147impl<M, T> fmt::Debug for Mutex<M, T>
148where
149 M: RawMutex,
150 T: ?Sized + fmt::Debug,
151{
152 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
153 let mut d = f.debug_struct("Mutex");
154 match self.try_lock() {
155 Ok(value) => {
156 d.field("inner", &&*value);
157 }
158 Err(TryLockError) => {
159 d.field("inner", &format_args!("<locked>"));
160 }
161 }
162
163 d.finish_non_exhaustive()
164 }
165}
166
167#[clippy::has_significant_drop]
174#[must_use = "if unused the Mutex will immediately unlock"]
175pub struct MutexGuard<'a, M, T>
176where
177 M: RawMutex,
178 T: ?Sized,
179{
180 mutex: &'a Mutex<M, T>,
181}
182
183impl<'a, M, T> MutexGuard<'a, M, T>
184where
185 M: RawMutex,
186 T: ?Sized,
187{
188 pub fn map<U>(this: Self, fun: impl FnOnce(&mut T) -> &mut U) -> MappedMutexGuard<'a, M, U> {
190 let mutex = this.mutex;
191 let value = fun(unsafe { &mut *this.mutex.inner.get() });
192 mem::forget(this);
195 MappedMutexGuard {
196 state: &mutex.state,
197 value,
198 }
199 }
200}
201
202impl<'a, M, T> Drop for MutexGuard<'a, M, T>
203where
204 M: RawMutex,
205 T: ?Sized,
206{
207 fn drop(&mut self) {
208 self.mutex.state.lock(|s| {
209 let mut s = unwrap!(s.try_borrow_mut());
210 s.locked = false;
211 s.waker.wake();
212 })
213 }
214}
215
216impl<'a, M, T> Deref for MutexGuard<'a, M, T>
217where
218 M: RawMutex,
219 T: ?Sized,
220{
221 type Target = T;
222 fn deref(&self) -> &Self::Target {
223 unsafe { &*(self.mutex.inner.get() as *const T) }
226 }
227}
228
229impl<'a, M, T> DerefMut for MutexGuard<'a, M, T>
230where
231 M: RawMutex,
232 T: ?Sized,
233{
234 fn deref_mut(&mut self) -> &mut Self::Target {
235 unsafe { &mut *(self.mutex.inner.get()) }
238 }
239}
240
241impl<'a, M, T> fmt::Debug for MutexGuard<'a, M, T>
242where
243 M: RawMutex,
244 T: ?Sized + fmt::Debug,
245{
246 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
247 fmt::Debug::fmt(&**self, f)
248 }
249}
250
251impl<'a, M, T> fmt::Display for MutexGuard<'a, M, T>
252where
253 M: RawMutex,
254 T: ?Sized + fmt::Display,
255{
256 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
257 fmt::Display::fmt(&**self, f)
258 }
259}
260
261#[clippy::has_significant_drop]
266pub struct MappedMutexGuard<'a, M, T>
267where
268 M: RawMutex,
269 T: ?Sized,
270{
271 state: &'a BlockingMutex<M, RefCell<State>>,
272 value: *mut T,
273}
274
275impl<'a, M, T> MappedMutexGuard<'a, M, T>
276where
277 M: RawMutex,
278 T: ?Sized,
279{
280 pub fn map<U>(this: Self, fun: impl FnOnce(&mut T) -> &mut U) -> MappedMutexGuard<'a, M, U> {
282 let state = this.state;
283 let value = fun(unsafe { &mut *this.value });
284 mem::forget(this);
287 MappedMutexGuard { state, value }
288 }
289}
290
291impl<'a, M, T> Deref for MappedMutexGuard<'a, M, T>
292where
293 M: RawMutex,
294 T: ?Sized,
295{
296 type Target = T;
297 fn deref(&self) -> &Self::Target {
298 unsafe { &*self.value }
301 }
302}
303
304impl<'a, M, T> DerefMut for MappedMutexGuard<'a, M, T>
305where
306 M: RawMutex,
307 T: ?Sized,
308{
309 fn deref_mut(&mut self) -> &mut Self::Target {
310 unsafe { &mut *self.value }
313 }
314}
315
316impl<'a, M, T> Drop for MappedMutexGuard<'a, M, T>
317where
318 M: RawMutex,
319 T: ?Sized,
320{
321 fn drop(&mut self) {
322 self.state.lock(|s| {
323 let mut s = unwrap!(s.try_borrow_mut());
324 s.locked = false;
325 s.waker.wake();
326 })
327 }
328}
329
330unsafe impl<M, T> Send for MappedMutexGuard<'_, M, T>
331where
332 M: RawMutex + Sync,
333 T: Send + ?Sized,
334{
335}
336
337unsafe impl<M, T> Sync for MappedMutexGuard<'_, M, T>
338where
339 M: RawMutex + Sync,
340 T: Sync + ?Sized,
341{
342}
343
344impl<'a, M, T> fmt::Debug for MappedMutexGuard<'a, M, T>
345where
346 M: RawMutex,
347 T: ?Sized + fmt::Debug,
348{
349 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
350 fmt::Debug::fmt(&**self, f)
351 }
352}
353
354impl<'a, M, T> fmt::Display for MappedMutexGuard<'a, M, T>
355where
356 M: RawMutex,
357 T: ?Sized + fmt::Display,
358{
359 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
360 fmt::Display::fmt(&**self, f)
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use crate::blocking_mutex::raw::NoopRawMutex;
367 use crate::mutex::{Mutex, MutexGuard};
368
369 #[futures_test::test]
370 async fn mapped_guard_releases_lock_when_dropped() {
371 let mutex: Mutex<NoopRawMutex, [i32; 2]> = Mutex::new([0, 1]);
372
373 {
374 let guard = mutex.lock().await;
375 assert_eq!(*guard, [0, 1]);
376 let mut mapped = MutexGuard::map(guard, |this| &mut this[1]);
377 assert_eq!(*mapped, 1);
378 *mapped = 2;
379 }
380
381 {
382 let guard = mutex.lock().await;
383 assert_eq!(*guard, [0, 2]);
384 let mut mapped = MutexGuard::map(guard, |this| &mut this[1]);
385 assert_eq!(*mapped, 2);
386 *mapped = 3;
387 }
388
389 assert_eq!(*mutex.lock().await, [0, 3]);
390 }
391}