surelock 0.1.0

Deadlock-free locks for Rust with compile time guarantees, incremental locks, and atomic lock sets.
Documentation
//! Internal trait for types that can be locked atomically.
//! See [`Acquirable`]. Most users do not need to interact with this
//! trait directly.

pub mod tuples;

use alloc::vec::Vec;

use crate::{
    id::LockId,
    level::IsLevel,
    mutex::{Mutex, guard::MutexGuard},
    raw_mutex::RawMutex,
};

/// Uniform access to a [`Mutex`] through any wrapper type.
///
/// Implemented for bare references, `Arc`, `Rc`, and `Box`. Used by
/// tuple [`Acquirable`] impls so that `LockSet::new((&arc, &bare))`
/// and `key.lock_with(&(&arc, &rc), ...)` work transparently.
///
/// Associated types (`Data`, `Lvl`, `RawMtx`) are determined by
/// `Self`, avoiding unconstrained type parameter issues in generic
/// tuple impls.
pub trait MutexRef<'a> {
    /// The data type guarded by the mutex.
    type Data: 'a;

    /// The level of the mutex.
    type Lvl: IsLevel;

    /// The raw mutex backend.
    type RawMtx: RawMutex + 'a;

    /// Return this mutex's [`LockId`].
    fn id(&self) -> LockId;

    /// Lock the mutex and return the guard.
    fn lock_ref(&'a self) -> MutexGuard<'a, Self::RawMtx, Self::Data>;
}

impl<'a, T: 'a, Lvl: IsLevel, R: RawMutex + 'a> MutexRef<'a> for &'a Mutex<T, Lvl, R> {
    type Data = T;
    type Lvl = Lvl;
    type RawMtx = R;

    fn id(&self) -> LockId {
        Mutex::id(self)
    }

    fn lock_ref(&'a self) -> MutexGuard<'a, R, T> {
        MutexGuard {
            data: &self.data,
            _raw_guard: self.raw.lock(),
        }
    }
}

#[cfg(target_has_atomic = "ptr")]
impl<'a, T: 'a, Lvl: IsLevel, R: RawMutex + 'a> MutexRef<'a>
    for &'a alloc::sync::Arc<Mutex<T, Lvl, R>>
{
    type Data = T;
    type Lvl = Lvl;
    type RawMtx = R;

    fn id(&self) -> LockId {
        Mutex::id(self)
    }

    fn lock_ref(&'a self) -> MutexGuard<'a, R, T> {
        MutexGuard {
            data: &self.data,
            _raw_guard: self.raw.lock(),
        }
    }
}

impl<'a, T: 'a, Lvl: IsLevel, R: RawMutex + 'a> MutexRef<'a>
    for &'a alloc::rc::Rc<Mutex<T, Lvl, R>>
{
    type Data = T;
    type Lvl = Lvl;
    type RawMtx = R;

    fn id(&self) -> LockId {
        Mutex::id(self)
    }

    fn lock_ref(&'a self) -> MutexGuard<'a, R, T> {
        MutexGuard {
            data: &self.data,
            _raw_guard: self.raw.lock(),
        }
    }
}

impl<'a, T: 'a, Lvl: IsLevel, R: RawMutex + 'a> MutexRef<'a>
    for &'a alloc::boxed::Box<Mutex<T, Lvl, R>>
{
    type Data = T;
    type Lvl = Lvl;
    type RawMtx = R;

    fn id(&self) -> LockId {
        Mutex::id(self)
    }

    fn lock_ref(&'a self) -> MutexGuard<'a, R, T> {
        MutexGuard {
            data: &self.data,
            _raw_guard: self.raw.lock(),
        }
    }
}

/// A type (or collection of types) that can be locked atomically.
///
/// This trait is primarily internal bookkeeping -- most users
/// interact with [`Mutex`], [`LockSet`](crate::set::LockSet), and
/// [`MutexKey`](crate::key::MutexKey) directly. Implementations are
/// provided for single `&Mutex` references and tuples of up to 12
/// mutex references. Users do not typically need to implement this
/// trait.
///
/// Elements can be at different levels. The `MinLvl` and `MaxLvl`
/// associated types track the range. For a single mutex or
/// same-level tuple, `MinLvl = MaxLvl`.
///
/// The lifetime `'a` is the borrow lifetime of the lock group. Guards
/// returned by [`lock_sorted`](Acquirable::lock_sorted) borrow from
/// the underlying mutexes for this lifetime.
///
/// This is a safe trait because a wrong implementation can only
/// cause deadlocks (a liveness failure), not undefined behaviour.
/// The soundness invariant is on [`RawMutex`], not on `Acquirable`.
///
/// # Invariants
///
/// Custom implementations must uphold these invariants:
///
/// 1. **`collect_ids` and `lock_sorted` must agree on indexing.**
///    `collect_ids` pushes `N` ids; `lock_sorted` receives a
///    permutation of `0..N` and must lock the element at each
///    index in the given order.
/// 2. **`lock_sorted` must respect the permutation.** Locking in
///    an order other than the one given defeats the deadlock
///    prevention guarantee.
/// 3. **`MinLvl` and `MaxLvl` must be accurate.** `MinLvl` is
///    checked against the key's current level; `MaxLvl` becomes
///    the key's new level after acquisition. Incorrect bounds
///    can break cross-level ordering.
#[diagnostic::on_unimplemented(
    message = "`{Self}` cannot be used as a lock group",
    note = "use `&Mutex<T>` or tuples of mutex references to implement `Acquirable`"
)]
pub trait Acquirable<'a> {
    /// The minimum level in this collection.
    ///
    /// Used to check that all locks are above the key's current
    /// level (`MinLvl: LockAfter<CurrentKeyLevel>`).
    type MinLvl: IsLevel;

    /// The maximum level in this collection.
    ///
    /// The key advances to this level after acquisition.
    type MaxLvl: IsLevel;

    /// The guard type(s) returned when all locks are held.
    type Guard;

    /// Collect the [`LockId`] of each lock into `out`.
    fn collect_ids(&self, out: &mut Vec<LockId>);

    /// Acquire all locks in the order given by `sorted_indices`,
    /// returning the guards.
    ///
    /// `sorted_indices` must be a valid permutation of `0..N`
    /// (as produced by [`LockSet`](crate::set::LockSet)). Invalid
    /// indices will panic, not cause undefined behavior.
    fn lock_sorted(&'a self, sorted_indices: &[usize]) -> Self::Guard;
}

// Used by key.lock_with(&mutex, |guard| ...)
impl<'a, T: 'a, Lvl: IsLevel, R: RawMutex + 'a> Acquirable<'a> for Mutex<T, Lvl, R> {
    type MinLvl = Lvl;
    type MaxLvl = Lvl;
    type Guard = MutexGuard<'a, R, T>;

    fn collect_ids(&self, out: &mut Vec<LockId>) {
        out.push(self.id());
    }

    fn lock_sorted(&'a self, _sorted_indices: &[usize]) -> Self::Guard {
        let raw_guard = self.raw.lock();
        MutexGuard {
            data: &self.data,
            _raw_guard: raw_guard,
        }
    }
}

// Used by LockSet (which stores &Mutex)
impl<'a, T: 'a, Lvl: IsLevel, R: RawMutex + 'a> Acquirable<'a> for &'a Mutex<T, Lvl, R> {
    type MinLvl = Lvl;
    type MaxLvl = Lvl;
    type Guard = MutexGuard<'a, R, T>;

    fn collect_ids(&self, out: &mut Vec<LockId>) {
        out.push(self.id());
    }

    fn lock_sorted(&'a self, _sorted_indices: &[usize]) -> Self::Guard {
        let raw_guard = self.raw.lock();
        MutexGuard {
            data: &self.data,
            _raw_guard: raw_guard,
        }
    }
}

// -- Smart pointer delegation --
//
// These delegate to the inner type's Acquirable impl via Deref.
// Enables key.lock_with(&arc_mutex, |g| ...) without &*.

#[cfg(target_has_atomic = "ptr")]
impl<'a, T: Acquirable<'a>> Acquirable<'a> for alloc::sync::Arc<T> {
    type MinLvl = T::MinLvl;
    type MaxLvl = T::MaxLvl;
    type Guard = T::Guard;

    fn collect_ids(&self, out: &mut Vec<LockId>) {
        T::collect_ids(self, out);
    }

    fn lock_sorted(&'a self, sorted_indices: &[usize]) -> Self::Guard {
        T::lock_sorted(self, sorted_indices)
    }
}

impl<'a, T: Acquirable<'a>> Acquirable<'a> for alloc::rc::Rc<T> {
    type MinLvl = T::MinLvl;
    type MaxLvl = T::MaxLvl;
    type Guard = T::Guard;

    fn collect_ids(&self, out: &mut Vec<LockId>) {
        T::collect_ids(self, out);
    }

    fn lock_sorted(&'a self, sorted_indices: &[usize]) -> Self::Guard {
        T::lock_sorted(self, sorted_indices)
    }
}

impl<'a, T: Acquirable<'a>> Acquirable<'a> for alloc::boxed::Box<T> {
    type MinLvl = T::MinLvl;
    type MaxLvl = T::MaxLvl;
    type Guard = T::Guard;

    fn collect_ids(&self, out: &mut Vec<LockId>) {
        T::collect_ids(self, out);
    }

    fn lock_sorted(&'a self, sorted_indices: &[usize]) -> Self::Guard {
        T::lock_sorted(self, sorted_indices)
    }
}

// Slice of mutexes -- same type, same level, dynamic length.
// Guards are returned as a Vec.
#[allow(clippy::indexing_slicing)] // sorted_indices are valid (produced by LockSet)
impl<'a, T: 'a, Lvl: IsLevel, R: RawMutex + 'a> Acquirable<'a> for &'a [Mutex<T, Lvl, R>] {
    type MinLvl = Lvl;
    type MaxLvl = Lvl;
    type Guard = Vec<MutexGuard<'a, R, T>>;

    fn collect_ids(&self, out: &mut Vec<LockId>) {
        for m in *self {
            out.push(m.id());
        }
    }

    fn lock_sorted(&'a self, sorted_indices: &[usize]) -> Self::Guard {
        sorted_indices
            .iter()
            .map(|&i| {
                let raw_guard = self[i].raw.lock();
                MutexGuard {
                    data: &self[i].data,
                    _raw_guard: raw_guard,
                }
            })
            .collect()
    }
}