Skip to main content

embed_collections/
irc.rs

1//! Intrusive Reference Counter (Irc)
2//!
3//! `Irc` is an intrusive reference counting smart pointer, similar to `Arc` but without weak reference support.
4//! It requires the inner type to implement [IrcItem] trait to provide a counter field.
5//!
6//! The underlayer of `Irc` is Box, unlike `Arc` which wrap a hidden ArcInner on your inner types,
7//! Irc use the same memory location of your inner types.
8//!
9//! [IrcItem::on_drop] in the trait allow you to have the ownship of underlying inner memory after the reference count of Irc is dropped.
10//!
11//! # Example
12//!
13//! ```rust
14//! use embed_collections::irc::{Irc, IrcItem};
15//! use core::sync::atomic::{AtomicUsize, AtomicBool, Ordering};
16//! use crossfire::oneshot;
17//! use std::thread;
18//! use std::time::Duration;
19//!
20//! // Usually we use Irc for some large structure, but we show a simple demo here.
21//! struct MyItem {
22//!     is_done: AtomicBool,
23//!     counter: AtomicUsize,
24//!     done_tx: Option<oneshot::TxOneshot<Box<MyItem>>>,
25//! }
26//!
27//! unsafe impl IrcItem<()> for MyItem {
28//!     type Counter = AtomicUsize;
29//!     fn counter(&self) -> &Self::Counter {
30//!         &self.counter
31//!     }
32//!
33//!     // overwrite default behavior to send the item through channel
34//!     fn on_drop(mut this: Box<Self>) {
35//!         let done_tx = this.done_tx.take().unwrap();
36//!         done_tx.send(this);
37//!     }
38//! }
39//!
40//! let (done_tx, done_rx) = oneshot::oneshot();
41//! let boxed_item = Box::new(MyItem {
42//!     is_done: AtomicBool::new(false),
43//!     counter: AtomicUsize::new(0),
44//!     done_tx: Some(done_tx),
45//! });
46//!
47//! // Convert from Box to Irc, which does not have additional allocation.
48//! let item = Irc::<_, ()>::from(boxed_item);
49//! thread::spawn(move || {
50//!     thread::sleep(Duration::from_secs(1));
51//!     item.is_done.store(true, Ordering::SeqCst);
52//!     drop(item);
53//! });
54//! let item: Box<MyItem> = done_rx.recv().unwrap();
55//! assert!(item.is_done.load(Ordering::SeqCst));
56//! ```
57
58use crate::{Pointer, SmartPointer};
59use alloc::boxed::Box;
60use atomic_traits::{
61    Atomic, NumOps,
62    fetch::{Add, Sub},
63};
64use core::fmt;
65use core::marker::PhantomData;
66use core::ops::Deref;
67use core::ptr::NonNull;
68use core::sync::atomic::{
69    Ordering::{Acquire, Relaxed, Release},
70    fence,
71};
72
73/// trait for types that can be wrapped by [Irc]
74///
75/// # Safety
76///
77/// Tag is for distinguish multiple Irc from the same Inner type.
78/// When implement multiple types of Irc from the same object,
79/// you must make sure they don't have overlapped Counter fields.
80pub unsafe trait IrcItem<Tag>: Sized + Send + Sync
81where
82    <Self::Counter as Atomic>::Type: From<u8> + Into<usize> + PartialEq,
83{
84    /// The type of counter
85    type Counter: NumOps;
86
87    /// return reference to the field of counter
88    fn counter(&self) -> &Self::Counter;
89
90    /// The default behavior for Irc is dropping the boxed inner.
91    ///
92    /// You can overwrite this if you want to send the inner somewhere.
93    /// We pass box here to reduce moving cost.
94    #[allow(clippy::boxed_local)]
95    #[inline(always)]
96    fn on_drop(_this: Box<Self>) {}
97
98    #[inline]
99    fn strong_count(&self) -> usize {
100        self.counter().load(Relaxed).into()
101    }
102}
103
104/// Intrusive reference counter, which support conversion bwteween `Box<T>`.
105///
106/// It does not support weak reference.
107pub struct Irc<T: IrcItem<Tag>, Tag> {
108    inner: NonNull<T>,
109    _phan: PhantomData<fn(&Tag)>,
110}
111
112impl<T: IrcItem<Tag>, Tag> Irc<T, Tag> {
113    /// Wrap a stack value T into Irc.
114    ///
115    /// The counter will be reset to 1 on initialization.
116    #[inline]
117    pub fn new(inner: T) -> Self {
118        inner.counter().store(1u8.into(), Relaxed);
119        Self {
120            inner: unsafe { NonNull::new_unchecked(Box::into_raw(Box::new(inner))) },
121            _phan: Default::default(),
122        }
123    }
124
125    #[inline(always)]
126    fn get_inner(&self) -> &T {
127        unsafe { self.inner.as_ref() }
128    }
129
130    #[inline]
131    pub fn ptr_eq(this: &Self, other: &Self) -> bool {
132        this.inner == other.inner
133    }
134
135    /// If is_unique returns true, then this thread is the only owner
136    ///
137    /// # False negative
138    ///
139    /// it's possible to return false when counter drop to 1,
140    /// Because of using Acquire load and Release on drop.
141    ///
142    /// # Example
143    ///
144    ///
145    /// ```rust
146    /// use embed_collections::irc::{Irc, IrcItem};
147    /// use core::sync::atomic::AtomicUsize;
148    ///
149    /// struct Tag;
150    ///
151    /// struct MyItem {
152    ///     value: i32,
153    ///     counter: AtomicUsize,
154    /// }
155    ///
156    /// unsafe impl IrcItem<Tag> for MyItem {
157    ///     type Counter = AtomicUsize;
158    ///     fn counter(&self) -> &Self::Counter {
159    ///         &self.counter
160    ///     }
161    /// }
162    ///
163    /// // Create a new Irc
164    /// let irc1 = Irc::<_, Tag>::new(MyItem { value: 10, counter: AtomicUsize::new(0) });
165    /// assert_eq!(irc1.value, 10);
166    /// assert!(irc1.is_unique());
167    ///
168    /// // Clone the Irc
169    /// let irc2 = irc1.clone();
170    /// assert_eq!(irc1.strong_count(), 2);
171    /// assert!(!irc1.is_unique());
172    #[inline]
173    pub fn is_unique(&self) -> bool {
174        // Safety:
175        // we have make sure counter reset to 1 on init.
176        // although clone use Relaxed, it can never pass this fence
177        self.counter().load(Acquire) == 1u8.into()
178    }
179
180    /// return mutable reference if we are the only owner
181    ///
182    /// # False negative
183    ///
184    /// It can return None even when only one reference left
185    #[inline]
186    pub fn get_mut(this: &mut Self) -> Option<&mut T> {
187        if this.is_unique() { Some(unsafe { this.inner.as_mut() }) } else { None }
188    }
189}
190
191impl<T: IrcItem<Tag> + Clone, Tag> Irc<T, Tag> {
192    /// The Cow function, the same as `Arc::make_mut()`
193    ///
194    /// # Example
195    ///
196    /// ```rust
197    /// use embed_collections::irc::{Irc, IrcItem};
198    /// use core::sync::atomic::AtomicUsize;
199    ///
200    /// struct Tag;
201    /// struct MyItem {
202    ///     value: i32,
203    ///     counter: AtomicUsize,
204    /// }
205    ///
206    /// impl Clone for MyItem {
207    ///     fn clone(&self) -> Self {
208    ///         Self { value: self.value, counter: AtomicUsize::new(0) }
209    ///     }
210    /// }
211    ///
212    /// unsafe impl IrcItem<Tag> for MyItem {
213    ///     type Counter = AtomicUsize;
214    ///     fn counter(&self) -> &Self::Counter {
215    ///         &self.counter
216    ///     }
217    /// }
218    ///
219    /// let mut irc1 = Irc::<_, Tag>::new(MyItem { value: 10, counter: AtomicUsize::new(0) });
220    /// let irc2 = irc1.clone();
221    ///
222    /// // This will clone the inner item because it's shared
223    /// let m = Irc::make_mut(&mut irc1);
224    /// m.value = 20;
225    ///
226    /// assert_eq!(irc1.value, 20);
227    /// assert_eq!(irc2.value, 10);
228    /// ```
229    #[inline]
230    pub fn make_mut(this: &mut Self) -> &mut T {
231        if !this.is_unique() {
232            let cloned_item = this.get_inner().clone();
233            let mut new_irc = Self::new(cloned_item);
234            core::mem::swap(this, &mut new_irc);
235        }
236        unsafe { this.inner.as_mut() }
237    }
238}
239
240impl<T: IrcItem<Tag>, Tag> Deref for Irc<T, Tag> {
241    type Target = T;
242    #[inline(always)]
243    fn deref(&self) -> &Self::Target {
244        self.get_inner()
245    }
246}
247
248impl<T: IrcItem<Tag>, Tag> AsRef<T> for Irc<T, Tag> {
249    #[inline(always)]
250    fn as_ref(&self) -> &T {
251        self.get_inner()
252    }
253}
254
255unsafe impl<T: IrcItem<Tag>, Tag> Send for Irc<T, Tag> {}
256unsafe impl<T: IrcItem<Tag>, Tag> Sync for Irc<T, Tag> {}
257
258impl<T: IrcItem<Tag>, Tag> From<Box<T>> for Irc<T, Tag> {
259    /// Convert a boxed T into Irc.
260    ///
261    /// The counter will be reset to 1 on initialization.
262    #[inline]
263    fn from(inner: Box<T>) -> Self {
264        inner.counter().store(1u8.into(), Relaxed);
265        Self {
266            inner: unsafe { NonNull::new_unchecked(Box::into_raw(inner)) },
267            _phan: Default::default(),
268        }
269    }
270}
271
272impl<T: IrcItem<Tag>, Tag> Clone for Irc<T, Tag> {
273    #[inline]
274    fn clone(&self) -> Self {
275        self.get_inner().counter().fetch_add(1u8.into(), Relaxed);
276        Self { inner: self.inner, _phan: Default::default() }
277    }
278}
279
280impl<T: IrcItem<Tag>, Tag> Drop for Irc<T, Tag> {
281    #[inline]
282    fn drop(&mut self) {
283        let p = self.inner.as_ptr();
284        unsafe {
285            if (*p).counter().fetch_sub(1u8.into(), Release) == 1u8.into() {
286                fence(Acquire);
287                let inner = Box::from_raw(p);
288                IrcItem::<Tag>::on_drop(inner);
289            }
290        }
291    }
292}
293
294impl<T: IrcItem<Tag> + fmt::Debug, Tag> fmt::Debug for Irc<T, Tag> {
295    #[inline]
296    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
297        self.get_inner().fmt(f)
298    }
299}
300
301impl<T: IrcItem<Tag> + fmt::Display, Tag> fmt::Display for Irc<T, Tag> {
302    #[inline]
303    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
304        self.get_inner().fmt(f)
305    }
306}
307
308impl<T: IrcItem<Tag>, Tag> Pointer for Irc<T, Tag> {
309    type Target = T;
310
311    #[inline]
312    fn as_ref(&self) -> &Self::Target {
313        unsafe { self.inner.as_ref() }
314    }
315
316    /// # Safety
317    ///
318    /// must be pointer acquire from [Irc::into_raw()]
319    #[inline]
320    unsafe fn from_raw(p: *const Self::Target) -> Self {
321        Self {
322            inner: unsafe { NonNull::new_unchecked(p as *mut Self::Target) },
323            _phan: Default::default(),
324        }
325    }
326
327    #[inline]
328    fn into_raw(self) -> *const Self::Target {
329        let p = self.inner.as_ptr();
330        core::mem::forget(self);
331        p
332    }
333}
334
335impl<T: IrcItem<Tag>, Tag> SmartPointer for Irc<T, Tag> {
336    #[inline]
337    fn new(inner: T) -> Self {
338        Irc::new(inner)
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use crate::test::{CounterI32, alive_count, reset_alive_count};
346    use core::sync::atomic::AtomicUsize;
347    use std::thread;
348
349    struct Tag;
350
351    struct TestItem {
352        value: CounterI32,
353        counter: AtomicUsize,
354    }
355
356    impl TestItem {
357        fn new(val: i32) -> Self {
358            Self { value: CounterI32::new(val), counter: AtomicUsize::new(0) }
359        }
360    }
361
362    impl Clone for TestItem {
363        fn clone(&self) -> Self {
364            Self { value: self.value.clone(), counter: AtomicUsize::new(0) }
365        }
366    }
367
368    unsafe impl IrcItem<Tag> for TestItem {
369        type Counter = AtomicUsize;
370        fn counter(&self) -> &Self::Counter {
371            &self.counter
372        }
373    }
374
375    #[test]
376    fn test_basic() {
377        reset_alive_count();
378        {
379            let item = TestItem::new(10);
380            let irc1 = Irc::<_, Tag>::new(item);
381            assert_eq!(irc1.value.value, 10);
382            assert_eq!(irc1.strong_count(), 1);
383            assert!(irc1.is_unique());
384            assert_eq!(alive_count(), 1);
385
386            let irc2 = irc1.clone();
387            assert_eq!(irc1.strong_count(), 2);
388            assert_eq!(irc2.strong_count(), 2);
389            assert!(!irc1.is_unique());
390            assert_eq!(alive_count(), 1);
391
392            drop(irc1);
393            assert_eq!(irc2.strong_count(), 1);
394            assert!(irc2.is_unique());
395            assert_eq!(alive_count(), 1);
396        }
397        assert_eq!(alive_count(), 0);
398    }
399
400    #[test]
401    fn test_get_mut() {
402        reset_alive_count();
403        let mut irc = Irc::<_, Tag>::new(TestItem::new(10));
404        assert!(Irc::get_mut(&mut irc).is_some());
405
406        let _irc2 = irc.clone();
407        assert!(Irc::get_mut(&mut irc).is_none());
408    }
409
410    #[test]
411    fn test_make_mut() {
412        reset_alive_count();
413        let mut irc = Irc::<_, Tag>::new(TestItem::new(10));
414
415        // Unique, no clone
416        {
417            let m = Irc::make_mut(&mut irc);
418            m.value.value = 20;
419        }
420        assert_eq!(irc.value.value, 20);
421        assert_eq!(alive_count(), 1);
422
423        // Not unique, should clone
424        let irc2 = irc.clone();
425        assert_eq!(alive_count(), 1);
426        {
427            let m = Irc::make_mut(&mut irc);
428            m.value.value = 30;
429        }
430        assert_eq!(irc.value.value, 30);
431        assert_eq!(irc2.value.value, 20);
432        assert_eq!(alive_count(), 2);
433
434        assert!(irc.is_unique());
435        assert!(irc2.is_unique());
436    }
437
438    #[test]
439    fn test_multithread_count() {
440        reset_alive_count();
441        {
442            let irc = Irc::<_, Tag>::new(TestItem::new(0));
443            let mut handles = vec![];
444
445            for _ in 0..10 {
446                let irc_clone = irc.clone();
447                handles.push(thread::spawn(move || {
448                    for _ in 0..1000 {
449                        let temp = irc_clone.clone();
450                        assert_eq!(temp.value.value, 0);
451                    }
452                }));
453            }
454
455            for handle in handles {
456                handle.join().unwrap();
457            }
458
459            assert_eq!(irc.strong_count(), 1);
460            assert!(irc.is_unique());
461            assert_eq!(alive_count(), 1);
462        }
463        assert_eq!(alive_count(), 0);
464    }
465
466    #[test]
467    fn test_multithread_drop() {
468        reset_alive_count();
469        {
470            let irc = Irc::<_, Tag>::new(TestItem::new(0));
471            let mut handles = vec![];
472            for _ in 0..10 {
473                let irc_clone = irc.clone();
474                handles.push(thread::spawn(move || {
475                    for _ in 0..1000 {
476                        let temp = irc_clone.clone();
477                        assert_eq!(temp.value.value, 0);
478                    }
479                }));
480            }
481            drop(irc);
482            for handle in handles {
483                handle.join().unwrap();
484            }
485        }
486        assert_eq!(alive_count(), 0);
487    }
488
489    #[test]
490    fn test_drop_all() {
491        reset_alive_count();
492        let irc = Irc::<_, Tag>::new(TestItem::new(0));
493        let mut clones = vec![];
494        for _ in 0..100 {
495            clones.push(irc.clone());
496        }
497        assert_eq!(alive_count(), 1);
498        drop(clones);
499        assert_eq!(alive_count(), 1);
500        drop(irc);
501        assert_eq!(alive_count(), 0);
502    }
503}