1#![cfg_attr(not(test), no_std)]
2
3extern crate alloc;
4
5use core::{
6 cell::UnsafeCell,
7 fmt::{self, Debug, Display},
8 ops::{Deref, DerefMut},
9 sync::atomic::{AtomicBool, Ordering},
10};
11
12use boot_services::{tpl::Tpl, BootServices, StandardBootServices};
13
14pub struct TplMutex<'a, T: ?Sized, B: BootServices = StandardBootServices<'a>> {
16 boot_services: &'a B,
17 tpl_lock_level: Tpl,
18 lock: AtomicBool,
19 data: UnsafeCell<T>,
20}
21
22#[must_use = "if unused the TplMutex will immediately unlock"]
24pub struct TplMutexGuard<'a, T: ?Sized, B: BootServices> {
25 tpl_mutex: &'a TplMutex<'a, T, B>,
26 release_tpl: Tpl,
27}
28
29impl<'a, T, B: BootServices> TplMutex<'a, T, B> {
30 pub const fn new(boot_services: &'a B, tpl_lock_level: Tpl, data: T) -> Self {
32 Self { boot_services, tpl_lock_level, lock: AtomicBool::new(false), data: UnsafeCell::new(data) }
33 }
34}
35
36impl<'a, T: ?Sized, B: BootServices> TplMutex<'a, T, B> {
37 pub fn lock(&'a self) -> TplMutexGuard<'a, T, B> {
42 self.try_lock().map_err(|_| "Re-entrant lock").unwrap()
43 }
44
45 pub fn try_lock(&'a self) -> Result<TplMutexGuard<'a, T, B>, ()> {
50 self.lock
51 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
52 .map(|_| TplMutexGuard { release_tpl: self.boot_services.raise_tpl(self.tpl_lock_level), tpl_mutex: &self })
53 .map_err(|_| ())
54 }
55}
56
57impl<T: ?Sized, B: BootServices> Drop for TplMutexGuard<'_, T, B> {
58 fn drop(&mut self) {
59 self.tpl_mutex.boot_services.restore_tpl(self.release_tpl);
60 self.tpl_mutex.lock.store(false, Ordering::Release);
61 }
62}
63
64impl<'a, T: ?Sized, B: BootServices> Deref for TplMutexGuard<'a, T, B> {
65 type Target = T;
66 fn deref(&self) -> &'a T {
67 unsafe { self.tpl_mutex.data.get().as_ref::<'a>().unwrap() }
71 }
72}
73
74impl<'a, T: ?Sized, B: BootServices> DerefMut for TplMutexGuard<'a, T, B> {
75 fn deref_mut(&mut self) -> &'a mut T {
76 unsafe { self.tpl_mutex.data.get().as_mut().unwrap() }
80 }
81}
82
83impl<'a, T: ?Sized + fmt::Debug, B: BootServices> fmt::Debug for TplMutex<'a, T, B> {
84 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
85 let mut dbg = f.debug_struct("TplMutex");
86 match self.try_lock() {
87 Ok(guard) => dbg.field("data", &guard),
88 Err(()) => dbg.field("data", &format_args!("<locked>")),
89 };
90 dbg.finish_non_exhaustive()
91 }
92}
93
94impl<'a, T: ?Sized + fmt::Debug, B: BootServices> fmt::Debug for TplMutexGuard<'a, T, B> {
95 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
96 Debug::fmt(self.deref(), f)
97 }
98}
99
100impl<'a, T: ?Sized + fmt::Display, B: BootServices> fmt::Display for TplMutexGuard<'a, T, B> {
101 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
102 Display::fmt(self.deref(), f)
103 }
104}
105
106unsafe impl<T: ?Sized + Send, B: BootServices> Sync for TplMutex<'_, T, B> {}
107unsafe impl<T: ?Sized + Send, B: BootServices> Send for TplMutex<'_, T, B> {}
108
109unsafe impl<T: ?Sized + Sync, B: BootServices> Sync for TplMutexGuard<'_, T, B> {}
110
111#[cfg(test)]
112mod test {
113 use super::*;
114 use boot_services::MockBootServices;
115 use mockall::predicate::*;
116
117 #[derive(Debug, Default)]
118 struct TestStruct {
119 field: u32,
120 }
121 impl Display for TestStruct {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 write!(f, "{}", &self.field)
124 }
125 }
126
127 fn boot_services() -> MockBootServices {
128 let mut boot_services = MockBootServices::new();
129 boot_services.expect_raise_tpl().with(eq(Tpl::NOTIFY)).return_const(Tpl::APPLICATION);
130 boot_services.expect_restore_tpl().with(eq(Tpl::APPLICATION)).return_const(());
131 boot_services
132 }
133
134 #[test]
135 fn test_try_lock() {
136 let boot_services = boot_services();
137 let mutex = TplMutex::new(&boot_services, Tpl::NOTIFY, 0);
138
139 let guard_result = mutex.try_lock();
140 assert!(matches!(guard_result, Ok(_)), "First lock should work.");
141
142 for _ in 0..2 {
143 assert!(
144 matches!(mutex.try_lock(), Err(())),
145 "Try lock should not work when there is already a lock guard."
146 );
147 }
148
149 drop(guard_result);
150 let guard_result = mutex.try_lock();
151 assert!(matches!(guard_result, Ok(_)), "Lock should work after the guard has been dropped.");
152 }
153
154 #[test]
155 #[should_panic(expected = "Re-entrant lock")]
156 fn test_that_locking_a_locked_mutex_with_lock_fn_should_panic() {
157 let boot_services = boot_services();
158 let mutex = TplMutex::new(&boot_services, Tpl::NOTIFY, TestStruct::default());
159 let guard_result = mutex.try_lock();
160 assert!(matches!(guard_result, Ok(_)));
161 let _ = mutex.lock();
162 }
163
164 #[test]
165 fn test_debug_output_for_tpl_mutex() {
166 let boot_services = boot_services();
167 let mutex = TplMutex::new(&boot_services, Tpl::NOTIFY, TestStruct::default());
168 assert_eq!("TplMutex { data: TestStruct { field: 0 }, .. }", format!("{mutex:?}"));
169 let _guard = mutex.lock();
170 assert_eq!("TplMutex { data: <locked>, .. }", format!("{mutex:?}"));
171 }
172
173 #[test]
174 fn test_display_and_debug_output_for_tpl_mutex_guard() {
175 let boot_services = boot_services();
176 let mutex = TplMutex::new(&boot_services, Tpl::NOTIFY, TestStruct::default());
177 let guard = mutex.lock();
178 assert_eq!("0", format!("{guard}"));
179 assert_eq!("TestStruct { field: 0 }", format!("{guard:?}"));
180 }
181}