use super::{SingleOrMultiThreadDynPtr, SingleOrMultiThreadPtr};
use crate::errors::{Error, ErrorEnum};
use crate::generation::edge::MTEdge;
use crate::generation::{MTGenerationPtr, ObjectState};
use crate::object::{DynMTObjectPtr, DynObjectPtr, MTObjectIntf, MTObjectPtr};
use crate::prelude::GcPtrEq;
use crate::sync::GcMtPtr;
use crate::{GcMemberPtr, GcPtr};
use std::fmt;
use std::ops::Deref;
use std::pin::Pin;
use std::ptr;
#[cfg(not(feature = "single_generation_mt"))]
use crate::generation::MTGeneration;
#[cfg_attr(docsrs, doc(cfg(feature = "multi_thread")))]
pub struct GcMtMemberPtr<T>
where
T: 'static + Send + Sync,
{
pub(super) origin: SingleOrMultiThreadDynPtr,
pub(super) ptr: Pin<MTObjectPtr<T>>,
}
impl<T> GcMtMemberPtr<T>
where
T: 'static + Send + Sync,
{
pub(crate) fn new_mt_ptr(origin: Pin<DynMTObjectPtr>, sp: GcMtPtr<T>) -> Self {
let edge = Box::pin(MTEdge::new(unsafe {
Pin::new_unchecked(sp.ptr.get_control_block())
}));
unsafe {
let sp_ptr = GcMtPtr::release(sp);
#[cfg(feature = "single_generation_mt")]
let origin_generation = origin.get_generation_ptr();
#[cfg(not(feature = "single_generation_mt"))]
let mut origin_generation = origin.get_generation_ptr();
#[cfg(feature = "single_generation_mt")]
{
let sp_ptr_generation = sp_ptr.get_generation_ptr();
assert!(MTGenerationPtr::ptr_eq(
&origin_generation,
&sp_ptr_generation
));
origin.get_control_block().register_edge(edge.as_ref());
sp_ptr.get_control_block().refcount_dec_no_gc();
}
#[cfg(not(feature = "single_generation_mt"))]
loop {
let origin_generation_lock = loop {
let re_read_generation = {
let lock = origin_generation.lock_read();
let re_read_generation = origin.get_generation_ptr();
if lock.same_generation(&re_read_generation) {
break lock;
}
re_read_generation
};
origin_generation = re_read_generation;
};
let sp_ptr_generation = sp_ptr.get_generation_ptr();
if MTGenerationPtr::ptr_eq(&origin_generation, &sp_ptr_generation) {
origin.get_control_block().register_edge(edge.as_ref());
sp_ptr.get_control_block().refcount_dec_no_gc();
break;
} else if origin_generation.id < sp_ptr_generation.id {
origin.get_control_block().register_edge(edge.as_ref());
break;
} else {
drop(origin_generation_lock);
origin_generation =
MTGeneration::merge(sp_ptr_generation, origin_generation.clone());
}
}
GcMtMemberPtr {
origin: SingleOrMultiThreadDynPtr::MultiThread(origin, edge),
ptr: sp_ptr,
}
}
}
pub(crate) fn new_ptr(origin: Pin<DynObjectPtr>, sp: GcMtPtr<T>) -> Self {
unsafe {
let sp_ptr = GcMtPtr::release(sp);
GcMtMemberPtr {
origin: SingleOrMultiThreadDynPtr::SingleThread(origin.created_backtrace().clone()),
ptr: sp_ptr,
}
}
}
#[inline]
pub fn as_ptr(&self) -> GcMtPtr<T> {
if let SingleOrMultiThreadDynPtr::MultiThread(origin, _) = &self.origin
&& origin.object_state() == ObjectState::Expired
{
panic!("cannot dereference member pointers on targets that are unreachable");
}
self.ptr.get_control_block().refcount_inc();
unsafe { GcMtPtr::new_from_raw(self.ptr.clone()) }
}
#[inline]
pub fn try_deref(&self) -> Result<&T, Error> {
if let SingleOrMultiThreadDynPtr::MultiThread(origin, _) = &self.origin
&& origin.object_state() == ObjectState::Expired
{
return Err(Error::new(ErrorEnum::OriginExpired(
self.origin.created_backtrace().clone(),
)));
}
Ok(self.ptr.get_data())
}
#[inline]
pub fn try_as_ptr(&self) -> Result<GcMtPtr<T>, Error> {
if let SingleOrMultiThreadDynPtr::MultiThread(origin, _) = &self.origin
&& origin.object_state() == ObjectState::Expired
{
return Err(Error::new(ErrorEnum::OriginExpired(
origin.created_backtrace().clone(),
)));
}
self.ptr.get_control_block().refcount_inc();
Ok(unsafe { GcMtPtr::new_from_raw(self.ptr.clone()) })
}
}
impl<T> GcPtrEq<GcPtr<T>> for GcMtMemberPtr<T>
where
T: 'static + Send + Sync,
{
#[inline]
fn ptr_eq(this: &Self, other: &GcPtr<T>) -> bool {
GcPtr::ptr_eq(other, this)
}
}
impl<T> GcPtrEq<GcMtPtr<T>> for GcMtMemberPtr<T>
where
T: 'static + Send + Sync,
{
#[inline]
fn ptr_eq(this: &Self, other: &GcMtPtr<T>) -> bool {
GcMtPtr::ptr_eq(other, this)
}
}
impl<T> GcPtrEq<GcMemberPtr<T>> for GcMtMemberPtr<T>
where
T: 'static + Send + Sync,
{
#[inline]
fn ptr_eq(this: &Self, other: &GcMemberPtr<T>) -> bool {
match &other.ptr {
SingleOrMultiThreadPtr::SingleThread(_) => false,
SingleOrMultiThreadPtr::MultiThread(other_ptr) => ptr::eq(&*this.ptr, &**other_ptr),
}
}
}
impl<T> GcPtrEq<GcMtMemberPtr<T>> for GcMtMemberPtr<T>
where
T: 'static + Send + Sync,
{
#[inline]
fn ptr_eq(this: &Self, other: &GcMtMemberPtr<T>) -> bool {
ptr::eq(&*this.ptr, &*other.ptr)
}
}
#[cfg(feature = "weak_pointer")]
impl<T> GcPtrEq<crate::Weak<T>> for GcMtMemberPtr<T>
where
T: 'static + Send + Sync,
{
#[inline]
fn ptr_eq(this: &Self, other: &crate::Weak<T>) -> bool {
other
.ptr
.as_ref()
.map(|other_ptr| match other_ptr {
SingleOrMultiThreadPtr::SingleThread(_) => false,
SingleOrMultiThreadPtr::MultiThread(other_ptr) => ptr::eq(&*this.ptr, &**other_ptr),
})
.unwrap_or(false)
}
}
#[cfg(all(feature = "multi_thread", feature = "weak_pointer"))]
impl<T> GcPtrEq<crate::sync::Weak<T>> for GcMtMemberPtr<T>
where
T: 'static + Send + Sync,
{
#[inline]
fn ptr_eq(this: &Self, other: &crate::sync::Weak<T>) -> bool {
other
.ptr
.as_ref()
.map(|other_ptr| ptr::eq(&*this.ptr, &**other_ptr))
.unwrap_or(false)
}
}
impl<T> Drop for GcMtMemberPtr<T>
where
T: 'static + Send + Sync,
{
#[allow(
clippy::missing_inline_in_public_items,
reason = "The drop of a member function does a lot of complicated things. There's probably very little benefit to inlining this, and not inlining it hopefully allows for easier debugging."
)]
fn drop(&mut self) {
match &self.origin {
SingleOrMultiThreadDynPtr::SingleThread(_) => {
self.ptr.get_control_block().refcount_dec();
}
SingleOrMultiThreadDynPtr::MultiThread(origin, edge) => {
let mut origin_generation = origin.get_generation_ptr();
let origin_generation_lock = loop {
let re_read_generation = {
let lock = origin_generation.lock_read();
let re_read_generation = origin.get_generation_ptr();
if lock.same_generation(&re_read_generation) {
break lock;
}
re_read_generation
};
origin_generation = re_read_generation;
};
let ptr_generation = self.ptr.get_generation_ptr();
origin
.get_control_block()
.deregister_edge(edge.as_ref().get_ref());
#[cfg(feature = "single_generation_mt")]
{
assert!(MTGenerationPtr::ptr_eq(&origin_generation, &ptr_generation));
self.ptr
.get_control_block()
.gc_if_zero_refs(origin_generation_lock);
}
#[cfg(not(feature = "single_generation_mt"))]
if MTGenerationPtr::ptr_eq(&origin_generation, &ptr_generation) {
self.ptr
.get_control_block()
.gc_if_zero_refs(origin_generation_lock);
} else {
assert!(origin_generation.id < ptr_generation.id);
drop(origin_generation_lock);
self.ptr.get_control_block().refcount_dec();
}
}
};
}
}
impl<T> Deref for GcMtMemberPtr<T>
where
T: 'static + Send + Sync,
{
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
if let SingleOrMultiThreadDynPtr::MultiThread(origin, _) = &self.origin
&& origin.object_state() == ObjectState::Expired
{
panic!("cannot dereference member pointers on targets that are unreachable");
}
self.ptr.get_data()
}
}
impl<T> fmt::Debug for GcMtMemberPtr<T>
where
T: 'static + Send + Sync + fmt::Debug,
{
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
fmt::Debug::fmt(&**self, f)
}
}
#[cfg(test)]
mod tests {
use super::GcMtMemberPtr;
use crate::prelude::GcMemberPtrNew;
use crate::sync::GenerationRef;
use crate::sync::{GcMtPtr, Metadata};
use std::sync::{Arc, Barrier, Mutex, RwLock};
use std::thread;
#[derive(Debug)]
struct Bla {
n: Arc<Mutex<i32>>,
}
#[derive(Debug)]
struct BlaParent {
_bla: GcMtMemberPtr<Bla>,
}
impl Bla {
fn new(n: Arc<Mutex<i32>>) -> GcMtPtr<Bla> {
*n.lock().unwrap() += 1;
GcMtPtr::new(|_| Bla { n })
}
}
impl Drop for Bla {
fn drop(&mut self) {
*self.n.lock().unwrap() -= 1;
}
}
impl BlaParent {
fn new(bla: GcMtPtr<Bla>) -> GcMtPtr<BlaParent> {
GcMtPtr::new(|metadata| BlaParent {
_bla: metadata.new_pointer(bla),
})
}
}
struct Cycle {
n: Arc<Mutex<i32>>,
cptr: RwLock<Option<GcMtMemberPtr<Cycle>>>,
metadata: Metadata,
}
impl Cycle {
fn new(n: Arc<Mutex<i32>>) -> GcMtPtr<Cycle> {
*n.lock().unwrap() += 1;
GcMtPtr::new(|metadata| Cycle {
n,
cptr: RwLock::new(None),
metadata,
})
}
fn new_with_generation(generation: &GenerationRef, n: Arc<Mutex<i32>>) -> GcMtPtr<Cycle> {
*n.lock().unwrap() += 1;
generation.make(|metadata| Cycle {
n,
cptr: RwLock::new(None),
metadata,
})
}
fn assign(&self, ptr: GcMtPtr<Cycle>) {
*self.cptr.write().unwrap() = Some(self.metadata.new_pointer(ptr));
}
}
impl Drop for Cycle {
fn drop(&mut self) {
*self.n.lock().unwrap() -= 1;
}
}
#[test]
#[cfg_attr(
feature = "single_generation_mt",
ignore = "In single-generation, any of the other test threads may be running the GC task, making it not run in this function, and thus fail the n=0 check at the end."
)]
fn create_pointer() {
let n = Arc::new(Mutex::new(0));
let bla = Bla::new(n.clone());
let bla_parent = BlaParent::new(bla.clone());
assert_eq!(*n.lock().unwrap(), 1);
drop(bla);
assert_eq!(
*n.lock().unwrap(),
1,
"`bla` is to remain live, because `bla_parent` has a member-pointer pointing at it"
);
drop(bla_parent);
assert_eq!(*n.lock().unwrap(), 0);
}
#[test]
#[cfg_attr(
feature = "single_generation_mt",
ignore = "In single-generation, any of the other test threads may be running the GC task, making it not run in this function, and thus fail the n=0 check at the end."
)]
fn create_cycle() {
let n = Arc::new(Mutex::new(0));
let p = Cycle::new(n.clone());
let q = Cycle::new(n.clone());
assert_eq!(*n.lock().unwrap(), 2);
p.assign(q.clone());
q.assign(p.clone());
assert_eq!(*n.lock().unwrap(), 2);
drop(p);
assert_eq!(
*n.lock().unwrap(),
2,
"`cycle p` is to remain live, because `cycle q` has a member-pointer pointing at it"
);
let p = q.cptr.read().unwrap().as_ref().unwrap().as_ptr();
drop(q);
assert_eq!(
*n.lock().unwrap(),
2,
"`cycle q` is to remain live, because `cycle p` has a member-pointer pointing at it"
);
drop(p);
assert_eq!(*n.lock().unwrap(), 0);
}
#[test]
#[cfg_attr(
feature = "single_generation_mt",
ignore = "In single-generation, any of the other test threads may be running the GC task, making it not run in this function, and thus fail the n=0 check at the end."
)]
fn pointer_works_with_threads() {
const THREADS: usize = 4;
const ELEMENTS_PER_THREAD: usize = 400;
let n = Arc::new(Mutex::new(0));
{
let barrier = Barrier::new(THREADS); let generation = GenerationRef::default(); thread::scope(|s| {
for _ in 0..THREADS {
s.spawn(|| {
let vec_of_pointers: Vec<_> = (0..ELEMENTS_PER_THREAD)
.map(|_| {
let p = Cycle::new_with_generation(&generation, n.clone());
let q = Cycle::new_with_generation(&generation, n.clone());
q.assign(p.clone());
p.assign(q);
p
})
.collect();
barrier.wait();
drop(vec_of_pointers);
});
}
});
}
assert_eq!(*n.lock().unwrap(), 0);
}
}