patina 21.1.1

Common types and functionality used in UEFI development.
Documentation
//! A module containing a TPL aware Mutex implementation.
//!
//! ## License
//!
//! Copyright (c) Microsoft Corporation.
//!
//! SPDX-License-Identifier: Apache-2.0
//!
use core::{
    cell::{OnceCell, UnsafeCell},
    fmt::{self, Debug, Display},
    ops::{Deref, DerefMut},
    sync::atomic::{AtomicBool, Ordering},
};

use crate::boot_services::{BootServices, StandardBootServices, tpl::Tpl};

/// Type use for mutual exclusion of data across Tpl (task priority level)
///
/// This mutex will raise the TPL to the specified level when locked, and restore it when the lock is released.
///
/// The mutex owns the BootServices instance. Callers pass an owned instance or clone if needed.
pub struct TplMutex<T: ?Sized, B: BootServices = StandardBootServices> {
    boot_services: OnceCell<B>,
    tpl_lock_level: Tpl,
    lock: AtomicBool,
    data: UnsafeCell<T>,
}

/// RAII implementation of a [TplMutex] lock. When this structure is dropped, the lock will be unlocked.
#[must_use = "if unused the TplMutex will immediately unlock"]
pub struct TplMutexGuard<'a, T: ?Sized, B: BootServices> {
    tpl_mutex: &'a TplMutex<T, B>,
    release_tpl: Tpl,
}

impl<T, B: BootServices> TplMutex<T, B> {
    /// Create a new TplMutex in an unlocked state.
    /// Takes ownership of the boot_services instance. Callers can pass an owned
    /// instance directly or clone if they need to retain a copy.
    ///
    /// # Panics
    /// This call will panic if the mutex is already initialized (should not be possible here).
    pub fn new(boot_services: B, tpl_lock_level: Tpl, data: T) -> Self {
        let bs_cell = OnceCell::new();
        bs_cell.set(boot_services).map_err(|_| "Boot services already initialized!").unwrap();
        Self { boot_services: bs_cell, tpl_lock_level, lock: AtomicBool::new(false), data: UnsafeCell::new(data) }
    }

    /// Create a new TplMutex in an unlocked, uninitialized state.
    /// The resulting TplMutex will not be usable until `boot_services` is initialized.
    pub const fn new_uninit(tpl_lock_level: Tpl, data: T) -> Self {
        Self {
            boot_services: OnceCell::new(),
            tpl_lock_level,
            lock: AtomicBool::new(false),
            data: UnsafeCell::new(data),
        }
    }

    /// Initialize the boot services for this TplMutex. This must be called before the mutex can be used.
    ///
    /// # Panics
    /// This call will panic if the mutex is already initialized.
    pub fn init(&self, boot_services: B) {
        self.boot_services.set(boot_services).map_err(|_| "Boot services already initialized!").unwrap();
    }
}

impl<T: ?Sized, B: BootServices> TplMutex<T, B> {
    /// Attempt to lock the mutex and return a [TplMutexGuard] if the mutex was not locked.
    ///
    /// # Panics
    /// This call will panic if the mutex is already locked.
    pub fn lock(&self) -> TplMutexGuard<'_, T, B> {
        self.try_lock().map_err(|_| "Re-entrant lock").unwrap()
    }

    /// Attempt to lock the mutex and return [TplMutexGuard] if the mutex was not locked.
    ///
    /// # Errors
    /// If the mutex is already lock, then this call will return [Err].
    ///
    /// # Panics
    /// This call will panic if the mutex is not initialized.
    #[allow(clippy::result_unit_err)]
    pub fn try_lock(&self) -> Result<TplMutexGuard<'_, T, B>, ()> {
        self.lock
            .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
            .map(|_| TplMutexGuard {
                release_tpl: self
                    .boot_services
                    .get()
                    .expect("BootServices not initialized!")
                    .raise_tpl(self.tpl_lock_level),
                tpl_mutex: self,
            })
            .map_err(|_| ())
    }
}

impl<T: ?Sized, B: BootServices> Drop for TplMutexGuard<'_, T, B> {
    fn drop(&mut self) {
        self.tpl_mutex.boot_services.get().expect("BootServices not initialized!").restore_tpl(self.release_tpl);
        self.tpl_mutex.lock.store(false, Ordering::Release);
    }
}

impl<T: ?Sized, B: BootServices> Deref for TplMutexGuard<'_, T, B> {
    type Target = T;
    fn deref(&self) -> &T {
        // SAFETY:
        // `as_ref` is guarantee to have a valid pointer because it come from a UnsafeCell.
        // This also comply to the aliasing rule because it is the only way to get a reference to the data, thus no other mutable reference to this data exist.
        unsafe { self.tpl_mutex.data.get().as_ref().unwrap() }
    }
}

impl<T: ?Sized, B: BootServices> DerefMut for TplMutexGuard<'_, T, B> {
    fn deref_mut(&mut self) -> &mut T {
        // SAFETY:
        // `as_ref` is guarantee to have a valid pointer because it come from a UnsafeCell.
        // This also comply to the mutability rule because it is the only way to get a reference to the data, thus no other mutable reference to this data exist.
        unsafe { self.tpl_mutex.data.get().as_mut().unwrap() }
    }
}

impl<T: ?Sized + fmt::Debug, B: BootServices> fmt::Debug for TplMutex<T, B> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        let mut dbg = f.debug_struct("TplMutex");
        match self.try_lock() {
            Ok(guard) => dbg.field("data", &guard),
            Err(()) => dbg.field("data", &format_args!("<locked>")),
        };
        dbg.finish_non_exhaustive()
    }
}

impl<T: ?Sized + fmt::Debug, B: BootServices> fmt::Debug for TplMutexGuard<'_, T, B> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        Debug::fmt(self.deref(), f)
    }
}

impl<T: ?Sized + fmt::Display, B: BootServices> fmt::Display for TplMutexGuard<'_, T, B> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        Display::fmt(self.deref(), f)
    }
}

// SAFETY: TplMutex is Sync because it ensures exclusive access to T through TPL-based locking.
// The lock/unlock operations at TPL_HIGH_LEVEL prevent concurrent access. T must be Send to
// allow transfer between threads, and the mutex ensures only one thread accesses T at a time.
unsafe impl<T: ?Sized + Send, B: BootServices + Send> Sync for TplMutex<T, B> {}
// SAFETY: TplMutex is Send because it owns T (which is Send) and uses TPL locking to ensure
// thread-safe access. The mutex can be safely transferred between threads.
unsafe impl<T: ?Sized + Send, B: BootServices + Send> Send for TplMutex<T, B> {}

// SAFETY: TplMutexGuard is Sync when T is Sync because the guard represents exclusive access
// to T through the TPL mutex. The guard can be shared across threads safely.
unsafe impl<T: ?Sized + Sync, B: BootServices> Sync for TplMutexGuard<'_, T, B> {}

#[cfg(test)]
#[coverage(off)]
mod tests {
    use super::*;
    use crate::boot_services::MockBootServices;
    use mockall::predicate::*;

    #[derive(Debug, Default)]
    struct TestStruct {
        field: u32,
    }
    impl Display for TestStruct {
        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
            write!(f, "{}", &self.field)
        }
    }

    fn boot_services() -> MockBootServices {
        let mut boot_services = MockBootServices::new();
        boot_services.expect_raise_tpl().with(eq(Tpl::NOTIFY)).return_const(Tpl::APPLICATION);
        boot_services.expect_restore_tpl().with(eq(Tpl::APPLICATION)).return_const(());
        boot_services
    }

    #[test]
    fn test_try_lock() {
        let mutex = TplMutex::new(boot_services(), Tpl::NOTIFY, 0);

        let guard_result = mutex.try_lock();
        assert!(guard_result.is_ok(), "First lock should work.");

        for _ in 0..2 {
            assert!(
                matches!(mutex.try_lock(), Err(())),
                "Try lock should not work when there is already a lock guard."
            );
        }

        drop(guard_result);
        let guard_result = mutex.try_lock();
        assert!(guard_result.is_ok(), "Lock should work after the guard has been dropped.");
    }

    #[test]
    #[should_panic(expected = "Re-entrant lock")]
    fn test_that_locking_a_locked_mutex_with_lock_fn_should_panic() {
        let mutex = TplMutex::new(boot_services(), Tpl::NOTIFY, TestStruct::default());
        let guard_result = mutex.try_lock();
        assert!(guard_result.is_ok());
        let _ = mutex.lock();
    }

    #[test]
    fn test_debug_output_for_tpl_mutex() {
        let mutex = TplMutex::new(boot_services(), Tpl::NOTIFY, TestStruct::default());
        assert_eq!("TplMutex { data: TestStruct { field: 0 }, .. }", format!("{mutex:?}"));
        let _guard = mutex.lock();
        assert_eq!("TplMutex { data: <locked>, .. }", format!("{mutex:?}"));
    }

    #[test]
    fn test_display_and_debug_output_for_tpl_mutex_guard() {
        let mutex = TplMutex::new(boot_services(), Tpl::NOTIFY, TestStruct::default());
        let guard = mutex.lock();
        assert_eq!("0", format!("{guard}"));
        assert_eq!("TestStruct { field: 0 }", format!("{guard:?}"));
    }
}