atomic_ext/
lib.rs

1use std::{
2    alloc::Layout,
3    cmp::max,
4    marker::PhantomData,
5    ptr,
6    ptr::NonNull,
7    sync::{
8        atomic::{AtomicU64, AtomicUsize, Ordering},
9        Arc,
10    },
11};
12
13struct AtomicPtr<T> {
14    state: AtomicU64,
15    phantom: PhantomData<Arc<T>>,
16}
17
18impl<T> AtomicPtr<T> {
19    /// Creates an [`AtomicPtr`].
20    fn new(value: *const T) -> Self {
21        let state = new_state(value);
22        Self {
23            state: AtomicU64::new(state),
24            phantom: PhantomData,
25        }
26    }
27
28    /// Loads an [`Arc`] pointer.
29    fn load(&self, order: Ordering) -> *const T {
30        let state = self.state.fetch_add(1, order);
31        let (addr, count) = unpack_state(state);
32        if count >= RESERVED_COUNT {
33            panic!("external reference count overflow");
34        }
35        if count >= RESERVED_COUNT / 2 {
36            self.push_count(addr);
37        }
38        addr as _
39    }
40
41    /// Stores an [`Arc`] pointer and returns the previous pointer.
42    fn swap(&self, value: *const T, order: Ordering) -> *const T {
43        let state = self.state.swap(new_state(value), order);
44        let (addr, count) = unpack_state(state);
45        unsafe {
46            decrease_count::<T>(addr, RESERVED_COUNT - count);
47            addr as _
48        }
49    }
50
51    /// Stores an [`Arc`] pointer if the current value is the same as `current`.
52    fn compare_exchange(
53        &self,
54        current: *const T,
55        new: *const T,
56        success: Ordering,
57        failure: Ordering,
58    ) -> Result<*const T, *const T> {
59        let new_state = pack_state(new.addr());
60        let mut state = self.state.load(failure);
61        loop {
62            let (addr, count) = unpack_state(state);
63            if addr != current.addr() {
64                unsafe {
65                    increase_count::<T>(addr, 1);
66                }
67                return Err(addr as _);
68            }
69            match self
70                .state
71                .compare_exchange_weak(state, new_state, success, failure)
72            {
73                Ok(_) => {
74                    unsafe {
75                        decrease_count::<T>(addr, RESERVED_COUNT - count);
76                        increase_count::<T>(new.addr(), RESERVED_COUNT + 1);
77                    }
78                    return Ok(addr as _);
79                }
80                Err(now_state) => state = now_state,
81            }
82        }
83    }
84
85    /// Pushes the external reference count back to the original [`Arc`].
86    fn push_count(&self, expect_addr: usize) {
87        let mut current = self.state.load(Ordering::Acquire);
88        let desired = pack_state(expect_addr);
89        loop {
90            let (addr, count) = unpack_state(current);
91            if addr != expect_addr || count < RESERVED_COUNT / 2 {
92                // Someone else has changed the address or the reference count.
93                break;
94            }
95            match self.state.compare_exchange_weak(
96                current,
97                desired,
98                Ordering::Release,
99                Ordering::Relaxed,
100            ) {
101                Ok(_) => unsafe {
102                    increase_count::<T>(addr, count);
103                },
104                Err(actual) => current = actual,
105            }
106        }
107    }
108}
109
110impl<T> Drop for AtomicPtr<T> {
111    fn drop(&mut self) {
112        let state = self.state.load(Ordering::Acquire);
113        let (addr, count) = unpack_state(state);
114        unsafe {
115            decrease_count::<T>(addr, RESERVED_COUNT + 1 - count);
116        }
117    }
118}
119
120impl<T> Default for AtomicPtr<T> {
121    fn default() -> Self {
122        Self {
123            state: AtomicU64::new(0),
124            phantom: PhantomData,
125        }
126    }
127}
128
129/// An atomic pointer for [`Arc`].
130///
131/// **Note**: The imlementation manipuates the internal reference count of the
132/// original [`Arc`] for optimization. This means that the result of
133/// [`Arc::strong_count`] is incorrect, until the [`Arc`] gets rid of
134/// this pointer's control (with [`AtomicArc::swap`]). Users who depend on the
135/// correctness of [`Arc::strong_count`] should be careful.
136///
137/// # Limitations
138///
139/// The implementation borrows some bits from the `Arc` pointer as an external
140/// reference count (a technique called "split reference counting"). It
141/// will panic in some extreame scenario when the reference count is increased
142/// more than a threshold (2^15) at the same time. This is almost impossible
143/// unless someone creates more than 2^15 threads to load the same pointer at
144/// the same time.
145///
146/// # Examples
147///
148/// ```
149/// use std::{
150///     sync::{atomic::Ordering, Arc},
151///     thread,
152/// };
153///
154/// use atomic_ext::AtomicArc;
155///
156/// let a = Arc::new(1);
157/// let b = Arc::new(2);
158/// let x = Arc::new(AtomicArc::new(a));
159/// {
160///     let x = x.clone();
161///     thread::spawn(move || {
162///         x.swap(b, Ordering::AcqRel) // Returns `a`
163///     });
164/// }
165/// {
166///     let x = x.clone();
167///     thread::spawn(move || {
168///         x.load(Ordering::Acquire) // Returns either `a` or `b`
169///     });
170/// };
171/// ```
172pub struct AtomicArc<T>(AtomicPtr<T>);
173
174impl<T> AtomicArc<T> {
175    /// Creates a [`AtomicArc`] with the value.
176    pub fn new(value: Arc<T>) -> Self {
177        Self(AtomicPtr::new(Arc::into_raw(value)))
178    }
179
180    /// Loads an [`Arc`] from the pointer.
181    ///
182    /// The fast path uses just one atomic operation to load the [`Arc`] and
183    /// increase its reference count.
184    pub fn load(&self, order: Ordering) -> Arc<T> {
185        let ptr = self.0.load(order);
186        unsafe { Arc::from_raw(ptr) }
187    }
188
189    /// Stores an [`Arc`] into the pointer, returning the previous value.
190    pub fn swap(&self, value: Arc<T>, order: Ordering) -> Arc<T> {
191        let new = Arc::into_raw(value);
192        let current = self.0.swap(new, order);
193        unsafe { Arc::from_raw(current) }
194    }
195
196    /// Stores an [`Arc`] into the pointer if the current value is the same as
197    /// `current`.
198    pub fn compare_exchange(
199        &self,
200        current: &Arc<T>,
201        new: &Arc<T>,
202        success: Ordering,
203        failure: Ordering,
204    ) -> Result<Arc<T>, Arc<T>> {
205        let new = Arc::as_ptr(new);
206        let current = Arc::as_ptr(current);
207        self.0
208            .compare_exchange(current, new, success, failure)
209            .map(|ptr| unsafe { Arc::from_raw(ptr) })
210            .map_err(|ptr| unsafe { Arc::from_raw(ptr) })
211    }
212}
213
214/// An atomic pointer for [`Option<Arc>`].
215///
216/// This is similar to [`AtomicArc`], but allows null values.
217pub struct AtomicOptionArc<T>(AtomicPtr<T>);
218
219impl<T> AtomicOptionArc<T> {
220    /// Creates a [`AtomicOptionArc`] with the value.
221    pub fn new(value: Arc<T>) -> Self {
222        Self(AtomicPtr::new(Arc::into_raw(value)))
223    }
224
225    /// Loads an [`Arc`] from the pointer.
226    ///
227    /// Returns [`None`] if the pointer is null.
228    pub fn load(&self, order: Ordering) -> Option<Arc<T>> {
229        let ptr = self.0.load(order);
230        unsafe { Self::from_ptr(ptr) }
231    }
232
233    /// Stores an [`Arc`] into the pointer, returning the previous value.
234    pub fn swap(&self, value: Option<Arc<T>>, order: Ordering) -> Option<Arc<T>> {
235        let new = Self::into_ptr(value);
236        let current = self.0.swap(new, order);
237        unsafe { Self::from_ptr(current) }
238    }
239
240    /// Stores an [`Arc`] into the pointer if the current value is the same as
241    /// `current`.
242    pub fn compare_exchange(
243        &self,
244        current: Option<&Arc<T>>,
245        new: Option<&Arc<T>>,
246        success: Ordering,
247        failure: Ordering,
248    ) -> Result<Option<Arc<T>>, Option<Arc<T>>> {
249        let new = new.map(Arc::as_ptr).unwrap_or(ptr::null());
250        let current = current.map(Arc::as_ptr).unwrap_or(ptr::null());
251        self.0
252            .compare_exchange(current, new, success, failure)
253            .map(|ptr| unsafe { Self::from_ptr(ptr) })
254            .map_err(|ptr| unsafe { Self::from_ptr(ptr) })
255    }
256
257    fn into_ptr(value: Option<Arc<T>>) -> *const T {
258        value.map(Arc::into_raw).unwrap_or(ptr::null())
259    }
260
261    unsafe fn from_ptr(ptr: *const T) -> Option<Arc<T>> {
262        if ptr.is_null() {
263            None
264        } else {
265            Some(unsafe { Arc::from_raw(ptr) })
266        }
267    }
268}
269
270impl<T> Default for AtomicOptionArc<T> {
271    fn default() -> Self {
272        Self(AtomicPtr::new(ptr::null()))
273    }
274}
275
276const RESERVED_COUNT: usize = 0x8000;
277
278fn new_state<T>(ptr: *const T) -> u64 {
279    let addr = ptr.addr();
280    unsafe {
281        increase_count::<T>(addr, RESERVED_COUNT);
282        pack_state(addr)
283    }
284}
285
286fn pack_state(addr: usize) -> u64 {
287    let addr = addr as u64;
288    assert_eq!(addr >> 48, 0);
289    addr << 16
290}
291
292fn unpack_state(state: u64) -> (usize, usize) {
293    ((state >> 16) as usize, (state & 0xFFFF) as usize)
294}
295
296/// Constructs the same layout with [`std::sync::Arc`] so that we can manipulate
297/// the internal reference count.
298#[repr(C)]
299struct ArcInner {
300    count: AtomicUsize,
301    weak_count: AtomicUsize,
302}
303
304unsafe fn inner_ptr<T>(addr: usize) -> NonNull<ArcInner> {
305    let align = align_of::<T>();
306    let layout = Layout::new::<ArcInner>();
307    let offset = max(layout.size(), align);
308    NonNull::new_unchecked((addr - offset) as _)
309}
310
311unsafe fn increase_count<T>(addr: usize, count: usize) {
312    if addr != 0 {
313        let ptr = inner_ptr::<T>(addr);
314        ptr.as_ref().count.fetch_add(count, Ordering::Release);
315    }
316}
317
318unsafe fn decrease_count<T>(addr: usize, count: usize) {
319    if addr != 0 {
320        let ptr = inner_ptr::<T>(addr);
321        ptr.as_ref().count.fetch_sub(count, Ordering::Release);
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn test_arc() {
331        let a = Arc::new(1);
332        let b = Arc::new(2);
333        let x = AtomicArc::new(a.clone());
334        {
335            let c = x.load(Ordering::Acquire);
336            assert_eq!(c, a);
337            assert_eq!(Arc::strong_count(&a), RESERVED_COUNT + 2);
338        }
339        {
340            let c = x.swap(b.clone(), Ordering::AcqRel);
341            assert_eq!(c, a);
342            assert_eq!(Arc::strong_count(&a), 2);
343            assert_eq!(Arc::strong_count(&b), RESERVED_COUNT + 2);
344            let c = x.load(Ordering::Acquire);
345            assert_eq!(c, b);
346            assert_eq!(Arc::strong_count(&b), RESERVED_COUNT + 2);
347        }
348        {
349            let c = x
350                .compare_exchange(&b, &a, Ordering::AcqRel, Ordering::Acquire)
351                .unwrap();
352            assert_eq!(c, b);
353            assert_eq!(Arc::strong_count(&b), 2);
354            assert_eq!(Arc::strong_count(&a), RESERVED_COUNT + 2);
355            let c = x
356                .compare_exchange(&b, &a, Ordering::AcqRel, Ordering::Acquire)
357                .unwrap_err();
358            assert_eq!(c, a);
359            assert_eq!(Arc::strong_count(&a), RESERVED_COUNT + 3);
360        }
361        drop(x);
362        assert_eq!(Arc::strong_count(&a), 1);
363        assert_eq!(Arc::strong_count(&b), 1);
364    }
365
366    #[test]
367    fn test_option_arc() {
368        let a = Arc::new(1);
369        let b = Arc::new(2);
370        let x = AtomicOptionArc::new(a.clone());
371        {
372            let c = x.load(Ordering::Acquire);
373            assert_eq!(c, Some(a.clone()));
374        }
375        {
376            let c = x.swap(Some(b.clone()), Ordering::AcqRel);
377            assert_eq!(c, Some(a.clone()));
378            let c = x.load(Ordering::Acquire);
379            assert_eq!(c, Some(b.clone()));
380        }
381        {
382            let c = x
383                .compare_exchange(Some(&b), None, Ordering::AcqRel, Ordering::Relaxed)
384                .unwrap();
385            assert_eq!(c, Some(b.clone()));
386            let c = x
387                .compare_exchange(Some(&b), None, Ordering::AcqRel, Ordering::Relaxed)
388                .unwrap_err();
389            assert_eq!(c, None);
390        }
391        assert_eq!(x.load(Ordering::Acquire), None);
392        assert_eq!(Arc::strong_count(&a), 1);
393        assert_eq!(Arc::strong_count(&b), 1);
394    }
395
396    #[test]
397    fn test_push_count() {
398        let x = AtomicArc::new(Arc::new(1));
399        let mut v = Vec::new();
400        for _ in 0..(RESERVED_COUNT / 2) {
401            let a = x.load(Ordering::Relaxed);
402            assert_eq!(Arc::strong_count(&a), RESERVED_COUNT + 1);
403            v.push(a);
404        }
405        // This load will push the external count back to the `Arc`.
406        let a = x.load(Ordering::Relaxed);
407        assert_eq!(Arc::strong_count(&a), RESERVED_COUNT + v.len() + 2);
408        let b = x.swap(Arc::new(2), Ordering::Relaxed);
409        assert_eq!(Arc::strong_count(&b), v.len() + 2);
410    }
411}