Skip to main content

embed_collections/
irc.rs

1//! Intrusive Reference Counter (Irc)
2//!
3//! `Irc` is an intrusive reference counting smart pointer, similar to `Arc` but without weak reference support.
4//! It requires the inner type to implement [IrcItem] trait to provide a counter field.
5//!
6//! The underlayer of `Irc` is customizable (default to be Box),
7//! unlike `Arc` which wrap a hidden ArcInner on your inner types,
8//! Irc use the pointer of your inner types by [Pointer::into_raw]
9//!
10//! The atomic ordering is mostly the same with std `Arc` (miri test cases verified)
11//!
12//! # Benefits
13//!
14//! - No need to manual implementing the inc / dec on counter.
15//!
16//! - No enforced weak counter if you don't need it (every atomic op has cost).
17//!
18//! - Customized counter type (not limited to AtomicUsize)
19//!
20//! - [IrcItem::on_drop] in the trait allow you to have the ownship of underlying inner memory after
21//!   the reference count of Irc is dropped. And you only need to define the drop behavior once,
22//!   instead of write the same logic `Arc::into_inner` in every possible places
23//!   (If forgetting so make your code block and hard to debug).
24//!
25//! - Using `Irc` to wrap a `Box`, no additional memory allocation and memory fragmentation, no
26//!   additional dereference cost (than using `Arc<Box<T>>`)
27//!
28//! - You can allocate a box from the time of its birth and wrap it will `Irc` for temporary usage,
29//!   don't need to move bytes from / to stack. (especially when the inner object is large)
30//!
31//! - Advanced usage, multiple layer customized counter, on the same heap object, while preserving
32//!   the safe boundary
33//!
34//! # Example
35//!
36//! The follow example shows `Irc` wrapping a `Box` (You can also the same to Change the param P with `Arc`, or other [Pointer] type)
37//!
38//! ```rust
39//! use embed_collections::irc::{Irc, IrcItem};
40//! use core::sync::atomic::{AtomicUsize, AtomicBool, Ordering};
41//! use crossfire::oneshot;
42//! use std::thread;
43//! use std::time::Duration;
44//!
45//! // Usually we use Irc for some large structure, but we show a simple demo here.
46//! struct MyItem {
47//!     is_done: AtomicBool,
48//!     counter: AtomicUsize,
49//!     done_tx: Option<oneshot::TxOneshot<Box<MyItem>>>,
50//! }
51//!
52//! // The default parameter Tag=(), P=Box<Self>
53//! unsafe impl IrcItem for MyItem {
54//!     type Counter = AtomicUsize;
55//!     fn counter(&self) -> &Self::Counter {
56//!         &self.counter
57//!     }
58//!
59//!     // overwrite default behavior to send the item through channel
60//!     fn on_drop(mut this: Box<Self>) {
61//!         let done_tx = this.done_tx.take().unwrap();
62//!         done_tx.send(this);
63//!     }
64//! }
65//!
66//! let (done_tx, done_rx) = oneshot::oneshot();
67//! let boxed_item = Box::new(MyItem {
68//!     is_done: AtomicBool::new(false),
69//!     counter: AtomicUsize::new(0),
70//!     done_tx: Some(done_tx),
71//! });
72//!
73//! // Convert from Box to Irc, which does not have additional allocation.
74//! let item = Irc::from(boxed_item);
75//! thread::spawn(move || {
76//!     thread::sleep(Duration::from_secs(1));
77//!     item.is_done.store(true, Ordering::SeqCst);
78//!     drop(item);
79//! });
80//! let item: Box<MyItem> = done_rx.recv().unwrap();
81//! assert!(item.is_done.load(Ordering::SeqCst));
82//! ```
83
84use crate::{Pointer, SmartPointer};
85use alloc::boxed::Box;
86use atomic_traits::{
87    Atomic, NumOps,
88    fetch::{Add, Sub},
89};
90use core::fmt;
91use core::marker::PhantomData;
92use core::ops::Deref;
93use core::ptr::NonNull;
94use core::sync::atomic::{
95    Ordering::{Acquire, Relaxed, Release},
96    fence,
97};
98
99/// trait for types that can be wrapped by [Irc]
100///
101/// # Safety
102///
103/// Tag is for distinguish multiple Irc from the same Inner type.
104/// When implement multiple types of Irc from the same object,
105/// you must make sure they don't have overlapped Counter fields.
106pub unsafe trait IrcItem<Tag = (), P = Box<Self>>: Sized + Send + Sync
107where
108    <Self::Counter as Atomic>::Type: From<u8> + Into<usize> + PartialEq,
109    P: Pointer<Target = Self>,
110{
111    /// The type of counter
112    type Counter: NumOps;
113
114    /// return reference to the field of counter
115    fn counter(&self) -> &Self::Counter;
116
117    /// The default behavior for Irc is dropping the inner smart pointer type.
118    ///
119    /// You can overwrite this if you want to send the inner somewhere.
120    #[inline(always)]
121    fn on_drop(_this: P) {}
122
123    #[inline]
124    fn strong_count(&self) -> usize {
125        self.counter().load(Relaxed).into()
126    }
127}
128
129/// Intrusive reference counter, which support conversion bwteween `P`.
130///
131/// It does not support weak reference.
132pub struct Irc<T, Tag = (), P = Box<T>>
133where
134    T: IrcItem<Tag, P>,
135    P: Pointer<Target = T>,
136{
137    inner: NonNull<T>,
138    _phan: PhantomData<fn(&Tag, &P)>,
139}
140
141impl<T, Tag, P> Irc<T, Tag, P>
142where
143    T: IrcItem<Tag, P>,
144    P: SmartPointer<Target = T>,
145{
146    /// Wrap a stack value T inside P with Irc.
147    ///
148    /// The counter will be reset to 1 on initialization.
149    #[inline]
150    pub fn new(inner: T) -> Self {
151        Self::from(P::new(inner))
152    }
153}
154
155impl<T: IrcItem<Tag, P>, Tag, P> SmartPointer for Irc<T, Tag, P>
156where
157    T: IrcItem<Tag, P>,
158    P: SmartPointer<Target = T>,
159{
160    #[inline]
161    fn new(inner: T) -> Self {
162        Irc::new(inner)
163    }
164}
165
166impl<T, Tag, P> From<P> for Irc<T, Tag, P>
167where
168    T: IrcItem<Tag, P>,
169    P: Pointer<Target = T>,
170{
171    /// Convert a [Pointer] containing `T` into Irc.
172    ///
173    /// The counter will be reset to 1 on initialization.
174    #[inline]
175    fn from(inner: P) -> Self {
176        inner.as_ref().counter().store(1u8.into(), Relaxed);
177        Self {
178            inner: unsafe { NonNull::new_unchecked(inner.into_raw() as *mut T) },
179            _phan: Default::default(),
180        }
181    }
182}
183
184impl<T, Tag, P> Irc<T, Tag, P>
185where
186    T: IrcItem<Tag, P>,
187    P: Pointer<Target = T>,
188{
189    #[inline(always)]
190    fn get_inner(&self) -> &T {
191        unsafe { self.inner.as_ref() }
192    }
193
194    #[inline]
195    pub fn ptr_eq(this: &Self, other: &Self) -> bool {
196        this.inner == other.inner
197    }
198
199    /// If is_unique returns true, then this thread is the only owner
200    ///
201    /// # False negative
202    ///
203    /// it's possible to return false when counter drop to 1,
204    /// Because of using Acquire load and Release on drop.
205    ///
206    /// # Example
207    ///
208    ///
209    /// ```rust
210    /// use embed_collections::irc::{Irc, IrcItem};
211    /// use core::sync::atomic::AtomicUsize;
212    ///
213    /// struct Tag;
214    ///
215    /// struct MyItem {
216    ///     value: i32,
217    ///     counter: AtomicUsize,
218    /// }
219    ///
220    /// unsafe impl IrcItem<Tag> for MyItem {
221    ///     type Counter = AtomicUsize;
222    ///     fn counter(&self) -> &Self::Counter {
223    ///         &self.counter
224    ///     }
225    /// }
226    ///
227    /// // Create a new Irc
228    /// let irc1 = Irc::<_, Tag>::new(MyItem { value: 10, counter: AtomicUsize::new(0) });
229    /// assert_eq!(irc1.value, 10);
230    /// assert!(irc1.is_unique());
231    ///
232    /// // Clone the Irc
233    /// let irc2 = irc1.clone();
234    /// assert_eq!(irc1.strong_count(), 2);
235    /// assert!(!irc1.is_unique());
236    /// ```
237    #[inline]
238    pub fn is_unique(&self) -> bool {
239        // Safety:
240        // we have make sure counter reset to 1 on init.
241        // although clone use Relaxed, it can never pass this fence
242        self.counter().load(Acquire) == 1u8.into()
243    }
244
245    /// return mutable reference if we are the only owner
246    ///
247    /// # False negative
248    ///
249    /// It can return None even when only one reference left
250    #[inline]
251    pub fn get_mut(this: &mut Self) -> Option<&mut T> {
252        if this.is_unique() { Some(unsafe { this.inner.as_mut() }) } else { None }
253    }
254}
255
256impl<T, Tag, P> Irc<T, Tag, P>
257where
258    T: IrcItem<Tag, P> + Clone,
259    P: SmartPointer<Target = T>,
260{
261    /// The Cow function, the same as `Arc::make_mut()`
262    ///
263    /// # Example
264    ///
265    /// ```rust
266    /// use embed_collections::irc::{Irc, IrcItem};
267    /// use core::sync::atomic::AtomicUsize;
268    ///
269    /// struct Tag;
270    /// struct MyItem {
271    ///     value: i32,
272    ///     counter: AtomicUsize,
273    /// }
274    ///
275    /// impl Clone for MyItem {
276    ///     fn clone(&self) -> Self {
277    ///         Self { value: self.value, counter: AtomicUsize::new(0) }
278    ///     }
279    /// }
280    ///
281    /// unsafe impl IrcItem<Tag> for MyItem {
282    ///     type Counter = AtomicUsize;
283    ///     fn counter(&self) -> &Self::Counter {
284    ///         &self.counter
285    ///     }
286    /// }
287    ///
288    /// let mut irc1 = Irc::<_, Tag>::new(MyItem { value: 10, counter: AtomicUsize::new(0) });
289    /// let irc2 = irc1.clone();
290    ///
291    /// // This will clone the inner item because it's shared
292    /// let m = Irc::make_mut(&mut irc1);
293    /// m.value = 20;
294    ///
295    /// assert_eq!(irc1.value, 20);
296    /// assert_eq!(irc2.value, 10);
297    /// ```
298    #[inline]
299    pub fn make_mut(this: &mut Self) -> &mut T {
300        if !this.is_unique() {
301            let cloned_item = this.get_inner().clone();
302            let mut new_irc = Self::new(cloned_item);
303            core::mem::swap(this, &mut new_irc);
304        }
305        unsafe { this.inner.as_mut() }
306    }
307}
308
309impl<T, Tag, P> Deref for Irc<T, Tag, P>
310where
311    T: IrcItem<Tag, P>,
312    P: Pointer<Target = T>,
313{
314    type Target = T;
315    #[inline(always)]
316    fn deref(&self) -> &Self::Target {
317        self.get_inner()
318    }
319}
320
321impl<T, Tag, P> AsRef<T> for Irc<T, Tag, P>
322where
323    T: IrcItem<Tag, P>,
324    P: Pointer<Target = T>,
325{
326    #[inline(always)]
327    fn as_ref(&self) -> &T {
328        self.get_inner()
329    }
330}
331
332unsafe impl<T, Tag, P> Send for Irc<T, Tag, P>
333where
334    T: IrcItem<Tag, P>,
335    P: Pointer<Target = T>,
336{
337}
338unsafe impl<T, Tag, P> Sync for Irc<T, Tag, P>
339where
340    T: IrcItem<Tag, P>,
341    P: Pointer<Target = T>,
342{
343}
344
345impl<T, Tag, P> Clone for Irc<T, Tag, P>
346where
347    T: IrcItem<Tag, P>,
348    P: Pointer<Target = T>,
349{
350    #[inline]
351    fn clone(&self) -> Self {
352        self.get_inner().counter().fetch_add(1u8.into(), Relaxed);
353        Self { inner: self.inner, _phan: Default::default() }
354    }
355}
356
357impl<T, Tag, P> Drop for Irc<T, Tag, P>
358where
359    T: IrcItem<Tag, P>,
360    P: Pointer<Target = T>,
361{
362    #[inline]
363    fn drop(&mut self) {
364        let p = self.inner.as_ptr();
365        unsafe {
366            if (*p).counter().fetch_sub(1u8.into(), Release) == 1u8.into() {
367                fence(Acquire);
368                let inner = P::from_raw(p);
369                IrcItem::<Tag, P>::on_drop(inner);
370            }
371        }
372    }
373}
374
375impl<T, Tag, P> fmt::Debug for Irc<T, Tag, P>
376where
377    T: IrcItem<Tag, P> + fmt::Debug,
378    P: Pointer<Target = T>,
379{
380    #[inline]
381    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
382        self.get_inner().fmt(f)
383    }
384}
385
386impl<T, Tag, P> fmt::Display for Irc<T, Tag, P>
387where
388    T: IrcItem<Tag, P> + fmt::Display,
389    P: Pointer<Target = T>,
390{
391    #[inline]
392    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
393        self.get_inner().fmt(f)
394    }
395}
396
397impl<T: IrcItem<Tag, P>, Tag, P> Pointer for Irc<T, Tag, P>
398where
399    T: IrcItem<Tag, P>,
400    P: Pointer<Target = T>,
401{
402    type Target = T;
403
404    #[inline]
405    fn as_ref(&self) -> &Self::Target {
406        unsafe { self.inner.as_ref() }
407    }
408
409    /// # Safety
410    ///
411    /// must be pointer acquire from [Irc::into_raw()]
412    #[inline]
413    unsafe fn from_raw(p: *const Self::Target) -> Self {
414        Self {
415            inner: unsafe { NonNull::new_unchecked(p as *mut Self::Target) },
416            _phan: Default::default(),
417        }
418    }
419
420    #[inline]
421    fn into_raw(self) -> *const Self::Target {
422        let p = self.inner.as_ptr();
423        core::mem::forget(self);
424        p
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431    use crate::test::{CounterI32, alive_count, reset_alive_count};
432    use alloc::sync::Arc;
433    use core::sync::atomic::AtomicUsize;
434    use std::thread;
435
436    struct TestItem {
437        value: CounterI32,
438        counter: AtomicUsize,
439    }
440
441    impl TestItem {
442        fn new(val: i32) -> Self {
443            Self { value: CounterI32::new(val), counter: AtomicUsize::new(0) }
444        }
445    }
446
447    impl Clone for TestItem {
448        fn clone(&self) -> Self {
449            Self { value: self.value.clone(), counter: AtomicUsize::new(0) }
450        }
451    }
452
453    unsafe impl IrcItem for TestItem {
454        type Counter = AtomicUsize;
455        fn counter(&self) -> &Self::Counter {
456            &self.counter
457        }
458    }
459
460    struct ArcTestItem {
461        value: CounterI32,
462        counter: AtomicUsize,
463    }
464
465    impl ArcTestItem {
466        fn new(val: i32) -> Self {
467            Self { value: CounterI32::new(val), counter: AtomicUsize::new(0) }
468        }
469    }
470
471    unsafe impl IrcItem<(), Arc<ArcTestItem>> for ArcTestItem {
472        type Counter = AtomicUsize;
473        fn counter(&self) -> &Self::Counter {
474            &self.counter
475        }
476    }
477
478    #[test]
479    fn test_basic() {
480        reset_alive_count();
481        {
482            let item = TestItem::new(10);
483            let irc1 = Irc::<_, _, _>::new(item);
484            assert_eq!(irc1.value.value, 10);
485            assert_eq!(irc1.strong_count(), 1);
486            assert!(irc1.is_unique());
487            assert_eq!(alive_count(), 1);
488
489            let irc2 = irc1.clone();
490            assert_eq!(irc1.strong_count(), 2);
491            assert_eq!(irc2.strong_count(), 2);
492            assert!(!irc1.is_unique());
493            assert_eq!(alive_count(), 1);
494
495            drop(irc1);
496            assert_eq!(irc2.strong_count(), 1);
497            assert!(irc2.is_unique());
498            assert_eq!(alive_count(), 1);
499        }
500        assert_eq!(alive_count(), 0);
501    }
502
503    #[test]
504    fn test_arc_underlayer() {
505        reset_alive_count();
506        {
507            let item = ArcTestItem::new(10);
508            let irc1 = Irc::<ArcTestItem, (), Arc<ArcTestItem>>::new(item);
509            assert_eq!(irc1.value.value, 10);
510            assert_eq!(irc1.strong_count(), 1);
511            assert!(irc1.is_unique());
512            assert_eq!(alive_count(), 1);
513
514            let irc2 = irc1.clone();
515            assert_eq!(irc1.strong_count(), 2);
516            assert_eq!(alive_count(), 1);
517
518            drop(irc1);
519            assert_eq!(irc2.strong_count(), 1);
520            assert_eq!(alive_count(), 1);
521        }
522        assert_eq!(alive_count(), 0);
523    }
524
525    #[test]
526    fn test_get_mut() {
527        reset_alive_count();
528        let mut irc = Irc::<_, _, _>::new(TestItem::new(10));
529        assert!(Irc::get_mut(&mut irc).is_some());
530
531        let _irc2 = irc.clone();
532        assert!(Irc::get_mut(&mut irc).is_none());
533    }
534
535    #[test]
536    fn test_make_mut() {
537        reset_alive_count();
538        let mut irc = Irc::new(TestItem::new(10));
539
540        // Unique, no clone
541        {
542            let m = Irc::make_mut(&mut irc);
543            m.value.value = 20;
544        }
545        assert_eq!(irc.value.value, 20);
546        assert_eq!(alive_count(), 1);
547
548        // Not unique, should clone
549        let irc2 = irc.clone();
550        assert_eq!(alive_count(), 1);
551        {
552            let m = Irc::make_mut(&mut irc);
553            m.value.value = 30;
554        }
555        assert_eq!(irc.value.value, 30);
556        assert_eq!(irc2.value.value, 20);
557        assert_eq!(alive_count(), 2);
558
559        assert!(irc.is_unique());
560        assert!(irc2.is_unique());
561    }
562
563    #[test]
564    fn test_multithread_count() {
565        reset_alive_count();
566        {
567            let irc = Irc::new(TestItem::new(0));
568            let mut handles = vec![];
569
570            for _ in 0..10 {
571                let irc_clone = irc.clone();
572                handles.push(thread::spawn(move || {
573                    for _ in 0..1000 {
574                        let temp = irc_clone.clone();
575                        assert_eq!(temp.value.value, 0);
576                    }
577                }));
578            }
579
580            for handle in handles {
581                handle.join().unwrap();
582            }
583
584            assert_eq!(irc.strong_count(), 1);
585            assert!(irc.is_unique());
586            assert_eq!(alive_count(), 1);
587        }
588        assert_eq!(alive_count(), 0);
589    }
590
591    #[test]
592    fn test_multithread_drop() {
593        reset_alive_count();
594        {
595            let irc = Irc::new(TestItem::new(0));
596            let mut handles = vec![];
597            for _ in 0..10 {
598                let irc_clone = irc.clone();
599                handles.push(thread::spawn(move || {
600                    for _ in 0..1000 {
601                        let temp = irc_clone.clone();
602                        assert_eq!(temp.value.value, 0);
603                    }
604                }));
605            }
606            drop(irc);
607            for handle in handles {
608                handle.join().unwrap();
609            }
610        }
611        assert_eq!(alive_count(), 0);
612    }
613
614    #[test]
615    fn test_drop_all() {
616        reset_alive_count();
617        let irc = Irc::new(TestItem::new(0));
618        let mut clones = vec![];
619        for _ in 0..100 {
620            clones.push(irc.clone());
621        }
622        assert_eq!(alive_count(), 1);
623        drop(clones);
624        assert_eq!(alive_count(), 1);
625        drop(irc);
626        assert_eq!(alive_count(), 0);
627    }
628
629    #[test]
630    fn test_from_into_raw() {
631        {
632            let irc = Irc::new(TestItem::new(0));
633            let irc_1 = irc.clone();
634            let irc_2 = irc.clone();
635            let irc1_p = irc_1.into_raw();
636            let irc2_p = irc_2.into_raw();
637            assert_eq!(irc.strong_count(), 3);
638            assert_eq!(alive_count(), 1);
639            let _irc1 = unsafe { Irc::from_raw(irc1_p) };
640            let _irc2 = unsafe { Irc::from_raw(irc2_p) };
641            assert_eq!(irc.strong_count(), 3);
642            assert_eq!(alive_count(), 1);
643        }
644        assert_eq!(alive_count(), 0);
645    }
646}