1use std::marker::PhantomData;
2use std::ptr;
3use std::ptr::{null, null_mut, NonNull};
4use std::sync::atomic::AtomicPtr;
5use std::sync::atomic::Ordering::{Relaxed, SeqCst};
6
7use fast_smr::smr::{load, protect};
8
9use crate::smart_ptrs::{Arc, AsPtr, Guard, Weak};
10use crate::StrongPtr;
11
12#[derive(Default)]
41pub struct AtomicArc<T: 'static> {
42 ptr: AtomicPtr<T>,
43 phantom: PhantomData<T>,
44}
45
46impl<T: 'static> AtomicArc<T> {
47 pub fn new<D: Into<Option<T>>>(data: D) -> Self {
52 let ptr = data.into().map_or(null(), |x| Arc::into_raw(Arc::new(x)));
53 Self {
54 ptr: AtomicPtr::new(ptr.cast_mut()),
55 phantom: PhantomData,
56 }
57 }
58
59 pub fn compare_exchange<N: AsPtr<Target = T> + StrongPtr>(
63 &self,
64 current: *const T,
65 new: Option<&N>,
66 ) -> Result<(), Option<Guard<T>>> {
67 let c = current.cast_mut();
68 let n = new.map_or(null(), N::as_ptr).cast_mut();
69 match self.ptr.compare_exchange(c, n, SeqCst, SeqCst) {
70 Ok(before) => unsafe {
71 Self::after_swap(n, before);
72 Ok(())
73 },
74 Err(actual) => {
75 let mut opt = None;
76 if let Some(ptr) = NonNull::new(actual) {
77 if let Some(guard) = protect(&self.ptr, ptr) {
78 opt = Some(Guard { guard })
79 }
80 }
81 Err(opt)
82 }
83 }
84 }
85
86 pub fn load(&self) -> Option<Guard<T>> {
89 let guard = load(&self.ptr)?;
90 Some(Guard { guard })
91 }
92
93 pub fn store<N: AsPtr<Target = T> + StrongPtr>(&self, new: Option<&N>) {
95 let n = new.map_or(null(), N::as_ptr);
97 let before = self.ptr.swap(n.cast_mut(), SeqCst);
98 unsafe {
99 Self::after_swap(n, before);
100 }
101 }
102
103 unsafe fn after_swap(new: *const T, before: *const T) {
104 if !ptr::eq(new, before) {
105 if !new.is_null() {
106 Arc::increment_strong_count(new);
107 }
108 if !before.is_null() {
109 drop(Arc::from_raw(before));
110 }
111 }
112 }
113}
114
115impl<T: 'static> Clone for AtomicArc<T> {
116 fn clone(&self) -> Self {
117 let ptr = if let Some(guard) = self.load() {
118 unsafe {
119 Arc::increment_strong_count(guard.as_ptr());
120 }
121 guard.as_ptr().cast_mut()
122 } else {
123 null_mut()
124 };
125 Self {
126 ptr: AtomicPtr::new(ptr),
127 phantom: PhantomData,
128 }
129 }
130}
131
132impl<T: 'static> Drop for AtomicArc<T> {
133 fn drop(&mut self) {
134 if let Some(ptr) = NonNull::new(self.ptr.load(Relaxed)) {
135 unsafe {
136 drop(Arc::from_raw(ptr.as_ptr()));
137 }
138 }
139 }
140}
141
142unsafe impl<T: 'static + Send + Sync> Send for AtomicArc<T> {}
143
144unsafe impl<T: 'static + Send + Sync> Sync for AtomicArc<T> {}
145
146#[derive(Default)]
165pub struct AtomicWeak<T: 'static> {
166 ptr: AtomicPtr<T>,
167}
168
169impl<T: 'static> AtomicWeak<T> {
170 pub fn compare_exchange<N: AsPtr<Target = T>>(
174 &self,
175 current: *const T,
176 new: Option<&N>,
177 ) -> Result<(), Option<Guard<T>>> {
178 let c = current.cast_mut();
179 let n = new.map_or(null(), N::as_ptr).cast_mut();
180 match self.ptr.compare_exchange(c, n, SeqCst, SeqCst) {
181 Ok(before) => unsafe {
182 Self::after_swap(n, before);
183 Ok(())
184 },
185 Err(actual) => unsafe {
186 let mut opt = None;
187 if let Some(ptr) = NonNull::new(actual) {
188 if let Some(guard) = protect(&self.ptr, ptr) {
189 opt = (Arc::strong_count_raw(guard.as_ptr()) > 0).then_some(Guard { guard })
190 }
191 }
192 Err(opt)
193 },
194 }
195 }
196
197 pub fn load(&self) -> Option<Guard<T>> {
205 let guard = load(&self.ptr)?;
206 unsafe { (Arc::strong_count_raw(guard.as_ptr()) > 0).then_some(Guard { guard }) }
207 }
208
209 pub fn store<N: AsPtr<Target = T>>(&self, new: Option<&N>) {
211 let n = new.map_or(null(), N::as_ptr);
212 let before = self.ptr.swap(n.cast_mut(), SeqCst);
213 unsafe {
214 Self::after_swap(n, before);
215 }
216 }
217
218 unsafe fn after_swap(new: *const T, before: *const T) {
219 if !ptr::eq(new, before) {
220 if !new.is_null() {
221 Weak::increment_weak_count(new);
222 }
223 if !before.is_null() {
224 drop(Weak::from_raw(before));
225 }
226 }
227 }
228}
229
230impl<T: 'static> Clone for AtomicWeak<T> {
231 fn clone(&self) -> Self {
232 let ptr = if let Some(guard) = self.load() {
233 unsafe {
234 Weak::increment_weak_count(guard.as_ptr());
235 }
236 guard.as_ptr().cast_mut()
237 } else {
238 null_mut()
239 };
240 Self {
241 ptr: AtomicPtr::new(ptr),
242 }
243 }
244}
245
246impl<T: 'static> Drop for AtomicWeak<T> {
247 fn drop(&mut self) {
248 if let Some(ptr) = NonNull::new(self.ptr.load(Relaxed)) {
249 unsafe {
250 drop(Weak::from_raw(ptr.as_ptr()));
251 }
252 }
253 }
254}
255
256impl<T: 'static, P: AsPtr<Target = T> + StrongPtr> From<&P> for AtomicArc<T> {
257 fn from(value: &P) -> Self {
258 unsafe {
259 let ptr = P::as_ptr(value);
260 Arc::increment_strong_count(ptr);
261 Self {
262 ptr: AtomicPtr::new(ptr.cast_mut()),
263 phantom: PhantomData,
264 }
265 }
266 }
267}
268
269impl<T: 'static, P: AsPtr<Target = T>> From<&P> for AtomicWeak<T> {
270 fn from(value: &P) -> Self {
271 unsafe {
272 let ptr = P::as_ptr(value);
273 Weak::increment_weak_count(ptr);
274 Self {
275 ptr: AtomicPtr::new(ptr.cast_mut()),
276 }
277 }
278 }
279}
280
281unsafe impl<T: 'static + Send + Sync> Send for AtomicWeak<T> {}
282
283unsafe impl<T: 'static + Send + Sync> Sync for AtomicWeak<T> {}