aarc/
atomics.rs

1use std::marker::PhantomData;
2use std::ptr::{eq, null, null_mut, NonNull};
3use std::sync::atomic::AtomicPtr;
4use std::sync::atomic::Ordering::SeqCst;
5
6use crate::smart_ptrs::{find_inner_ptr, ArcInner, Guard, CTX};
7use crate::Arc;
8
9/// An [`Arc`] with an atomically updatable pointer.
10///
11/// Usage notes:
12/// * An `AtomicArc` can intrinsically store `None` (a hypothetical `Option<AtomicArc<T>>` would
13///   no longer be atomic).
14/// * An `AtomicArc` contributes to the strong count of the pointed-to allocation, if any. However,
15///   it does not implement `Deref`, so methods like `load` must be used to obtain a [`Guard`]
16///   through which the data can be accessed.
17/// * `T` must be `Sized`. This may be relaxed in the future.
18/// * When an `AtomicArc` is updated or dropped, the strong count of the previously pointed-to
19///   object may not be immediately decremented. Thus:
20///     * `T` must be `'static` to support delayed deallocations.
21///     * The value returned by `ref_count` may be an overestimate.
22///
23/// # Examples
24/// ```
25/// use aarc::{Arc, AtomicArc, Guard};
26///
27/// // ref count: 1
28/// let x = Arc::new(53);
29/// assert_eq!(Arc::ref_count(&x), 1);
30///
31/// // ref count: 2
32/// let atomic = AtomicArc::new(0);
33/// atomic.store(Some(&x));
34/// assert_eq!(Arc::ref_count(&x), 2);
35///
36/// // guard doesn't affect the ref count
37/// let guard = atomic.load().unwrap();
38/// assert_eq!(Arc::ref_count(&x), 2);
39///
40/// // both the `Arc` and the `Guard` point to the same block
41/// assert_eq!(*guard, 53);
42/// assert_eq!(*guard, *x);
43/// ```
44#[derive(Default)]
45pub struct AtomicArc<T: 'static> {
46    ptr: AtomicPtr<ArcInner<T>>,
47    phantom: PhantomData<ArcInner<T>>,
48}
49
50impl<T: 'static> AtomicArc<T> {
51    /// Similar to [`Arc::new`], but `None` is a valid input, in which case the `AtomicArc` will
52    /// store a null pointer.
53    ///
54    /// To create an `AtomicArc` from an existing `Arc`, use `from`.
55    pub fn new<D: Into<Option<T>>>(data: D) -> Self {
56        let ptr = data.into().map_or(null_mut(), ArcInner::new);
57        Self {
58            ptr: AtomicPtr::new(ptr),
59            phantom: PhantomData,
60        }
61    }
62
63    /// Loads a [`Guard`], which allows the pointed-to value to be accessed. `None` indicates that
64    /// the inner atomic pointer is null.
65    pub fn load(&self) -> Option<Guard<'static, T>> {
66        let guard = CTX.with_borrow(|ctx| ctx.load(&self.ptr, 1))?;
67        Some(Guard { guard })
68    }
69
70    /// Stores `new`'s pointer (or `None`) into `self` and returns the previously-stored `Arc`.
71    pub fn swap<N: Into<NonNull<T>>>(&self, new: Option<N>) -> Option<Arc<T>> {
72        unsafe {
73            let n = new.map_or(null_mut(), |n| find_inner_ptr(n.into().as_ptr()).cast_mut());
74            if !n.is_null() {
75                ArcInner::increment(n);
76            }
77            let before = NonNull::new(self.ptr.swap(n, SeqCst))?;
78            Some(Arc {
79                ptr: before,
80                phantom: PhantomData,
81            })
82        }
83    }
84
85    /// Stores `new`'s pointer (or `None`) into `self`. Equivalent to `swap`, but discards the result.
86    pub fn store<N: Into<NonNull<T>>>(&self, new: Option<N>) {
87        _ = self.swap(new)
88    }
89}
90
91/// A trait for implementations of `compare_exchange` on `AtomicArc`.
92///
93/// If `self` and `current` point to the same object, new’s pointer will be stored into self
94/// and the result will be an empty `Ok`. Otherwise, a `load` occurs, and an `Err` containing
95/// a [`Guard`] will be returned.
96pub trait CompareExchange<T, N> {
97    fn compare_exchange<C: Into<NonNull<T>>>(
98        &self,
99        current: Option<C>,
100        new: Option<N>,
101    ) -> Result<(), Option<Guard<'static, T>>>;
102}
103
104impl<T: 'static> CompareExchange<T, &Guard<'static, T>> for AtomicArc<T> {
105    fn compare_exchange<C: Into<NonNull<T>>>(
106        &self,
107        current: Option<C>,
108        new: Option<&Guard<'static, T>>,
109    ) -> Result<(), Option<Guard<'static, T>>> {
110        unsafe {
111            let c = current.map_or(null_mut(), |c| find_inner_ptr(c.into().as_ptr()).cast_mut());
112            let n = new.map_or(null(), Guard::inner_ptr).cast_mut();
113            match self.ptr.compare_exchange(c, n, SeqCst, SeqCst) {
114                Ok(before) => {
115                    if !eq(before, n) {
116                        if !n.is_null() {
117                            ArcInner::increment(n);
118                        }
119                        if !before.is_null() {
120                            ArcInner::delayed_decrement(before);
121                        }
122                    }
123                    Ok(())
124                }
125                Err(actual) => {
126                    if let Some(ptr) = NonNull::new(actual) {
127                        let mut opt = None;
128                        let loaded = CTX.with_borrow(|ctx| ctx.protect(&self.ptr, ptr, 1));
129                        if let Some(guard) = loaded {
130                            opt = Some(Guard { guard })
131                        }
132                        Err(opt)
133                    } else {
134                        Err(None)
135                    }
136                }
137            }
138        }
139    }
140}
141
142impl<T: 'static> CompareExchange<T, &Arc<T>> for AtomicArc<T> {
143    fn compare_exchange<C: Into<NonNull<T>>>(
144        &self,
145        current: Option<C>,
146        new: Option<&Arc<T>>,
147    ) -> Result<(), Option<Guard<'static, T>>> {
148        let g = new.map(Guard::from);
149        CompareExchange::compare_exchange(self, current, g.as_ref())
150    }
151}
152
153impl<T: 'static> Clone for AtomicArc<T> {
154    fn clone(&self) -> Self {
155        let ptr = if let Some(guard) = self.load() {
156            unsafe {
157                let ptr = guard.guard.as_ptr();
158                _ = (*ptr).ref_count.fetch_add(1, SeqCst);
159                ptr
160            }
161        } else {
162            null_mut()
163        };
164        Self {
165            ptr: AtomicPtr::new(ptr.cast_mut()),
166            phantom: PhantomData,
167        }
168    }
169}
170
171impl<T: 'static> Drop for AtomicArc<T> {
172    fn drop(&mut self) {
173        if let Some(ptr) = NonNull::new(self.ptr.load(SeqCst)) {
174            unsafe {
175                ArcInner::delayed_decrement(ptr.as_ptr());
176            }
177        }
178    }
179}
180
181unsafe impl<T: 'static + Send + Sync> Send for AtomicArc<T> {}
182
183unsafe impl<T: 'static + Send + Sync> Sync for AtomicArc<T> {}
184
185impl<T: 'static, P: Into<NonNull<T>>> From<P> for AtomicArc<T> {
186    fn from(value: P) -> Self {
187        unsafe {
188            let inner_ptr = find_inner_ptr(value.into().as_ptr());
189            _ = (*inner_ptr).ref_count.fetch_add(1, SeqCst);
190            Self {
191                ptr: AtomicPtr::new(inner_ptr.cast_mut()),
192                phantom: PhantomData,
193            }
194        }
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use crate::{Arc, AtomicArc, CompareExchange};
201
202    #[test]
203    fn test_new_with_value() {
204        let atomic = AtomicArc::new(42);
205        let guard = atomic.load().unwrap();
206        assert_eq!(*guard, 42);
207    }
208
209    #[test]
210    fn test_new_with_none() {
211        let atomic: AtomicArc<i32> = AtomicArc::new(None);
212        assert!(atomic.load().is_none());
213    }
214
215    #[test]
216    fn test_swap() {
217        let atomic = AtomicArc::new(10);
218        let arc = Arc::new(20);
219
220        let old = atomic.swap(Some(&arc));
221        assert!(old.is_some());
222        assert_eq!(*old.unwrap(), 10);
223
224        let guard = atomic.load().unwrap();
225        assert_eq!(*guard, 20);
226    }
227
228    #[test]
229    fn test_swap_none() {
230        let atomic = AtomicArc::new(10);
231        let old = atomic.swap::<&Arc<i32>>(None);
232
233        assert!(old.is_some());
234        assert_eq!(*old.unwrap(), 10);
235        assert!(atomic.load().is_none());
236    }
237
238    #[test]
239    fn test_clone() {
240        let atomic = AtomicArc::new(42);
241        let cloned = atomic.clone();
242
243        let guard1 = atomic.load().unwrap();
244        let guard2 = cloned.load().unwrap();
245
246        assert_eq!(*guard1, 42);
247        assert_eq!(*guard2, 42);
248    }
249
250    #[test]
251    fn test_clone_none() {
252        let atomic: AtomicArc<i32> = AtomicArc::new(None);
253        let cloned = atomic.clone();
254
255        assert!(atomic.load().is_none());
256        assert!(cloned.load().is_none());
257    }
258
259    #[test]
260    fn test_compare_exchange_success_with_arc() {
261        let arc1 = Arc::new(10);
262        let arc2 = Arc::new(20);
263        let atomic = AtomicArc::new(10);
264        atomic.store(Some(&arc1));
265
266        let result = atomic.compare_exchange(Some(&arc1), Some(&arc2));
267        assert!(result.is_ok());
268
269        let guard = atomic.load().unwrap();
270        assert_eq!(*guard, 20);
271    }
272
273    #[test]
274    fn test_compare_exchange_failure_with_arc() {
275        let arc1 = Arc::new(10);
276        let arc2 = Arc::new(20);
277        let arc3 = Arc::new(30);
278        let atomic = AtomicArc::new(10);
279        atomic.store(Some(&arc1));
280
281        // Try to compare with arc2 (which is not the current value)
282        let result = atomic.compare_exchange(Some(&arc2), Some(&arc3));
283        assert!(result.is_err());
284
285        // Value should remain unchanged
286        let guard = atomic.load().unwrap();
287        assert_eq!(*guard, 10);
288    }
289
290    #[test]
291    fn test_compare_exchange_with_guard() {
292        let arc1 = Arc::new(10);
293        let arc2 = Arc::new(20);
294        let atomic = AtomicArc::new(10);
295        atomic.store(Some(&arc1));
296
297        let guard = atomic.load().unwrap();
298        let result = atomic.compare_exchange(Some(&guard), Some(&arc2));
299        assert!(result.is_ok());
300
301        let new_guard = atomic.load().unwrap();
302        assert_eq!(*new_guard, 20);
303    }
304
305    #[test]
306    fn test_from_arc() {
307        let arc = Arc::new(42);
308        let atomic = AtomicArc::new(0);
309        atomic.store(Some(&arc));
310
311        let guard = atomic.load().unwrap();
312        assert_eq!(*guard, 42);
313        assert_eq!(*arc, 42);
314    }
315}