surelock 0.1.0

Deadlock-free locks for Rust with compile time guarantees, incremental locks, and atomic lock sets.
Documentation
#![allow(unsafe_code)]
//! Deadlock-free mutex, generic over lock level and backend.
//!
//! [`Mutex`] wraps a [`RawMutex`]
//! implementation and a `T`, tagging the pair with a lock level.
//! All ordered locking goes through
//! [`MutexKey`](crate::key::MutexKey) +
//! [`LockSet`](crate::set::LockSet) -- there is no public `lock()`
//! method by default.
//!
//! The level parameter `Lvl` defaults to [`Base`] (= `Level<0>`).
//! The backend parameter `R` defaults to
//! [`StdMutex`](crate::raw_mutex::std_mutex::StdMutex) when the `std`
//! feature is enabled (which it is by default).
//!
//! Users can specify just the level without naming the backend:
//! `Mutex<u32, Level<3>>`.

pub mod guard;

use core::{cell::UnsafeCell, fmt, marker::PhantomData};

use crate::{id::LockId, level::IsLevel, raw_mutex::RawMutex};

#[cfg(feature = "std")]
use crate::level::Base;

#[cfg(feature = "escape-hatch")]
use guard::MutexGuard;

/// A deadlock-free mutex generic over lock level and backend.
///
/// `T` is the protected data. `Lvl` defaults to [`Base`] (=
/// `Level<0>`) -- levels are opt-in for incremental cross-level
/// acquisition. `R` defaults to
/// [`StdMutex`](crate::raw_mutex::std_mutex::StdMutex) on `std`.
///
/// Specify just the level to use the default backend:
/// `Mutex<u32, Level<3>>`.
///
/// All ordered locking goes through
/// [`MutexKey::lock`](crate::key::MutexKey::lock). There is no public
/// `lock()` method unless the `escape-hatch` feature is enabled.
///
/// # Examples
///
/// ```rust
/// use surelock::{key::lock_scope, mutex::Mutex};
///
/// let counter: Mutex<u32> = Mutex::new(0);
///
/// lock_scope(|key| {
///     let (mut guard, _key) = key.lock(&counter);
///     *guard += 1;
/// });
/// ```
#[cfg(feature = "std")]
pub struct Mutex<T, Lvl: IsLevel = Base, R: RawMutex = crate::raw_mutex::std_mutex::StdMutex> {
    id: LockId,
    pub(crate) raw: R,
    pub(crate) data: UnsafeCell<T>,
    _level: PhantomData<Lvl>,
}

/// A deadlock-free mutex generic over lock level and backend.
///
/// `T` is the protected data. `Lvl` defaults to [`Base`] (=
/// `Level<0>`). `R` is any [`RawMutex`] implementation.
///
/// Enable the `std` feature (on by default) for a default `R`
/// parameter, allowing `Mutex<u32>` or `Mutex<u32, Level<3>>`
/// without naming the backend.
///
/// # Examples
///
/// ```rust,ignore
/// use surelock::{key::lock_scope, mutex::Mutex};
///
/// type M<T> = Mutex<T, surelock::level::Base, spin::mutex::SpinMutex<()>>;
/// let counter: M<u32> = Mutex::new(0);
/// ```
#[cfg(not(feature = "std"))]
pub struct Mutex<T, Lvl: IsLevel, R: RawMutex> {
    id: LockId,
    pub(crate) raw: R,
    pub(crate) data: UnsafeCell<T>,
    _level: PhantomData<Lvl>,
}

impl<T, Lvl: IsLevel, R: RawMutex> Mutex<T, Lvl, R> {
    /// Create a new mutex with the given data.
    ///
    /// A unique [`LockId`] is assigned from a global atomic counter.
    #[must_use]
    pub fn new(data: T) -> Self {
        Self {
            id: LockId::next(),
            raw: RawMutex::new(),
            data: UnsafeCell::new(data),
            _level: PhantomData,
        }
    }

    /// Returns this mutex's unique [`LockId`].
    #[must_use]
    pub const fn id(&self) -> LockId {
        self.id
    }

    /// Exclusive access via `&mut` -- no locking needed.
    ///
    /// Since the caller has exclusive ownership, no other thread can
    /// be holding the lock.
    #[must_use]
    pub const fn get_mut(&mut self) -> &mut T {
        self.data.get_mut()
    }

    /// Consume the mutex and return the inner data.
    #[must_use]
    pub fn into_inner(self) -> T {
        self.data.into_inner()
    }
}

// -- new_higher: ordered construction with inferred level --

/// Associated function for creating mutexes ordered after existing
/// ones, using the default `StdMutex` backend.
///
/// For custom backends, use the [`NewHigher`] trait method instead:
/// `parent.new_higher(data)`.
#[cfg(feature = "std")]
impl<T> Mutex<T> {
    /// Create a new mutex ordered after one or more parents.
    ///
    /// The new mutex's level is `max(parent levels) + 1`. Accepts
    /// a single `&Mutex` or a tuple of `&Mutex` references.
    /// Smart-pointer wrappers (`Arc`, `Rc`, `Box`) are also accepted.
    ///
    /// Uses the default `StdMutex` backend. For custom backends, use
    /// the [`NewHigher`] trait method: `parent.new_higher(data)`,
    /// which inherits the backend from the parent.
    ///
    /// # Examples
    ///
    /// ```rust
    /// use surelock::mutex::Mutex;
    ///
    /// let config: Mutex<u32> = Mutex::new(10);
    /// let account = Mutex::new_higher(20u32, &config);  // Level<1>
    /// let txn = Mutex::new_higher(30u32, &account);      // Level<2>
    ///
    /// // Siblings: same parent, same level
    /// let acct_a = Mutex::new_higher(1u32, &config);  // Level<1>
    /// let acct_b = Mutex::new_higher(2u32, &config);  // Level<1>
    ///
    /// // Multi-parent: after both config and account
    /// let reconciler = Mutex::new_higher(0u32, (&config, &account));
    /// // Level = max(Level<0>, Level<1>) + 1 = Level<2>
    /// ```
    #[must_use]
    pub fn new_higher<Parents: NewHigher<T, crate::raw_mutex::std_mutex::StdMutex>>(
        data: T,
        parents: Parents,
    ) -> Mutex<T, Parents::NextLvl, crate::raw_mutex::std_mutex::StdMutex> {
        parents.new_higher(data)
    }
}

// Escape hatch: direct lock bypassing the ordering system.
#[cfg(feature = "escape-hatch")]
impl<T, Lvl: IsLevel, R: RawMutex> Mutex<T, Lvl, R> {
    /// Acquire this mutex directly, bypassing the ordering system.
    ///
    /// This has the same semantics as `std::sync::Mutex::lock()`:
    /// no key is needed, no ordering is checked, and deadlock
    /// prevention is the caller's responsibility.
    ///
    /// Only available with the `escape-hatch` feature enabled.
    pub fn unchecked_lock(&self) -> MutexGuard<'_, R, T> {
        let raw_guard = self.raw.lock();
        MutexGuard {
            data: &self.data,
            _raw_guard: raw_guard,
        }
    }
}

impl<T: fmt::Debug, Lvl: IsLevel, R: RawMutex> fmt::Debug for Mutex<T, Lvl, R> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("Mutex")
            .field("id", &self.id)
            .finish_non_exhaustive()
    }
}

// SAFETY: UnsafeCell<T>: Send when T: Send. Explicit for clarity since
// the Sync impl below is manual.
unsafe impl<T: Send, Lvl: IsLevel, R: RawMutex + Send> Send for Mutex<T, Lvl, R> {}

// SAFETY: UnsafeCell<T> is unconditionally !Sync, but the RawMutex lock
// guarantees exclusive access to the inner T. R: Sync because &Mutex
// shares &R across threads. T: Send (not Sync) because the mutex
// provides exclusive access -- we're sending T between threads, not
// sharing it.
unsafe impl<T: Send, Lvl: IsLevel, R: RawMutex + Sync> Sync for Mutex<T, Lvl, R> {}

// -- MutexLevel impls for Mutex references and smart pointers --
//
// These feed into the blanket `MaxLevelOf` impls in `level.rs`,
// which accept single parents and 2-tuples of any combination.

use crate::level::MutexLevel;

impl<T, Lvl: IsLevel, R: RawMutex> MutexLevel for &Mutex<T, Lvl, R> {
    type Lvl = Lvl;
}

#[cfg(target_has_atomic = "ptr")]
impl<T, Lvl: IsLevel, R: RawMutex> MutexLevel for &alloc::sync::Arc<Mutex<T, Lvl, R>> {
    type Lvl = Lvl;
}

impl<T, Lvl: IsLevel, R: RawMutex> MutexLevel for &alloc::rc::Rc<Mutex<T, Lvl, R>> {
    type Lvl = Lvl;
}

impl<T, Lvl: IsLevel, R: RawMutex> MutexLevel for &alloc::boxed::Box<Mutex<T, Lvl, R>> {
    type Lvl = Lvl;
}

// -- NewHigher impls --
//
// Single mutex: child level = parent level + 1, backend inherited.
// Smart pointer blankets delegate via Deref.
// 2-tuple: same backend required, level = max(levels) + 1.

use crate::level::NewHigher;

impl<ParentT, ChildT, Lvl: IsLevel + crate::level::NextLevel, R: RawMutex> NewHigher<ChildT, R>
    for Mutex<ParentT, Lvl, R>
{
    type NextLvl = Lvl::Next;

    fn new_higher(&self, data: ChildT) -> Mutex<ChildT, Lvl::Next, R> {
        Mutex {
            id: LockId::next(),
            raw: RawMutex::new(),
            data: UnsafeCell::new(data),
            _level: PhantomData,
        }
    }
}

impl<ChildT, R: RawMutex, T: NewHigher<ChildT, R> + ?Sized> NewHigher<ChildT, R> for &T {
    type NextLvl = T::NextLvl;

    fn new_higher(&self, data: ChildT) -> Mutex<ChildT, T::NextLvl, R> {
        T::new_higher(self, data)
    }
}

#[cfg(target_has_atomic = "ptr")]
impl<ChildT, R: RawMutex, T: NewHigher<ChildT, R>> NewHigher<ChildT, R> for alloc::sync::Arc<T> {
    type NextLvl = T::NextLvl;

    fn new_higher(&self, data: ChildT) -> Mutex<ChildT, T::NextLvl, R> {
        T::new_higher(self, data)
    }
}

impl<ChildT, R: RawMutex, T: NewHigher<ChildT, R>> NewHigher<ChildT, R> for alloc::rc::Rc<T> {
    type NextLvl = T::NextLvl;

    fn new_higher(&self, data: ChildT) -> Mutex<ChildT, T::NextLvl, R> {
        T::new_higher(self, data)
    }
}

impl<ChildT, R: RawMutex, T: NewHigher<ChildT, R>> NewHigher<ChildT, R> for alloc::boxed::Box<T> {
    type NextLvl = T::NextLvl;

    fn new_higher(&self, data: ChildT) -> Mutex<ChildT, T::NextLvl, R> {
        T::new_higher(self, data)
    }
}

// 2-tuple: both parents must share the same backend R.
// Level = max(Lvl1, Lvl2) + 1.
// Accepts any combination of bare refs, Arc, Rc, Box via NewHigher<ChildT, R>
// bounds on each element (which give us NextLvl = parent + 1). However, we
// need the RAW parent levels to compute max, so we use MutexLevel on the
// reference types.
//
// We implement for (A, B) where A and B are references to anything with
// MutexLevel, and both share backend R. This covers:
//   (&Mutex<_, Lvl1, R>, &Mutex<_, Lvl2, R>)
//   (&Arc<Mutex<_, Lvl1, R>>, &Mutex<_, Lvl2, R>)
//   (&Arc<Mutex<_, Lvl1, R>>, &Arc<Mutex<_, Lvl2, R>>)
//   etc.

macro_rules! impl_new_higher_tuple {
    ($a_wrapper:ty, $b_wrapper:ty) => {
        impl<
            ChildT,
            ParentT1,
            ParentT2,
            Lvl1: IsLevel + crate::level::MaxLevel<Lvl2>,
            Lvl2: IsLevel,
            R: RawMutex,
        > NewHigher<ChildT, R> for (&$a_wrapper, &$b_wrapper)
        where
            <Lvl1 as crate::level::MaxLevel<Lvl2>>::Max: crate::level::NextLevel,
        {
            type NextLvl =
                <<Lvl1 as crate::level::MaxLevel<Lvl2>>::Max as crate::level::NextLevel>::Next;

            fn new_higher(&self, data: ChildT) -> Mutex<ChildT, Self::NextLvl, R> {
                Mutex {
                    id: LockId::next(),
                    raw: RawMutex::new(),
                    data: UnsafeCell::new(data),
                    _level: PhantomData,
                }
            }
        }
    };
}

// 3x3 = 9 combinations of bare/Rc/Box (always available).
impl_new_higher_tuple!(Mutex<ParentT1, Lvl1, R>, Mutex<ParentT2, Lvl2, R>);
impl_new_higher_tuple!(Mutex<ParentT1, Lvl1, R>, alloc::rc::Rc<Mutex<ParentT2, Lvl2, R>>);
impl_new_higher_tuple!(Mutex<ParentT1, Lvl1, R>, alloc::boxed::Box<Mutex<ParentT2, Lvl2, R>>);
impl_new_higher_tuple!(alloc::rc::Rc<Mutex<ParentT1, Lvl1, R>>, Mutex<ParentT2, Lvl2, R>);
impl_new_higher_tuple!(
    alloc::rc::Rc<Mutex<ParentT1, Lvl1, R>>,
    alloc::rc::Rc<Mutex<ParentT2, Lvl2, R>>
);
impl_new_higher_tuple!(
    alloc::rc::Rc<Mutex<ParentT1, Lvl1, R>>,
    alloc::boxed::Box<Mutex<ParentT2, Lvl2, R>>
);
impl_new_higher_tuple!(alloc::boxed::Box<Mutex<ParentT1, Lvl1, R>>, Mutex<ParentT2, Lvl2, R>);
impl_new_higher_tuple!(
    alloc::boxed::Box<Mutex<ParentT1, Lvl1, R>>,
    alloc::rc::Rc<Mutex<ParentT2, Lvl2, R>>
);
impl_new_higher_tuple!(
    alloc::boxed::Box<Mutex<ParentT1, Lvl1, R>>,
    alloc::boxed::Box<Mutex<ParentT2, Lvl2, R>>
);

// +7 Arc combinations (4x4 - 3x3 = 7), requires pointer-width atomics.
#[cfg(target_has_atomic = "ptr")]
impl_new_higher_tuple!(Mutex<ParentT1, Lvl1, R>, alloc::sync::Arc<Mutex<ParentT2, Lvl2, R>>);
#[cfg(target_has_atomic = "ptr")]
impl_new_higher_tuple!(alloc::sync::Arc<Mutex<ParentT1, Lvl1, R>>, Mutex<ParentT2, Lvl2, R>);
#[cfg(target_has_atomic = "ptr")]
impl_new_higher_tuple!(
    alloc::sync::Arc<Mutex<ParentT1, Lvl1, R>>,
    alloc::sync::Arc<Mutex<ParentT2, Lvl2, R>>
);
#[cfg(target_has_atomic = "ptr")]
impl_new_higher_tuple!(
    alloc::sync::Arc<Mutex<ParentT1, Lvl1, R>>,
    alloc::rc::Rc<Mutex<ParentT2, Lvl2, R>>
);
#[cfg(target_has_atomic = "ptr")]
impl_new_higher_tuple!(
    alloc::sync::Arc<Mutex<ParentT1, Lvl1, R>>,
    alloc::boxed::Box<Mutex<ParentT2, Lvl2, R>>
);
#[cfg(target_has_atomic = "ptr")]
impl_new_higher_tuple!(
    alloc::rc::Rc<Mutex<ParentT1, Lvl1, R>>,
    alloc::sync::Arc<Mutex<ParentT2, Lvl2, R>>
);
#[cfg(target_has_atomic = "ptr")]
impl_new_higher_tuple!(
    alloc::boxed::Box<Mutex<ParentT1, Lvl1, R>>,
    alloc::sync::Arc<Mutex<ParentT2, Lvl2, R>>
);