atomic_ref2/
option_ref.rs

1use super::spinlock::SpinRwLock;
2use super::IntoOptionArc;
3use std::mem;
4use std::ptr::null_mut;
5use std::sync::atomic::{AtomicPtr, Ordering};
6use std::sync::Arc;
7
8/// An atomic reference that may be updated atomically.
9pub struct AtomicOptionRef<T> {
10    ptr: AtomicPtr<T>,
11    lock: SpinRwLock,
12}
13
14impl<T> AtomicOptionRef<T> {
15    /// Creates a new atomic reference with `None` initial value.
16    pub fn new() -> Self {
17        Self::default()
18    }
19
20    /// Creates a new atomic reference from the given initial value.
21    pub fn from(value: impl IntoOptionArc<T>) -> Self {
22        Self {
23            ptr: AtomicPtr::new(option_arc_to_ptr(value)),
24            lock: SpinRwLock::new(),
25        }
26    }
27
28    /// Returns `true` if the optional reference has `Some` value.
29    pub fn is_some(&self) -> bool {
30        self.ptr.load(Ordering::SeqCst).is_null()
31    }
32
33    /// Loads and returns a reference to the value or `None`
34    /// if the value is not set.
35    pub fn load(&self) -> Option<Arc<T>> {
36        let _guard = self.lock.read();
37        ptr_to_option_arc(self.ptr.load(Ordering::SeqCst), true)
38    }
39
40    /// Stores the value.
41    pub fn store(&self, value: impl IntoOptionArc<T>) {
42        self.swap(value);
43    }
44
45    /// Swaps the value, returning the previous value.
46    pub fn swap(&self, value: impl IntoOptionArc<T>) -> Option<Arc<T>> {
47        let _guard = self.lock.write();
48        ptr_to_option_arc(
49            self.ptr.swap(option_arc_to_ptr(value), Ordering::SeqCst),
50            false,
51        )
52    }
53}
54
55impl<T> Default for AtomicOptionRef<T> {
56    fn default() -> Self {
57        Self::from(None)
58    }
59}
60
61impl<T> Drop for AtomicOptionRef<T> {
62    fn drop(&mut self) {
63        let ptr = self.ptr.swap(null_mut(), Ordering::SeqCst);
64        if !ptr.is_null() {
65            unsafe {
66                // Reconstruct the Arc from the raw ptr which will trigger our destructor
67                // if there is one
68                let _ = Arc::from_raw(ptr);
69            }
70        }
71    }
72}
73
74fn option_arc_to_ptr<T>(value: impl IntoOptionArc<T>) -> *mut T {
75    if let Some(value) = value.into_option_arc() {
76        Arc::into_raw(value) as *mut _
77    } else {
78        null_mut()
79    }
80}
81
82fn ptr_to_option_arc<T>(ptr: *mut T, increment: bool) -> Option<Arc<T>> {
83    if ptr.is_null() {
84        // Return `None` if null is stored in the AtomicPtr
85        None
86    } else {
87        // Otherwise, reconstruct the stored Arc
88        let value = unsafe { Arc::from_raw(ptr) };
89
90        if increment {
91            // Increment the atomic reference count
92            mem::forget(Arc::clone(&value));
93        }
94
95        // And return our reference
96        Some(value)
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::AtomicOptionRef;
103
104    #[test]
105    fn test_store_load() {
106        let m = AtomicOptionRef::<String>::new();
107
108        // Store
109        m.store(String::from("2"));
110
111        // Load and assert
112        assert_eq!(m.load().unwrap().as_ref(), "2");
113    }
114
115    #[test]
116    fn test_overwrite() {
117        let m = AtomicOptionRef::<String>::new();
118
119        // Store
120        m.store(String::from("Hello World"));
121
122        // Take a reference
123        let m0 = m.load();
124
125        // Store (again)
126        m.store(String::from("Goodbye World"));
127
128        // Compare value of stored
129        assert_eq!(m0.unwrap().as_ref(), "Hello World");
130
131        // Compare value of new
132        assert_eq!(m.load().unwrap().as_ref(), "Goodbye World");
133    }
134
135    #[test]
136    fn test_drop() {
137        use std::sync::atomic::{AtomicUsize, Ordering};
138
139        static DROPS: AtomicUsize = AtomicUsize::new(0);
140
141        struct Foo;
142
143        impl Drop for Foo {
144            fn drop(&mut self) {
145                DROPS.fetch_add(1, Ordering::SeqCst);
146            }
147        }
148
149        let m = AtomicOptionRef::<Foo>::new();
150
151        m.swap(Foo);
152        m.swap(Foo);
153
154        assert_eq!(DROPS.load(Ordering::SeqCst), 1);
155    }
156
157    #[test]
158    fn test_threads() {
159        use rand::{thread_rng, Rng};
160        use std::sync::atomic::{AtomicUsize, Ordering};
161        use std::sync::Arc;
162        use std::thread;
163        use std::time::Duration;
164
165        const THREADS: usize = 100;
166        const ITERATIONS: usize = 100;
167
168        static DROPS: AtomicUsize = AtomicUsize::new(0);
169
170        #[derive(Default)]
171        struct Foo {
172            dropped: AtomicUsize,
173        };
174
175        impl Drop for Foo {
176            fn drop(&mut self) {
177                self.dropped.fetch_add(1, Ordering::SeqCst);
178                DROPS.fetch_add(1, Ordering::SeqCst);
179            }
180        }
181
182        let m = Arc::new(AtomicOptionRef::<Foo>::new());
183        m.store(Foo::default());
184
185        let mut threads = Vec::new();
186
187        for _ in 0..THREADS {
188            let m0 = Arc::clone(&m);
189            threads.push(thread::spawn(move || {
190                for _ in 0..ITERATIONS {
191                    let value = m0.load().unwrap();
192
193                    assert_eq!(value.dropped.load(Ordering::SeqCst), 0);
194
195                    let ms = thread_rng().gen_range(0, 10);
196                    thread::sleep(Duration::from_millis(ms));
197                }
198            }));
199
200            let m1 = Arc::clone(&m);
201            threads.push(thread::spawn(move || {
202                for _ in 0..ITERATIONS {
203                    m1.swap(Foo::default());
204
205                    let ms = thread_rng().gen_range(0, 10);
206                    thread::sleep(Duration::from_millis(ms));
207                }
208            }));
209        }
210
211        for thread in threads {
212            let _ = thread.join();
213        }
214
215        assert_eq!(DROPS.load(Ordering::SeqCst), (THREADS * ITERATIONS));
216    }
217}