1use std::marker::PhantomData;
2use std::ptr::null_mut;
3use std::sync::atomic::{AtomicPtr, Ordering};
4
5type PhantomUnsync<T> = PhantomData<*mut T>;
6
7pub struct AtomicOption<T> {
8 inner: AtomicPtr<T>,
9 _phantom: PhantomUnsync<T>,
10}
11
12impl<T> AtomicOption<T> {
13 #[inline(always)]
14 pub fn new(data: Option<Box<T>>) -> AtomicOption<T> {
15 let empty = AtomicOption {
16 inner: AtomicPtr::new(null_mut()),
17 _phantom: PhantomData,
18 };
19 empty.store(data);
20 empty
21 }
22
23 #[inline(always)]
24 pub fn swap(&self, new: Option<Box<T>>) -> Option<Box<T>> {
25 let addr = if let Some(new) = new {
26 Box::into_raw(new)
27 } else {
28 null_mut()
29 };
30
31 let addr = self.inner.swap(addr, Ordering::AcqRel);
32 if addr.is_null() {
33 None
34 } else {
35 Some(unsafe { Box::from_raw(addr) })
36 }
37 }
38
39 #[inline(always)]
40 pub fn take(&self) -> Option<Box<T>> {
41 self.swap(None)
42 }
43
44 #[inline(always)]
45 pub fn store(&self, new: Option<Box<T>>) {
46 drop(self.swap(new))
47 }
48}
49
50unsafe impl<T> Sync for AtomicOption<T> where T: Send {}
51unsafe impl<T> Send for AtomicOption<T> where T: Send {}
52
53impl<T> Drop for AtomicOption<T> {
54 fn drop(&mut self) {
55 let _ = self.take();
56 }
57}
58
59#[cfg(test)]
60mod tests {
61 use std::{mem::transmute, thread};
62
63 use super::AtomicOption;
64
65 #[test]
66 fn test_simple() {
67 let opt = AtomicOption::new(None);
68 assert_eq!(opt.take(), None);
69 assert_eq!(opt.swap(Some(Box::new(0))), None);
70 assert_eq!(opt.take(), Some(Box::new(0)));
71 opt.store(Some(Box::new(1)));
72 opt.store(Some(Box::new(2)));
73 assert_eq!(opt.swap(Some(Box::new(3))), Some(Box::new(2)));
74 }
75
76 #[test]
77 fn test_two_threads() {
78 for _ in 0..100 {
79 let opt = AtomicOption::<i64>::new(None);
80 let opt: &'static AtomicOption<i64> = unsafe { transmute(&opt) };
81 let func1 = move || {
82 let mut remain = 100;
83 loop {
84 let a = opt.swap(Some(Box::new(remain)));
85 if a.is_none() {
86 remain -= 1;
87 }
88 if remain == 0 {
89 break;
90 }
91 }
92 };
93
94 let func2 = move || {
95 let mut remain = 100;
96 loop {
97 let a = opt.swap(None);
98 if a.is_some() {
99 remain -= 1;
100 }
101 if remain == 0 {
102 break;
103 }
104 }
105 };
106
107 for h in [thread::spawn(func1), thread::spawn(func2)] {
108 h.join().unwrap();
109 }
110 assert_eq!(opt.take(), None);
111 }
112 }
113}