Skip to main content

heapless/pool/
arc.rs

1//! `std::sync::Arc`-like API on top of a lock-free memory pool
2//!
3//! # Example usage
4//!
5//! ```
6//! use core::ptr::addr_of_mut;
7//! use heapless::{arc_pool, pool::arc::{Arc, ArcBlock}};
8//!
9//! arc_pool!(MyArcPool: u128);
10//!
11//! // cannot allocate without first giving memory blocks to the pool
12//! assert!(MyArcPool.alloc(42).is_err());
13//!
14//! // (some `no_std` runtimes have safe APIs to create `&'static mut` references)
15//! let block: &'static mut ArcBlock<u128> = unsafe {
16//!     static mut BLOCK: ArcBlock<u128> = ArcBlock::new();
17//!     addr_of_mut!(BLOCK).as_mut().unwrap()
18//! };
19//!
20//! MyArcPool.manage(block);
21//!
22//! let arc = MyArcPool.alloc(1).unwrap();
23//!
24//! // number of smart pointers is limited to the number of blocks managed by the pool
25//! let res = MyArcPool.alloc(2);
26//! assert!(res.is_err());
27//!
28//! // but cloning does not consume an `ArcBlock`
29//! let arc2 = arc.clone();
30//!
31//! assert_eq!(1, *arc2);
32//!
33//! // `arc`'s destructor returns the memory block to the pool
34//! drop(arc2); // decrease reference counter
35//! drop(arc); // release memory
36//!
37//! // it's now possible to allocate a new `Arc` smart pointer
38//! let res = MyArcPool.alloc(3);
39//!
40//! assert!(res.is_ok());
41//! ```
42//!
43//! # Array block initialization
44//!
45//! You can create a static variable that contains an array of memory blocks and give all the blocks
46//! to the `ArcPool`. This requires an intermediate `const` value as shown below:
47//!
48//! ```
49//! use core::ptr::addr_of_mut;
50//! use heapless::{arc_pool, pool::arc::ArcBlock};
51//!
52//! arc_pool!(MyArcPool: u128);
53//!
54//! const POOL_CAPACITY: usize = 8;
55//!
56//! let blocks: &'static mut [ArcBlock<u128>] = {
57//!     const BLOCK: ArcBlock<u128> = ArcBlock::new(); // <=
58//!     static mut BLOCKS: [ArcBlock<u128>; POOL_CAPACITY] = [BLOCK; POOL_CAPACITY];
59//!     unsafe { addr_of_mut!(BLOCKS).as_mut().unwrap() }
60//! };
61//!
62//! for block in blocks {
63//!     MyArcPool.manage(block);
64//! }
65//! ```
66
67// reference counting logic is based on version 1.63.0 of the Rust standard library (`alloc`  crate)
68// which is licensed under 'MIT or APACHE-2.0'
69// https://github.com/rust-lang/rust/blob/1.63.0/library/alloc/src/sync.rs#L235 (last visited
70// 2022-09-05)
71
72use core::{
73    fmt,
74    hash::{Hash, Hasher},
75    mem::MaybeUninit,
76    ops, ptr,
77};
78
79#[cfg(not(feature = "portable-atomic"))]
80use core::sync::atomic;
81#[cfg(feature = "portable-atomic")]
82use portable_atomic as atomic;
83
84use atomic::{AtomicUsize, Ordering};
85
86use super::treiber::{NonNullPtr, Stack, UnionNode};
87
88/// Creates a new `ArcPool` singleton with the given `$name` that manages the specified `$data_type`
89///
90/// For more extensive documentation see the [module level documentation](crate::pool::arc)
91#[macro_export]
92macro_rules! arc_pool {
93    ($name:ident: $data_type:ty) => {
94        pub struct $name;
95
96        impl $crate::pool::arc::ArcPool for $name {
97            type Data = $data_type;
98
99            fn singleton() -> &'static $crate::pool::arc::ArcPoolImpl<$data_type> {
100                // Even though the static variable is not exposed to user code, it is
101                // still useful to have a descriptive symbol name for debugging.
102                #[allow(non_upper_case_globals)]
103                static $name: $crate::pool::arc::ArcPoolImpl<$data_type> =
104                    $crate::pool::arc::ArcPoolImpl::new();
105
106                &$name
107            }
108        }
109
110        impl $name {
111            /// Inherent method version of `ArcPool::alloc`
112            #[allow(dead_code)]
113            pub fn alloc(
114                &self,
115                value: $data_type,
116            ) -> Result<$crate::pool::arc::Arc<$name>, $data_type> {
117                <$name as $crate::pool::arc::ArcPool>::alloc(value)
118            }
119
120            /// Inherent method version of `ArcPool::manage`
121            #[allow(dead_code)]
122            pub fn manage(&self, block: &'static mut $crate::pool::arc::ArcBlock<$data_type>) {
123                <$name as $crate::pool::arc::ArcPool>::manage(block)
124            }
125        }
126    };
127}
128
129/// A singleton that manages `pool::arc::Arc` smart pointers
130pub trait ArcPool: Sized {
131    /// The data type managed by the memory pool
132    type Data: 'static;
133
134    /// `arc_pool!` implementation detail
135    #[doc(hidden)]
136    fn singleton() -> &'static ArcPoolImpl<Self::Data>;
137
138    /// Allocate a new `Arc` smart pointer initialized to the given `value`
139    ///
140    /// `manage` should be called at least once before calling `alloc`
141    ///
142    /// # Errors
143    ///
144    /// The `Err`or variant is returned when the memory pool has run out of memory blocks
145    fn alloc(value: Self::Data) -> Result<Arc<Self>, Self::Data> {
146        Ok(Arc {
147            node_ptr: Self::singleton().alloc(value)?,
148        })
149    }
150
151    /// Add a statically allocated memory block to the memory pool
152    fn manage(block: &'static mut ArcBlock<Self::Data>) {
153        Self::singleton().manage(block);
154    }
155}
156
157/// `arc_pool!` implementation detail
158// newtype to avoid having to make field types public
159#[doc(hidden)]
160pub struct ArcPoolImpl<T> {
161    stack: Stack<UnionNode<MaybeUninit<ArcInner<T>>>>,
162}
163
164impl<T> ArcPoolImpl<T> {
165    /// `arc_pool!` implementation detail
166    #[doc(hidden)]
167    #[allow(clippy::new_without_default)]
168    pub const fn new() -> Self {
169        Self {
170            stack: Stack::new(),
171        }
172    }
173
174    fn alloc(&self, value: T) -> Result<NonNullPtr<UnionNode<MaybeUninit<ArcInner<T>>>>, T> {
175        if let Some(node_ptr) = self.stack.try_pop() {
176            let inner = ArcInner {
177                data: value,
178                strong: AtomicUsize::new(1),
179            };
180            unsafe { node_ptr.as_ptr().cast::<ArcInner<T>>().write(inner) }
181
182            Ok(node_ptr)
183        } else {
184            Err(value)
185        }
186    }
187
188    fn manage(&self, block: &'static mut ArcBlock<T>) {
189        let node: &'static mut _ = &mut block.node;
190
191        // SAFETY: The node within an `ArcBlock` is always properly initialized for linking because
192        // the only way for         client code to construct an `ArcBlock` is through
193        // `ArcBlock::new`. The `NonNullPtr` comes from a         reference, so it is
194        // guaranteed to be dereferencable. It is also unique because the `ArcBlock` itself
195        //         is passed as a `&mut`
196        unsafe { self.stack.push(NonNullPtr::from_static_mut_ref(node)) }
197    }
198}
199
200unsafe impl<T> Sync for ArcPoolImpl<T> {}
201
202/// Like `std::sync::Arc` but managed by memory pool `P`
203pub struct Arc<P>
204where
205    P: ArcPool,
206{
207    node_ptr: NonNullPtr<UnionNode<MaybeUninit<ArcInner<P::Data>>>>,
208}
209
210impl<P> Arc<P>
211where
212    P: ArcPool,
213{
214    fn inner(&self) -> &ArcInner<P::Data> {
215        unsafe { &*self.node_ptr.as_ptr().cast::<ArcInner<P::Data>>() }
216    }
217
218    fn from_inner(node_ptr: NonNullPtr<UnionNode<MaybeUninit<ArcInner<P::Data>>>>) -> Self {
219        Self { node_ptr }
220    }
221
222    unsafe fn get_mut_unchecked(this: &mut Self) -> &mut P::Data {
223        &mut *ptr::addr_of_mut!((*this.node_ptr.as_ptr().cast::<ArcInner<P::Data>>()).data)
224    }
225
226    #[inline(never)]
227    unsafe fn drop_slow(&mut self) {
228        // run `P::Data`'s destructor
229        ptr::drop_in_place(Self::get_mut_unchecked(self));
230
231        // return memory to pool
232        P::singleton().stack.push(self.node_ptr);
233    }
234}
235
236impl<P> AsRef<P::Data> for Arc<P>
237where
238    P: ArcPool,
239{
240    fn as_ref(&self) -> &P::Data {
241        self
242    }
243}
244
245const MAX_REFCOUNT: usize = (isize::MAX) as usize;
246
247impl<P> Clone for Arc<P>
248where
249    P: ArcPool,
250{
251    fn clone(&self) -> Self {
252        let old_size = self.inner().strong.fetch_add(1, Ordering::Relaxed);
253
254        if old_size > MAX_REFCOUNT {
255            // XXX original code calls `intrinsics::abort` which is unstable API
256            panic!();
257        }
258
259        Self::from_inner(self.node_ptr)
260    }
261}
262
263impl<A> fmt::Debug for Arc<A>
264where
265    A: ArcPool,
266    A::Data: fmt::Debug,
267{
268    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269        A::Data::fmt(self, f)
270    }
271}
272
273impl<P> ops::Deref for Arc<P>
274where
275    P: ArcPool,
276{
277    type Target = P::Data;
278
279    fn deref(&self) -> &Self::Target {
280        unsafe { &*ptr::addr_of!((*self.node_ptr.as_ptr().cast::<ArcInner<P::Data>>()).data) }
281    }
282}
283
284impl<A> fmt::Display for Arc<A>
285where
286    A: ArcPool,
287    A::Data: fmt::Display,
288{
289    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
290        A::Data::fmt(self, f)
291    }
292}
293
294impl<A> Drop for Arc<A>
295where
296    A: ArcPool,
297{
298    fn drop(&mut self) {
299        if self.inner().strong.fetch_sub(1, Ordering::Release) != 1 {
300            return;
301        }
302
303        atomic::fence(Ordering::Acquire);
304
305        unsafe { self.drop_slow() }
306    }
307}
308
309impl<A> Eq for Arc<A>
310where
311    A: ArcPool,
312    A::Data: Eq,
313{
314}
315
316impl<A> Hash for Arc<A>
317where
318    A: ArcPool,
319    A::Data: Hash,
320{
321    fn hash<H>(&self, state: &mut H)
322    where
323        H: Hasher,
324    {
325        (**self).hash(state);
326    }
327}
328
329impl<A> Ord for Arc<A>
330where
331    A: ArcPool,
332    A::Data: Ord,
333{
334    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
335        A::Data::cmp(self, other)
336    }
337}
338
339impl<A, B> PartialEq<Arc<B>> for Arc<A>
340where
341    A: ArcPool,
342    B: ArcPool,
343    A::Data: PartialEq<B::Data>,
344{
345    fn eq(&self, other: &Arc<B>) -> bool {
346        A::Data::eq(self, &**other)
347    }
348}
349
350impl<A, B> PartialOrd<Arc<B>> for Arc<A>
351where
352    A: ArcPool,
353    B: ArcPool,
354    A::Data: PartialOrd<B::Data>,
355{
356    fn partial_cmp(&self, other: &Arc<B>) -> Option<core::cmp::Ordering> {
357        A::Data::partial_cmp(self, &**other)
358    }
359}
360
361unsafe impl<A> Send for Arc<A>
362where
363    A: ArcPool,
364    A::Data: Sync + Send,
365{
366}
367
368unsafe impl<A> Sync for Arc<A>
369where
370    A: ArcPool,
371    A::Data: Sync + Send,
372{
373}
374
375impl<A> Unpin for Arc<A> where A: ArcPool {}
376
377struct ArcInner<T> {
378    data: T,
379    strong: AtomicUsize,
380}
381
382/// A chunk of memory that an `ArcPool` can manage
383pub struct ArcBlock<T> {
384    node: UnionNode<MaybeUninit<ArcInner<T>>>,
385}
386
387impl<T> ArcBlock<T> {
388    /// Creates a new memory block
389    pub const fn new() -> Self {
390        Self {
391            node: UnionNode::unlinked(),
392        }
393    }
394}
395
396impl<T> Default for ArcBlock<T> {
397    fn default() -> Self {
398        Self::new()
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use std::ptr::addr_of_mut;
406
407    #[test]
408    fn cannot_alloc_if_empty() {
409        arc_pool!(MyArcPool: i32);
410
411        assert_eq!(Err(42), MyArcPool.alloc(42),);
412    }
413
414    #[test]
415    fn can_alloc_if_manages_one_block() {
416        arc_pool!(MyArcPool: i32);
417
418        let block = unsafe {
419            static mut BLOCK: ArcBlock<i32> = ArcBlock::new();
420            addr_of_mut!(BLOCK).as_mut().unwrap()
421        };
422        MyArcPool.manage(block);
423
424        assert_eq!(42, *MyArcPool.alloc(42).unwrap());
425    }
426
427    #[test]
428    fn alloc_drop_alloc() {
429        arc_pool!(MyArcPool: i32);
430
431        let block = unsafe {
432            static mut BLOCK: ArcBlock<i32> = ArcBlock::new();
433            addr_of_mut!(BLOCK).as_mut().unwrap()
434        };
435        MyArcPool.manage(block);
436
437        let arc = MyArcPool.alloc(1).unwrap();
438
439        drop(arc);
440
441        assert_eq!(2, *MyArcPool.alloc(2).unwrap());
442    }
443
444    #[test]
445    fn strong_count_starts_at_one() {
446        arc_pool!(MyArcPool: i32);
447
448        let block = unsafe {
449            static mut BLOCK: ArcBlock<i32> = ArcBlock::new();
450            addr_of_mut!(BLOCK).as_mut().unwrap()
451        };
452        MyArcPool.manage(block);
453
454        let arc = MyArcPool.alloc(1).ok().unwrap();
455
456        assert_eq!(1, arc.inner().strong.load(Ordering::Relaxed));
457    }
458
459    #[test]
460    fn clone_increases_strong_count() {
461        arc_pool!(MyArcPool: i32);
462
463        let block = unsafe {
464            static mut BLOCK: ArcBlock<i32> = ArcBlock::new();
465            addr_of_mut!(BLOCK).as_mut().unwrap()
466        };
467        MyArcPool.manage(block);
468
469        let arc = MyArcPool.alloc(1).ok().unwrap();
470
471        let before = arc.inner().strong.load(Ordering::Relaxed);
472
473        let arc2 = arc.clone();
474
475        let expected = before + 1;
476        assert_eq!(expected, arc.inner().strong.load(Ordering::Relaxed));
477        assert_eq!(expected, arc2.inner().strong.load(Ordering::Relaxed));
478    }
479
480    #[test]
481    fn drop_decreases_strong_count() {
482        arc_pool!(MyArcPool: i32);
483
484        let block = unsafe {
485            static mut BLOCK: ArcBlock<i32> = ArcBlock::new();
486            addr_of_mut!(BLOCK).as_mut().unwrap()
487        };
488        MyArcPool.manage(block);
489
490        let arc = MyArcPool.alloc(1).ok().unwrap();
491        let arc2 = arc.clone();
492
493        let before = arc.inner().strong.load(Ordering::Relaxed);
494
495        drop(arc);
496
497        let expected = before - 1;
498        assert_eq!(expected, arc2.inner().strong.load(Ordering::Relaxed));
499    }
500
501    #[test]
502    fn runs_destructor_exactly_once_when_strong_count_reaches_zero() {
503        static COUNT: AtomicUsize = AtomicUsize::new(0);
504
505        pub struct MyStruct;
506
507        impl Drop for MyStruct {
508            fn drop(&mut self) {
509                COUNT.fetch_add(1, Ordering::Relaxed);
510            }
511        }
512
513        arc_pool!(MyArcPool: MyStruct);
514
515        let block = unsafe {
516            static mut BLOCK: ArcBlock<MyStruct> = ArcBlock::new();
517            addr_of_mut!(BLOCK).as_mut().unwrap()
518        };
519        MyArcPool.manage(block);
520
521        let arc = MyArcPool.alloc(MyStruct).ok().unwrap();
522
523        assert_eq!(0, COUNT.load(Ordering::Relaxed));
524
525        drop(arc);
526
527        assert_eq!(1, COUNT.load(Ordering::Relaxed));
528    }
529
530    #[test]
531    fn zst_is_well_aligned() {
532        #[repr(align(4096))]
533        pub struct Zst4096;
534
535        arc_pool!(MyArcPool: Zst4096);
536
537        let block = unsafe {
538            static mut BLOCK: ArcBlock<Zst4096> = ArcBlock::new();
539            addr_of_mut!(BLOCK).as_mut().unwrap()
540        };
541        MyArcPool.manage(block);
542
543        let arc = MyArcPool.alloc(Zst4096).ok().unwrap();
544
545        let raw = std::ptr::from_ref::<Zst4096>(&*arc);
546        assert_eq!(0, raw as usize % 4096);
547    }
548}