Skip to main content

abox/
lib.rs

1use std::mem;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicPtr, Ordering};
4use std::ops::Deref;
5
6
7/// AtomicBox<T> is a safe wrapper around AtomicPtr<T>
8/// You can safely swap values using the replace_with method
9#[derive(Debug)]
10pub struct AtomicBox<T: Sized>
11{
12    ptr: AtomicPtr<T>,
13}
14
15impl<T: Sized> AtomicBox<T> {
16    /// Allocates a new AtomicBox containing the given value
17    pub fn new(value: T) -> AtomicBox<T> {
18        AtomicBox {
19            ptr: AtomicPtr::new(AtomicBox::alloc_from(value)),
20        }
21    }
22
23    #[inline]
24    fn alloc_from(value: T) -> *mut T {
25        let total: Arc<T> = Arc::new(value);
26
27        Arc::into_raw(total) as *mut T
28    }
29
30    fn compare_and_swap(&self,
31                        current: *mut T,
32                        new: *mut T,
33                        order: Ordering) -> *mut T {
34        self.ptr.compare_and_swap(current, new, order)
35    }
36
37    fn take(&self) -> Arc<T> {
38        loop {
39            let curr = self.ptr.load(Ordering::Acquire);
40            let null: *mut T = std::ptr::null_mut();
41
42            if curr == null {
43                continue;
44            }
45
46            if self.compare_and_swap(curr, null, Ordering::AcqRel) == curr {
47                return unsafe { Arc::from_raw(curr) };
48            }
49        }
50    }
51
52    fn release(&self, ptr: *mut T) {
53        assert!(ptr != 0xfffffffffffffff0 as *mut T);
54        self.ptr.store(ptr, Ordering::Release);
55    }
56
57    pub fn get(&self) -> Arc<T> {
58        let val = self.take();
59        let copy = Arc::clone(&val);
60        let ptr = Arc::into_raw(val) as *mut T;
61
62        self.release(ptr);
63        copy
64    }
65
66    /// Atomically replace the inner value with the result of applying the
67    /// given closure to the current value
68    pub fn replace_with<F>(&self, f: F)
69        where F: Fn(Arc<T>) -> T
70    {
71        let val = self.take();
72        let new_val = f(val);
73        let ptr = Arc::into_raw(Arc::new(new_val)) as *mut T;
74        self.release(ptr);
75    }
76}
77
78impl<T: Sized + PartialEq> PartialEq for AtomicBox<T> {
79    fn eq(&self, other: &AtomicBox<T>) -> bool {
80        self == other
81    }
82}
83
84impl<T: Sized> Drop for AtomicBox<T> {
85    fn drop(&mut self) {
86        unsafe {
87            Arc::from_raw(self.ptr.load(Ordering::Acquire))
88        };
89    }
90}
91
92unsafe impl<T: Sized + Sync> Sync for AtomicBox<T> {}
93unsafe impl<T: Sized + Send> Send for AtomicBox<T> {}
94
95#[cfg(test)]
96mod tests {
97    use std::sync::{Arc, Barrier};
98    use std::thread;
99
100    use super::AtomicBox;
101
102    #[test]
103    fn atomic_arc_new() {
104        let b = AtomicBox::new(1024);
105
106        assert_eq!(*b.get(), 1024);
107    }
108
109    #[test]
110    fn atomic_arc_replace_with() {
111        let value: i64 = 1024;
112        let b = AtomicBox::new(value);
113
114        b.replace_with(|x| *x  * 2);
115
116        assert_eq!(*b.get(), value * 2);
117    }
118
119    #[test]
120    fn atomic_arc_replace_with_ten_times() {
121        let value = 1024;
122        let b = AtomicBox::new(value);
123
124        for _i in 0..10 {
125            b.replace_with(|x| *x * 2);
126        }
127
128        assert_eq!(*b.get(), value * 2_i32.pow(10));
129    }
130
131    #[test]
132    fn atomic_arc_replace_instance() {
133        let b = Arc::new(AtomicBox::new(1024));
134        let b1 = b.clone();
135
136        b1.replace_with(|x| *x * 2);
137
138        assert_eq!(*b.get(), 2048);
139    }
140
141    #[test]
142    fn atomic_arc_threaded_leak_test() {
143        let val = Arc::new(AtomicBox::new(10));
144        let val_cpys: Vec<Arc<AtomicBox<i32>>> = (0..10)
145            .map(|_| val.clone())
146            .collect();
147        let mut guards = Vec::new();
148
149        for i in 0..10 {
150            let val_cpy = val_cpys[i].clone();
151            let guard = thread::spawn(move || {
152                val_cpy.replace_with(|x| *x * 2);
153            });
154
155            guards.push(guard);
156        }
157
158        for g in guards {
159            g.join().unwrap();
160        }
161
162        assert_eq!(*val.get(), 10 * 2_i32.pow(10));
163    }
164
165    #[test]
166    fn atomic_arc_threaded_contention() {
167        let abox = Arc::new(AtomicBox::new(0));
168        let thread_num = 10;
169        let mut guards = Vec::new();
170        let barrier = Arc::new(Barrier::new(thread_num));
171
172        for _i in 0..thread_num {
173            let b = Arc::clone(&barrier);
174            let cpy = abox.clone();
175            guards.push(thread::spawn(move || {
176                b.wait();
177                for _j in 0..1000 {
178                    cpy.replace_with(|x| *x + 100)
179                }
180            }));
181        }
182
183        for g in guards {
184            g.join().unwrap();
185        }
186
187        assert_eq!(*abox.get(), thread_num * 1000 * 100);
188    }
189
190    #[test]
191    fn atomic_arc_vector_container() {
192        let values: Vec<i32> = (0..10).map(|x: i32| {
193            x.pow(2)
194        }).collect();
195        let abox = Arc::new(AtomicBox::new(vec![]));
196        let mut guards = Vec::new();
197
198        for i in 0..10 {
199            let cpy = abox.clone();
200            let values: Vec<i32> = values.clone();
201
202            guards.push(thread::spawn(move || {
203                cpy.replace_with(|x| {
204                    let mut nx = (*x).clone();
205                    nx.push(values[i]);
206                    nx
207                })
208            }));
209        }
210
211        for g in guards {
212            g.join().unwrap();
213        }
214
215        assert_eq!(abox.get().len(), values.len());
216
217        for i in values {
218            assert_eq!(abox.get().contains(&i), true);
219        }
220    }
221}