1use std::fmt::Formatter;
5use std::ptr::null_mut;
6use std::sync::atomic::{AtomicPtr, Ordering};
7use std::sync::Arc;
8
9#[repr(transparent)]
31pub struct AtomicRef<T>(AtomicPtr<T>);
32
33unsafe impl<T> Send for AtomicRef<T> {}
34unsafe impl<T> Sync for AtomicRef<T> {}
35
36impl<T> AtomicRef<T> {
37 pub fn new(value: T) -> Self {
40 let arc = Arc::new(value);
41 let ptr = Arc::into_raw(arc) as *mut _;
42 AtomicRef(AtomicPtr::new(ptr))
43 }
44
45 pub fn get(&self) -> Option<Arc<T>> {
49 let ptr = self.0.load(Ordering::SeqCst);
50 if ptr.is_null() {
51 None
52 } else {
53 let arc = unsafe { Arc::from_raw(ptr) };
54 let result = arc.clone();
55 std::mem::forget(arc);
56 Some(result)
57 }
58 }
59
60 pub fn swap(&self, value: T) -> Option<Arc<T>> {
62 let new_ptr = Arc::into_raw(Arc::new(value)) as *mut _;
63 let prev = self.0.swap(new_ptr, Ordering::Release);
64 if prev.is_null() {
65 None
66 } else {
67 let arc = unsafe { Arc::from_raw(prev) };
68 Some(arc)
69 }
70 }
71
72 pub fn take(&self) -> Option<Arc<T>> {
74 let prev = self.0.swap(null_mut(), Ordering::Release);
75 if prev.is_null() {
76 None
77 } else {
78 let arc = unsafe { Arc::from_raw(prev) };
79 Some(arc)
80 }
81 }
82
83 pub fn update<F>(&self, f: F)
90 where
91 F: Fn(Option<&T>) -> T,
92 {
93 loop {
94 let old_ptr = self.0.load(Ordering::SeqCst);
95 let old_value = unsafe { old_ptr.as_ref() };
96
97 let new_value = f(old_value);
99
100 let new_ptr = Arc::into_raw(Arc::new(new_value)) as *mut _;
101
102 let swapped =
103 self.0
104 .compare_exchange(old_ptr, new_ptr, Ordering::AcqRel, Ordering::Relaxed);
105
106 match swapped {
107 Ok(old) => {
108 if !old.is_null() {
109 unsafe { Arc::decrement_strong_count(old) }; }
111 break; }
113 Err(new) => {
114 if !new.is_null() {
115 unsafe { Arc::decrement_strong_count(new) }; }
117 }
118 }
119 }
120 }
121}
122
123impl<T: Copy> AtomicRef<T> {
124 pub fn get_owned(&self) -> Option<T> {
128 let ptr = self.0.load(Ordering::SeqCst);
129 if ptr.is_null() {
130 None
131 } else {
132 let arc = unsafe { Arc::from_raw(ptr) };
133 let result = *arc;
134 std::mem::forget(arc);
135 Some(result)
136 }
137 }
138}
139
140impl<T> Drop for AtomicRef<T> {
141 fn drop(&mut self) {
142 unsafe {
143 let ptr = self.0.load(Ordering::Acquire);
144 if !ptr.is_null() {
145 Arc::decrement_strong_count(ptr);
146 }
147 }
148 }
149}
150
151impl<T> PartialEq for AtomicRef<T>
152where
153 T: PartialEq,
154{
155 fn eq(&self, other: &Self) -> bool {
156 let a = self.0.load(Ordering::Acquire);
157 let b = other.0.load(Ordering::Acquire);
158 if std::ptr::eq(a, b) {
159 true
160 } else {
161 unsafe { a.as_ref() == b.as_ref() }
162 }
163 }
164}
165
166impl<T> Eq for AtomicRef<T> where T: Eq {}
167
168impl<T: std::fmt::Debug> std::fmt::Debug for AtomicRef<T> {
169 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
170 let value = self.get();
171 write!(f, "AtomicRef({:?})", value.as_deref())
172 }
173}
174
175impl<T> Default for AtomicRef<T> {
176 fn default() -> Self {
177 AtomicRef(AtomicPtr::new(null_mut()))
178 }
179}
180
181#[cfg(test)]
182mod test {
183 use crate::atomic::AtomicRef;
184
185 #[test]
186 fn init_get() {
187 let atom = AtomicRef::new(1);
188 let value = atom.get();
189 assert_eq!(value.as_deref().cloned(), Some(1));
190 }
191
192 #[test]
193 fn update() {
194 let atom = AtomicRef::new(vec!["John"]);
195 let old_users = atom.get().unwrap();
196 let actual: &[&str] = &old_users;
197 assert_eq!(actual, &["John"]);
198
199 atom.update(|users| {
200 let mut users_copy = users.cloned().unwrap_or_else(Vec::default);
201 users_copy.push("Susan");
202 users_copy
203 });
204
205 let new_users = atom.get().unwrap();
207 let actual: &[&str] = &new_users;
208 assert_eq!(actual, &["John", "Susan"]);
209
210 let actual: &[&str] = &old_users;
212 assert_eq!(actual, &["John"]);
213 }
214}