#![cfg_attr(docsrs, feature(doc_cfg))]
#![cfg_attr(
all(not(loom), feature = "current_thread_id"),
feature(current_thread_id)
)]
#![warn(missing_docs)]
#[cfg(feature = "futures")]
#[cfg_attr(docsrs, doc(cfg(feature = "futures")))]
mod futures;
use std::{
fmt,
mem::{self, ManuallyDrop},
pin::Pin,
};
cfg_if::cfg_if! {
if #[cfg(any(loom, not(feature = "current_thread_id")))] {
#[cfg(loom)]
use loom::{thread_local, cell::Cell, thread::{self, ThreadId}};
#[cfg(not(loom))]
use std::{thread_local, cell::Cell, thread::{self, ThreadId}};
thread_local! {
static THREAD_ID: Cell<ThreadId> = Cell::new(thread::current().id());
}
pub(crate) fn current_id() -> ThreadId {
THREAD_ID.with(|id| id.get())
}
} else {
use std::thread::{self, current_id, ThreadId};
}
}
pub struct SendWrapper<T> {
data: ManuallyDrop<T>,
thread_id: ThreadId,
}
impl<T> SendWrapper<T> {
#[inline]
pub fn new(data: T) -> SendWrapper<T> {
SendWrapper {
data: ManuallyDrop::new(data),
thread_id: current_id(),
}
}
#[inline]
pub fn valid(&self) -> bool {
self.thread_id == current_id()
}
pub unsafe fn take_unchecked(self) -> T {
let mut this = ManuallyDrop::new(self);
unsafe { ManuallyDrop::take(&mut this.data) }
}
#[track_caller]
pub fn take(self) -> T {
if self.valid() {
unsafe { self.take_unchecked() }
} else {
invalid_deref()
}
}
#[inline]
pub unsafe fn get_unchecked(&self) -> &T {
&self.data
}
#[inline]
pub unsafe fn get_unchecked_mut(&mut self) -> &mut T {
&mut self.data
}
#[inline]
pub unsafe fn get_unchecked_pinned(self: Pin<&Self>) -> Pin<&T> {
unsafe { self.map_unchecked(|s| &*s.data) }
}
#[inline]
pub unsafe fn get_unchecked_pinned_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
unsafe { self.map_unchecked_mut(|s| &mut *s.data) }
}
#[inline]
pub fn get(&self) -> Option<&T> {
if self.valid() { Some(&self.data) } else { None }
}
#[inline]
pub fn get_mut(&mut self) -> Option<&mut T> {
if self.valid() {
Some(&mut self.data)
} else {
None
}
}
#[inline]
pub fn get_pinned(self: Pin<&Self>) -> Option<Pin<&T>> {
if self.valid() {
Some(unsafe { self.get_unchecked_pinned() })
} else {
None
}
}
#[inline]
pub fn get_pinned_mut(self: Pin<&mut Self>) -> Option<Pin<&mut T>> {
if self.valid() {
Some(unsafe { self.get_unchecked_pinned_mut() })
} else {
None
}
}
#[inline]
pub fn tracker(&self) -> SendWrapper<()> {
SendWrapper {
data: ManuallyDrop::new(()),
thread_id: self.thread_id,
}
}
}
unsafe impl<T> Send for SendWrapper<T> {}
unsafe impl<T> Sync for SendWrapper<T> {}
impl<T> Drop for SendWrapper<T> {
#[track_caller]
fn drop(&mut self) {
if !mem::needs_drop::<T>() || self.valid() {
unsafe {
ManuallyDrop::drop(&mut self.data);
}
} else {
invalid_drop()
}
}
}
impl<T: fmt::Debug> fmt::Debug for SendWrapper<T> {
#[track_caller]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut f = f.debug_struct("SendWrapper");
if let Some(data) = self.get() {
f.field("data", data);
} else {
f.field("data", &"<invalid>");
}
f.field("thread_id", &self.thread_id).finish()
}
}
impl<T: Clone> Clone for SendWrapper<T> {
#[track_caller]
fn clone(&self) -> Self {
Self::new(self.get().unwrap_or_else(|| invalid_deref()).clone())
}
}
#[cold]
#[inline(never)]
#[track_caller]
fn invalid_deref() -> ! {
const DEREF_ERROR: &str = "Accessed SendWrapper<T> variable from a thread different to the \
one it has been created with.";
panic!("{}", DEREF_ERROR)
}
#[cold]
#[inline(never)]
#[track_caller]
#[cfg(feature = "futures")]
fn invalid_poll() -> ! {
const POLL_ERROR: &str = "Polling SendWrapper<T> variable from a thread different to the one \
it has been created with.";
panic!("{}", POLL_ERROR)
}
#[cold]
#[inline(never)]
#[track_caller]
fn invalid_drop() {
const DROP_ERROR: &str = "Dropped SendWrapper<T> variable from a thread different to the one \
it has been created with.";
if !thread::panicking() {
panic!("{}", DROP_ERROR)
}
}
#[cfg(test)]
mod tests {
use std::{
pin::Pin,
rc::Rc,
sync::{Arc, mpsc::channel},
thread,
};
use super::SendWrapper;
#[test]
fn get_and_get_mut_on_creator_thread_and_pinned_variants() {
let mut wrapper = SendWrapper::new(1_i32);
let r = wrapper.get();
assert!(r.is_some());
assert_eq!(*r.unwrap(), 1);
let r_mut = wrapper.get_mut();
assert!(r_mut.is_some());
*r_mut.unwrap() = 2;
let r_after = wrapper.get();
assert!(r_after.is_some());
assert_eq!(*r_after.unwrap(), 2);
let pinned = Pin::new(&wrapper);
let pinned_ref = pinned.get_pinned();
assert!(pinned_ref.is_some());
assert_eq!(*pinned_ref.unwrap(), 2);
let mut wrapper2 = SendWrapper::new(10_i32);
let pinned_mut = Pin::new(&mut wrapper2);
let pinned_mut_ref = pinned_mut.get_pinned_mut();
assert!(pinned_mut_ref.is_some());
*pinned_mut_ref.unwrap() = 11;
let after_mut = wrapper2.get();
assert!(after_mut.is_some());
assert_eq!(*after_mut.unwrap(), 11);
}
#[test]
fn accessors_return_none_on_non_creator_thread() {
let mut wrapper = SendWrapper::new(123_i32);
let handle = thread::spawn(move || {
assert!(wrapper.get().is_none());
assert!(wrapper.get_mut().is_none());
let pinned = Pin::new(&wrapper);
assert!(pinned.get_pinned().is_none());
let mut wrapper = wrapper;
let pinned_mut = Pin::new(&mut wrapper);
assert!(pinned_mut.get_pinned_mut().is_none());
});
handle.join().unwrap();
}
#[test]
fn test_valid() {
let (sender, receiver) = channel();
let w = SendWrapper::new(Rc::new(42));
assert!(w.valid());
let t = thread::spawn(move || {
sender.send(w).unwrap();
});
let w2 = receiver.recv().unwrap();
assert!(w2.valid());
assert!(t.join().is_ok());
}
#[test]
fn test_invalid() {
let w = SendWrapper::new(Rc::new(42));
let t = thread::spawn(move || {
assert!(!w.valid());
w
});
let join_result = t.join();
assert!(join_result.is_ok());
}
#[test]
fn test_drop_panic() {
let w = SendWrapper::new(Rc::new(42));
let t = thread::spawn(move || {
drop(w);
});
let join_result = t.join();
assert!(join_result.is_err());
}
#[test]
fn test_take() {
let w = SendWrapper::new(Rc::new(42));
let inner: Rc<usize> = w.take();
assert_eq!(42, *inner);
}
#[test]
fn test_take_panic() {
let w = SendWrapper::new(Rc::new(42));
let t = thread::spawn(move || {
let _ = w.take();
});
assert!(t.join().is_err());
}
#[test]
fn test_sync() {
let arc = Arc::new(SendWrapper::new(42));
thread::spawn(move || {
let _ = arc;
});
}
#[test]
fn test_debug() {
let w = SendWrapper::new(Rc::new(42));
let info = format!("{:?}", w);
assert!(info.contains("SendWrapper {"));
assert!(info.contains("data: 42,"));
assert!(info.contains("thread_id: ThreadId("));
}
#[test]
fn test_debug_invalid() {
let w = SendWrapper::new(Rc::new(42));
let t = thread::spawn(move || {
let info = format!("{:?}", w);
assert!(info.contains("SendWrapper {"));
assert!(info.contains("data: \"<invalid>\","));
assert!(info.contains("thread_id: ThreadId("));
w
});
assert!(t.join().is_ok());
}
#[test]
fn test_clone() {
let w1 = SendWrapper::new(Rc::new(42));
let w2 = w1.clone();
assert_eq!(format!("{:?}", w1), format!("{:?}", w2));
}
#[test]
fn test_clone_panic() {
let w = SendWrapper::new(Rc::new(42));
let t = thread::spawn(move || {
let _ = w.clone();
});
assert!(t.join().is_err());
}
}