cycle_ptr 0.1.1

Smart pointers, with cycles
Documentation
//!Module for the sendable strong pointer type.
use super::SingleOrMultiThreadPtr;
use crate::object::{MTObject, MTObjectIntf, MTObjectPtr};
use crate::prelude::GcPtrEq;
use crate::sync::{GcMtMemberPtr, Metadata};
use crate::{GcMemberPtr, GcPtr};
use std::fmt;
use std::mem;
use std::ops::Deref;
use std::pin::Pin;
use std::ptr;

#[cfg(feature = "weak_pointer")]
use super::sync_weak::Weak;
#[cfg(feature = "weak_pointer")]
use crate::errors::Error;

/// A strong reference to an object.
///
/// This is for use in:
/// - global scope
/// - function scope
/// - members on objects that don't participate in the reachability-graph
///
/// This is implements [Send] and [Sync], and consequently can only accept objects that are [Send] and [Sync].
#[cfg_attr(docsrs, doc(cfg(feature = "multi_thread")))]
pub struct GcMtPtr<T>
where
    T: 'static + Send + Sync,
{
    /// Pointer to the actual [MTObject] holding on to `T` and the associated [MTControlBlock][crate::generation::MTControlBlock].
    pub(super) ptr: Pin<MTObjectPtr<T>>,
}

impl<T> Clone for GcMtPtr<T>
where
    T: 'static + Send + Sync,
{
    #[inline]
    fn clone(&self) -> Self {
        let ptr = self.ptr.clone();
        ptr.get_control_block().refcount_inc_strong();
        GcMtPtr { ptr }
    }
}

impl<T> Drop for GcMtPtr<T>
where
    T: 'static + Send + Sync,
{
    #[inline]
    fn drop(&mut self) {
        self.ptr.get_control_block().refcount_dec();
    }
}

impl<T> Deref for GcMtPtr<T>
where
    T: 'static + Send + Sync,
{
    type Target = T;

    #[inline]
    fn deref(&self) -> &T {
        self.ptr.get_data()
    }
}

impl<T> fmt::Debug for GcMtPtr<T>
where
    T: 'static + Send + Sync + fmt::Debug,
{
    #[inline]
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
        fmt::Debug::fmt(&**self, f)
    }
}

impl<T> GcMtPtr<T>
where
    T: 'static + Send + Sync,
{
    /// Create a new [GcMtPtr].
    ///
    /// The factory function is given a [Metadata] that can be used to initialize [member pointers][GcMtMemberPtr].
    #[inline]
    pub fn new<Factory>(factory: Factory) -> GcMtPtr<T>
    where
        Factory: FnOnce(Metadata) -> T,
    {
        unsafe { GcMtPtr::new_from_raw(MTObject::new_ptr(1, factory, None)) }
    }

    /// Create a new [GcMtPtr].
    ///
    /// The factory function is given a [Metadata] that can be used to initialize [member pointers][GcMtMemberPtr].
    ///
    /// A [Weak] pointer is also provided to the factory function.
    /// Note that the [Weak] pointer can't be dereferenced until the object has actually been constructed.
    #[cfg(feature = "weak_pointer")]
    #[cfg_attr(docsrs, doc(cfg(feature = "weak_pointer")))]
    #[inline]
    pub fn new_cyclic<Factory>(factory: Factory) -> GcMtPtr<T>
    where
        Factory: FnOnce(Metadata, Weak<T>) -> T,
    {
        unsafe { GcMtPtr::new_from_raw(MTObject::new_cyclic_ptr(1, factory, None)) }
    }

    /// Downgrade to a [Weak] pointer.
    #[cfg(feature = "weak_pointer")]
    #[cfg_attr(docsrs, doc(cfg(feature = "weak_pointer")))]
    #[inline]
    pub fn downgrade(this: &Self) -> Weak<T> {
        Weak::new_ptr(this.ptr.clone())
    }

    /// Attempt to promote an unreferenced [MTObjectPtr] to a strong pointer.
    ///
    /// # Errors
    ///
    /// Returns an [Error] if the [MTObject] has [expired][crate::generation::ObjectState::Expired].
    #[cfg(feature = "weak_pointer")]
    pub(super) fn new_from_weak(ptr: &Pin<MTObjectPtr<T>>) -> Result<Self, Error> {
        ptr.get_control_block()
            .try_refcount_inc()
            .map(|_| GcMtPtr { ptr: ptr.clone() })
    }

    /// Create a new pointer from raw [MTObjectPtr].
    ///
    /// SAFETY: ObjectPtr must have 1 reference which will be adopted.
    #[inline]
    pub(crate) const unsafe fn new_from_raw(ptr: Pin<MTObjectPtr<T>>) -> Self {
        GcMtPtr { ptr }
    }

    /// Release the internal pointer, without decrementing the reference counter.
    #[inline]
    pub(super) unsafe fn release(this: Self) -> Pin<MTObjectPtr<T>> {
        unsafe {
            let this = mem::ManuallyDrop::new(this);
            let mut new_ptr = mem::MaybeUninit::<Pin<MTObjectPtr<T>>>::zeroed();
            ptr::copy_nonoverlapping(&this.ptr, new_ptr.as_mut_ptr(), 1);
            new_ptr.assume_init()
        }
    }
}

impl<T> GcPtrEq<GcPtr<T>> for GcMtPtr<T>
where
    T: 'static + Send + Sync,
{
    /// Test if two [GcMtPtr] point at the same object.
    #[inline]
    fn ptr_eq(this: &Self, other: &GcPtr<T>) -> bool {
        match &other.ptr {
            SingleOrMultiThreadPtr::SingleThread(_) => false,
            SingleOrMultiThreadPtr::MultiThread(other_ptr) => ptr::eq(&*this.ptr, &**other_ptr),
        }
    }
}

impl<T> GcPtrEq<GcMtPtr<T>> for GcMtPtr<T>
where
    T: 'static + Send + Sync,
{
    /// Test if two [GcMtPtr] point at the same object.
    #[inline]
    fn ptr_eq(this: &Self, other: &GcMtPtr<T>) -> bool {
        ptr::eq(&*this.ptr, &*other.ptr)
    }
}

impl<T> GcPtrEq<GcMemberPtr<T>> for GcMtPtr<T>
where
    T: 'static + Send + Sync,
{
    /// Test if two [GcMtPtr] point at the same object.
    #[inline]
    fn ptr_eq(this: &Self, other: &GcMemberPtr<T>) -> bool {
        match &other.ptr {
            SingleOrMultiThreadPtr::SingleThread(_) => false,
            SingleOrMultiThreadPtr::MultiThread(other_ptr) => ptr::eq(&*this.ptr, &**other_ptr),
        }
    }
}

impl<T> GcPtrEq<GcMtMemberPtr<T>> for GcMtPtr<T>
where
    T: 'static + Send + Sync,
{
    /// Test if two [GcMtPtr] point at the same object.
    #[inline]
    fn ptr_eq(this: &Self, other: &GcMtMemberPtr<T>) -> bool {
        ptr::eq(&*this.ptr, &*other.ptr)
    }
}

#[cfg(feature = "weak_pointer")]
impl<T> GcPtrEq<crate::Weak<T>> for GcMtPtr<T>
where
    T: 'static + Send + Sync,
{
    /// Test if two [GcMtPtr] point at the same object.
    #[inline]
    fn ptr_eq(this: &Self, other: &crate::Weak<T>) -> bool {
        other
            .ptr
            .as_ref()
            .map(|other_ptr| match other_ptr {
                SingleOrMultiThreadPtr::SingleThread(_) => false,
                SingleOrMultiThreadPtr::MultiThread(other_ptr) => ptr::eq(&*this.ptr, &**other_ptr),
            })
            .unwrap_or(false)
    }
}

#[cfg(feature = "weak_pointer")]
impl<T> GcPtrEq<crate::sync::Weak<T>> for GcMtPtr<T>
where
    T: 'static + Send + Sync,
{
    /// Test if two [GcMtPtr] point at the same object.
    #[inline]
    fn ptr_eq(this: &Self, other: &crate::sync::Weak<T>) -> bool {
        other
            .ptr
            .as_ref()
            .map(|other_ptr| ptr::eq(&*this.ptr, &**other_ptr))
            .unwrap_or(false)
    }
}

#[cfg(test)]
mod tests {
    use super::GcMtPtr;
    use crate::prelude::GcPtrEq;
    use crate::sync::GenerationRef;
    use std::sync::{Arc, Barrier, Mutex};
    use std::thread;

    #[derive(Debug)]
    struct Bla {
        n: Arc<Mutex<i32>>,
    }

    impl Bla {
        fn new(n: Arc<Mutex<i32>>) -> Bla {
            *n.lock().unwrap() += 1;
            Bla { n }
        }
    }

    impl Drop for Bla {
        fn drop(&mut self) {
            *self.n.lock().unwrap() -= 1;
        }
    }

    #[test]
    #[cfg_attr(
        feature = "single_generation_mt",
        ignore = "In single-generation, any of the other test threads may be running the GC task, making it not run in this function, and thus fail the n=0 check at the end."
    )]
    fn create_pointer() {
        let n = Arc::new(Mutex::new(0));
        let p = GcMtPtr::new(|_| Bla::new(n.clone()));
        assert_eq!(*n.lock().unwrap(), 1);

        drop(p);
        assert_eq!(*n.lock().unwrap(), 0);
    }

    #[test]
    #[cfg_attr(
        feature = "single_generation_mt",
        ignore = "In single-generation, any of the other test threads may be running the GC task, making it not run in this function, and thus fail the n=0 check at the end."
    )]
    fn clone_pointer() {
        let n = Arc::new(Mutex::new(0));
        let p = GcMtPtr::new(|_| Bla::new(n.clone()));
        let q = p.clone();
        assert_eq!(*n.lock().unwrap(), 1);

        drop(p);
        assert_eq!(*n.lock().unwrap(), 1);

        drop(q);
        assert_eq!(*n.lock().unwrap(), 0);
    }

    #[test]
    fn equality() {
        let n = Arc::new(Mutex::new(0));
        let p = GcMtPtr::new(|_| Bla::new(n.clone()));
        let q = p.clone();

        assert!(GcMtPtr::ptr_eq(&p, &q));
    }

    fn is_send<T: Send>(_: &T) -> bool {
        true
    }

    #[test]
    fn pointer_is_send() {
        let n = Arc::new(Mutex::new(0));
        let p = GcMtPtr::new(|_| Bla::new(n.clone()));

        assert!(is_send(&p));
    }

    #[test]
    #[cfg_attr(
        feature = "single_generation_mt",
        ignore = "In single-generation, any of the other test threads may be running the GC task, making it not run in this function, and thus fail the n=0 check at the end."
    )]
    fn pointer_works_with_threads() {
        // Tests run with debug asserts, and List uses a reachability debug assert,
        // making the mark-sweep garbage-collector quadratic complexity.
        // So you want to keep these numbers somewhat low.
        const THREADS: usize = 4;
        const ELEMENTS_PER_THREAD: usize = 400;

        let n = Arc::new(Mutex::new(0));
        {
            let barrier = Barrier::new(THREADS); // Use a barrier, to cause all threads to release their pointers simultaneously.
            let generation = GenerationRef::default(); // All threads use a shared generation.
            thread::scope(|s| {
                for _ in 0..THREADS {
                    s.spawn(|| {
                        let vec_of_pointers: Vec<_> = (0..ELEMENTS_PER_THREAD)
                            .map(|_| generation.make(|_| Bla::new(n.clone())))
                            .collect();
                        barrier.wait();
                        drop(vec_of_pointers);
                    });
                }
            });
        }

        assert_eq!(*n.lock().unwrap(), 0);
    }
}