1use core::cell::{RefCell, UnsafeCell};
5use core::future::{poll_fn, Future};
6use core::ops::{Deref, DerefMut};
7use core::task::Poll;
8use core::{fmt, mem};
9
10use crate::blocking_mutex::raw::RawMutex;
11use crate::blocking_mutex::Mutex as BlockingMutex;
12use crate::waitqueue::WakerRegistration;
13
14#[derive(PartialEq, Eq, Clone, Copy, Debug)]
16#[cfg_attr(feature = "defmt", derive(defmt::Format))]
17pub struct TryLockError;
18
19#[derive(Debug)]
20struct State {
21 locked: bool,
22 waker: WakerRegistration,
23}
24
25pub struct Mutex<M, T>
41where
42 M: RawMutex,
43 T: ?Sized,
44{
45 state: BlockingMutex<M, RefCell<State>>,
46 inner: UnsafeCell<T>,
47}
48
49unsafe impl<M: RawMutex + Send, T: ?Sized + Send> Send for Mutex<M, T> {}
50unsafe impl<M: RawMutex + Sync, T: ?Sized + Send> Sync for Mutex<M, T> {}
51
52impl<M, T> Mutex<M, T>
54where
55 M: RawMutex,
56{
57 pub const fn new(value: T) -> Self {
59 Self {
60 inner: UnsafeCell::new(value),
61 state: BlockingMutex::new(RefCell::new(State {
62 locked: false,
63 waker: WakerRegistration::new(),
64 })),
65 }
66 }
67}
68
69impl<M, T> Mutex<M, T>
70where
71 M: RawMutex,
72 T: ?Sized,
73{
74 pub fn lock(&self) -> impl Future<Output = MutexGuard<'_, M, T>> {
78 poll_fn(|cx| {
79 let ready = self.state.lock(|s| {
80 let mut s = s.borrow_mut();
81 if s.locked {
82 s.waker.register(cx.waker());
83 false
84 } else {
85 s.locked = true;
86 true
87 }
88 });
89
90 if ready {
91 Poll::Ready(MutexGuard { mutex: self })
92 } else {
93 Poll::Pending
94 }
95 })
96 }
97
98 pub fn try_lock(&self) -> Result<MutexGuard<'_, M, T>, TryLockError> {
102 self.state.lock(|s| {
103 let mut s = s.borrow_mut();
104 if s.locked {
105 Err(TryLockError)
106 } else {
107 s.locked = true;
108 Ok(())
109 }
110 })?;
111
112 Ok(MutexGuard { mutex: self })
113 }
114
115 pub fn into_inner(self) -> T
117 where
118 T: Sized,
119 {
120 self.inner.into_inner()
121 }
122
123 pub fn get_mut(&mut self) -> &mut T {
128 self.inner.get_mut()
129 }
130}
131
132impl<M: RawMutex, T> From<T> for Mutex<M, T> {
133 fn from(from: T) -> Self {
134 Self::new(from)
135 }
136}
137
138impl<M, T> Default for Mutex<M, T>
139where
140 M: RawMutex,
141 T: Default,
142{
143 fn default() -> Self {
144 Self::new(Default::default())
145 }
146}
147
148impl<M, T> fmt::Debug for Mutex<M, T>
149where
150 M: RawMutex,
151 T: ?Sized + fmt::Debug,
152{
153 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
154 let mut d = f.debug_struct("Mutex");
155 match self.try_lock() {
156 Ok(value) => {
157 d.field("inner", &&*value);
158 }
159 Err(TryLockError) => {
160 d.field("inner", &format_args!("<locked>"));
161 }
162 }
163
164 d.finish_non_exhaustive()
165 }
166}
167
168#[clippy::has_significant_drop]
175#[must_use = "if unused the Mutex will immediately unlock"]
176pub struct MutexGuard<'a, M, T>
177where
178 M: RawMutex,
179 T: ?Sized,
180{
181 mutex: &'a Mutex<M, T>,
182}
183
184impl<'a, M, T> MutexGuard<'a, M, T>
185where
186 M: RawMutex,
187 T: ?Sized,
188{
189 pub fn map<U>(this: Self, fun: impl FnOnce(&mut T) -> &mut U) -> MappedMutexGuard<'a, M, U> {
191 let mutex = this.mutex;
192 let value = fun(unsafe { &mut *this.mutex.inner.get() });
193 mem::forget(this);
196 MappedMutexGuard {
197 state: &mutex.state,
198 value,
199 }
200 }
201}
202
203impl<'a, M, T> Drop for MutexGuard<'a, M, T>
204where
205 M: RawMutex,
206 T: ?Sized,
207{
208 fn drop(&mut self) {
209 self.mutex.state.lock(|s| {
210 let mut s = unwrap!(s.try_borrow_mut());
211 s.locked = false;
212 s.waker.wake();
213 })
214 }
215}
216
217impl<'a, M, T> Deref for MutexGuard<'a, M, T>
218where
219 M: RawMutex,
220 T: ?Sized,
221{
222 type Target = T;
223 fn deref(&self) -> &Self::Target {
224 unsafe { &*(self.mutex.inner.get() as *const T) }
227 }
228}
229
230impl<'a, M, T> DerefMut for MutexGuard<'a, M, T>
231where
232 M: RawMutex,
233 T: ?Sized,
234{
235 fn deref_mut(&mut self) -> &mut Self::Target {
236 unsafe { &mut *(self.mutex.inner.get()) }
239 }
240}
241
242impl<'a, M, T> fmt::Debug for MutexGuard<'a, M, T>
243where
244 M: RawMutex,
245 T: ?Sized + fmt::Debug,
246{
247 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
248 fmt::Debug::fmt(&**self, f)
249 }
250}
251
252impl<'a, M, T> fmt::Display for MutexGuard<'a, M, T>
253where
254 M: RawMutex,
255 T: ?Sized + fmt::Display,
256{
257 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258 fmt::Display::fmt(&**self, f)
259 }
260}
261
262#[clippy::has_significant_drop]
267pub struct MappedMutexGuard<'a, M, T>
268where
269 M: RawMutex,
270 T: ?Sized,
271{
272 state: &'a BlockingMutex<M, RefCell<State>>,
273 value: *mut T,
274}
275
276impl<'a, M, T> MappedMutexGuard<'a, M, T>
277where
278 M: RawMutex,
279 T: ?Sized,
280{
281 pub fn map<U>(this: Self, fun: impl FnOnce(&mut T) -> &mut U) -> MappedMutexGuard<'a, M, U> {
283 let state = this.state;
284 let value = fun(unsafe { &mut *this.value });
285 mem::forget(this);
288 MappedMutexGuard { state, value }
289 }
290}
291
292impl<'a, M, T> Deref for MappedMutexGuard<'a, M, T>
293where
294 M: RawMutex,
295 T: ?Sized,
296{
297 type Target = T;
298 fn deref(&self) -> &Self::Target {
299 unsafe { &*self.value }
302 }
303}
304
305impl<'a, M, T> DerefMut for MappedMutexGuard<'a, M, T>
306where
307 M: RawMutex,
308 T: ?Sized,
309{
310 fn deref_mut(&mut self) -> &mut Self::Target {
311 unsafe { &mut *self.value }
314 }
315}
316
317impl<'a, M, T> Drop for MappedMutexGuard<'a, M, T>
318where
319 M: RawMutex,
320 T: ?Sized,
321{
322 fn drop(&mut self) {
323 self.state.lock(|s| {
324 let mut s = unwrap!(s.try_borrow_mut());
325 s.locked = false;
326 s.waker.wake();
327 })
328 }
329}
330
331unsafe impl<M, T> Send for MappedMutexGuard<'_, M, T>
332where
333 M: RawMutex + Sync,
334 T: Send + ?Sized,
335{
336}
337
338unsafe impl<M, T> Sync for MappedMutexGuard<'_, M, T>
339where
340 M: RawMutex + Sync,
341 T: Sync + ?Sized,
342{
343}
344
345impl<'a, M, T> fmt::Debug for MappedMutexGuard<'a, M, T>
346where
347 M: RawMutex,
348 T: ?Sized + fmt::Debug,
349{
350 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351 fmt::Debug::fmt(&**self, f)
352 }
353}
354
355impl<'a, M, T> fmt::Display for MappedMutexGuard<'a, M, T>
356where
357 M: RawMutex,
358 T: ?Sized + fmt::Display,
359{
360 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361 fmt::Display::fmt(&**self, f)
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use crate::blocking_mutex::raw::NoopRawMutex;
368 use crate::mutex::{Mutex, MutexGuard};
369
370 #[futures_test::test]
371 async fn mapped_guard_releases_lock_when_dropped() {
372 let mutex: Mutex<NoopRawMutex, [i32; 2]> = Mutex::new([0, 1]);
373
374 {
375 let guard = mutex.lock().await;
376 assert_eq!(*guard, [0, 1]);
377 let mut mapped = MutexGuard::map(guard, |this| &mut this[1]);
378 assert_eq!(*mapped, 1);
379 *mapped = 2;
380 }
381
382 {
383 let guard = mutex.lock().await;
384 assert_eq!(*guard, [0, 2]);
385 let mut mapped = MutexGuard::map(guard, |this| &mut this[1]);
386 assert_eq!(*mapped, 2);
387 *mapped = 3;
388 }
389
390 assert_eq!(*mutex.lock().await, [0, 3]);
391 }
392}