thread_owned_lock/
lib.rs

1//! This crates provides a concurrency primitive similar to [`Mutex`] but only allows the currently
2//! bound thread to access its contents. Unlike [`Mutex`] it does not cause a thread to block if
3//! another thread has acquired the lock, but the operation will fail immediately.
4//!
5//! The primitive also ensures that the owning thread can only acquire the lock once in order to
6//! not break Rust's aliasing rules.
7//!
8//! # Use Case
9//!
10//! This concurrency primitive is useful to enforce that only one specific thread can access the
11//! data within. Depending on your OS, it may also be faster that a regular [`Mutex`]. You can run
12//! this crate's benchmark to check how it fairs on your machine.
13//!
14//! # Example
15//! ```
16//! use std::sync::RwLock;
17//! use thread_owned_lock::StdThreadOwnedLock;
18//!
19//! struct SharedData {
20//!     main_thread_data: StdThreadOwnedLock<i32>,
21//!     shared_data: RwLock<i32>,
22//! }
23//!
24//! let shared_data = std::sync::Arc::new(SharedData {
25//!     main_thread_data: StdThreadOwnedLock::new(20),
26//!     shared_data:RwLock::new(30)
27//! });
28//! {
29//!     let guard = shared_data.main_thread_data.lock();
30//!     // Main thread can now access the contents;
31//! }
32//! let data_cloned = shared_data.clone();
33//! std::thread::spawn(move|| {
34//!     if let Err(e) = data_cloned.main_thread_data.try_lock() {
35//!         // On other threads, accessing the main thread data will fail.
36//!     }
37//! });
38//! ```
39//!
40//! # no-std
41//!
42//! This crate is compatible with no-std. You just need to provide an implementation of
43//! [`ThreadIdProvider`] trait for your environment and enable the feature `no-std`.
44//!
45//! ```
46//!use thread_owned_lock::{ThreadIdProvider, ThreadOwnedLock};
47//! struct MYThreadIdProvider{}
48//!
49//! impl ThreadIdProvider for MYThreadIdProvider {
50//!     type Id = u32;
51//!     fn current_thread_id() -> Self::Id {
52//!         todo!()
53//!     }
54//! }
55//!
56//! type MyThreadOwnedLock<T> = ThreadOwnedLock<T, MYThreadIdProvider>;
57//! ```
58//!
59//! [`Mutex`]:std::sync::Mutex
60
61#[cfg(feature = "no-std")]
62#[no_std]
63#[cfg(not(feature = "no-std"))]
64use std::thread::ThreadId;
65
66/// A mutual exclusion primitive similar to [`Mutex`] but it only allows the owning
67/// thread to access the data.
68///
69/// Lock ownership can be transferred to another thread if the data implements [`Send`] with the
70/// [`rebind`] function.
71///
72/// Attempting to [`lock`] more than one time on the same thread will result in an error.
73///
74/// [`Mutex`]:std::sync::Mutex
75/// [`rebind`]:Self::rebind
76/// [`lock`]:Self::lock
77pub struct ThreadOwnedLock<T: ?Sized, P: ThreadIdProvider> {
78    thread_id: P::Id,
79    guard: DoubleLockGuard,
80    data: core::cell::UnsafeCell<T>,
81}
82
83unsafe impl<T: ?Sized + Send, P: ThreadIdProvider> Send for ThreadOwnedLock<T, P> {}
84unsafe impl<T: ?Sized + Send, P: ThreadIdProvider> Sync for ThreadOwnedLock<T, P> {}
85
86/// An RAII implementation of a "scoped lock" of a mutex. When this structure is
87/// dropped (falls out of scope), the lock will be unlocked.
88///
89/// The data protected by the mutex can be accessed through this guard via its
90/// [`Deref`] and [`DerefMut`] implementations.
91///
92/// This structure is created by the [`lock`] and [`try_lock`] methods on
93/// [`ThreadOwnedLock`].
94///
95/// [`lock`]: ThreadOwnedLock::lock
96/// [`try_lock`]: ThreadOwnedLock::try_lock
97#[must_use = "if unused the ThreadOwnedLock will immediately unlock"]
98pub struct ThreadOwnedLockGuard<'l, T: ?Sized + 'l, P: ThreadIdProvider> {
99    lock: &'l ThreadOwnedLock<T, P>,
100    p: core::marker::PhantomData<*mut ()>, // Prevent the guard from becoming send.
101}
102
103#[derive(Debug)]
104pub enum ThreadOwnedMutexError {
105    /// The thread attempting accessing this lock does not match the bound thread.
106    InvalidThread,
107    /// There is already an active [`ThreadOwnedLockGuard`] for this lock.
108    AlreadyLocked,
109}
110
111impl<T, P: ThreadIdProvider> ThreadOwnedLock<T, P> {
112    /// Create a new instance of [`ThreadOwnedLock`] and bind it to the current thread.
113    #[inline]
114    pub fn new(value: T) -> Self {
115        Self {
116            data: core::cell::UnsafeCell::new(value),
117            thread_id: P::current_thread_id(),
118            guard: DoubleLockGuard::new(),
119        }
120    }
121
122    /// Transfer ownership of the lock to another thread.
123    ///
124    /// # Example
125    /// ```
126    /// use thread_owned_lock::StdThreadOwnedLock;
127    /// let lock = StdThreadOwnedLock::new(10);
128    /// std::thread::spawn(move|| {
129    ///     let lock = lock.rebind();
130    ///     // lock can now be accessed on this thread.
131    ///     let guard = lock.lock();
132    /// });
133    /// ```
134    pub fn rebind(mut self) -> Self {
135        self.thread_id = P::current_thread_id();
136        self
137    }
138}
139
140impl<T: ?Sized, P: ThreadIdProvider> ThreadOwnedLock<T, P> {
141    /// Acquires the mutex, returning an RAII style guard which allows access to the data.
142    ///
143    /// # Panics
144    /// This call will panic if this method is called from a thread other than the owning thread
145    /// or if the lock has already been acquired.
146    #[inline]
147    pub fn lock(&self) -> ThreadOwnedLockGuard<'_, T, P> {
148        match self.try_lock() {
149            Ok(v) => v,
150            Err(e) => panic!("{}", e),
151        }
152    }
153
154    /// Try to acquire the mutex. If one of the following conditions fails, an error will be
155    /// returned:
156    ///  * The thread accessing the lock must be the bound thread.
157    ///  * The lock can only be acquired on time.
158    ///
159    ///
160    #[inline]
161    pub fn try_lock(&self) -> Result<ThreadOwnedLockGuard<'_, T, P>, ThreadOwnedMutexError> {
162        let current_thread_id = P::current_thread_id();
163        if current_thread_id != self.thread_id {
164            return Err(ThreadOwnedMutexError::InvalidThread);
165        }
166        if self.guard.try_enter() {
167            return Err(ThreadOwnedMutexError::AlreadyLocked);
168        }
169        Ok(ThreadOwnedLockGuard {
170            lock: self,
171            p: core::marker::PhantomData,
172        })
173    }
174}
175
176/// Trait which abstract what the thread ID is and how it can be obtained.
177pub trait ThreadIdProvider {
178    type Id: PartialEq + Eq + Copy;
179
180    /// Get the thread id of the current running thread.
181    fn current_thread_id() -> Self::Id;
182}
183
184/// ThreadIdProvider implementation based on std::thread::ThreadId
185#[cfg(not(feature = "no-std"))]
186pub struct StdThreadIdProvider {}
187
188#[cfg(not(feature = "no-std"))]
189impl ThreadIdProvider for StdThreadIdProvider {
190    type Id = std::thread::ThreadId;
191
192    fn current_thread_id() -> Self::Id {
193        std::thread::current().id()
194    }
195}
196#[cfg(not(feature = "no-std"))]
197pub type StdThreadOwnedLock<T> = ThreadOwnedLock<T, StdThreadIdProvider>;
198
199impl<T, P: ThreadIdProvider> From<T> for ThreadOwnedLock<T, P> {
200    fn from(value: T) -> Self {
201        Self::new(value)
202    }
203}
204impl<T: Default, P: ThreadIdProvider> Default for ThreadOwnedLock<T, P> {
205    fn default() -> Self {
206        Self::new(T::default())
207    }
208}
209
210impl<T: ?Sized, P: ThreadIdProvider> Drop for ThreadOwnedLockGuard<'_, T, P> {
211    fn drop(&mut self) {
212        self.lock.guard.exit();
213    }
214}
215
216impl<T: ?Sized, P: ThreadIdProvider> core::ops::Deref for ThreadOwnedLockGuard<'_, T, P> {
217    type Target = T;
218    fn deref(&self) -> &Self::Target {
219        unsafe { &*self.lock.data.get() }
220    }
221}
222
223impl<T: ?Sized, P: ThreadIdProvider> core::ops::DerefMut for ThreadOwnedLockGuard<'_, T, P> {
224    fn deref_mut(&mut self) -> &mut Self::Target {
225        unsafe { &mut *self.lock.data.get() }
226    }
227}
228
229impl<T: ?Sized + core::fmt::Debug, P: ThreadIdProvider> core::fmt::Debug
230    for ThreadOwnedLockGuard<'_, T, P>
231{
232    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
233        core::fmt::Debug::fmt(&**self, f)
234    }
235}
236
237impl<T: ?Sized + core::fmt::Display, P: ThreadIdProvider> core::fmt::Display
238    for ThreadOwnedLockGuard<'_, T, P>
239{
240    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
241        (**self).fmt(f)
242    }
243}
244
245impl core::fmt::Display for ThreadOwnedMutexError {
246    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
247        match self {
248            ThreadOwnedMutexError::InvalidThread => {
249                f.write_str("Current thread does not own this lock")
250            }
251            ThreadOwnedMutexError::AlreadyLocked => f.write_str("Already locked"),
252        }
253    }
254}
255
256impl std::error::Error for ThreadOwnedMutexError {}
257
258#[doc(hidden)]
259struct DoubleLockGuard(core::cell::UnsafeCell<bool>);
260
261impl DoubleLockGuard {
262    fn new() -> Self {
263        Self(core::cell::UnsafeCell::new(false))
264    }
265
266    #[inline]
267    fn try_enter(&self) -> bool {
268        // SAFETY: This is only accessed from within the ThreadOwnedLock, which is already
269        // guarded by a thread id check.
270        unsafe {
271            let old = *self.0.get();
272            *self.0.get() = true;
273            old
274        }
275    }
276
277    #[inline]
278    fn exit(&self) {
279        // SAFETY: This is only accessed from within the ThreadOwnedLock, which is already
280        // guarded by a thread id check.
281        unsafe {
282            *self.0.get() = false;
283        }
284    }
285}
286
287#[cfg(all(test, not(feature = "no-std")))]
288mod test {
289    use super::*;
290
291    #[test]
292    fn test_lock() {
293        let lock = StdThreadOwnedLock::new(20);
294        {
295            let guard = lock.try_lock().expect("failed to acquire lock");
296            assert_eq!(*guard, 20);
297        }
298        let h = std::thread::spawn(move || {
299            let err = lock.try_lock().expect_err("Should fail");
300            assert!(matches!(err, ThreadOwnedMutexError::InvalidThread));
301        });
302        h.join().unwrap();
303    }
304
305    #[test]
306    fn test_double_lock_fails() {
307        let lock = StdThreadOwnedLock::new(20);
308        {
309            let _guard = lock.try_lock().expect("failed to acquire lock");
310            let err = lock.try_lock().expect_err("Should fail");
311            assert!(matches!(err, ThreadOwnedMutexError::AlreadyLocked));
312        }
313        let _guard = lock.try_lock().expect("failed to acquire lock");
314    }
315
316    #[test]
317    fn test_lock_rebind() {
318        let lock = StdThreadOwnedLock::new(20);
319        assert_eq!(lock.thread_id, std::thread::current().id());
320        let h = std::thread::spawn(move || {
321            let lock = lock.rebind();
322            assert_eq!(lock.thread_id, std::thread::current().id());
323        });
324        h.join().unwrap();
325    }
326}