1use std::{
2 alloc::Layout,
3 cmp::max,
4 marker::PhantomData,
5 ptr,
6 ptr::NonNull,
7 sync::{
8 atomic::{AtomicU64, AtomicUsize, Ordering},
9 Arc,
10 },
11};
12
13struct AtomicPtr<T> {
14 state: AtomicU64,
15 phantom: PhantomData<Arc<T>>,
16}
17
18impl<T> AtomicPtr<T> {
19 fn new(value: *const T) -> Self {
21 let state = new_state(value);
22 Self {
23 state: AtomicU64::new(state),
24 phantom: PhantomData,
25 }
26 }
27
28 fn load(&self, order: Ordering) -> *const T {
30 let state = self.state.fetch_add(1, order);
31 let (addr, count) = unpack_state(state);
32 if count >= RESERVED_COUNT {
33 panic!("external reference count overflow");
34 }
35 if count >= RESERVED_COUNT / 2 {
36 self.push_count(addr);
37 }
38 addr as _
39 }
40
41 fn swap(&self, value: *const T, order: Ordering) -> *const T {
43 let state = self.state.swap(new_state(value), order);
44 let (addr, count) = unpack_state(state);
45 unsafe {
46 decrease_count::<T>(addr, RESERVED_COUNT - count);
47 addr as _
48 }
49 }
50
51 fn compare_exchange(
53 &self,
54 current: *const T,
55 new: *const T,
56 success: Ordering,
57 failure: Ordering,
58 ) -> Result<*const T, *const T> {
59 let new_state = pack_state(new.addr());
60 let mut state = self.state.load(failure);
61 loop {
62 let (addr, count) = unpack_state(state);
63 if addr != current.addr() {
64 unsafe {
65 increase_count::<T>(addr, 1);
66 }
67 return Err(addr as _);
68 }
69 match self
70 .state
71 .compare_exchange_weak(state, new_state, success, failure)
72 {
73 Ok(_) => {
74 unsafe {
75 decrease_count::<T>(addr, RESERVED_COUNT - count);
76 increase_count::<T>(new.addr(), RESERVED_COUNT + 1);
77 }
78 return Ok(addr as _);
79 }
80 Err(now_state) => state = now_state,
81 }
82 }
83 }
84
85 fn push_count(&self, expect_addr: usize) {
87 let mut current = self.state.load(Ordering::Acquire);
88 let desired = pack_state(expect_addr);
89 loop {
90 let (addr, count) = unpack_state(current);
91 if addr != expect_addr || count < RESERVED_COUNT / 2 {
92 break;
94 }
95 match self.state.compare_exchange_weak(
96 current,
97 desired,
98 Ordering::Release,
99 Ordering::Relaxed,
100 ) {
101 Ok(_) => unsafe {
102 increase_count::<T>(addr, count);
103 },
104 Err(actual) => current = actual,
105 }
106 }
107 }
108}
109
110impl<T> Drop for AtomicPtr<T> {
111 fn drop(&mut self) {
112 let state = self.state.load(Ordering::Acquire);
113 let (addr, count) = unpack_state(state);
114 unsafe {
115 decrease_count::<T>(addr, RESERVED_COUNT + 1 - count);
116 }
117 }
118}
119
120impl<T> Default for AtomicPtr<T> {
121 fn default() -> Self {
122 Self {
123 state: AtomicU64::new(0),
124 phantom: PhantomData,
125 }
126 }
127}
128
129pub struct AtomicArc<T>(AtomicPtr<T>);
173
174impl<T> AtomicArc<T> {
175 pub fn new(value: Arc<T>) -> Self {
177 Self(AtomicPtr::new(Arc::into_raw(value)))
178 }
179
180 pub fn load(&self, order: Ordering) -> Arc<T> {
185 let ptr = self.0.load(order);
186 unsafe { Arc::from_raw(ptr) }
187 }
188
189 pub fn swap(&self, value: Arc<T>, order: Ordering) -> Arc<T> {
191 let new = Arc::into_raw(value);
192 let current = self.0.swap(new, order);
193 unsafe { Arc::from_raw(current) }
194 }
195
196 pub fn compare_exchange(
199 &self,
200 current: &Arc<T>,
201 new: &Arc<T>,
202 success: Ordering,
203 failure: Ordering,
204 ) -> Result<Arc<T>, Arc<T>> {
205 let new = Arc::as_ptr(new);
206 let current = Arc::as_ptr(current);
207 self.0
208 .compare_exchange(current, new, success, failure)
209 .map(|ptr| unsafe { Arc::from_raw(ptr) })
210 .map_err(|ptr| unsafe { Arc::from_raw(ptr) })
211 }
212}
213
214pub struct AtomicOptionArc<T>(AtomicPtr<T>);
218
219impl<T> AtomicOptionArc<T> {
220 pub fn new(value: Arc<T>) -> Self {
222 Self(AtomicPtr::new(Arc::into_raw(value)))
223 }
224
225 pub fn load(&self, order: Ordering) -> Option<Arc<T>> {
229 let ptr = self.0.load(order);
230 unsafe { Self::from_ptr(ptr) }
231 }
232
233 pub fn swap(&self, value: Option<Arc<T>>, order: Ordering) -> Option<Arc<T>> {
235 let new = Self::into_ptr(value);
236 let current = self.0.swap(new, order);
237 unsafe { Self::from_ptr(current) }
238 }
239
240 pub fn compare_exchange(
243 &self,
244 current: Option<&Arc<T>>,
245 new: Option<&Arc<T>>,
246 success: Ordering,
247 failure: Ordering,
248 ) -> Result<Option<Arc<T>>, Option<Arc<T>>> {
249 let new = new.map(Arc::as_ptr).unwrap_or(ptr::null());
250 let current = current.map(Arc::as_ptr).unwrap_or(ptr::null());
251 self.0
252 .compare_exchange(current, new, success, failure)
253 .map(|ptr| unsafe { Self::from_ptr(ptr) })
254 .map_err(|ptr| unsafe { Self::from_ptr(ptr) })
255 }
256
257 fn into_ptr(value: Option<Arc<T>>) -> *const T {
258 value.map(Arc::into_raw).unwrap_or(ptr::null())
259 }
260
261 unsafe fn from_ptr(ptr: *const T) -> Option<Arc<T>> {
262 if ptr.is_null() {
263 None
264 } else {
265 Some(unsafe { Arc::from_raw(ptr) })
266 }
267 }
268}
269
270impl<T> Default for AtomicOptionArc<T> {
271 fn default() -> Self {
272 Self(AtomicPtr::new(ptr::null()))
273 }
274}
275
276const RESERVED_COUNT: usize = 0x8000;
277
278fn new_state<T>(ptr: *const T) -> u64 {
279 let addr = ptr.addr();
280 unsafe {
281 increase_count::<T>(addr, RESERVED_COUNT);
282 pack_state(addr)
283 }
284}
285
286fn pack_state(addr: usize) -> u64 {
287 let addr = addr as u64;
288 assert_eq!(addr >> 48, 0);
289 addr << 16
290}
291
292fn unpack_state(state: u64) -> (usize, usize) {
293 ((state >> 16) as usize, (state & 0xFFFF) as usize)
294}
295
296#[repr(C)]
299struct ArcInner {
300 count: AtomicUsize,
301 weak_count: AtomicUsize,
302}
303
304unsafe fn inner_ptr<T>(addr: usize) -> NonNull<ArcInner> {
305 let align = align_of::<T>();
306 let layout = Layout::new::<ArcInner>();
307 let offset = max(layout.size(), align);
308 NonNull::new_unchecked((addr - offset) as _)
309}
310
311unsafe fn increase_count<T>(addr: usize, count: usize) {
312 if addr != 0 {
313 let ptr = inner_ptr::<T>(addr);
314 ptr.as_ref().count.fetch_add(count, Ordering::Release);
315 }
316}
317
318unsafe fn decrease_count<T>(addr: usize, count: usize) {
319 if addr != 0 {
320 let ptr = inner_ptr::<T>(addr);
321 ptr.as_ref().count.fetch_sub(count, Ordering::Release);
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn test_arc() {
331 let a = Arc::new(1);
332 let b = Arc::new(2);
333 let x = AtomicArc::new(a.clone());
334 {
335 let c = x.load(Ordering::Acquire);
336 assert_eq!(c, a);
337 assert_eq!(Arc::strong_count(&a), RESERVED_COUNT + 2);
338 }
339 {
340 let c = x.swap(b.clone(), Ordering::AcqRel);
341 assert_eq!(c, a);
342 assert_eq!(Arc::strong_count(&a), 2);
343 assert_eq!(Arc::strong_count(&b), RESERVED_COUNT + 2);
344 let c = x.load(Ordering::Acquire);
345 assert_eq!(c, b);
346 assert_eq!(Arc::strong_count(&b), RESERVED_COUNT + 2);
347 }
348 {
349 let c = x
350 .compare_exchange(&b, &a, Ordering::AcqRel, Ordering::Acquire)
351 .unwrap();
352 assert_eq!(c, b);
353 assert_eq!(Arc::strong_count(&b), 2);
354 assert_eq!(Arc::strong_count(&a), RESERVED_COUNT + 2);
355 let c = x
356 .compare_exchange(&b, &a, Ordering::AcqRel, Ordering::Acquire)
357 .unwrap_err();
358 assert_eq!(c, a);
359 assert_eq!(Arc::strong_count(&a), RESERVED_COUNT + 3);
360 }
361 drop(x);
362 assert_eq!(Arc::strong_count(&a), 1);
363 assert_eq!(Arc::strong_count(&b), 1);
364 }
365
366 #[test]
367 fn test_option_arc() {
368 let a = Arc::new(1);
369 let b = Arc::new(2);
370 let x = AtomicOptionArc::new(a.clone());
371 {
372 let c = x.load(Ordering::Acquire);
373 assert_eq!(c, Some(a.clone()));
374 }
375 {
376 let c = x.swap(Some(b.clone()), Ordering::AcqRel);
377 assert_eq!(c, Some(a.clone()));
378 let c = x.load(Ordering::Acquire);
379 assert_eq!(c, Some(b.clone()));
380 }
381 {
382 let c = x
383 .compare_exchange(Some(&b), None, Ordering::AcqRel, Ordering::Relaxed)
384 .unwrap();
385 assert_eq!(c, Some(b.clone()));
386 let c = x
387 .compare_exchange(Some(&b), None, Ordering::AcqRel, Ordering::Relaxed)
388 .unwrap_err();
389 assert_eq!(c, None);
390 }
391 assert_eq!(x.load(Ordering::Acquire), None);
392 assert_eq!(Arc::strong_count(&a), 1);
393 assert_eq!(Arc::strong_count(&b), 1);
394 }
395
396 #[test]
397 fn test_push_count() {
398 let x = AtomicArc::new(Arc::new(1));
399 let mut v = Vec::new();
400 for _ in 0..(RESERVED_COUNT / 2) {
401 let a = x.load(Ordering::Relaxed);
402 assert_eq!(Arc::strong_count(&a), RESERVED_COUNT + 1);
403 v.push(a);
404 }
405 let a = x.load(Ordering::Relaxed);
407 assert_eq!(Arc::strong_count(&a), RESERVED_COUNT + v.len() + 2);
408 let b = x.swap(Arc::new(2), Ordering::Relaxed);
409 assert_eq!(Arc::strong_count(&b), v.len() + 2);
410 }
411}