use std::borrow::Borrow;
use std::cell::{Cell, RefCell};
use std::collections::{HashMap, HashSet};
use std::fmt::{Debug, Display};
use std::hash::{Hash, Hasher};
use std::ops::Deref;
use std::ptr::NonNull;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use crate::thread_id::current_thread;
unsafe impl<T> Send for SendRc<T> where T: Send {}
struct Inner<T> {
pinned_to: AtomicU64,
parking: Cell<u64>,
val: T,
strong_count: Cell<usize>,
}
static NEXT_SENDRC_ID: AtomicUsize = AtomicUsize::new(0);
static NEXT_PRE_SEND_ID: AtomicU64 = AtomicU64::new(1);
pub struct SendRc<T> {
ptr: NonNull<Inner<T>>,
id: usize,
}
enum PinError {
BadThread,
Parking,
}
impl PinError {
fn msg(&self) -> &'static str {
match self {
PinError::BadThread => "SendRc accessed from wrong thread; call pre_send() first",
PinError::Parking => "access to SendRc that is about to be sent to a new thread",
}
}
fn panic(&self, what: &str) -> ! {
panic!("{what}: {}", self.msg());
}
}
impl<T> SendRc<T> {
pub fn new(val: T) -> Self {
let ptr = Box::into_raw(Box::new(Inner {
pinned_to: AtomicU64::new(current_thread()),
parking: Cell::new(0),
val,
strong_count: Cell::new(1),
}));
SendRc::from_inner_ptr(ptr)
}
fn from_inner_ptr(ptr: *mut Inner<T>) -> Self {
SendRc {
ptr: NonNull::new(ptr).unwrap(),
id: NEXT_SENDRC_ID.fetch_add(1, Ordering::Relaxed),
}
}
fn inner(&self) -> &Inner<T> {
unsafe { self.ptr.as_ref() }
}
unsafe fn inner_mut(&mut self) -> &mut Inner<T> {
unsafe { self.ptr.as_mut() }
}
#[inline]
fn check_pinned(&self) -> Result<(), PinError> {
let inner = self.inner();
if inner.pinned_to.load(Ordering::Relaxed) != current_thread() {
return Err(PinError::BadThread);
}
if inner.parking.get() != 0 {
return Err(PinError::Parking);
}
Ok(())
}
#[inline]
fn assert_pinned(&self, op: &str) {
self.check_pinned()
.unwrap_or_else(|pinerr| pinerr.panic(op));
}
pub fn pre_send() -> PreSend<T> {
PreSend {
parked: Default::default(),
pre_send_id: NEXT_PRE_SEND_ID.fetch_add(1, Ordering::Relaxed),
}
}
pub fn strong_count(this: &Self) -> usize {
this.assert_pinned("SendRc::strong_count()");
this.inner().strong_count.get()
}
pub fn try_unwrap(this: Self) -> Result<T, Self> {
this.assert_pinned("SendRc::try_unwrap()");
if this.inner().strong_count.get() == 1 {
let inner_box = unsafe { Box::from_raw(this.ptr.as_ptr()) };
std::mem::forget(this);
Ok(inner_box.val)
} else {
Err(this)
}
}
pub fn get_mut(this: &mut Self) -> Option<&mut T> {
this.assert_pinned("SendRc::get_mut()");
if this.inner().strong_count.get() == 1 {
unsafe { Some(&mut this.inner_mut().val) }
} else {
None
}
}
pub fn ptr_eq(this: &Self, other: &Self) -> bool {
this.ptr == other.ptr
}
}
pub struct PreSend<T> {
parked: RefCell<HashMap<usize, NonNull<Inner<T>>>>,
pre_send_id: u64,
}
impl<T> PreSend<T> {
pub fn park<'a>(&'a self, send_rc: &'a mut SendRc<T>) -> &'a T {
match send_rc.check_pinned() {
Ok(()) => send_rc.inner().parking.set(self.pre_send_id),
Err(pinerr @ PinError::BadThread) => pinerr.panic("PreSend::park()"),
Err(PinError::Parking) => {
if send_rc.inner().parking.get() != self.pre_send_id {
panic!("PreSend::park(): call from different PreSend");
}
}
}
self.parked.borrow_mut().insert(send_rc.id, send_rc.ptr);
&send_rc.inner().val
}
pub fn ready(self) -> PostSend<T> {
if !self.is_ready() {
panic!("PreSend::ready() called before all SendRcs have been parked");
}
let ptrs: HashSet<_> = self.parked.into_inner().into_values().collect();
for &ptr in &ptrs {
let inner = unsafe { &*ptr.as_ptr() };
inner.pinned_to.store(0, Ordering::Relaxed);
}
PostSend { ptrs }
}
pub fn is_ready(&self) -> bool {
let ptr_sendrc_cnt: HashMap<_, usize> =
self.parked
.borrow()
.values()
.fold(HashMap::new(), |mut map, &ptr| {
*map.entry(ptr).or_default() += 1;
map
});
ptr_sendrc_cnt.into_iter().all(|(ptr, cnt)| {
let inner = unsafe { &*ptr.as_ptr() };
cnt == inner.strong_count.get()
})
}
pub fn park_status_of(&self, send_rc: &SendRc<T>) -> ParkStatus {
let (mut sendrc_parked, mut value_parked) = (false, false);
match send_rc.check_pinned() {
Ok(()) => {}
Err(PinError::Parking) => {
value_parked = true;
sendrc_parked = self.parked.borrow().contains_key(&send_rc.id);
}
Err(pinerr @ PinError::BadThread) => pinerr.panic("PreSend::park_status_of()"),
}
ParkStatus {
sendrc_parked,
value_parked,
}
}
}
pub struct ParkStatus {
pub sendrc_parked: bool,
pub value_parked: bool,
}
#[must_use]
pub struct PostSend<T> {
ptrs: HashSet<NonNull<Inner<T>>>,
}
unsafe impl<T> Send for PostSend<T> where T: Send {}
impl<T> PostSend<T> {
pub fn unpark(self) {
let current_thread = current_thread();
for ptr in self.ptrs {
let inner = unsafe { &*ptr.as_ptr() };
inner.pinned_to.store(current_thread, Ordering::Relaxed);
inner.parking.set(0);
}
}
}
impl<T: Display> Display for SendRc<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(&**self, f)
}
}
impl<T: Debug> Debug for SendRc<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&**self, f)
}
}
impl<T> Deref for SendRc<T> {
type Target = T;
fn deref(&self) -> &T {
self.assert_pinned("SendRc::deref()");
&self.inner().val
}
}
impl<T> Clone for SendRc<T> {
fn clone(&self) -> Self {
self.assert_pinned("SendRc::clone()");
self.inner()
.strong_count
.set(self.inner().strong_count.get() + 1);
SendRc::from_inner_ptr(self.ptr.as_ptr())
}
}
impl<T> Drop for SendRc<T> {
fn drop(&mut self) {
if let Err(pinerr) = self.check_pinned() {
if std::thread::panicking() {
return;
}
pinerr.panic("SendRc::drop()");
}
let refcnt = self.inner().strong_count.get();
if refcnt == 1 {
unsafe {
let inner_box = Box::from_raw(self.ptr.as_ptr());
drop(inner_box);
}
} else {
self.inner().strong_count.set(refcnt - 1);
}
}
}
impl<T> AsRef<T> for SendRc<T> {
fn as_ref(&self) -> &T {
self
}
}
impl<T> Borrow<T> for SendRc<T> {
fn borrow(&self) -> &T {
self
}
}
impl<T: Default> Default for SendRc<T> {
fn default() -> SendRc<T> {
SendRc::new(Default::default())
}
}
impl<T: Eq> Eq for SendRc<T> {}
impl<T: PartialEq> PartialEq for SendRc<T> {
fn eq(&self, other: &SendRc<T>) -> bool {
SendRc::ptr_eq(self, other) || **self == **other
}
}
impl<T: PartialOrd> PartialOrd for SendRc<T> {
fn partial_cmp(&self, other: &SendRc<T>) -> Option<std::cmp::Ordering> {
(**self).partial_cmp(&**other)
}
}
impl<T: Ord> Ord for SendRc<T> {
fn cmp(&self, other: &SendRc<T>) -> std::cmp::Ordering {
(**self).cmp(&**other)
}
}
impl<T: Hash> Hash for SendRc<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
(**self).hash(state);
}
}
#[cfg(test)]
mod tests {
use std::cell::RefCell;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use super::SendRc;
#[test]
fn trivial() {
let r1 = SendRc::new(RefCell::new(1));
let r2 = SendRc::clone(&r1);
*r1.borrow_mut() = 2;
assert_eq!(*r2.borrow(), 2);
}
#[test]
fn drops() {
struct Payload(Rc<RefCell<bool>>);
impl Drop for Payload {
fn drop(&mut self) {
*self.0.as_ref().borrow_mut() = true;
}
}
let make = || {
let is_dropped = Rc::new(RefCell::new(false));
let payload = Payload(Rc::clone(&is_dropped));
(SendRc::new(payload), is_dropped)
};
let (r1, is_dropped) = make();
assert!(!*is_dropped.borrow());
drop(r1);
assert!(*is_dropped.borrow());
let (r1, is_dropped) = make();
let r2 = SendRc::clone(&r1);
assert!(!*is_dropped.borrow());
drop(r1);
assert!(!*is_dropped.borrow());
drop(r2);
assert!(*is_dropped.borrow());
let (r1, is_dropped) = make();
let r2 = SendRc::clone(&r1);
let r3 = SendRc::clone(&r1);
assert!(!*is_dropped.borrow());
drop(r1);
assert!(!*is_dropped.borrow());
drop(r2);
assert!(!*is_dropped.borrow());
drop(r3);
assert!(*is_dropped.borrow());
}
#[test]
fn ok_send() {
let mut r1 = SendRc::new(RefCell::new(1));
let mut r2 = SendRc::clone(&r1);
let pre_send = SendRc::pre_send();
pre_send.park(&mut r1);
pre_send.park(&mut r2);
let post_send = pre_send.ready();
std::thread::spawn(move || {
post_send.unpark();
*r1.borrow_mut() += 1;
assert_eq!(*r2.borrow(), 2);
})
.join()
.unwrap();
}
#[test]
#[should_panic = "drop()"]
fn missing_pre_send_drop() {
let r = SendRc::new(RefCell::new(1));
std::thread::spawn(move || {
drop(r);
})
.join()
.map_err(|e| e.downcast::<String>().unwrap())
.unwrap();
}
#[test]
#[should_panic = "deref()"]
fn missing_pre_send_deref() {
let r1 = SendRc::new(RefCell::new(1));
let r2 = SendRc::clone(&r1);
std::thread::spawn(move || {
*r1.borrow_mut() = 2; assert_eq!(*r2.borrow(), 2);
})
.join()
.map_err(|e| e.downcast::<String>().unwrap())
.unwrap();
}
#[test]
#[should_panic = "ready() called before"]
fn incomplete_pre_send() {
let mut r1 = SendRc::new(RefCell::new(1));
let _r2 = SendRc::clone(&r1);
let pre_send = SendRc::pre_send();
pre_send.park(&mut r1);
let _ = pre_send.ready(); }
#[test]
#[should_panic = "before all SendRcs have been parked"]
fn incomplete_pre_send_other_shared_value() {
let mut r1 = SendRc::new(RefCell::new(1));
let mut r2 = SendRc::clone(&r1);
let mut q1 = SendRc::new(RefCell::new(1));
let _q2 = SendRc::clone(&q1);
let pre_send = SendRc::pre_send();
pre_send.park(&mut r1);
pre_send.park(&mut r2);
pre_send.park(&mut q1);
let _ = pre_send.ready(); }
#[test]
#[should_panic = "before all SendRcs have been parked"]
fn faked_pre_send_count_reusing_same_ptr() {
let mut r1 = SendRc::new(RefCell::new(1));
let _r2 = SendRc::clone(&r1);
let pre_send = SendRc::pre_send();
pre_send.park(&mut r1);
pre_send.park(&mut r1);
let _ = pre_send.ready();
}
#[test]
#[should_panic = "call from different PreSend"]
fn same_sendrc_different_presend() {
let mut r1 = SendRc::new(RefCell::new(1));
let mut r2 = SendRc::clone(&r1);
let pre_send1 = SendRc::pre_send();
pre_send1.park(&mut r1);
pre_send1.park(&mut r2);
let pre_send2 = SendRc::pre_send();
let _ref1: &RefCell<u32> = pre_send2.park(&mut r1); let post_send = pre_send1.ready();
let t = std::thread::spawn(move || {
post_send.unpark();
let _ref2: &RefCell<u32> = &*r2;
});
t.join().unwrap();
}
#[test]
fn park_twice_good() {
let mut r1 = SendRc::new(RefCell::new(1));
let mut r2 = SendRc::new(RefCell::new(1));
let pre_send = SendRc::pre_send();
pre_send.park(&mut r1);
pre_send.park(&mut r1);
pre_send.park(&mut r2);
let post_send = pre_send.ready();
post_send.unpark();
}
#[test]
fn park_twice_bad() {
let state = Arc::new(Mutex::new(0));
let result = std::thread::spawn({
let state = state.clone();
move || {
let mut r1 = SendRc::new(RefCell::new(1));
let mut r2 = SendRc::new(RefCell::new(1));
let pre_send1 = SendRc::pre_send();
let pre_send2 = SendRc::pre_send();
*state.lock().unwrap() = 1;
pre_send1.park(&mut r1);
*state.lock().unwrap() = 2;
pre_send1.park(&mut r2);
*state.lock().unwrap() = 3;
pre_send2.park(&mut r2); *state.lock().unwrap() = 4;
}
})
.join();
assert!(result.is_err());
assert_eq!(*state.lock().unwrap(), 3);
}
#[test]
fn test_try_unwrap_success() {
let r = SendRc::new(RefCell::new(42));
let val = SendRc::try_unwrap(r).unwrap();
assert_eq!(*val.borrow(), 42);
}
#[test]
fn test_try_unwrap_failure() {
let r1 = SendRc::new(RefCell::new(42));
let r2 = SendRc::clone(&r1);
assert!(SendRc::try_unwrap(r1).is_err());
drop(r2);
}
#[test]
#[should_panic = "SendRc accessed from wrong thread"]
fn clones_sent_to_different_threads() {
let mut r1 = SendRc::new(RefCell::new(1));
let mut r2 = SendRc::clone(&r1);
let pre_send = SendRc::pre_send();
pre_send.park(&mut r1);
pre_send.park(&mut r2);
let post_send = pre_send.ready();
let jh1 = std::thread::spawn(move || {
post_send.unpark();
drop(r1);
});
let jh2 = std::thread::spawn(move || {
drop(r2); });
jh1.join().unwrap();
jh2.join()
.map_err(|e| std::panic::resume_unwind(e))
.unwrap();
}
}