1use std::{
2 alloc::Layout,
3 cmp::max,
4 marker::PhantomData,
5 ptr::NonNull,
6 sync::{
7 atomic::{AtomicU64, AtomicUsize, Ordering},
8 Arc,
9 },
10};
11
12pub struct AtomicArc<T> {
56 state: AtomicU64,
57 phantom: PhantomData<*mut Arc<T>>,
58}
59
60impl<T> AtomicArc<T> {
61 pub fn new(value: Arc<T>) -> Self {
63 let state = new_state(value);
64 Self {
65 state: AtomicU64::new(state),
66 phantom: PhantomData,
67 }
68 }
69
70 pub fn load(&self, order: Ordering) -> Option<Arc<T>> {
77 let state = self.state.fetch_add(1, order);
78 let (addr, count) = unpack_state(state);
79 if addr == 0 {
80 return None;
81 }
82 if count >= RESERVED_COUNT {
83 panic!("external reference count overflow");
84 }
85 if count >= RESERVED_COUNT / 2 {
86 self.push_count(addr);
87 }
88 Some(unsafe { Arc::from_raw(addr as _) })
89 }
90
91 pub fn swap(&self, value: Option<Arc<T>>, order: Ordering) -> Option<Arc<T>> {
93 let state = self.state.swap(value.map(new_state).unwrap_or(0), order);
94 let (addr, count) = unpack_state(state);
95 if addr == 0 {
96 return None;
97 }
98 unsafe {
99 decrease_count::<T>(addr, RESERVED_COUNT - count);
100 Some(Arc::from_raw(addr as _))
101 }
102 }
103
104 fn push_count(&self, expect_addr: usize) {
106 let mut current = self.state.load(Ordering::Acquire);
107 let desired = pack_state(expect_addr);
108 loop {
109 let (addr, count) = unpack_state(current);
110 if addr != expect_addr || count < RESERVED_COUNT / 2 {
111 break;
113 }
114 match self.state.compare_exchange_weak(
115 current,
116 desired,
117 Ordering::AcqRel,
118 Ordering::Relaxed,
119 ) {
120 Ok(_) => unsafe {
121 increase_count::<T>(addr, count);
122 },
123 Err(actual) => current = actual,
124 }
125 }
126 }
127}
128
129impl<T> Drop for AtomicArc<T> {
130 fn drop(&mut self) {
131 self.swap(None, Ordering::AcqRel);
132 }
133}
134
135impl<T> Default for AtomicArc<T> {
136 fn default() -> Self {
137 Self {
138 state: AtomicU64::new(0),
139 phantom: PhantomData,
140 }
141 }
142}
143
144unsafe impl<T> Sync for AtomicArc<T> {}
145unsafe impl<T> Send for AtomicArc<T> {}
146
147const RESERVED_COUNT: usize = 0x8000;
148
149fn new_state<T>(value: Arc<T>) -> u64 {
150 let addr = Arc::into_raw(value) as usize;
151 unsafe {
152 increase_count::<T>(addr, RESERVED_COUNT);
153 pack_state(addr)
154 }
155}
156
157fn pack_state(addr: usize) -> u64 {
158 let addr = addr as u64;
159 assert_eq!(addr >> 48, 0);
160 addr << 16
161}
162
163fn unpack_state(state: u64) -> (usize, usize) {
164 ((state >> 16) as usize, (state & 0xFFFF) as usize)
165}
166
167#[repr(C)]
170struct ArcInner {
171 count: AtomicUsize,
172 weak_count: AtomicUsize,
173}
174
175unsafe fn inner_ptr<T>(addr: usize) -> NonNull<ArcInner> {
176 let align = align_of::<T>();
177 let layout = Layout::new::<ArcInner>();
178 let offset = max(layout.size(), align);
179 NonNull::new_unchecked((addr - offset) as _)
180}
181
182unsafe fn increase_count<T>(addr: usize, count: usize) {
183 let ptr = inner_ptr::<T>(addr);
184 ptr.as_ref().count.fetch_add(count, Ordering::Release);
185}
186
187unsafe fn decrease_count<T>(addr: usize, count: usize) {
188 let ptr = inner_ptr::<T>(addr);
189 ptr.as_ref().count.fetch_sub(count, Ordering::Release);
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn simple() {
198 let a = Arc::new(1);
199 let b = Arc::new(2);
200 let x = AtomicArc::new(a.clone());
201 {
202 let c = x.load(Ordering::Acquire).unwrap();
203 assert_eq!(c, a);
204 assert_eq!(Arc::strong_count(&c), RESERVED_COUNT + 2);
205 }
206 {
207 let c = x.swap(Some(b.clone()), Ordering::AcqRel).unwrap();
208 assert_eq!(c, a);
209 assert_eq!(Arc::strong_count(&c), 2);
210 }
211 {
212 let c = x.load(Ordering::Acquire).unwrap();
213 assert_eq!(c, b);
214 assert_eq!(Arc::strong_count(&c), RESERVED_COUNT + 2);
215 }
216 }
217
218 #[test]
219 fn option() {
220 let x = AtomicArc::default();
221 assert!(x.load(Ordering::Acquire).is_none());
222 let a = Arc::new(1);
223 assert!(x.swap(Some(a.clone()), Ordering::AcqRel).is_none());
224 let b = x.swap(None, Ordering::AcqRel).unwrap();
225 assert_eq!(b, a);
226 assert!(x.load(Ordering::Acquire).is_none());
227 }
228
229 #[test]
230 fn push_count() {
231 let x = AtomicArc::new(Arc::new(1));
232 let mut v = Vec::new();
233 for _ in 0..(RESERVED_COUNT / 2) {
234 let a = x.load(Ordering::Relaxed).unwrap();
235 assert_eq!(Arc::strong_count(&a), RESERVED_COUNT + 1);
236 v.push(a);
237 }
238 let a = x.load(Ordering::Relaxed).unwrap();
240 assert_eq!(Arc::strong_count(&a), RESERVED_COUNT + v.len() + 2);
241 let b = x.swap(Some(Arc::new(2)), Ordering::Relaxed).unwrap();
242 assert_eq!(Arc::strong_count(&b), v.len() + 2);
243 }
244}