1use std::marker::PhantomData;
2use std::ptr::{eq, null, null_mut, NonNull};
3use std::sync::atomic::AtomicPtr;
4use std::sync::atomic::Ordering::SeqCst;
5
6use crate::smart_ptrs::{find_inner_ptr, ArcInner, Guard, CTX};
7use crate::Arc;
8
9#[derive(Default)]
45pub struct AtomicArc<T: 'static> {
46 ptr: AtomicPtr<ArcInner<T>>,
47 phantom: PhantomData<ArcInner<T>>,
48}
49
50impl<T: 'static> AtomicArc<T> {
51 pub fn new<D: Into<Option<T>>>(data: D) -> Self {
56 let ptr = data.into().map_or(null_mut(), ArcInner::new);
57 Self {
58 ptr: AtomicPtr::new(ptr),
59 phantom: PhantomData,
60 }
61 }
62
63 pub fn load(&self) -> Option<Guard<'static, T>> {
66 let guard = CTX.with_borrow(|ctx| ctx.load(&self.ptr, 1))?;
67 Some(Guard { guard })
68 }
69
70 pub fn swap<N: Into<NonNull<T>>>(&self, new: Option<N>) -> Option<Arc<T>> {
72 unsafe {
73 let n = new.map_or(null_mut(), |n| find_inner_ptr(n.into().as_ptr()).cast_mut());
74 if !n.is_null() {
75 ArcInner::increment(n);
76 }
77 let before = NonNull::new(self.ptr.swap(n, SeqCst))?;
78 Some(Arc {
79 ptr: before,
80 phantom: PhantomData,
81 })
82 }
83 }
84
85 pub fn store<N: Into<NonNull<T>>>(&self, new: Option<N>) {
87 _ = self.swap(new)
88 }
89}
90
91pub trait CompareExchange<T, N> {
97 fn compare_exchange<C: Into<NonNull<T>>>(
98 &self,
99 current: Option<C>,
100 new: Option<N>,
101 ) -> Result<(), Option<Guard<'static, T>>>;
102}
103
104impl<T: 'static> CompareExchange<T, &Guard<'static, T>> for AtomicArc<T> {
105 fn compare_exchange<C: Into<NonNull<T>>>(
106 &self,
107 current: Option<C>,
108 new: Option<&Guard<'static, T>>,
109 ) -> Result<(), Option<Guard<'static, T>>> {
110 unsafe {
111 let c = current.map_or(null_mut(), |c| find_inner_ptr(c.into().as_ptr()).cast_mut());
112 let n = new.map_or(null(), Guard::inner_ptr).cast_mut();
113 match self.ptr.compare_exchange(c, n, SeqCst, SeqCst) {
114 Ok(before) => {
115 if !eq(before, n) {
116 if !n.is_null() {
117 ArcInner::increment(n);
118 }
119 if !before.is_null() {
120 ArcInner::delayed_decrement(before);
121 }
122 }
123 Ok(())
124 }
125 Err(actual) => {
126 if let Some(ptr) = NonNull::new(actual) {
127 let mut opt = None;
128 let loaded = CTX.with_borrow(|ctx| ctx.protect(&self.ptr, ptr, 1));
129 if let Some(guard) = loaded {
130 opt = Some(Guard { guard })
131 }
132 Err(opt)
133 } else {
134 Err(None)
135 }
136 }
137 }
138 }
139 }
140}
141
142impl<T: 'static> CompareExchange<T, &Arc<T>> for AtomicArc<T> {
143 fn compare_exchange<C: Into<NonNull<T>>>(
144 &self,
145 current: Option<C>,
146 new: Option<&Arc<T>>,
147 ) -> Result<(), Option<Guard<'static, T>>> {
148 let g = new.map(Guard::from);
149 CompareExchange::compare_exchange(self, current, g.as_ref())
150 }
151}
152
153impl<T: 'static> Clone for AtomicArc<T> {
154 fn clone(&self) -> Self {
155 let ptr = if let Some(guard) = self.load() {
156 unsafe {
157 let ptr = guard.guard.as_ptr();
158 _ = (*ptr).ref_count.fetch_add(1, SeqCst);
159 ptr
160 }
161 } else {
162 null_mut()
163 };
164 Self {
165 ptr: AtomicPtr::new(ptr.cast_mut()),
166 phantom: PhantomData,
167 }
168 }
169}
170
171impl<T: 'static> Drop for AtomicArc<T> {
172 fn drop(&mut self) {
173 if let Some(ptr) = NonNull::new(self.ptr.load(SeqCst)) {
174 unsafe {
175 ArcInner::delayed_decrement(ptr.as_ptr());
176 }
177 }
178 }
179}
180
181unsafe impl<T: 'static + Send + Sync> Send for AtomicArc<T> {}
182
183unsafe impl<T: 'static + Send + Sync> Sync for AtomicArc<T> {}
184
185impl<T: 'static, P: Into<NonNull<T>>> From<P> for AtomicArc<T> {
186 fn from(value: P) -> Self {
187 unsafe {
188 let inner_ptr = find_inner_ptr(value.into().as_ptr());
189 _ = (*inner_ptr).ref_count.fetch_add(1, SeqCst);
190 Self {
191 ptr: AtomicPtr::new(inner_ptr.cast_mut()),
192 phantom: PhantomData,
193 }
194 }
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use crate::{Arc, AtomicArc, CompareExchange};
201
202 #[test]
203 fn test_new_with_value() {
204 let atomic = AtomicArc::new(42);
205 let guard = atomic.load().unwrap();
206 assert_eq!(*guard, 42);
207 }
208
209 #[test]
210 fn test_new_with_none() {
211 let atomic: AtomicArc<i32> = AtomicArc::new(None);
212 assert!(atomic.load().is_none());
213 }
214
215 #[test]
216 fn test_swap() {
217 let atomic = AtomicArc::new(10);
218 let arc = Arc::new(20);
219
220 let old = atomic.swap(Some(&arc));
221 assert!(old.is_some());
222 assert_eq!(*old.unwrap(), 10);
223
224 let guard = atomic.load().unwrap();
225 assert_eq!(*guard, 20);
226 }
227
228 #[test]
229 fn test_swap_none() {
230 let atomic = AtomicArc::new(10);
231 let old = atomic.swap::<&Arc<i32>>(None);
232
233 assert!(old.is_some());
234 assert_eq!(*old.unwrap(), 10);
235 assert!(atomic.load().is_none());
236 }
237
238 #[test]
239 fn test_clone() {
240 let atomic = AtomicArc::new(42);
241 let cloned = atomic.clone();
242
243 let guard1 = atomic.load().unwrap();
244 let guard2 = cloned.load().unwrap();
245
246 assert_eq!(*guard1, 42);
247 assert_eq!(*guard2, 42);
248 }
249
250 #[test]
251 fn test_clone_none() {
252 let atomic: AtomicArc<i32> = AtomicArc::new(None);
253 let cloned = atomic.clone();
254
255 assert!(atomic.load().is_none());
256 assert!(cloned.load().is_none());
257 }
258
259 #[test]
260 fn test_compare_exchange_success_with_arc() {
261 let arc1 = Arc::new(10);
262 let arc2 = Arc::new(20);
263 let atomic = AtomicArc::new(10);
264 atomic.store(Some(&arc1));
265
266 let result = atomic.compare_exchange(Some(&arc1), Some(&arc2));
267 assert!(result.is_ok());
268
269 let guard = atomic.load().unwrap();
270 assert_eq!(*guard, 20);
271 }
272
273 #[test]
274 fn test_compare_exchange_failure_with_arc() {
275 let arc1 = Arc::new(10);
276 let arc2 = Arc::new(20);
277 let arc3 = Arc::new(30);
278 let atomic = AtomicArc::new(10);
279 atomic.store(Some(&arc1));
280
281 let result = atomic.compare_exchange(Some(&arc2), Some(&arc3));
283 assert!(result.is_err());
284
285 let guard = atomic.load().unwrap();
287 assert_eq!(*guard, 10);
288 }
289
290 #[test]
291 fn test_compare_exchange_with_guard() {
292 let arc1 = Arc::new(10);
293 let arc2 = Arc::new(20);
294 let atomic = AtomicArc::new(10);
295 atomic.store(Some(&arc1));
296
297 let guard = atomic.load().unwrap();
298 let result = atomic.compare_exchange(Some(&guard), Some(&arc2));
299 assert!(result.is_ok());
300
301 let new_guard = atomic.load().unwrap();
302 assert_eq!(*new_guard, 20);
303 }
304
305 #[test]
306 fn test_from_arc() {
307 let arc = Arc::new(42);
308 let atomic = AtomicArc::new(0);
309 atomic.store(Some(&arc));
310
311 let guard = atomic.load().unwrap();
312 assert_eq!(*guard, 42);
313 assert_eq!(*arc, 42);
314 }
315}