atomic_ext/
lib.rs

1use std::{
2    alloc::Layout,
3    cmp::max,
4    marker::PhantomData,
5    ptr::NonNull,
6    sync::{
7        atomic::{AtomicU64, AtomicUsize, Ordering},
8        Arc,
9    },
10};
11
12/// A lightweight atomic pointer to [`Arc`].
13///
14/// **Note**: The imlementation manipuates the internal reference count of the
15/// original [`Arc`] for optimization. This means that the result of
16/// [`Arc::strong_count`] is incorrect, until the [`Arc`] gets rid of
17/// this pointer's control (with [`AtomicArc::swap`]). Users who depend on the
18/// correctness of [`Arc::strong_count`] should be careful.
19///
20/// # Limitations
21///
22/// The implementation borrows some bits from the `Arc` pointer as an external
23/// reference count (a technique called "split reference counting"). It
24/// will panic in some extreame scenario when the reference count is increased
25/// more than a threshold (2^15) at the same time. This is almost impossible
26/// unless someone creates more than 2^15 threads to load the same pointer at
27/// the same time.
28///
29/// # Examples
30///
31/// ```
32/// use std::{
33///     sync::{atomic::Ordering, Arc},
34///     thread,
35/// };
36///
37/// use atomic_ext::AtomicArc;
38///
39/// let a = Arc::new(1);
40/// let b = Arc::new(2);
41/// let x = Arc::new(AtomicArc::new(a));
42/// {
43///     let x = x.clone();
44///     thread::spawn(move || {
45///         x.swap(Some(b), Ordering::AcqRel) // Returns `a`
46///     });
47/// }
48/// {
49///     let x = x.clone();
50///     thread::spawn(move || {
51///         x.load(Ordering::Acquire) // Returns either `a` or `b`
52///     });
53/// };
54/// ```
55pub struct AtomicArc<T> {
56    state: AtomicU64,
57    phantom: PhantomData<*mut Arc<T>>,
58}
59
60impl<T> AtomicArc<T> {
61    /// Constructs a new [`AtomicArc`].
62    pub fn new(value: Arc<T>) -> Self {
63        let state = new_state(value);
64        Self {
65            state: AtomicU64::new(state),
66            phantom: PhantomData,
67        }
68    }
69
70    /// Loads an [`Arc`] from the pointer.
71    ///
72    /// The fast path uses just one atomic operation to load the [`Arc`] and
73    /// increase its reference count.
74    ///
75    /// Returns [`None`] if the pointer is null.
76    pub fn load(&self, order: Ordering) -> Option<Arc<T>> {
77        let state = self.state.fetch_add(1, order);
78        let (addr, count) = unpack_state(state);
79        if addr == 0 {
80            return None;
81        }
82        if count >= RESERVED_COUNT {
83            panic!("external reference count overflow");
84        }
85        if count >= RESERVED_COUNT / 2 {
86            self.push_count(addr);
87        }
88        Some(unsafe { Arc::from_raw(addr as _) })
89    }
90
91    /// Stores an [`Arc`] into the pointer, returning the previous value.
92    pub fn swap(&self, value: Option<Arc<T>>, order: Ordering) -> Option<Arc<T>> {
93        let state = self.state.swap(value.map(new_state).unwrap_or(0), order);
94        let (addr, count) = unpack_state(state);
95        if addr == 0 {
96            return None;
97        }
98        unsafe {
99            decrease_count::<T>(addr, RESERVED_COUNT - count);
100            Some(Arc::from_raw(addr as _))
101        }
102    }
103
104    /// Pushes the external reference count back to the original [`Arc`].
105    fn push_count(&self, expect_addr: usize) {
106        let mut current = self.state.load(Ordering::Acquire);
107        let desired = pack_state(expect_addr);
108        loop {
109            let (addr, count) = unpack_state(current);
110            if addr != expect_addr || count < RESERVED_COUNT / 2 {
111                // Someone else has changed the address or the reference count.
112                break;
113            }
114            match self.state.compare_exchange_weak(
115                current,
116                desired,
117                Ordering::AcqRel,
118                Ordering::Relaxed,
119            ) {
120                Ok(_) => unsafe {
121                    increase_count::<T>(addr, count);
122                },
123                Err(actual) => current = actual,
124            }
125        }
126    }
127}
128
129impl<T> Drop for AtomicArc<T> {
130    fn drop(&mut self) {
131        self.swap(None, Ordering::AcqRel);
132    }
133}
134
135impl<T> Default for AtomicArc<T> {
136    fn default() -> Self {
137        Self {
138            state: AtomicU64::new(0),
139            phantom: PhantomData,
140        }
141    }
142}
143
144unsafe impl<T> Sync for AtomicArc<T> {}
145unsafe impl<T> Send for AtomicArc<T> {}
146
147const RESERVED_COUNT: usize = 0x8000;
148
149fn new_state<T>(value: Arc<T>) -> u64 {
150    let addr = Arc::into_raw(value) as usize;
151    unsafe {
152        increase_count::<T>(addr, RESERVED_COUNT);
153        pack_state(addr)
154    }
155}
156
157fn pack_state(addr: usize) -> u64 {
158    let addr = addr as u64;
159    assert_eq!(addr >> 48, 0);
160    addr << 16
161}
162
163fn unpack_state(state: u64) -> (usize, usize) {
164    ((state >> 16) as usize, (state & 0xFFFF) as usize)
165}
166
167/// Constructs the same layout with [`std::sync::Arc`] so that we can manipulate
168/// the internal reference count.
169#[repr(C)]
170struct ArcInner {
171    count: AtomicUsize,
172    weak_count: AtomicUsize,
173}
174
175unsafe fn inner_ptr<T>(addr: usize) -> NonNull<ArcInner> {
176    let align = align_of::<T>();
177    let layout = Layout::new::<ArcInner>();
178    let offset = max(layout.size(), align);
179    NonNull::new_unchecked((addr - offset) as _)
180}
181
182unsafe fn increase_count<T>(addr: usize, count: usize) {
183    let ptr = inner_ptr::<T>(addr);
184    ptr.as_ref().count.fetch_add(count, Ordering::Release);
185}
186
187unsafe fn decrease_count<T>(addr: usize, count: usize) {
188    let ptr = inner_ptr::<T>(addr);
189    ptr.as_ref().count.fetch_sub(count, Ordering::Release);
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn simple() {
198        let a = Arc::new(1);
199        let b = Arc::new(2);
200        let x = AtomicArc::new(a.clone());
201        {
202            let c = x.load(Ordering::Acquire).unwrap();
203            assert_eq!(c, a);
204            assert_eq!(Arc::strong_count(&c), RESERVED_COUNT + 2);
205        }
206        {
207            let c = x.swap(Some(b.clone()), Ordering::AcqRel).unwrap();
208            assert_eq!(c, a);
209            assert_eq!(Arc::strong_count(&c), 2);
210        }
211        {
212            let c = x.load(Ordering::Acquire).unwrap();
213            assert_eq!(c, b);
214            assert_eq!(Arc::strong_count(&c), RESERVED_COUNT + 2);
215        }
216    }
217
218    #[test]
219    fn option() {
220        let x = AtomicArc::default();
221        assert!(x.load(Ordering::Acquire).is_none());
222        let a = Arc::new(1);
223        assert!(x.swap(Some(a.clone()), Ordering::AcqRel).is_none());
224        let b = x.swap(None, Ordering::AcqRel).unwrap();
225        assert_eq!(b, a);
226        assert!(x.load(Ordering::Acquire).is_none());
227    }
228
229    #[test]
230    fn push_count() {
231        let x = AtomicArc::new(Arc::new(1));
232        let mut v = Vec::new();
233        for _ in 0..(RESERVED_COUNT / 2) {
234            let a = x.load(Ordering::Relaxed).unwrap();
235            assert_eq!(Arc::strong_count(&a), RESERVED_COUNT + 1);
236            v.push(a);
237        }
238        // This load will push the external count back to the `Arc`.
239        let a = x.load(Ordering::Relaxed).unwrap();
240        assert_eq!(Arc::strong_count(&a), RESERVED_COUNT + v.len() + 2);
241        let b = x.swap(Some(Arc::new(2)), Ordering::Relaxed).unwrap();
242        assert_eq!(Arc::strong_count(&b), v.len() + 2);
243    }
244}