aarc/
atomics.rs

1use std::marker::PhantomData;
2use std::ptr;
3use std::ptr::{null, null_mut, NonNull};
4use std::sync::atomic::AtomicPtr;
5use std::sync::atomic::Ordering::{Relaxed, SeqCst};
6
7use fast_smr::smr::{load, protect};
8
9use crate::smart_ptrs::{Arc, AsPtr, Guard, Weak};
10use crate::StrongPtr;
11
12/// An [`Arc`] with an atomically updatable pointer.
13///
14/// Usage notes:
15/// * An `AtomicArc` can intrinsically store `None` (a hypothetical `Option<AtomicArc<T>>` would
16///   no longer be atomic).
17/// * An `AtomicArc` contributes to the strong count of the pointed-to allocation, if any. However,
18///   it does not implement `Deref`, so methods like `load` must be used to obtain a [`Guard`]
19///   through which the data can be accessed.
20/// * `T` must be `Sized` for compatibility with `AtomicPtr`. This may be relaxed in the future.
21/// * When an `AtomicArc` is updated or dropped, the strong count of the previously pointed-to
22///   object may not be immediately decremented. Thus:
23///     * `T` must be `'static` to support delayed deallocations.
24///     * The value returned by `strong_count` will likely be an overestimate.
25///
26/// # Examples
27/// ```
28/// use aarc::{Arc, AtomicArc, Guard, RefCount};
29///
30/// let atomic = AtomicArc::new(53);
31///
32/// let guard = atomic.load().unwrap(); // guard doesn't affect strong count
33/// assert_eq!(*guard, 53);
34///
35/// let arc = Arc::from(&guard);
36/// assert_eq!(arc.strong_count(), 2);
37///
38/// assert_eq!(*arc, *guard);
39/// ```
40#[derive(Default)]
41pub struct AtomicArc<T: 'static> {
42    ptr: AtomicPtr<T>,
43    phantom: PhantomData<T>,
44}
45
46impl<T: 'static> AtomicArc<T> {
47    /// Similar to [`Arc::new`], but `None` is a valid input, in which case the `AtomicArc` will
48    /// store a null pointer.
49    ///
50    /// To create an `AtomicArc` from an existing `Arc`, use `from`.
51    pub fn new<D: Into<Option<T>>>(data: D) -> Self {
52        let ptr = data.into().map_or(null(), |x| Arc::into_raw(Arc::new(x)));
53        Self {
54            ptr: AtomicPtr::new(ptr.cast_mut()),
55            phantom: PhantomData,
56        }
57    }
58
59    /// If `self` and `current` point to the same object, new’s pointer will be stored into self
60    /// and the result will be an empty `Ok`. Otherwise, a `load` occurs, and an `Err` containing
61    /// a [`Guard`] will be returned.
62    pub fn compare_exchange<N: AsPtr<Target = T> + StrongPtr>(
63        &self,
64        current: *const T,
65        new: Option<&N>,
66    ) -> Result<(), Option<Guard<T>>> {
67        let c = current.cast_mut();
68        let n = new.map_or(null(), N::as_ptr).cast_mut();
69        match self.ptr.compare_exchange(c, n, SeqCst, SeqCst) {
70            Ok(before) => unsafe {
71                Self::after_swap(n, before);
72                Ok(())
73            },
74            Err(actual) => {
75                let mut opt = None;
76                if let Some(ptr) = NonNull::new(actual) {
77                    if let Some(guard) = protect(&self.ptr, ptr) {
78                        opt = Some(Guard { guard })
79                    }
80                }
81                Err(opt)
82            }
83        }
84    }
85
86    /// Loads a [`Guard`], which allows the pointed-to value to be accessed. `None` indicates that
87    /// the inner atomic pointer is null.
88    pub fn load(&self) -> Option<Guard<T>> {
89        let guard = load(&self.ptr)?;
90        Some(Guard { guard })
91    }
92
93    /// Stores `new`'s pointer (or `None`) into `self`.
94    pub fn store<N: AsPtr<Target = T> + StrongPtr>(&self, new: Option<&N>) {
95        // TODO: rework this method to possibly take ownership of new (avoid increment).
96        let n = new.map_or(null(), N::as_ptr);
97        let before = self.ptr.swap(n.cast_mut(), SeqCst);
98        unsafe {
99            Self::after_swap(n, before);
100        }
101    }
102
103    unsafe fn after_swap(new: *const T, before: *const T) {
104        if !ptr::eq(new, before) {
105            if !new.is_null() {
106                Arc::increment_strong_count(new);
107            }
108            if !before.is_null() {
109                drop(Arc::from_raw(before));
110            }
111        }
112    }
113}
114
115impl<T: 'static> Clone for AtomicArc<T> {
116    fn clone(&self) -> Self {
117        let ptr = if let Some(guard) = self.load() {
118            unsafe {
119                Arc::increment_strong_count(guard.as_ptr());
120            }
121            guard.as_ptr().cast_mut()
122        } else {
123            null_mut()
124        };
125        Self {
126            ptr: AtomicPtr::new(ptr),
127            phantom: PhantomData,
128        }
129    }
130}
131
132impl<T: 'static> Drop for AtomicArc<T> {
133    fn drop(&mut self) {
134        if let Some(ptr) = NonNull::new(self.ptr.load(Relaxed)) {
135            unsafe {
136                drop(Arc::from_raw(ptr.as_ptr()));
137            }
138        }
139    }
140}
141
142unsafe impl<T: 'static + Send + Sync> Send for AtomicArc<T> {}
143
144unsafe impl<T: 'static + Send + Sync> Sync for AtomicArc<T> {}
145
146/// A [`Weak`] with an atomically updatable pointer.
147///
148/// See [`AtomicArc`] for usage notes. `AtomicWeak` differs only in that it contributes to the weak
149/// count instead of the strong count.
150///
151/// # Examples
152/// ```
153/// use aarc::{Arc, AtomicWeak, RefCount, Weak};
154///
155/// let arc = Arc::new(53);
156///
157/// let atomic = AtomicWeak::from(&arc); // +1 weak count
158///
159/// let guard = atomic.load().unwrap();
160///
161/// assert_eq!(*arc, *guard);
162/// assert_eq!(arc.weak_count(), 1);
163/// ```
164#[derive(Default)]
165pub struct AtomicWeak<T: 'static> {
166    ptr: AtomicPtr<T>,
167}
168
169impl<T: 'static> AtomicWeak<T> {
170    /// If `self` and `current` point to the same object, new’s pointer will be stored into self
171    /// and the result will be an empty `Ok`. Otherwise, a load will be attempted and a
172    /// [`Guard`] will be returned if possible. See `load`.
173    pub fn compare_exchange<N: AsPtr<Target = T>>(
174        &self,
175        current: *const T,
176        new: Option<&N>,
177    ) -> Result<(), Option<Guard<T>>> {
178        let c = current.cast_mut();
179        let n = new.map_or(null(), N::as_ptr).cast_mut();
180        match self.ptr.compare_exchange(c, n, SeqCst, SeqCst) {
181            Ok(before) => unsafe {
182                Self::after_swap(n, before);
183                Ok(())
184            },
185            Err(actual) => unsafe {
186                let mut opt = None;
187                if let Some(ptr) = NonNull::new(actual) {
188                    if let Some(guard) = protect(&self.ptr, ptr) {
189                        opt = (Arc::strong_count_raw(guard.as_ptr()) > 0).then_some(Guard { guard })
190                    }
191                }
192                Err(opt)
193            },
194        }
195    }
196
197    /// Attempts to load a [`Guard`]. This method differs from the one on `AtomicArc` in that
198    /// `None` may indicate one of two things:
199    /// * The `AtomicWeak` is indeed not pointing to anything (null pointer).
200    /// * The pointer is not null, but the strong count is 0, so a `Guard` cannot be loaded.
201    ///
202    /// There is currently no way for the user to differentiate between the two cases (this may
203    /// change in the future).
204    pub fn load(&self) -> Option<Guard<T>> {
205        let guard = load(&self.ptr)?;
206        unsafe { (Arc::strong_count_raw(guard.as_ptr()) > 0).then_some(Guard { guard }) }
207    }
208
209    /// Stores `new`'s pointer (or `None`) into `self`.
210    pub fn store<N: AsPtr<Target = T>>(&self, new: Option<&N>) {
211        let n = new.map_or(null(), N::as_ptr);
212        let before = self.ptr.swap(n.cast_mut(), SeqCst);
213        unsafe {
214            Self::after_swap(n, before);
215        }
216    }
217
218    unsafe fn after_swap(new: *const T, before: *const T) {
219        if !ptr::eq(new, before) {
220            if !new.is_null() {
221                Weak::increment_weak_count(new);
222            }
223            if !before.is_null() {
224                drop(Weak::from_raw(before));
225            }
226        }
227    }
228}
229
230impl<T: 'static> Clone for AtomicWeak<T> {
231    fn clone(&self) -> Self {
232        let ptr = if let Some(guard) = self.load() {
233            unsafe {
234                Weak::increment_weak_count(guard.as_ptr());
235            }
236            guard.as_ptr().cast_mut()
237        } else {
238            null_mut()
239        };
240        Self {
241            ptr: AtomicPtr::new(ptr),
242        }
243    }
244}
245
246impl<T: 'static> Drop for AtomicWeak<T> {
247    fn drop(&mut self) {
248        if let Some(ptr) = NonNull::new(self.ptr.load(Relaxed)) {
249            unsafe {
250                drop(Weak::from_raw(ptr.as_ptr()));
251            }
252        }
253    }
254}
255
256impl<T: 'static, P: AsPtr<Target = T> + StrongPtr> From<&P> for AtomicArc<T> {
257    fn from(value: &P) -> Self {
258        unsafe {
259            let ptr = P::as_ptr(value);
260            Arc::increment_strong_count(ptr);
261            Self {
262                ptr: AtomicPtr::new(ptr.cast_mut()),
263                phantom: PhantomData,
264            }
265        }
266    }
267}
268
269impl<T: 'static, P: AsPtr<Target = T>> From<&P> for AtomicWeak<T> {
270    fn from(value: &P) -> Self {
271        unsafe {
272            let ptr = P::as_ptr(value);
273            Weak::increment_weak_count(ptr);
274            Self {
275                ptr: AtomicPtr::new(ptr.cast_mut()),
276            }
277        }
278    }
279}
280
281unsafe impl<T: 'static + Send + Sync> Send for AtomicWeak<T> {}
282
283unsafe impl<T: 'static + Send + Sync> Sync for AtomicWeak<T> {}