tpl_mutex/
tpl_mutex.rs

1#![cfg_attr(not(test), no_std)]
2
3extern crate alloc;
4
5use core::{
6    cell::UnsafeCell,
7    fmt::{self, Debug, Display},
8    ops::{Deref, DerefMut},
9    sync::atomic::{AtomicBool, Ordering},
10};
11
12use boot_services::{tpl::Tpl, BootServices, StandardBootServices};
13
14/// Type use for mutual exclusion of data across Tpl (task priority level)
15pub struct TplMutex<'a, T: ?Sized, B: BootServices = StandardBootServices<'a>> {
16    boot_services: &'a B,
17    tpl_lock_level: Tpl,
18    lock: AtomicBool,
19    data: UnsafeCell<T>,
20}
21
22/// RAII implementation of a [TplMutex] lock. When this structure is dropped, the lock will be unlocked.
23#[must_use = "if unused the TplMutex will immediately unlock"]
24pub struct TplMutexGuard<'a, T: ?Sized, B: BootServices> {
25    tpl_mutex: &'a TplMutex<'a, T, B>,
26    release_tpl: Tpl,
27}
28
29impl<'a, T, B: BootServices> TplMutex<'a, T, B> {
30    /// Create an new TplMutex in an unlock state.
31    pub const fn new(boot_services: &'a B, tpl_lock_level: Tpl, data: T) -> Self {
32        Self { boot_services, tpl_lock_level, lock: AtomicBool::new(false), data: UnsafeCell::new(data) }
33    }
34}
35
36impl<'a, T: ?Sized, B: BootServices> TplMutex<'a, T, B> {
37    /// Attempt to lock the mutex and return a [TplMutexGuard] if the mutex was not locked.
38    ///
39    /// # Panics
40    /// This call will panic if the mutex is already locked.
41    pub fn lock(&'a self) -> TplMutexGuard<'a, T, B> {
42        self.try_lock().map_err(|_| "Re-entrant lock").unwrap()
43    }
44
45    /// Attempt to lock the mutex and return [TplMutexGuard] if the mutex was not locked.
46    ///
47    /// # Errors
48    /// If the mutex is already lock, then this call will return [Err].
49    pub fn try_lock(&'a self) -> Result<TplMutexGuard<'a, T, B>, ()> {
50        self.lock
51            .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
52            .map(|_| TplMutexGuard { release_tpl: self.boot_services.raise_tpl(self.tpl_lock_level), tpl_mutex: &self })
53            .map_err(|_| ())
54    }
55}
56
57impl<T: ?Sized, B: BootServices> Drop for TplMutexGuard<'_, T, B> {
58    fn drop(&mut self) {
59        self.tpl_mutex.boot_services.restore_tpl(self.release_tpl);
60        self.tpl_mutex.lock.store(false, Ordering::Release);
61    }
62}
63
64impl<'a, T: ?Sized, B: BootServices> Deref for TplMutexGuard<'a, T, B> {
65    type Target = T;
66    fn deref(&self) -> &'a T {
67        // SAFETY:
68        // `as_ref` is guarantee to have a valid pointer because it come from a UnsafeCell.
69        // This also comply to the aliasing rule because it is the only way to get a reference to the data, thus no other mutable reference to this data exist.
70        unsafe { self.tpl_mutex.data.get().as_ref::<'a>().unwrap() }
71    }
72}
73
74impl<'a, T: ?Sized, B: BootServices> DerefMut for TplMutexGuard<'a, T, B> {
75    fn deref_mut(&mut self) -> &'a mut T {
76        // SAFETY:
77        // `as_ref` is guarantee to have a valid pointer because it come from a UnsafeCell.
78        // This also comply to the mutability rule because it is the only way to get a reference to the data, thus no other mutable reference to this data exist.
79        unsafe { self.tpl_mutex.data.get().as_mut().unwrap() }
80    }
81}
82
83impl<'a, T: ?Sized + fmt::Debug, B: BootServices> fmt::Debug for TplMutex<'a, T, B> {
84    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
85        let mut dbg = f.debug_struct("TplMutex");
86        match self.try_lock() {
87            Ok(guard) => dbg.field("data", &guard),
88            Err(()) => dbg.field("data", &format_args!("<locked>")),
89        };
90        dbg.finish_non_exhaustive()
91    }
92}
93
94impl<'a, T: ?Sized + fmt::Debug, B: BootServices> fmt::Debug for TplMutexGuard<'a, T, B> {
95    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
96        Debug::fmt(self.deref(), f)
97    }
98}
99
100impl<'a, T: ?Sized + fmt::Display, B: BootServices> fmt::Display for TplMutexGuard<'a, T, B> {
101    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
102        Display::fmt(self.deref(), f)
103    }
104}
105
106unsafe impl<T: ?Sized + Send, B: BootServices> Sync for TplMutex<'_, T, B> {}
107unsafe impl<T: ?Sized + Send, B: BootServices> Send for TplMutex<'_, T, B> {}
108
109unsafe impl<T: ?Sized + Sync, B: BootServices> Sync for TplMutexGuard<'_, T, B> {}
110
111#[cfg(test)]
112mod test {
113    use super::*;
114    use boot_services::MockBootServices;
115    use mockall::predicate::*;
116
117    #[derive(Debug, Default)]
118    struct TestStruct {
119        field: u32,
120    }
121    impl Display for TestStruct {
122        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123            write!(f, "{}", &self.field)
124        }
125    }
126
127    fn boot_services() -> MockBootServices {
128        let mut boot_services = MockBootServices::new();
129        boot_services.expect_raise_tpl().with(eq(Tpl::NOTIFY)).return_const(Tpl::APPLICATION);
130        boot_services.expect_restore_tpl().with(eq(Tpl::APPLICATION)).return_const(());
131        boot_services
132    }
133
134    #[test]
135    fn test_try_lock() {
136        let boot_services = boot_services();
137        let mutex = TplMutex::new(&boot_services, Tpl::NOTIFY, 0);
138
139        let guard_result = mutex.try_lock();
140        assert!(matches!(guard_result, Ok(_)), "First lock should work.");
141
142        for _ in 0..2 {
143            assert!(
144                matches!(mutex.try_lock(), Err(())),
145                "Try lock should not work when there is already a lock guard."
146            );
147        }
148
149        drop(guard_result);
150        let guard_result = mutex.try_lock();
151        assert!(matches!(guard_result, Ok(_)), "Lock should work after the guard has been dropped.");
152    }
153
154    #[test]
155    #[should_panic(expected = "Re-entrant lock")]
156    fn test_that_locking_a_locked_mutex_with_lock_fn_should_panic() {
157        let boot_services = boot_services();
158        let mutex = TplMutex::new(&boot_services, Tpl::NOTIFY, TestStruct::default());
159        let guard_result = mutex.try_lock();
160        assert!(matches!(guard_result, Ok(_)));
161        let _ = mutex.lock();
162    }
163
164    #[test]
165    fn test_debug_output_for_tpl_mutex() {
166        let boot_services = boot_services();
167        let mutex = TplMutex::new(&boot_services, Tpl::NOTIFY, TestStruct::default());
168        assert_eq!("TplMutex { data: TestStruct { field: 0 }, .. }", format!("{mutex:?}"));
169        let _guard = mutex.lock();
170        assert_eq!("TplMutex { data: <locked>, .. }", format!("{mutex:?}"));
171    }
172
173    #[test]
174    fn test_display_and_debug_output_for_tpl_mutex_guard() {
175        let boot_services = boot_services();
176        let mutex = TplMutex::new(&boot_services, Tpl::NOTIFY, TestStruct::default());
177        let guard = mutex.lock();
178        assert_eq!("0", format!("{guard}"));
179        assert_eq!("TestStruct { field: 0 }", format!("{guard:?}"));
180    }
181}