#[cfg(feature = "no-std")]
#[no_std]
#[cfg(not(feature = "no-std"))]
use std::thread::ThreadId;
pub struct ThreadOwnedLock<T: ?Sized, P: ThreadIdProvider> {
thread_id: P::Id,
guard: DoubleLockGuard,
data: core::cell::UnsafeCell<T>,
}
unsafe impl<T: ?Sized + Send, P: ThreadIdProvider> Send for ThreadOwnedLock<T, P> {}
unsafe impl<T: ?Sized + Send, P: ThreadIdProvider> Sync for ThreadOwnedLock<T, P> {}
#[must_use = "if unused the ThreadOwnedLock will immediately unlock"]
pub struct ThreadOwnedLockGuard<'l, T: ?Sized + 'l, P: ThreadIdProvider> {
lock: &'l ThreadOwnedLock<T, P>,
p: core::marker::PhantomData<*mut ()>, }
#[derive(Debug)]
pub enum ThreadOwnedMutexError {
InvalidThread,
AlreadyLocked,
}
impl<T, P: ThreadIdProvider> ThreadOwnedLock<T, P> {
#[inline]
pub fn new(value: T) -> Self {
Self {
data: core::cell::UnsafeCell::new(value),
thread_id: P::current_thread_id(),
guard: DoubleLockGuard::new(),
}
}
pub fn rebind(mut self) -> Self {
self.thread_id = P::current_thread_id();
self
}
}
impl<T: ?Sized, P: ThreadIdProvider> ThreadOwnedLock<T, P> {
#[inline]
pub fn lock(&self) -> ThreadOwnedLockGuard<'_, T, P> {
match self.try_lock() {
Ok(v) => v,
Err(e) => panic!("{}", e),
}
}
#[inline]
pub fn try_lock(&self) -> Result<ThreadOwnedLockGuard<'_, T, P>, ThreadOwnedMutexError> {
let current_thread_id = P::current_thread_id();
if current_thread_id != self.thread_id {
return Err(ThreadOwnedMutexError::InvalidThread);
}
if self.guard.try_enter() {
return Err(ThreadOwnedMutexError::AlreadyLocked);
}
Ok(ThreadOwnedLockGuard {
lock: self,
p: core::marker::PhantomData,
})
}
}
pub trait ThreadIdProvider {
type Id: PartialEq + Eq + Copy;
fn current_thread_id() -> Self::Id;
}
#[cfg(not(feature = "no-std"))]
pub struct StdThreadIdProvider {}
#[cfg(not(feature = "no-std"))]
impl ThreadIdProvider for StdThreadIdProvider {
type Id = std::thread::ThreadId;
fn current_thread_id() -> Self::Id {
std::thread::current().id()
}
}
#[cfg(not(feature = "no-std"))]
pub type StdThreadOwnedLock<T> = ThreadOwnedLock<T, StdThreadIdProvider>;
impl<T, P: ThreadIdProvider> From<T> for ThreadOwnedLock<T, P> {
fn from(value: T) -> Self {
Self::new(value)
}
}
impl<T: Default, P: ThreadIdProvider> Default for ThreadOwnedLock<T, P> {
fn default() -> Self {
Self::new(T::default())
}
}
impl<T: ?Sized, P: ThreadIdProvider> Drop for ThreadOwnedLockGuard<'_, T, P> {
fn drop(&mut self) {
self.lock.guard.exit();
}
}
impl<T: ?Sized, P: ThreadIdProvider> core::ops::Deref for ThreadOwnedLockGuard<'_, T, P> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*self.lock.data.get() }
}
}
impl<T: ?Sized, P: ThreadIdProvider> core::ops::DerefMut for ThreadOwnedLockGuard<'_, T, P> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.lock.data.get() }
}
}
impl<T: ?Sized + core::fmt::Debug, P: ThreadIdProvider> core::fmt::Debug
for ThreadOwnedLockGuard<'_, T, P>
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Debug::fmt(&**self, f)
}
}
impl<T: ?Sized + core::fmt::Display, P: ThreadIdProvider> core::fmt::Display
for ThreadOwnedLockGuard<'_, T, P>
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
(**self).fmt(f)
}
}
impl core::fmt::Display for ThreadOwnedMutexError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
ThreadOwnedMutexError::InvalidThread => {
f.write_str("Current thread does not own this lock")
}
ThreadOwnedMutexError::AlreadyLocked => f.write_str("Already locked"),
}
}
}
impl std::error::Error for ThreadOwnedMutexError {}
#[doc(hidden)]
struct DoubleLockGuard(core::cell::UnsafeCell<bool>);
impl DoubleLockGuard {
fn new() -> Self {
Self(core::cell::UnsafeCell::new(false))
}
#[inline]
fn try_enter(&self) -> bool {
unsafe {
let old = *self.0.get();
*self.0.get() = true;
old
}
}
#[inline]
fn exit(&self) {
unsafe {
*self.0.get() = false;
}
}
}
#[cfg(all(test, not(feature = "no-std")))]
mod test {
use super::*;
#[test]
fn test_lock() {
let lock = StdThreadOwnedLock::new(20);
{
let guard = lock.try_lock().expect("failed to acquire lock");
assert_eq!(*guard, 20);
}
let h = std::thread::spawn(move || {
let err = lock.try_lock().expect_err("Should fail");
assert!(matches!(err, ThreadOwnedMutexError::InvalidThread));
});
h.join().unwrap();
}
#[test]
fn test_double_lock_fails() {
let lock = StdThreadOwnedLock::new(20);
{
let _guard = lock.try_lock().expect("failed to acquire lock");
let err = lock.try_lock().expect_err("Should fail");
assert!(matches!(err, ThreadOwnedMutexError::AlreadyLocked));
}
let _guard = lock.try_lock().expect("failed to acquire lock");
}
#[test]
fn test_lock_rebind() {
let lock = StdThreadOwnedLock::new(20);
assert_eq!(lock.thread_id, std::thread::current().id());
let h = std::thread::spawn(move || {
let lock = lock.rebind();
assert_eq!(lock.thread_id, std::thread::current().id());
});
h.join().unwrap();
}
}