use std::{
error::Error,
fmt,
marker::PhantomData,
mem::{self, ManuallyDrop},
rc::Rc,
thread::{self, ThreadId},
thread_local,
};
pub struct ThreadSafe<T: ?Sized> {
origin_thread: ThreadId,
handle_drop: bool,
inner: ManuallyDrop<T>,
}
impl<T: Default> Default for ThreadSafe<T> {
#[inline]
fn default() -> Self {
Self {
inner: ManuallyDrop::new(T::default()),
handle_drop: mem::needs_drop::<T>(),
origin_thread: thread::current().id(),
}
}
}
impl<T: fmt::Debug + ?Sized> fmt::Debug for ThreadSafe<T> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.origin_thread == thread::current().id() {
fmt::Debug::fmt(&self.inner, f)
} else {
f.write_str("<not in origin thread>")
}
}
}
unsafe impl<T> Send for ThreadSafe<T> {}
unsafe impl<T> Sync for ThreadSafe<T> {}
impl<T> ThreadSafe<T> {
#[inline]
pub fn new(inner: T) -> ThreadSafe<T> {
ThreadSafe {
origin_thread: thread::current().id(),
handle_drop: mem::needs_drop::<T>(),
inner: ManuallyDrop::new(inner),
}
}
#[inline]
pub fn try_into_inner(self) -> Result<T, ThreadSafe<T>> {
self.try_into_inner_with_key(ThreadKey::get())
}
#[inline]
pub fn try_into_inner_with_key(mut self, key: ThreadKey) -> Result<T, ThreadSafe<T>> {
if self.origin_thread == key.id() {
let inner = unsafe { ManuallyDrop::take(&mut self.inner) };
mem::forget(self);
Ok(inner)
} else {
Err(self)
}
}
#[inline]
pub fn into_inner(self) -> T {
match self.try_into_inner() {
Ok(i) => i,
Err(_) => panic!("Attempted to use a ThreadSafe outside of its origin thread"),
}
}
#[inline]
pub fn into_inner_with_key(self, key: ThreadKey) -> T {
match self.try_into_inner_with_key(key) {
Ok(i) => i,
Err(_) => panic!("Attempted to use a ThreadSafe outside of its origin thread"),
}
}
#[inline]
pub unsafe fn into_inner_unchecked(mut self) -> T {
let inner = ManuallyDrop::take(&mut self.inner);
mem::forget(self);
inner
}
}
impl<T: ?Sized> ThreadSafe<T> {
#[inline]
pub fn try_get_ref(&self) -> Result<&T, NotInOriginThread> {
self.try_get_ref_with_key(ThreadKey::get())
}
#[inline]
pub fn try_get_ref_with_key(&self, key: ThreadKey) -> Result<&T, NotInOriginThread> {
if self.origin_thread == key.id() {
Ok(&self.inner)
} else {
Err(NotInOriginThread)
}
}
#[inline]
pub fn get_ref(&self) -> &T {
match self.try_get_ref() {
Ok(i) => i,
Err(NotInOriginThread) => {
panic!("Attempted to use a ThreadSafe outside of its origin thread")
}
}
}
#[inline]
pub fn get_ref_with_key(&self, key: ThreadKey) -> &T {
match self.try_get_ref_with_key(key) {
Ok(i) => i,
Err(NotInOriginThread) => {
panic!("Attempted to use a ThreadSafe outside of its origin thread")
}
}
}
#[inline]
pub unsafe fn get_ref_unchecked(&self) -> &T {
&self.inner
}
#[inline]
pub fn try_get_mut(&mut self) -> Result<&mut T, NotInOriginThread> {
self.try_get_mut_with_key(ThreadKey::get())
}
#[inline]
pub fn try_get_mut_with_key(&mut self, key: ThreadKey) -> Result<&mut T, NotInOriginThread> {
if self.origin_thread == key.id() {
Ok(&mut self.inner)
} else {
Err(NotInOriginThread)
}
}
#[inline]
pub fn get_mut(&mut self) -> &mut T {
match self.try_get_mut() {
Ok(i) => i,
Err(NotInOriginThread) => {
panic!("Attempted to use a ThreadSafe outside of its origin thread")
}
}
}
#[inline]
pub fn get_mut_with_key(&mut self, key: ThreadKey) -> &mut T {
match self.try_get_mut_with_key(key) {
Ok(i) => i,
Err(NotInOriginThread) => {
panic!("Attempted to use a ThreadSafe outside of its origin thread")
}
}
}
#[inline]
pub unsafe fn get_mut_unchecked(&mut self) -> &mut T {
&mut self.inner
}
}
impl<T: Clone> ThreadSafe<T> {
#[inline]
pub fn try_clone(&self) -> Result<ThreadSafe<T>, NotInOriginThread> {
self.try_clone_with_key(ThreadKey::get())
}
#[inline]
pub fn try_clone_with_key(&self, key: ThreadKey) -> Result<ThreadSafe<T>, NotInOriginThread> {
match self.try_get_ref_with_key(key) {
Ok(r) => Ok(ThreadSafe {
inner: ManuallyDrop::new(r.clone()),
handle_drop: self.handle_drop,
origin_thread: self.origin_thread,
}),
Err(NotInOriginThread) => Err(NotInOriginThread),
}
}
#[inline]
pub fn clone_with_key(&self, key: ThreadKey) -> ThreadSafe<T> {
ThreadSafe {
inner: ManuallyDrop::new(self.get_ref_with_key(key).clone()),
handle_drop: self.handle_drop,
origin_thread: self.origin_thread,
}
}
}
impl<T: Clone> Clone for ThreadSafe<T> {
#[inline]
fn clone(&self) -> ThreadSafe<T> {
self.clone_with_key(ThreadKey::get())
}
}
impl<T: ?Sized> Drop for ThreadSafe<T> {
#[inline]
fn drop(&mut self) {
if self.handle_drop && self.origin_thread != thread::current().id() {
panic!("Attempted to drop ThreadSafe<_> outside of its origin thread");
} else {
unsafe { ManuallyDrop::drop(&mut self.inner) };
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub struct ThreadKey {
id: ThreadId,
_phantom: PhantomData<Rc<ThreadId>>,
}
impl Default for ThreadKey {
#[inline]
fn default() -> Self {
Self::get()
}
}
impl ThreadKey {
#[inline]
pub fn get() -> Self {
thread_local! {
static ID: ThreadId = thread::current().id();
}
Self {
id: ID
.try_with(|&id| id)
.unwrap_or_else(|_| thread::current().id()),
_phantom: PhantomData,
}
}
#[inline]
pub unsafe fn new(id: ThreadId) -> Self {
Self {
id,
_phantom: PhantomData,
}
}
#[inline]
pub fn id(self) -> ThreadId {
self.id
}
}
impl From<ThreadKey> for ThreadId {
#[inline]
fn from(k: ThreadKey) -> ThreadId {
k.id
}
}
#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct NotInOriginThread;
impl fmt::Display for NotInOriginThread {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Attempted to use ThreadSafe<_> outside of its origin thread")
}
}
impl Error for NotInOriginThread {}