atomic_ref2/
option_ref.rs1use super::spinlock::SpinRwLock;
2use super::IntoOptionArc;
3use std::mem;
4use std::ptr::null_mut;
5use std::sync::atomic::{AtomicPtr, Ordering};
6use std::sync::Arc;
7
8pub struct AtomicOptionRef<T> {
10 ptr: AtomicPtr<T>,
11 lock: SpinRwLock,
12}
13
14impl<T> AtomicOptionRef<T> {
15 pub fn new() -> Self {
17 Self::default()
18 }
19
20 pub fn from(value: impl IntoOptionArc<T>) -> Self {
22 Self {
23 ptr: AtomicPtr::new(option_arc_to_ptr(value)),
24 lock: SpinRwLock::new(),
25 }
26 }
27
28 pub fn is_some(&self) -> bool {
30 self.ptr.load(Ordering::SeqCst).is_null()
31 }
32
33 pub fn load(&self) -> Option<Arc<T>> {
36 let _guard = self.lock.read();
37 ptr_to_option_arc(self.ptr.load(Ordering::SeqCst), true)
38 }
39
40 pub fn store(&self, value: impl IntoOptionArc<T>) {
42 self.swap(value);
43 }
44
45 pub fn swap(&self, value: impl IntoOptionArc<T>) -> Option<Arc<T>> {
47 let _guard = self.lock.write();
48 ptr_to_option_arc(
49 self.ptr.swap(option_arc_to_ptr(value), Ordering::SeqCst),
50 false,
51 )
52 }
53}
54
55impl<T> Default for AtomicOptionRef<T> {
56 fn default() -> Self {
57 Self::from(None)
58 }
59}
60
61impl<T> Drop for AtomicOptionRef<T> {
62 fn drop(&mut self) {
63 let ptr = self.ptr.swap(null_mut(), Ordering::SeqCst);
64 if !ptr.is_null() {
65 unsafe {
66 let _ = Arc::from_raw(ptr);
69 }
70 }
71 }
72}
73
74fn option_arc_to_ptr<T>(value: impl IntoOptionArc<T>) -> *mut T {
75 if let Some(value) = value.into_option_arc() {
76 Arc::into_raw(value) as *mut _
77 } else {
78 null_mut()
79 }
80}
81
82fn ptr_to_option_arc<T>(ptr: *mut T, increment: bool) -> Option<Arc<T>> {
83 if ptr.is_null() {
84 None
86 } else {
87 let value = unsafe { Arc::from_raw(ptr) };
89
90 if increment {
91 mem::forget(Arc::clone(&value));
93 }
94
95 Some(value)
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::AtomicOptionRef;
103
104 #[test]
105 fn test_store_load() {
106 let m = AtomicOptionRef::<String>::new();
107
108 m.store(String::from("2"));
110
111 assert_eq!(m.load().unwrap().as_ref(), "2");
113 }
114
115 #[test]
116 fn test_overwrite() {
117 let m = AtomicOptionRef::<String>::new();
118
119 m.store(String::from("Hello World"));
121
122 let m0 = m.load();
124
125 m.store(String::from("Goodbye World"));
127
128 assert_eq!(m0.unwrap().as_ref(), "Hello World");
130
131 assert_eq!(m.load().unwrap().as_ref(), "Goodbye World");
133 }
134
135 #[test]
136 fn test_drop() {
137 use std::sync::atomic::{AtomicUsize, Ordering};
138
139 static DROPS: AtomicUsize = AtomicUsize::new(0);
140
141 struct Foo;
142
143 impl Drop for Foo {
144 fn drop(&mut self) {
145 DROPS.fetch_add(1, Ordering::SeqCst);
146 }
147 }
148
149 let m = AtomicOptionRef::<Foo>::new();
150
151 m.swap(Foo);
152 m.swap(Foo);
153
154 assert_eq!(DROPS.load(Ordering::SeqCst), 1);
155 }
156
157 #[test]
158 fn test_threads() {
159 use rand::{thread_rng, Rng};
160 use std::sync::atomic::{AtomicUsize, Ordering};
161 use std::sync::Arc;
162 use std::thread;
163 use std::time::Duration;
164
165 const THREADS: usize = 100;
166 const ITERATIONS: usize = 100;
167
168 static DROPS: AtomicUsize = AtomicUsize::new(0);
169
170 #[derive(Default)]
171 struct Foo {
172 dropped: AtomicUsize,
173 };
174
175 impl Drop for Foo {
176 fn drop(&mut self) {
177 self.dropped.fetch_add(1, Ordering::SeqCst);
178 DROPS.fetch_add(1, Ordering::SeqCst);
179 }
180 }
181
182 let m = Arc::new(AtomicOptionRef::<Foo>::new());
183 m.store(Foo::default());
184
185 let mut threads = Vec::new();
186
187 for _ in 0..THREADS {
188 let m0 = Arc::clone(&m);
189 threads.push(thread::spawn(move || {
190 for _ in 0..ITERATIONS {
191 let value = m0.load().unwrap();
192
193 assert_eq!(value.dropped.load(Ordering::SeqCst), 0);
194
195 let ms = thread_rng().gen_range(0, 10);
196 thread::sleep(Duration::from_millis(ms));
197 }
198 }));
199
200 let m1 = Arc::clone(&m);
201 threads.push(thread::spawn(move || {
202 for _ in 0..ITERATIONS {
203 m1.swap(Foo::default());
204
205 let ms = thread_rng().gen_range(0, 10);
206 thread::sleep(Duration::from_millis(ms));
207 }
208 }));
209 }
210
211 for thread in threads {
212 let _ = thread.join();
213 }
214
215 assert_eq!(DROPS.load(Ordering::SeqCst), (THREADS * ITERATIONS));
216 }
217}