use crate::{Pointer, SmartPointer};
use alloc::boxed::Box;
use atomic_traits::{
Atomic, NumOps,
fetch::{Add, Sub},
};
use core::fmt;
use core::marker::PhantomData;
use core::ops::Deref;
use core::ptr::NonNull;
use core::sync::atomic::{
Ordering::{Acquire, Relaxed, Release},
fence,
};
pub unsafe trait IrcItem<Tag = (), P = Box<Self>>: Sized + Send + Sync
where
<Self::Counter as Atomic>::Type: From<u8> + Into<usize> + PartialEq,
P: Pointer<Target = Self>,
{
type Counter: NumOps;
fn counter(&self) -> &Self::Counter;
#[inline(always)]
fn on_drop(_this: P) {}
#[inline]
fn strong_count(&self) -> usize {
self.counter().load(Relaxed).into()
}
}
pub struct Irc<T, Tag = (), P = Box<T>>
where
T: IrcItem<Tag, P>,
P: Pointer<Target = T>,
{
inner: NonNull<T>,
_phan: PhantomData<fn(&Tag, &P)>,
}
impl<T, Tag, P> Irc<T, Tag, P>
where
T: IrcItem<Tag, P>,
P: SmartPointer<Target = T>,
{
#[inline]
pub fn new(inner: T) -> Self {
Self::from(P::new(inner))
}
}
impl<T: IrcItem<Tag, P>, Tag, P> SmartPointer for Irc<T, Tag, P>
where
T: IrcItem<Tag, P>,
P: SmartPointer<Target = T>,
{
#[inline]
fn new(inner: T) -> Self {
Irc::new(inner)
}
}
impl<T, Tag, P> From<P> for Irc<T, Tag, P>
where
T: IrcItem<Tag, P>,
P: Pointer<Target = T>,
{
#[inline]
fn from(inner: P) -> Self {
inner.as_ref().counter().store(1u8.into(), Relaxed);
Self {
inner: unsafe { NonNull::new_unchecked(inner.into_raw() as *mut T) },
_phan: Default::default(),
}
}
}
impl<T, Tag, P> Irc<T, Tag, P>
where
T: IrcItem<Tag, P>,
P: Pointer<Target = T>,
{
#[inline(always)]
fn get_inner(&self) -> &T {
unsafe { self.inner.as_ref() }
}
#[inline]
pub fn ptr_eq(this: &Self, other: &Self) -> bool {
this.inner == other.inner
}
#[inline]
pub fn is_unique(&self) -> bool {
self.counter().load(Acquire) == 1u8.into()
}
#[inline]
pub fn get_mut(this: &mut Self) -> Option<&mut T> {
if this.is_unique() { Some(unsafe { this.inner.as_mut() }) } else { None }
}
}
impl<T, Tag, P> Irc<T, Tag, P>
where
T: IrcItem<Tag, P> + Clone,
P: SmartPointer<Target = T>,
{
#[inline]
pub fn make_mut(this: &mut Self) -> &mut T {
if !this.is_unique() {
let cloned_item = this.get_inner().clone();
let mut new_irc = Self::new(cloned_item);
core::mem::swap(this, &mut new_irc);
}
unsafe { this.inner.as_mut() }
}
}
impl<T, Tag, P> Deref for Irc<T, Tag, P>
where
T: IrcItem<Tag, P>,
P: Pointer<Target = T>,
{
type Target = T;
#[inline(always)]
fn deref(&self) -> &Self::Target {
self.get_inner()
}
}
impl<T, Tag, P> AsRef<T> for Irc<T, Tag, P>
where
T: IrcItem<Tag, P>,
P: Pointer<Target = T>,
{
#[inline(always)]
fn as_ref(&self) -> &T {
self.get_inner()
}
}
unsafe impl<T, Tag, P> Send for Irc<T, Tag, P>
where
T: IrcItem<Tag, P>,
P: Pointer<Target = T>,
{
}
unsafe impl<T, Tag, P> Sync for Irc<T, Tag, P>
where
T: IrcItem<Tag, P>,
P: Pointer<Target = T>,
{
}
impl<T, Tag, P> Clone for Irc<T, Tag, P>
where
T: IrcItem<Tag, P>,
P: Pointer<Target = T>,
{
#[inline]
fn clone(&self) -> Self {
self.get_inner().counter().fetch_add(1u8.into(), Relaxed);
Self { inner: self.inner, _phan: Default::default() }
}
}
impl<T, Tag, P> Drop for Irc<T, Tag, P>
where
T: IrcItem<Tag, P>,
P: Pointer<Target = T>,
{
#[inline]
fn drop(&mut self) {
let p = self.inner.as_ptr();
unsafe {
if (*p).counter().fetch_sub(1u8.into(), Release) == 1u8.into() {
fence(Acquire);
let inner = P::from_raw(p);
IrcItem::<Tag, P>::on_drop(inner);
}
}
}
}
impl<T, Tag, P> fmt::Debug for Irc<T, Tag, P>
where
T: IrcItem<Tag, P> + fmt::Debug,
P: Pointer<Target = T>,
{
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.get_inner().fmt(f)
}
}
impl<T, Tag, P> fmt::Display for Irc<T, Tag, P>
where
T: IrcItem<Tag, P> + fmt::Display,
P: Pointer<Target = T>,
{
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.get_inner().fmt(f)
}
}
impl<T: IrcItem<Tag, P>, Tag, P> Pointer for Irc<T, Tag, P>
where
T: IrcItem<Tag, P>,
P: Pointer<Target = T>,
{
type Target = T;
#[inline]
fn as_ref(&self) -> &Self::Target {
unsafe { self.inner.as_ref() }
}
#[inline]
unsafe fn from_raw(p: *const Self::Target) -> Self {
Self {
inner: unsafe { NonNull::new_unchecked(p as *mut Self::Target) },
_phan: Default::default(),
}
}
#[inline]
fn into_raw(self) -> *const Self::Target {
let p = self.inner.as_ptr();
core::mem::forget(self);
p
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test::{CounterI32, alive_count, reset_alive_count};
use alloc::sync::Arc;
use core::sync::atomic::AtomicUsize;
use std::thread;
struct TestItem {
value: CounterI32,
counter: AtomicUsize,
}
impl TestItem {
fn new(val: i32) -> Self {
Self { value: CounterI32::new(val), counter: AtomicUsize::new(0) }
}
}
impl Clone for TestItem {
fn clone(&self) -> Self {
Self { value: self.value.clone(), counter: AtomicUsize::new(0) }
}
}
unsafe impl IrcItem for TestItem {
type Counter = AtomicUsize;
fn counter(&self) -> &Self::Counter {
&self.counter
}
}
struct ArcTestItem {
value: CounterI32,
counter: AtomicUsize,
}
impl ArcTestItem {
fn new(val: i32) -> Self {
Self { value: CounterI32::new(val), counter: AtomicUsize::new(0) }
}
}
unsafe impl IrcItem<(), Arc<ArcTestItem>> for ArcTestItem {
type Counter = AtomicUsize;
fn counter(&self) -> &Self::Counter {
&self.counter
}
}
#[test]
fn test_basic() {
reset_alive_count();
{
let item = TestItem::new(10);
let irc1 = Irc::<_, _, _>::new(item);
assert_eq!(irc1.value.value, 10);
assert_eq!(irc1.strong_count(), 1);
assert!(irc1.is_unique());
assert_eq!(alive_count(), 1);
let irc2 = irc1.clone();
assert_eq!(irc1.strong_count(), 2);
assert_eq!(irc2.strong_count(), 2);
assert!(!irc1.is_unique());
assert_eq!(alive_count(), 1);
drop(irc1);
assert_eq!(irc2.strong_count(), 1);
assert!(irc2.is_unique());
assert_eq!(alive_count(), 1);
}
assert_eq!(alive_count(), 0);
}
#[test]
fn test_arc_underlayer() {
reset_alive_count();
{
let item = ArcTestItem::new(10);
let irc1 = Irc::<ArcTestItem, (), Arc<ArcTestItem>>::new(item);
assert_eq!(irc1.value.value, 10);
assert_eq!(irc1.strong_count(), 1);
assert!(irc1.is_unique());
assert_eq!(alive_count(), 1);
let irc2 = irc1.clone();
assert_eq!(irc1.strong_count(), 2);
assert_eq!(alive_count(), 1);
drop(irc1);
assert_eq!(irc2.strong_count(), 1);
assert_eq!(alive_count(), 1);
}
assert_eq!(alive_count(), 0);
}
#[test]
fn test_get_mut() {
reset_alive_count();
let mut irc = Irc::<_, _, _>::new(TestItem::new(10));
assert!(Irc::get_mut(&mut irc).is_some());
let _irc2 = irc.clone();
assert!(Irc::get_mut(&mut irc).is_none());
}
#[test]
fn test_make_mut() {
reset_alive_count();
let mut irc = Irc::new(TestItem::new(10));
{
let m = Irc::make_mut(&mut irc);
m.value.value = 20;
}
assert_eq!(irc.value.value, 20);
assert_eq!(alive_count(), 1);
let irc2 = irc.clone();
assert_eq!(alive_count(), 1);
{
let m = Irc::make_mut(&mut irc);
m.value.value = 30;
}
assert_eq!(irc.value.value, 30);
assert_eq!(irc2.value.value, 20);
assert_eq!(alive_count(), 2);
assert!(irc.is_unique());
assert!(irc2.is_unique());
}
#[test]
fn test_multithread_count() {
reset_alive_count();
{
let irc = Irc::new(TestItem::new(0));
let mut handles = vec![];
for _ in 0..10 {
let irc_clone = irc.clone();
handles.push(thread::spawn(move || {
for _ in 0..1000 {
let temp = irc_clone.clone();
assert_eq!(temp.value.value, 0);
}
}));
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(irc.strong_count(), 1);
assert!(irc.is_unique());
assert_eq!(alive_count(), 1);
}
assert_eq!(alive_count(), 0);
}
#[test]
fn test_multithread_drop() {
reset_alive_count();
{
let irc = Irc::new(TestItem::new(0));
let mut handles = vec![];
for _ in 0..10 {
let irc_clone = irc.clone();
handles.push(thread::spawn(move || {
for _ in 0..1000 {
let temp = irc_clone.clone();
assert_eq!(temp.value.value, 0);
}
}));
}
drop(irc);
for handle in handles {
handle.join().unwrap();
}
}
assert_eq!(alive_count(), 0);
}
#[test]
fn test_drop_all() {
reset_alive_count();
let irc = Irc::new(TestItem::new(0));
let mut clones = vec![];
for _ in 0..100 {
clones.push(irc.clone());
}
assert_eq!(alive_count(), 1);
drop(clones);
assert_eq!(alive_count(), 1);
drop(irc);
assert_eq!(alive_count(), 0);
}
#[test]
fn test_from_into_raw() {
{
let irc = Irc::new(TestItem::new(0));
let irc_1 = irc.clone();
let irc_2 = irc.clone();
let irc1_p = irc_1.into_raw();
let irc2_p = irc_2.into_raw();
assert_eq!(irc.strong_count(), 3);
assert_eq!(alive_count(), 1);
let _irc1 = unsafe { Irc::from_raw(irc1_p) };
let _irc2 = unsafe { Irc::from_raw(irc2_p) };
assert_eq!(irc.strong_count(), 3);
assert_eq!(alive_count(), 1);
}
assert_eq!(alive_count(), 0);
}
}