Skip to main content

juggle/utils/
cell.rs

1use core::cell::UnsafeCell;
2use core::fmt::{Debug, Formatter};
3use core::mem::{forget, ManuallyDrop};
4use core::ops::DerefMut;
5use core::sync::atomic::*;
6use core::convert::identity;
7
8/// Wrapper struct that allows modifying and swapping value without using locks.
9///
10/// AtomicCell does not use atomic load/store/cas on contained data so that it can hold structs
11/// of arbitrary size.
12pub struct AtomicCell<T> {
13    mark: AtomicBool,
14    cell: UnsafeCell<ManuallyDrop<T>>,
15}
16
17unsafe impl<T> Send for AtomicCell<T> where T: Send + Sync {}
18
19unsafe impl<T> Sync for AtomicCell<T> where T: Send + Sync {}
20
21impl<T> AtomicCell<T> {
22    /// Create new atomic cell with initial value.
23    pub const fn new(value: T) -> Self {
24        Self {
25            mark: AtomicBool::new(false),
26            cell: UnsafeCell::new(ManuallyDrop::new(value)),
27        }
28    }
29
30    /// Tries to swap the value inside cell.
31    ///
32    /// When successful returns `Ok` with previous value. In case of failure, returns `Err` with
33    /// argument passed to it, returning ownership of value.
34    ///
35    /// AtomicCell does not use atomic load/store/cas on contained data so that it can hold structs
36    /// of arbitrary size. This method might fail in case some other thread is now modifying this value.
37    /// In case of failure you can perform additional checks or try to swap value again until success.
38    pub fn try_swap(&self, value: T) -> Result<T, T> {
39        let res = self.mark.compare_exchange_weak(false, true, Ordering::AcqRel, Ordering::Acquire);
40        if res.unwrap_or_else(identity) {
41            Err(value) //other thread interfered
42        } else {
43            //we know for sure we are only thread writing to this location
44            //swap values
45            unsafe {
46                let first = self.cell.get().read_volatile();
47                self.cell.get().write_volatile(ManuallyDrop::new(value));
48                self.mark.store(false, Ordering::Release);
49                Ok(ManuallyDrop::into_inner(first))
50            }
51        }
52    }
53
54    /// Swap the value inside cell.
55    ///
56    /// Returns previous value from cell.
57    ///
58    /// AtomicCell does not use atomic load/store/cas on contained data so that it can hold structs
59    /// of arbitrary size. This method tries to swap value in busy loop until success.
60    pub fn swap(&self, mut value: T) -> T {
61        loop {
62            match self.try_swap(value) {
63                Ok(val) => return val,
64                Err(val) => {
65                    value = val;
66                    spin_loop_hint();
67                }
68            }
69        }
70    }
71
72    /// Tries to perform action on value inside cell, possibly mutating it.
73    ///
74    /// When successful returns `Ok` with value returned from executed function.
75    /// In case of failure, returns `Err` with argument passed to it,
76    /// returning ownership of function.
77    ///
78    /// `T` is required to be copy in case if given function panics. When this occurs, the value is
79    /// restored to state from before applying the function and panic is propagated.
80    ///
81    /// AtomicCell does not use atomic load/store/cas on contained data so that it can hold structs
82    /// of arbitrary size. This method might fail in case some other thread is now modifying this value.
83    /// In case of failure you can perform additional checks or try to apply action again until success.
84    pub fn try_apply<F, R>(&self, func: F) -> Result<R, F> where F: FnOnce(&mut T) -> R, T: Copy {
85        let res = self.mark.compare_exchange_weak(false, true, Ordering::AcqRel, Ordering::Acquire);
86        if res.unwrap_or_else(identity) {
87            Err(func) //other thread interfered
88        } else {
89            //we know for sure we are only thread writing to this location
90            struct UnwindGuard<'a>(&'a AtomicBool);
91            impl<'a> Drop for UnwindGuard<'a> {
92                fn drop(&mut self) { //perform cleanup on normal execution and if closure panics
93                    self.0.store(false, Ordering::Release);
94                }
95            }
96            //modify value
97            unsafe {
98                let mut first = self.cell.get().read_volatile();
99                let guard = UnwindGuard(&self.mark);
100                let res = func(&mut first.deref_mut());//modify local copy to ensure volatile operations
101                self.cell.get().write_volatile(first);
102                drop(guard);//explicit drop
103                Ok(res)
104            }
105        }
106    }
107
108    /// Perform action on value inside cell, possibly mutating it.
109    ///
110    /// Returns value returned from executed function.
111    ///
112    /// `T` is required to be copy in case if given function panics. When this occurs, the value is
113    /// restored to state from before applying the function and panic is propagated.
114    ///
115    /// AtomicCell does not use atomic load/store/cas on contained data so that it can hold structs
116    /// of arbitrary size. This method tries to apply function in busy loop until success.
117    pub fn apply<F, R>(&self, mut func: F) -> R where F: FnOnce(&mut T) -> R, T: Copy {
118        loop {
119            match self.try_apply(func) {
120                Ok(res) => return res,
121                Err(f) => {
122                    func = f;
123                    spin_loop_hint();
124                }
125            }
126        }
127    }
128
129    /// Get mutable reference to content of this struct. This method statically ensures that mutation
130    /// is allowed because it takes self by mutable reference.
131    #[inline(always)]
132    pub fn get_mut(&mut self) -> &mut T {
133        unsafe { &mut *self.cell.get() }
134    }
135    /// Takes ownership of cell and extracts wrapped value from it.
136    #[inline(always)]
137    pub fn into_inner(self) -> T {
138        unsafe {
139            let data = self.cell.get().read();
140            forget(self);//don't run destructor
141            ManuallyDrop::into_inner(data)
142        }
143    }
144}
145
146impl<T> Drop for AtomicCell<T> {
147    fn drop(&mut self) {
148        unsafe {
149            ManuallyDrop::drop(&mut *self.cell.get());
150        }
151    }
152}
153
154impl<T: Debug> Debug for AtomicCell<T> {
155    //Debug bound just in case we will support showing content.
156    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
157        write!(f, "AtomicCell<{}>", core::any::type_name::<T>())?;
158        f.debug_struct("").field("holds_lock", &self.mark.load(Ordering::Relaxed)).finish()
159    }
160}
161
162
163#[cfg(test)]
164mod test {
165    extern crate std;
166    use std::collections::hash_map::DefaultHasher;
167    use std::collections::HashSet;
168    use std::hash::{Hash, Hasher};
169    use std::mem::replace;
170    use std::num::*;
171    use std::prelude::v1::*;
172    use std::sync::{Arc, Barrier};
173    use std::thread::spawn;
174    use super::*;
175
176    fn test_swap_many_case<T, F>(threads: u64, per_thread: u64, mut factory: impl FnMut(u64) -> T, op: F)
177        where T: Send + Sync + Eq + Hash + 'static,
178              F: Fn(&AtomicCell<Option<T>>, Option<T>) -> Option<T> + Send + Sync + 'static {
179        let swap = Arc::new(AtomicCell::new(None));
180        let op = Arc::new(op);
181        let thr = (0..threads).map(|t| {
182            let mut v = Vec::new();
183            for i in 0..per_thread {
184                v.push(Some(factory(t * per_thread + i + 1)));
185            }
186            (v, swap.clone(), op.clone())
187        }).collect::<Vec<_>>().into_iter().map(|(vec, swap, op)| spawn(move || {
188            let mut res = Vec::new();
189            for val in vec {
190                res.push(op(&swap, val));
191            }
192            res
193        })).collect::<Vec<_>>();
194        let mut data = thr.into_iter().map(|j| j.join().unwrap()).flatten().collect::<HashSet<_>>();
195        data.insert(op(&swap, None));
196        assert!(data.contains(&None));
197        let res = (1..(per_thread * threads + 1)).filter(|v| !data.contains(&Some(factory(*v)))).collect::<Vec<_>>();
198        assert!(res.is_empty(), "Results not empty {:#?}", &res);
199    }
200
201
202    fn test_swap_single_case<T, F>(threads: usize, iters: usize, repeats: usize, default: T, unique: T, op: F)
203        where T: Send + Sync + Eq + Clone + 'static,
204              F: Fn(&AtomicCell<T>, T) -> T + Send + Sync + 'static {
205        let barriers = Arc::new((Barrier::new(threads + 1), Barrier::new(threads + 1), Barrier::new(threads + 1)));
206        let swap = Arc::new(AtomicCell::new(default.clone()));
207        let op = Arc::new(op);
208        let handles = (0..threads).map(|_| {
209            let b = barriers.clone();
210            let default = default.clone();
211            let unique = unique.clone();
212            let swap = swap.clone();
213            let op = op.clone();
214            spawn(move || {
215                let mut it = Vec::with_capacity(iters);
216                for _ in 0..iters {
217                    let mut v = Vec::with_capacity(repeats + 1);
218                    b.0.wait();
219                    b.1.wait();
220                    for _ in 0..repeats {
221                        v.push(op(&swap, default.clone()));
222                    }
223                    b.2.wait();
224                    v.push(op(&swap, default.clone()));
225                    it.push(v.into_iter().find(|v| v == &unique).is_some());
226                }
227                it
228            })
229        }).collect::<Vec<_>>();
230
231        let mut defs = Vec::with_capacity(iters);
232        for _ in 0..iters {
233            barriers.0.wait();
234            op(&swap, default.clone());
235            barriers.1.wait();
236            defs.push(op(&swap, unique.clone()));
237            barriers.2.wait();
238        }
239        let results = handles.into_iter().map(|v| v.join().unwrap()).collect::<Vec<_>>();
240        assert!(defs.into_iter().all(|v| v == default));
241        let len = results.iter().map(|v| v.len()).min().unwrap();
242        assert_eq!(len, iters);
243        for i in 0..iters {
244            let count = results.iter().filter(|v| v[i]).count();
245            assert_eq!(count, 1); //only exactly single swap resulted in other value in each iteration
246        }
247    }
248
249    #[derive(Clone, Eq, PartialEq, Hash)]
250    struct TestData {
251        d0: [u64; 32],
252        d1: [u64; 32],
253        d2: [u64; 32],
254        d3: [u64; 32],
255    }
256
257    //large struct to test statistical data integrity
258    impl TestData {
259        pub fn new(val: u64) -> Self {
260            let (mut d1, mut d2, mut d3) = ([0; 32], [0; 32], [0; 32]);
261            let mut h = DefaultHasher::default();
262            for a in d1.iter_mut() {
263                h.write_u64(val);
264                *a = h.finish();
265            }
266            for a in d2.iter_mut() {
267                h.write_u64(val);
268                *a = h.finish();
269            }
270            for a in d3.iter_mut() {
271                h.write_u64(val);
272                *a = h.finish();
273            }
274            Self {
275                d0: [val; 32],
276                d1,
277                d2,
278                d3,
279            }
280        }
281    }
282
283    fn swap_func<T>() -> impl Fn(&AtomicCell<T>, T) -> T { |s, o| s.swap(o) }
284
285    fn apply_func<T: Copy>() -> impl Fn(&AtomicCell<T>, T) -> T {
286        |s, o| {
287            s.apply(move |val| {
288                replace(val, o)
289            })
290        }
291    }
292
293    #[test]
294    fn test_basic() {
295        let swap = AtomicCell::new(1);
296        assert_eq!(swap.try_swap(2), Ok(1));
297        assert_eq!(swap.swap(3), 2);
298        assert_eq!(swap.swap(12345), 3);
299        swap.try_apply(|val| {
300            assert_eq!(*val, 12345);
301            *val = 10;
302        }).ok().unwrap();
303        assert_eq!(swap.swap(0), 10);
304    }
305
306    #[test]
307    fn test_swap_single() {
308        test_swap_single_case(8, 1000, 1000, 11, 22, swap_func());
309        test_swap_single_case(8, 1000, 100, TestData::new(1), TestData::new(2), swap_func());
310    }
311
312    #[test]
313    fn test_apply_single() {
314        test_swap_single_case(8, 1000, 1000, 11, 22, apply_func());
315    }
316
317    #[test]
318    fn test_swap_many() {
319        test_swap_many_case(8, 10000, |v| NonZeroU32::new(v as u32).unwrap(), swap_func());
320        test_swap_many_case(8, 5000, |v| TestData::new(v), swap_func());
321    }
322
323    #[test]
324    fn test_apply_many() {
325        test_swap_many_case(8, 10000, |v| NonZeroU32::new(v as u32).unwrap(), apply_func());
326    }
327}