atomption/
lib.rs

1use std::marker::PhantomData;
2use std::ptr::null_mut;
3use std::sync::atomic::{AtomicPtr, Ordering};
4
5type PhantomUnsync<T> = PhantomData<*mut T>;
6
7pub struct AtomicOption<T> {
8    inner: AtomicPtr<T>,
9    _phantom: PhantomUnsync<T>,
10}
11
12impl<T> AtomicOption<T> {
13    #[inline(always)]
14    pub fn new(data: Option<Box<T>>) -> AtomicOption<T> {
15        let empty = AtomicOption {
16            inner: AtomicPtr::new(null_mut()),
17            _phantom: PhantomData,
18        };
19        empty.store(data);
20        empty
21    }
22
23    #[inline(always)]
24    pub fn swap(&self, new: Option<Box<T>>) -> Option<Box<T>> {
25        let addr = if let Some(new) = new {
26            Box::into_raw(new)
27        } else {
28            null_mut()
29        };
30
31        let addr = self.inner.swap(addr, Ordering::AcqRel);
32        if addr.is_null() {
33            None
34        } else {
35            Some(unsafe { Box::from_raw(addr) })
36        }
37    }
38
39    #[inline(always)]
40    pub fn take(&self) -> Option<Box<T>> {
41        self.swap(None)
42    }
43
44    #[inline(always)]
45    pub fn store(&self, new: Option<Box<T>>) {
46        drop(self.swap(new))
47    }
48}
49
50unsafe impl<T> Sync for AtomicOption<T> where T: Send {}
51unsafe impl<T> Send for AtomicOption<T> where T: Send {}
52
53impl<T> Drop for AtomicOption<T> {
54    fn drop(&mut self) {
55        let _ = self.take();
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use std::{mem::transmute, thread};
62
63    use super::AtomicOption;
64
65    #[test]
66    fn test_simple() {
67        let opt = AtomicOption::new(None);
68        assert_eq!(opt.take(), None);
69        assert_eq!(opt.swap(Some(Box::new(0))), None);
70        assert_eq!(opt.take(), Some(Box::new(0)));
71        opt.store(Some(Box::new(1)));
72        opt.store(Some(Box::new(2)));
73        assert_eq!(opt.swap(Some(Box::new(3))), Some(Box::new(2)));
74    }
75
76    #[test]
77    fn test_two_threads() {
78        for _ in 0..100 {
79            let opt = AtomicOption::<i64>::new(None);
80            let opt: &'static AtomicOption<i64> = unsafe { transmute(&opt) };
81            let func1 = move || {
82                let mut remain = 100;
83                loop {
84                    let a = opt.swap(Some(Box::new(remain)));
85                    if a.is_none() {
86                        remain -= 1;
87                    }
88                    if remain == 0 {
89                        break;
90                    }
91                }
92            };
93
94            let func2 = move || {
95                let mut remain = 100;
96                loop {
97                    let a = opt.swap(None);
98                    if a.is_some() {
99                        remain -= 1;
100                    }
101                    if remain == 0 {
102                        break;
103                    }
104                }
105            };
106
107            for h in [thread::spawn(func1), thread::spawn(func2)] {
108                h.join().unwrap();
109            }
110            assert_eq!(opt.take(), None);
111        }
112    }
113}