1use std::mem;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicPtr, Ordering};
4use std::ops::Deref;
5
6
7#[derive(Debug)]
10pub struct AtomicBox<T: Sized>
11{
12 ptr: AtomicPtr<T>,
13}
14
15impl<T: Sized> AtomicBox<T> {
16 pub fn new(value: T) -> AtomicBox<T> {
18 AtomicBox {
19 ptr: AtomicPtr::new(AtomicBox::alloc_from(value)),
20 }
21 }
22
23 #[inline]
24 fn alloc_from(value: T) -> *mut T {
25 let total: Arc<T> = Arc::new(value);
26
27 Arc::into_raw(total) as *mut T
28 }
29
30 fn compare_and_swap(&self,
31 current: *mut T,
32 new: *mut T,
33 order: Ordering) -> *mut T {
34 self.ptr.compare_and_swap(current, new, order)
35 }
36
37 fn take(&self) -> Arc<T> {
38 loop {
39 let curr = self.ptr.load(Ordering::Acquire);
40 let null: *mut T = std::ptr::null_mut();
41
42 if curr == null {
43 continue;
44 }
45
46 if self.compare_and_swap(curr, null, Ordering::AcqRel) == curr {
47 return unsafe { Arc::from_raw(curr) };
48 }
49 }
50 }
51
52 fn release(&self, ptr: *mut T) {
53 assert!(ptr != 0xfffffffffffffff0 as *mut T);
54 self.ptr.store(ptr, Ordering::Release);
55 }
56
57 pub fn get(&self) -> Arc<T> {
58 let val = self.take();
59 let copy = Arc::clone(&val);
60 let ptr = Arc::into_raw(val) as *mut T;
61
62 self.release(ptr);
63 copy
64 }
65
66 pub fn replace_with<F>(&self, f: F)
69 where F: Fn(Arc<T>) -> T
70 {
71 let val = self.take();
72 let new_val = f(val);
73 let ptr = Arc::into_raw(Arc::new(new_val)) as *mut T;
74 self.release(ptr);
75 }
76}
77
78impl<T: Sized + PartialEq> PartialEq for AtomicBox<T> {
79 fn eq(&self, other: &AtomicBox<T>) -> bool {
80 self == other
81 }
82}
83
84impl<T: Sized> Drop for AtomicBox<T> {
85 fn drop(&mut self) {
86 unsafe {
87 Arc::from_raw(self.ptr.load(Ordering::Acquire))
88 };
89 }
90}
91
92unsafe impl<T: Sized + Sync> Sync for AtomicBox<T> {}
93unsafe impl<T: Sized + Send> Send for AtomicBox<T> {}
94
95#[cfg(test)]
96mod tests {
97 use std::sync::{Arc, Barrier};
98 use std::thread;
99
100 use super::AtomicBox;
101
102 #[test]
103 fn atomic_arc_new() {
104 let b = AtomicBox::new(1024);
105
106 assert_eq!(*b.get(), 1024);
107 }
108
109 #[test]
110 fn atomic_arc_replace_with() {
111 let value: i64 = 1024;
112 let b = AtomicBox::new(value);
113
114 b.replace_with(|x| *x * 2);
115
116 assert_eq!(*b.get(), value * 2);
117 }
118
119 #[test]
120 fn atomic_arc_replace_with_ten_times() {
121 let value = 1024;
122 let b = AtomicBox::new(value);
123
124 for _i in 0..10 {
125 b.replace_with(|x| *x * 2);
126 }
127
128 assert_eq!(*b.get(), value * 2_i32.pow(10));
129 }
130
131 #[test]
132 fn atomic_arc_replace_instance() {
133 let b = Arc::new(AtomicBox::new(1024));
134 let b1 = b.clone();
135
136 b1.replace_with(|x| *x * 2);
137
138 assert_eq!(*b.get(), 2048);
139 }
140
141 #[test]
142 fn atomic_arc_threaded_leak_test() {
143 let val = Arc::new(AtomicBox::new(10));
144 let val_cpys: Vec<Arc<AtomicBox<i32>>> = (0..10)
145 .map(|_| val.clone())
146 .collect();
147 let mut guards = Vec::new();
148
149 for i in 0..10 {
150 let val_cpy = val_cpys[i].clone();
151 let guard = thread::spawn(move || {
152 val_cpy.replace_with(|x| *x * 2);
153 });
154
155 guards.push(guard);
156 }
157
158 for g in guards {
159 g.join().unwrap();
160 }
161
162 assert_eq!(*val.get(), 10 * 2_i32.pow(10));
163 }
164
165 #[test]
166 fn atomic_arc_threaded_contention() {
167 let abox = Arc::new(AtomicBox::new(0));
168 let thread_num = 10;
169 let mut guards = Vec::new();
170 let barrier = Arc::new(Barrier::new(thread_num));
171
172 for _i in 0..thread_num {
173 let b = Arc::clone(&barrier);
174 let cpy = abox.clone();
175 guards.push(thread::spawn(move || {
176 b.wait();
177 for _j in 0..1000 {
178 cpy.replace_with(|x| *x + 100)
179 }
180 }));
181 }
182
183 for g in guards {
184 g.join().unwrap();
185 }
186
187 assert_eq!(*abox.get(), thread_num * 1000 * 100);
188 }
189
190 #[test]
191 fn atomic_arc_vector_container() {
192 let values: Vec<i32> = (0..10).map(|x: i32| {
193 x.pow(2)
194 }).collect();
195 let abox = Arc::new(AtomicBox::new(vec![]));
196 let mut guards = Vec::new();
197
198 for i in 0..10 {
199 let cpy = abox.clone();
200 let values: Vec<i32> = values.clone();
201
202 guards.push(thread::spawn(move || {
203 cpy.replace_with(|x| {
204 let mut nx = (*x).clone();
205 nx.push(values[i]);
206 nx
207 })
208 }));
209 }
210
211 for g in guards {
212 g.join().unwrap();
213 }
214
215 assert_eq!(abox.get().len(), values.len());
216
217 for i in values {
218 assert_eq!(abox.get().contains(&i), true);
219 }
220 }
221}