use core::{
cell::{OnceCell, UnsafeCell},
fmt::{self, Debug, Display},
ops::{Deref, DerefMut},
sync::atomic::{AtomicBool, Ordering},
};
use crate::boot_services::{BootServices, StandardBootServices, tpl::Tpl};
pub struct TplMutex<T: ?Sized, B: BootServices = StandardBootServices> {
boot_services: OnceCell<B>,
tpl_lock_level: Tpl,
lock: AtomicBool,
data: UnsafeCell<T>,
}
#[must_use = "if unused the TplMutex will immediately unlock"]
pub struct TplMutexGuard<'a, T: ?Sized, B: BootServices> {
tpl_mutex: &'a TplMutex<T, B>,
release_tpl: Tpl,
}
impl<T, B: BootServices> TplMutex<T, B> {
pub fn new(boot_services: B, tpl_lock_level: Tpl, data: T) -> Self {
let bs_cell = OnceCell::new();
bs_cell.set(boot_services).map_err(|_| "Boot services already initialized!").unwrap();
Self { boot_services: bs_cell, tpl_lock_level, lock: AtomicBool::new(false), data: UnsafeCell::new(data) }
}
pub const fn new_uninit(tpl_lock_level: Tpl, data: T) -> Self {
Self {
boot_services: OnceCell::new(),
tpl_lock_level,
lock: AtomicBool::new(false),
data: UnsafeCell::new(data),
}
}
pub fn init(&self, boot_services: B) {
self.boot_services.set(boot_services).map_err(|_| "Boot services already initialized!").unwrap();
}
}
impl<T: ?Sized, B: BootServices> TplMutex<T, B> {
pub fn lock(&self) -> TplMutexGuard<'_, T, B> {
self.try_lock().map_err(|_| "Re-entrant lock").unwrap()
}
#[allow(clippy::result_unit_err)]
pub fn try_lock(&self) -> Result<TplMutexGuard<'_, T, B>, ()> {
self.lock
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.map(|_| TplMutexGuard {
release_tpl: self
.boot_services
.get()
.expect("BootServices not initialized!")
.raise_tpl(self.tpl_lock_level),
tpl_mutex: self,
})
.map_err(|_| ())
}
}
impl<T: ?Sized, B: BootServices> Drop for TplMutexGuard<'_, T, B> {
fn drop(&mut self) {
self.tpl_mutex.boot_services.get().expect("BootServices not initialized!").restore_tpl(self.release_tpl);
self.tpl_mutex.lock.store(false, Ordering::Release);
}
}
impl<T: ?Sized, B: BootServices> Deref for TplMutexGuard<'_, T, B> {
type Target = T;
fn deref(&self) -> &T {
unsafe { self.tpl_mutex.data.get().as_ref().unwrap() }
}
}
impl<T: ?Sized, B: BootServices> DerefMut for TplMutexGuard<'_, T, B> {
fn deref_mut(&mut self) -> &mut T {
unsafe { self.tpl_mutex.data.get().as_mut().unwrap() }
}
}
impl<T: ?Sized + fmt::Debug, B: BootServices> fmt::Debug for TplMutex<T, B> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut dbg = f.debug_struct("TplMutex");
match self.try_lock() {
Ok(guard) => dbg.field("data", &guard),
Err(()) => dbg.field("data", &format_args!("<locked>")),
};
dbg.finish_non_exhaustive()
}
}
impl<T: ?Sized + fmt::Debug, B: BootServices> fmt::Debug for TplMutexGuard<'_, T, B> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Debug::fmt(self.deref(), f)
}
}
impl<T: ?Sized + fmt::Display, B: BootServices> fmt::Display for TplMutexGuard<'_, T, B> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Display::fmt(self.deref(), f)
}
}
unsafe impl<T: ?Sized + Send, B: BootServices + Send> Sync for TplMutex<T, B> {}
unsafe impl<T: ?Sized + Send, B: BootServices + Send> Send for TplMutex<T, B> {}
unsafe impl<T: ?Sized + Sync, B: BootServices> Sync for TplMutexGuard<'_, T, B> {}
#[cfg(test)]
#[coverage(off)]
mod tests {
use super::*;
use crate::boot_services::MockBootServices;
use mockall::predicate::*;
#[derive(Debug, Default)]
struct TestStruct {
field: u32,
}
impl Display for TestStruct {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", &self.field)
}
}
fn boot_services() -> MockBootServices {
let mut boot_services = MockBootServices::new();
boot_services.expect_raise_tpl().with(eq(Tpl::NOTIFY)).return_const(Tpl::APPLICATION);
boot_services.expect_restore_tpl().with(eq(Tpl::APPLICATION)).return_const(());
boot_services
}
#[test]
fn test_try_lock() {
let mutex = TplMutex::new(boot_services(), Tpl::NOTIFY, 0);
let guard_result = mutex.try_lock();
assert!(guard_result.is_ok(), "First lock should work.");
for _ in 0..2 {
assert!(
matches!(mutex.try_lock(), Err(())),
"Try lock should not work when there is already a lock guard."
);
}
drop(guard_result);
let guard_result = mutex.try_lock();
assert!(guard_result.is_ok(), "Lock should work after the guard has been dropped.");
}
#[test]
#[should_panic(expected = "Re-entrant lock")]
fn test_that_locking_a_locked_mutex_with_lock_fn_should_panic() {
let mutex = TplMutex::new(boot_services(), Tpl::NOTIFY, TestStruct::default());
let guard_result = mutex.try_lock();
assert!(guard_result.is_ok());
let _ = mutex.lock();
}
#[test]
fn test_debug_output_for_tpl_mutex() {
let mutex = TplMutex::new(boot_services(), Tpl::NOTIFY, TestStruct::default());
assert_eq!("TplMutex { data: TestStruct { field: 0 }, .. }", format!("{mutex:?}"));
let _guard = mutex.lock();
assert_eq!("TplMutex { data: <locked>, .. }", format!("{mutex:?}"));
}
#[test]
fn test_display_and_debug_output_for_tpl_mutex_guard() {
let mutex = TplMutex::new(boot_services(), Tpl::NOTIFY, TestStruct::default());
let guard = mutex.lock();
assert_eq!("0", format!("{guard}"));
assert_eq!("TestStruct { field: 0 }", format!("{guard:?}"));
}
}