hipstr/
smart.rs

1//! Limited but generic smart pointer.
2//!
3//! This module provides a smart pointer that abstracts over its "kind":
4//!
5//! - unique,
6//! - reference counted,
7//! - atomically reference counted.
8
9use alloc::boxed::Box;
10use core::cell::Cell;
11use core::mem::ManuallyDrop;
12use core::ops::Deref;
13use core::ptr::NonNull;
14#[cfg(not(loom))]
15use core::sync::atomic::{fence, AtomicUsize, Ordering};
16
17#[cfg(loom)]
18use loom::sync::atomic::{fence, AtomicUsize, Ordering};
19
20#[cfg(test)]
21mod tests;
22
23/// Unique reference marker.
24pub struct Unique(/* nothing but not constructible either */);
25
26/// Local (thread-unsafe) reference counter.
27pub struct Rc(Cell<usize>);
28
29/// Atomic (thread-safe) reference counter.
30#[cfg(target_has_atomic = "ptr")]
31pub struct Arc(AtomicUsize);
32
33/// Reference counting update result.
34#[derive(Clone, Copy, PartialEq, Eq)]
35pub enum UpdateResult {
36    /// The update was successful.
37    Done,
38    /// No update was performed because the counter was already reaches a boundary.
39    Overflow,
40}
41
42/// Trait for a basic reference counter.
43pub trait Kind {
44    /// Creates a new counter that starts at one.
45    fn one() -> Self;
46
47    /// Tries to increment the counter.
48    fn incr(&self) -> UpdateResult;
49
50    /// Tries to decrement the counter.
51    fn decr(&self) -> UpdateResult;
52
53    /// Returns the current value of the counter.
54    fn get(&self) -> usize;
55
56    /// Checks if the counter is at one.
57    ///
58    /// In case of atomics, the [`Ordering::Acquire`] semantics is expected.
59    fn is_unique(&self) -> bool {
60        self.get() == 1
61    }
62}
63
64impl Kind for Unique {
65    #[inline]
66    fn one() -> Self {
67        Self {}
68    }
69    #[inline]
70    fn incr(&self) -> UpdateResult {
71        UpdateResult::Overflow
72    }
73    #[inline]
74    fn decr(&self) -> UpdateResult {
75        UpdateResult::Overflow
76    }
77    #[inline]
78    fn get(&self) -> usize {
79        1
80    }
81}
82
83impl Kind for Rc {
84    #[inline]
85    fn one() -> Self {
86        Self(Cell::new(0))
87    }
88
89    #[inline]
90    fn incr(&self) -> UpdateResult {
91        let new = self.0.get() + 1;
92        if new < usize::MAX {
93            // usize::MAX is forbidden
94            self.0.set(new);
95            UpdateResult::Done
96        } else {
97            UpdateResult::Overflow
98        }
99    }
100
101    #[inline]
102    fn decr(&self) -> UpdateResult {
103        self.0
104            .get()
105            .checked_sub(1)
106            .map_or(UpdateResult::Overflow, |new| {
107                self.0.set(new);
108                UpdateResult::Done
109            })
110    }
111
112    #[inline]
113    fn get(&self) -> usize {
114        // the count is strictly less than `usize::MAX`
115        self.0.get() + 1
116    }
117}
118
119#[cfg(target_has_atomic = "ptr")]
120impl Kind for Arc {
121    #[inline]
122    fn one() -> Self {
123        Self(AtomicUsize::new(0))
124    }
125
126    #[inline]
127    fn decr(&self) -> UpdateResult {
128        let old_value = self.0.fetch_sub(1, Ordering::Release);
129        if old_value == 0 {
130            fence(Ordering::Acquire);
131            UpdateResult::Overflow
132        } else {
133            UpdateResult::Done
134        }
135    }
136
137    #[inline]
138    fn incr(&self) -> UpdateResult {
139        let set_order = Ordering::Release;
140        let fetch_order = Ordering::Relaxed;
141
142        let atomic = &self.0;
143        let mut old = atomic.load(fetch_order);
144        while old < usize::MAX - 1 {
145            let new = old + 1;
146            match atomic.compare_exchange_weak(old, new, set_order, fetch_order) {
147                Ok(_) => {
148                    return UpdateResult::Done;
149                }
150                Err(next_prev) => old = next_prev,
151            }
152        }
153        UpdateResult::Overflow
154    }
155
156    #[inline]
157    fn get(&self) -> usize {
158        self.0.load(Ordering::Relaxed) + 1
159    }
160
161    #[inline]
162    fn is_unique(&self) -> bool {
163        if self.0.load(Ordering::Relaxed) == 0 {
164            fence(Ordering::Acquire);
165            true
166        } else {
167            false
168        }
169    }
170}
171
172/// Smart pointer inner cell.
173pub struct Inner<T, C>
174where
175    C: Kind,
176{
177    count: C,
178    value: T,
179}
180
181impl<T, C> Clone for Inner<T, C>
182where
183    T: Clone,
184    C: Kind,
185{
186    fn clone(&self) -> Self {
187        Self {
188            count: C::one(),
189            value: self.value.clone(),
190        }
191    }
192}
193
194/// Basic smart pointer, with generic counter.
195pub struct Smart<T, C>(NonNull<Inner<T, C>>)
196where
197    T: Clone,
198    C: Kind;
199
200#[allow(unused)]
201impl<T, C> Smart<T, C>
202where
203    T: Clone,
204    C: Kind,
205{
206    /// Creates the smart pointer.
207    #[inline]
208    #[must_use]
209    pub fn new(value: T) -> Self {
210        let ptr = Box::into_raw(Box::new(Inner {
211            count: C::one(),
212            value,
213        }));
214        Self(unsafe { NonNull::new_unchecked(ptr) })
215    }
216
217    #[inline]
218    #[must_use]
219    const fn inner(&self) -> &Inner<T, C> {
220        // SAFETY: type invariant
221        unsafe { self.0.as_ref() }
222    }
223
224    /// Converts the smart pointer to a raw pointer.
225    #[inline]
226    #[must_use]
227    pub fn into_raw(self) -> NonNull<Inner<T, C>> {
228        let smart = ManuallyDrop::new(self);
229        smart.0
230    }
231
232    /// Creates a smart pointer from a raw pointer.
233    #[inline]
234    pub fn from_raw(ptr: NonNull<Inner<T, C>>) -> Self {
235        debug_assert!(ptr.is_aligned());
236        unsafe { Self(ptr) }
237    }
238
239    /// Gets a reference to the value.
240    #[inline]
241    #[must_use]
242    pub const fn as_ref(&self) -> &T {
243        &self.inner().value
244    }
245
246    /// Checks if this reference is unique.
247    #[inline]
248    #[must_use]
249    pub fn is_unique(&self) -> bool {
250        self.inner().count.is_unique()
251    }
252
253    /// Gets a mutable reference to the value
254    #[inline]
255    #[must_use]
256    pub fn as_mut(&mut self) -> Option<&mut T> {
257        // SAFETY: type invariant, the raw pointer is valid
258        if self.is_unique() {
259            // SAFETY: uniqueness checked
260            Some(unsafe { &mut self.0.as_mut().value })
261        } else {
262            None
263        }
264    }
265
266    /// Gets a mutable reference to the value without checking the uniqueness
267    ///
268    /// # Safety
269    ///
270    /// Any caller should check the uniqueness first with [`Self::is_unique`].
271    #[inline]
272    pub const unsafe fn as_mut_unchecked(&mut self) -> &mut T {
273        // SAFETY: uniqueness precondition
274        unsafe { &mut self.0.as_mut().value }
275    }
276
277    /// Gets a mutable reference to the value without checking the uniqueness
278    ///
279    /// # Safety
280    ///
281    /// - Any caller should check the uniqueness first with [`Self::is_unique`].
282    /// - The referenced value must outlive `'a`.
283    #[inline]
284    pub(crate) const unsafe fn as_mut_unchecked_extended<'a>(&mut self) -> &'a mut T
285    where
286        Self: 'a,
287    {
288        // SAFETY: uniqueness precondition
289        unsafe { &mut self.0.as_mut().value }
290    }
291
292    /// Gets the reference count.
293    #[inline]
294    #[must_use]
295    pub(crate) fn ref_count(&self) -> usize {
296        // SAFETY: type invariant, the raw pointer cannot be dangling
297        let inner = unsafe { self.0.as_ref() };
298        inner.count.get()
299    }
300
301    /// Try to unwrap to its inner value.
302    #[inline]
303    pub fn try_unwrap(self) -> Result<T, Self> {
304        unsafe {
305            if self.is_unique() {
306                // do not drop `self`!
307                let this = ManuallyDrop::new(self);
308                // SAFETY: type invariant, pointer must be valid
309                let inner = unsafe { Box::from_raw(this.0.as_ptr()) };
310                Ok(inner.value)
311            } else {
312                Err(self)
313            }
314        }
315    }
316
317    pub(crate) fn incr(&self) -> UpdateResult {
318        self.inner().count.incr()
319    }
320}
321
322impl<T, C> Clone for Smart<T, C>
323where
324    T: Clone,
325    C: Kind,
326{
327    fn clone(&self) -> Self {
328        if unsafe { &(*self.0.as_ptr()).count }.incr() == UpdateResult::Done {
329            Self(self.0)
330        } else {
331            let inner = self.inner().clone();
332            let ptr = Box::into_raw(Box::new(inner));
333            // SAFETY: duh
334            let nonnull = unsafe { NonNull::new_unchecked(ptr) };
335            Self(nonnull)
336        }
337    }
338}
339
340impl<T, C> Drop for Smart<T, C>
341where
342    T: Clone,
343    C: Kind,
344{
345    fn drop(&mut self) {
346        // SAFETY: type invariant, cannot be dangling
347        unsafe {
348            if self.inner().count.decr() == UpdateResult::Overflow {
349                let ptr = self.0.as_ptr();
350                let _ = Box::from_raw(ptr);
351            }
352        }
353    }
354}
355
356impl<T, C> Deref for Smart<T, C>
357where
358    T: Clone,
359    C: Kind,
360{
361    type Target = T;
362
363    #[inline]
364    fn deref(&self) -> &Self::Target {
365        self.as_ref()
366    }
367}
368
369unsafe impl<T, C> Send for Smart<T, C>
370where
371    T: Sync + Send + Clone,
372    C: Send + Kind,
373{
374}
375
376unsafe impl<T, C> Sync for Smart<T, C>
377where
378    T: Sync + Send + Clone,
379    C: Sync + Kind,
380{
381}