thread-owned-lock 0.1.0

Mutex which can only be unlocked by the owning thread.
Documentation
//! This crates provides a concurrency primitive similar to [`Mutex`] but only allows the currently
//! bound thread to access its contents. Unlike [`Mutex`] it does not cause a thread to block if
//! another thread has acquired the lock, but the operation will fail immediately.
//!
//! The primitive also ensures that the owning thread can only acquire the lock once in order to
//! not break Rust's aliasing rules.
//!
//! # Use Case
//!
//! This concurrency primitive is useful to enforce that only one specific thread can access the
//! data within. Depending on your OS, it may also be faster that a regular [`Mutex`]. You can run
//! this crate's benchmark to check how it fairs on your machine.
//!
//! # Example
//! ```
//! use std::sync::RwLock;
//! use thread_owned_lock::StdThreadOwnedLock;
//!
//! struct SharedData {
//!     main_thread_data: StdThreadOwnedLock<i32>,
//!     shared_data: RwLock<i32>,
//! }
//!
//! let shared_data = std::sync::Arc::new(SharedData {
//!     main_thread_data: StdThreadOwnedLock::new(20),
//!     shared_data:RwLock::new(30)
//! });
//! {
//!     let guard = shared_data.main_thread_data.lock();
//!     // Main thread can now access the contents;
//! }
//! let data_cloned = shared_data.clone();
//! std::thread::spawn(move|| {
//!     if let Err(e) = data_cloned.main_thread_data.try_lock() {
//!         // On other threads, accessing the main thread data will fail.
//!     }
//! });
//! ```
//!
//! # no-std
//!
//! This crate is compatible with no-std. You just need to provide an implementation of
//! [`ThreadIdProvider`] trait for your environment and enable the feature `no-std`.
//!
//! ```
//!use thread_owned_lock::{ThreadIdProvider, ThreadOwnedLock};
//! struct MYThreadIdProvider{}
//!
//! impl ThreadIdProvider for MYThreadIdProvider {
//!     type Id = u32;
//!     fn current_thread_id() -> Self::Id {
//!         todo!()
//!     }
//! }
//!
//! type MyThreadOwnedLock<T> = ThreadOwnedLock<T, MYThreadIdProvider>;
//! ```
//!
//! [`Mutex`]:std::sync::Mutex

#[cfg(feature = "no-std")]
#[no_std]
#[cfg(not(feature = "no-std"))]
use std::thread::ThreadId;

/// A mutual exclusion primitive similar to [`Mutex`] but it only allows the owning
/// thread to access the data.
///
/// Lock ownership can be transferred to another thread if the data implements [`Send`] with the
/// [`rebind`] function.
///
/// Attempting to [`lock`] more than one time on the same thread will result in an error.
///
/// [`Mutex`]:std::sync::Mutex
/// [`rebind`]:Self::rebind
/// [`lock`]:Self::lock
pub struct ThreadOwnedLock<T: ?Sized, P: ThreadIdProvider> {
    thread_id: P::Id,
    guard: DoubleLockGuard,
    data: core::cell::UnsafeCell<T>,
}

unsafe impl<T: ?Sized + Send, P: ThreadIdProvider> Send for ThreadOwnedLock<T, P> {}
unsafe impl<T: ?Sized + Send, P: ThreadIdProvider> Sync for ThreadOwnedLock<T, P> {}

/// An RAII implementation of a "scoped lock" of a mutex. When this structure is
/// dropped (falls out of scope), the lock will be unlocked.
///
/// The data protected by the mutex can be accessed through this guard via its
/// [`Deref`] and [`DerefMut`] implementations.
///
/// This structure is created by the [`lock`] and [`try_lock`] methods on
/// [`ThreadOwnedLock`].
///
/// [`lock`]: ThreadOwnedLock::lock
/// [`try_lock`]: ThreadOwnedLock::try_lock
#[must_use = "if unused the ThreadOwnedLock will immediately unlock"]
pub struct ThreadOwnedLockGuard<'l, T: ?Sized + 'l, P: ThreadIdProvider> {
    lock: &'l ThreadOwnedLock<T, P>,
    p: core::marker::PhantomData<*mut ()>, // Prevent the guard from becoming send.
}

#[derive(Debug)]
pub enum ThreadOwnedMutexError {
    /// The thread attempting accessing this lock does not match the bound thread.
    InvalidThread,
    /// There is already an active [`ThreadOwnedLockGuard`] for this lock.
    AlreadyLocked,
}

impl<T, P: ThreadIdProvider> ThreadOwnedLock<T, P> {
    /// Create a new instance of [`ThreadOwnedLock`] and bind it to the current thread.
    #[inline]
    pub fn new(value: T) -> Self {
        Self {
            data: core::cell::UnsafeCell::new(value),
            thread_id: P::current_thread_id(),
            guard: DoubleLockGuard::new(),
        }
    }

    /// Transfer ownership of the lock to another thread.
    ///
    /// # Example
    /// ```
    /// use thread_owned_lock::StdThreadOwnedLock;
    /// let lock = StdThreadOwnedLock::new(10);
    /// std::thread::spawn(move|| {
    ///     let lock = lock.rebind();
    ///     // lock can now be accessed on this thread.
    ///     let guard = lock.lock();
    /// });
    /// ```
    pub fn rebind(mut self) -> Self {
        self.thread_id = P::current_thread_id();
        self
    }
}

impl<T: ?Sized, P: ThreadIdProvider> ThreadOwnedLock<T, P> {
    /// Acquires the mutex, returning an RAII style guard which allows access to the data.
    ///
    /// # Panics
    /// This call will panic if this method is called from a thread other than the owning thread
    /// or if the lock has already been acquired.
    #[inline]
    pub fn lock(&self) -> ThreadOwnedLockGuard<'_, T, P> {
        match self.try_lock() {
            Ok(v) => v,
            Err(e) => panic!("{}", e),
        }
    }

    /// Try to acquire the mutex. If one of the following conditions fails, an error will be
    /// returned:
    ///  * The thread accessing the lock must be the bound thread.
    ///  * The lock can only be acquired on time.
    ///
    ///
    #[inline]
    pub fn try_lock(&self) -> Result<ThreadOwnedLockGuard<'_, T, P>, ThreadOwnedMutexError> {
        let current_thread_id = P::current_thread_id();
        if current_thread_id != self.thread_id {
            return Err(ThreadOwnedMutexError::InvalidThread);
        }
        if self.guard.try_enter() {
            return Err(ThreadOwnedMutexError::AlreadyLocked);
        }
        Ok(ThreadOwnedLockGuard {
            lock: self,
            p: core::marker::PhantomData,
        })
    }
}

/// Trait which abstract what the thread ID is and how it can be obtained.
pub trait ThreadIdProvider {
    type Id: PartialEq + Eq + Copy;

    /// Get the thread id of the current running thread.
    fn current_thread_id() -> Self::Id;
}

/// ThreadIdProvider implementation based on std::thread::ThreadId
#[cfg(not(feature = "no-std"))]
pub struct StdThreadIdProvider {}

#[cfg(not(feature = "no-std"))]
impl ThreadIdProvider for StdThreadIdProvider {
    type Id = std::thread::ThreadId;

    fn current_thread_id() -> Self::Id {
        std::thread::current().id()
    }
}
#[cfg(not(feature = "no-std"))]
pub type StdThreadOwnedLock<T> = ThreadOwnedLock<T, StdThreadIdProvider>;

impl<T, P: ThreadIdProvider> From<T> for ThreadOwnedLock<T, P> {
    fn from(value: T) -> Self {
        Self::new(value)
    }
}
impl<T: Default, P: ThreadIdProvider> Default for ThreadOwnedLock<T, P> {
    fn default() -> Self {
        Self::new(T::default())
    }
}

impl<T: ?Sized, P: ThreadIdProvider> Drop for ThreadOwnedLockGuard<'_, T, P> {
    fn drop(&mut self) {
        self.lock.guard.exit();
    }
}

impl<T: ?Sized, P: ThreadIdProvider> core::ops::Deref for ThreadOwnedLockGuard<'_, T, P> {
    type Target = T;
    fn deref(&self) -> &Self::Target {
        unsafe { &*self.lock.data.get() }
    }
}

impl<T: ?Sized, P: ThreadIdProvider> core::ops::DerefMut for ThreadOwnedLockGuard<'_, T, P> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        unsafe { &mut *self.lock.data.get() }
    }
}

impl<T: ?Sized + core::fmt::Debug, P: ThreadIdProvider> core::fmt::Debug
    for ThreadOwnedLockGuard<'_, T, P>
{
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        core::fmt::Debug::fmt(&**self, f)
    }
}

impl<T: ?Sized + core::fmt::Display, P: ThreadIdProvider> core::fmt::Display
    for ThreadOwnedLockGuard<'_, T, P>
{
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        (**self).fmt(f)
    }
}

impl core::fmt::Display for ThreadOwnedMutexError {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        match self {
            ThreadOwnedMutexError::InvalidThread => {
                f.write_str("Current thread does not own this lock")
            }
            ThreadOwnedMutexError::AlreadyLocked => f.write_str("Already locked"),
        }
    }
}

impl std::error::Error for ThreadOwnedMutexError {}

#[doc(hidden)]
struct DoubleLockGuard(core::cell::UnsafeCell<bool>);

impl DoubleLockGuard {
    fn new() -> Self {
        Self(core::cell::UnsafeCell::new(false))
    }

    #[inline]
    fn try_enter(&self) -> bool {
        // SAFETY: This is only accessed from within the ThreadOwnedLock, which is already
        // guarded by a thread id check.
        unsafe {
            let old = *self.0.get();
            *self.0.get() = true;
            old
        }
    }

    #[inline]
    fn exit(&self) {
        // SAFETY: This is only accessed from within the ThreadOwnedLock, which is already
        // guarded by a thread id check.
        unsafe {
            *self.0.get() = false;
        }
    }
}

#[cfg(all(test, not(feature = "no-std")))]
mod test {
    use super::*;

    #[test]
    fn test_lock() {
        let lock = StdThreadOwnedLock::new(20);
        {
            let guard = lock.try_lock().expect("failed to acquire lock");
            assert_eq!(*guard, 20);
        }
        let h = std::thread::spawn(move || {
            let err = lock.try_lock().expect_err("Should fail");
            assert!(matches!(err, ThreadOwnedMutexError::InvalidThread));
        });
        h.join().unwrap();
    }

    #[test]
    fn test_double_lock_fails() {
        let lock = StdThreadOwnedLock::new(20);
        {
            let _guard = lock.try_lock().expect("failed to acquire lock");
            let err = lock.try_lock().expect_err("Should fail");
            assert!(matches!(err, ThreadOwnedMutexError::AlreadyLocked));
        }
        let _guard = lock.try_lock().expect("failed to acquire lock");
    }

    #[test]
    fn test_lock_rebind() {
        let lock = StdThreadOwnedLock::new(20);
        assert_eq!(lock.thread_id, std::thread::current().id());
        let h = std::thread::spawn(move || {
            let lock = lock.rebind();
            assert_eq!(lock.thread_id, std::thread::current().id());
        });
        h.join().unwrap();
    }
}