use super::SingleOrMultiThreadPtr;
use crate::object::{MTObject, MTObjectIntf, MTObjectPtr};
use crate::prelude::GcPtrEq;
use crate::sync::{GcMtMemberPtr, Metadata};
use crate::{GcMemberPtr, GcPtr};
use std::fmt;
use std::mem;
use std::ops::Deref;
use std::pin::Pin;
use std::ptr;
#[cfg(feature = "weak_pointer")]
use super::sync_weak::Weak;
#[cfg(feature = "weak_pointer")]
use crate::errors::Error;
#[cfg_attr(docsrs, doc(cfg(feature = "multi_thread")))]
pub struct GcMtPtr<T>
where
T: 'static + Send + Sync,
{
pub(super) ptr: Pin<MTObjectPtr<T>>,
}
impl<T> Clone for GcMtPtr<T>
where
T: 'static + Send + Sync,
{
#[inline]
fn clone(&self) -> Self {
let ptr = self.ptr.clone();
ptr.get_control_block().refcount_inc_strong();
GcMtPtr { ptr }
}
}
impl<T> Drop for GcMtPtr<T>
where
T: 'static + Send + Sync,
{
#[inline]
fn drop(&mut self) {
self.ptr.get_control_block().refcount_dec();
}
}
impl<T> Deref for GcMtPtr<T>
where
T: 'static + Send + Sync,
{
type Target = T;
#[inline]
fn deref(&self) -> &T {
self.ptr.get_data()
}
}
impl<T> fmt::Debug for GcMtPtr<T>
where
T: 'static + Send + Sync + fmt::Debug,
{
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
fmt::Debug::fmt(&**self, f)
}
}
impl<T> GcMtPtr<T>
where
T: 'static + Send + Sync,
{
#[inline]
pub fn new<Factory>(factory: Factory) -> GcMtPtr<T>
where
Factory: FnOnce(Metadata) -> T,
{
unsafe { GcMtPtr::new_from_raw(MTObject::new_ptr(1, factory, None)) }
}
#[cfg(feature = "weak_pointer")]
#[cfg_attr(docsrs, doc(cfg(feature = "weak_pointer")))]
#[inline]
pub fn new_cyclic<Factory>(factory: Factory) -> GcMtPtr<T>
where
Factory: FnOnce(Metadata, Weak<T>) -> T,
{
unsafe { GcMtPtr::new_from_raw(MTObject::new_cyclic_ptr(1, factory, None)) }
}
#[cfg(feature = "weak_pointer")]
#[cfg_attr(docsrs, doc(cfg(feature = "weak_pointer")))]
#[inline]
pub fn downgrade(this: &Self) -> Weak<T> {
Weak::new_ptr(this.ptr.clone())
}
#[cfg(feature = "weak_pointer")]
pub(super) fn new_from_weak(ptr: &Pin<MTObjectPtr<T>>) -> Result<Self, Error> {
ptr.get_control_block()
.try_refcount_inc()
.map(|_| GcMtPtr { ptr: ptr.clone() })
}
#[inline]
pub(crate) const unsafe fn new_from_raw(ptr: Pin<MTObjectPtr<T>>) -> Self {
GcMtPtr { ptr }
}
#[inline]
pub(super) unsafe fn release(this: Self) -> Pin<MTObjectPtr<T>> {
unsafe {
let this = mem::ManuallyDrop::new(this);
let mut new_ptr = mem::MaybeUninit::<Pin<MTObjectPtr<T>>>::zeroed();
ptr::copy_nonoverlapping(&this.ptr, new_ptr.as_mut_ptr(), 1);
new_ptr.assume_init()
}
}
}
impl<T> GcPtrEq<GcPtr<T>> for GcMtPtr<T>
where
T: 'static + Send + Sync,
{
#[inline]
fn ptr_eq(this: &Self, other: &GcPtr<T>) -> bool {
match &other.ptr {
SingleOrMultiThreadPtr::SingleThread(_) => false,
SingleOrMultiThreadPtr::MultiThread(other_ptr) => ptr::eq(&*this.ptr, &**other_ptr),
}
}
}
impl<T> GcPtrEq<GcMtPtr<T>> for GcMtPtr<T>
where
T: 'static + Send + Sync,
{
#[inline]
fn ptr_eq(this: &Self, other: &GcMtPtr<T>) -> bool {
ptr::eq(&*this.ptr, &*other.ptr)
}
}
impl<T> GcPtrEq<GcMemberPtr<T>> for GcMtPtr<T>
where
T: 'static + Send + Sync,
{
#[inline]
fn ptr_eq(this: &Self, other: &GcMemberPtr<T>) -> bool {
match &other.ptr {
SingleOrMultiThreadPtr::SingleThread(_) => false,
SingleOrMultiThreadPtr::MultiThread(other_ptr) => ptr::eq(&*this.ptr, &**other_ptr),
}
}
}
impl<T> GcPtrEq<GcMtMemberPtr<T>> for GcMtPtr<T>
where
T: 'static + Send + Sync,
{
#[inline]
fn ptr_eq(this: &Self, other: &GcMtMemberPtr<T>) -> bool {
ptr::eq(&*this.ptr, &*other.ptr)
}
}
#[cfg(feature = "weak_pointer")]
impl<T> GcPtrEq<crate::Weak<T>> for GcMtPtr<T>
where
T: 'static + Send + Sync,
{
#[inline]
fn ptr_eq(this: &Self, other: &crate::Weak<T>) -> bool {
other
.ptr
.as_ref()
.map(|other_ptr| match other_ptr {
SingleOrMultiThreadPtr::SingleThread(_) => false,
SingleOrMultiThreadPtr::MultiThread(other_ptr) => ptr::eq(&*this.ptr, &**other_ptr),
})
.unwrap_or(false)
}
}
#[cfg(feature = "weak_pointer")]
impl<T> GcPtrEq<crate::sync::Weak<T>> for GcMtPtr<T>
where
T: 'static + Send + Sync,
{
#[inline]
fn ptr_eq(this: &Self, other: &crate::sync::Weak<T>) -> bool {
other
.ptr
.as_ref()
.map(|other_ptr| ptr::eq(&*this.ptr, &**other_ptr))
.unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::GcMtPtr;
use crate::prelude::GcPtrEq;
use crate::sync::GenerationRef;
use std::sync::{Arc, Barrier, Mutex};
use std::thread;
#[derive(Debug)]
struct Bla {
n: Arc<Mutex<i32>>,
}
impl Bla {
fn new(n: Arc<Mutex<i32>>) -> Bla {
*n.lock().unwrap() += 1;
Bla { n }
}
}
impl Drop for Bla {
fn drop(&mut self) {
*self.n.lock().unwrap() -= 1;
}
}
#[test]
#[cfg_attr(
feature = "single_generation_mt",
ignore = "In single-generation, any of the other test threads may be running the GC task, making it not run in this function, and thus fail the n=0 check at the end."
)]
fn create_pointer() {
let n = Arc::new(Mutex::new(0));
let p = GcMtPtr::new(|_| Bla::new(n.clone()));
assert_eq!(*n.lock().unwrap(), 1);
drop(p);
assert_eq!(*n.lock().unwrap(), 0);
}
#[test]
#[cfg_attr(
feature = "single_generation_mt",
ignore = "In single-generation, any of the other test threads may be running the GC task, making it not run in this function, and thus fail the n=0 check at the end."
)]
fn clone_pointer() {
let n = Arc::new(Mutex::new(0));
let p = GcMtPtr::new(|_| Bla::new(n.clone()));
let q = p.clone();
assert_eq!(*n.lock().unwrap(), 1);
drop(p);
assert_eq!(*n.lock().unwrap(), 1);
drop(q);
assert_eq!(*n.lock().unwrap(), 0);
}
#[test]
fn equality() {
let n = Arc::new(Mutex::new(0));
let p = GcMtPtr::new(|_| Bla::new(n.clone()));
let q = p.clone();
assert!(GcMtPtr::ptr_eq(&p, &q));
}
fn is_send<T: Send>(_: &T) -> bool {
true
}
#[test]
fn pointer_is_send() {
let n = Arc::new(Mutex::new(0));
let p = GcMtPtr::new(|_| Bla::new(n.clone()));
assert!(is_send(&p));
}
#[test]
#[cfg_attr(
feature = "single_generation_mt",
ignore = "In single-generation, any of the other test threads may be running the GC task, making it not run in this function, and thus fail the n=0 check at the end."
)]
fn pointer_works_with_threads() {
const THREADS: usize = 4;
const ELEMENTS_PER_THREAD: usize = 400;
let n = Arc::new(Mutex::new(0));
{
let barrier = Barrier::new(THREADS); let generation = GenerationRef::default(); thread::scope(|s| {
for _ in 0..THREADS {
s.spawn(|| {
let vec_of_pointers: Vec<_> = (0..ELEMENTS_PER_THREAD)
.map(|_| generation.make(|_| Bla::new(n.clone())))
.collect();
barrier.wait();
drop(vec_of_pointers);
});
}
});
}
assert_eq!(*n.lock().unwrap(), 0);
}
}