1use core::cell::UnsafeCell;
2use core::fmt::{Debug, Formatter};
3use core::mem::{forget, ManuallyDrop};
4use core::ops::DerefMut;
5use core::sync::atomic::*;
6use core::convert::identity;
7
8pub struct AtomicCell<T> {
13 mark: AtomicBool,
14 cell: UnsafeCell<ManuallyDrop<T>>,
15}
16
17unsafe impl<T> Send for AtomicCell<T> where T: Send + Sync {}
18
19unsafe impl<T> Sync for AtomicCell<T> where T: Send + Sync {}
20
21impl<T> AtomicCell<T> {
22 pub const fn new(value: T) -> Self {
24 Self {
25 mark: AtomicBool::new(false),
26 cell: UnsafeCell::new(ManuallyDrop::new(value)),
27 }
28 }
29
30 pub fn try_swap(&self, value: T) -> Result<T, T> {
39 let res = self.mark.compare_exchange_weak(false, true, Ordering::AcqRel, Ordering::Acquire);
40 if res.unwrap_or_else(identity) {
41 Err(value) } else {
43 unsafe {
46 let first = self.cell.get().read_volatile();
47 self.cell.get().write_volatile(ManuallyDrop::new(value));
48 self.mark.store(false, Ordering::Release);
49 Ok(ManuallyDrop::into_inner(first))
50 }
51 }
52 }
53
54 pub fn swap(&self, mut value: T) -> T {
61 loop {
62 match self.try_swap(value) {
63 Ok(val) => return val,
64 Err(val) => {
65 value = val;
66 spin_loop_hint();
67 }
68 }
69 }
70 }
71
72 pub fn try_apply<F, R>(&self, func: F) -> Result<R, F> where F: FnOnce(&mut T) -> R, T: Copy {
85 let res = self.mark.compare_exchange_weak(false, true, Ordering::AcqRel, Ordering::Acquire);
86 if res.unwrap_or_else(identity) {
87 Err(func) } else {
89 struct UnwindGuard<'a>(&'a AtomicBool);
91 impl<'a> Drop for UnwindGuard<'a> {
92 fn drop(&mut self) { self.0.store(false, Ordering::Release);
94 }
95 }
96 unsafe {
98 let mut first = self.cell.get().read_volatile();
99 let guard = UnwindGuard(&self.mark);
100 let res = func(&mut first.deref_mut());self.cell.get().write_volatile(first);
102 drop(guard);Ok(res)
104 }
105 }
106 }
107
108 pub fn apply<F, R>(&self, mut func: F) -> R where F: FnOnce(&mut T) -> R, T: Copy {
118 loop {
119 match self.try_apply(func) {
120 Ok(res) => return res,
121 Err(f) => {
122 func = f;
123 spin_loop_hint();
124 }
125 }
126 }
127 }
128
129 #[inline(always)]
132 pub fn get_mut(&mut self) -> &mut T {
133 unsafe { &mut *self.cell.get() }
134 }
135 #[inline(always)]
137 pub fn into_inner(self) -> T {
138 unsafe {
139 let data = self.cell.get().read();
140 forget(self);ManuallyDrop::into_inner(data)
142 }
143 }
144}
145
146impl<T> Drop for AtomicCell<T> {
147 fn drop(&mut self) {
148 unsafe {
149 ManuallyDrop::drop(&mut *self.cell.get());
150 }
151 }
152}
153
154impl<T: Debug> Debug for AtomicCell<T> {
155 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
157 write!(f, "AtomicCell<{}>", core::any::type_name::<T>())?;
158 f.debug_struct("").field("holds_lock", &self.mark.load(Ordering::Relaxed)).finish()
159 }
160}
161
162
163#[cfg(test)]
164mod test {
165 extern crate std;
166 use std::collections::hash_map::DefaultHasher;
167 use std::collections::HashSet;
168 use std::hash::{Hash, Hasher};
169 use std::mem::replace;
170 use std::num::*;
171 use std::prelude::v1::*;
172 use std::sync::{Arc, Barrier};
173 use std::thread::spawn;
174 use super::*;
175
176 fn test_swap_many_case<T, F>(threads: u64, per_thread: u64, mut factory: impl FnMut(u64) -> T, op: F)
177 where T: Send + Sync + Eq + Hash + 'static,
178 F: Fn(&AtomicCell<Option<T>>, Option<T>) -> Option<T> + Send + Sync + 'static {
179 let swap = Arc::new(AtomicCell::new(None));
180 let op = Arc::new(op);
181 let thr = (0..threads).map(|t| {
182 let mut v = Vec::new();
183 for i in 0..per_thread {
184 v.push(Some(factory(t * per_thread + i + 1)));
185 }
186 (v, swap.clone(), op.clone())
187 }).collect::<Vec<_>>().into_iter().map(|(vec, swap, op)| spawn(move || {
188 let mut res = Vec::new();
189 for val in vec {
190 res.push(op(&swap, val));
191 }
192 res
193 })).collect::<Vec<_>>();
194 let mut data = thr.into_iter().map(|j| j.join().unwrap()).flatten().collect::<HashSet<_>>();
195 data.insert(op(&swap, None));
196 assert!(data.contains(&None));
197 let res = (1..(per_thread * threads + 1)).filter(|v| !data.contains(&Some(factory(*v)))).collect::<Vec<_>>();
198 assert!(res.is_empty(), "Results not empty {:#?}", &res);
199 }
200
201
202 fn test_swap_single_case<T, F>(threads: usize, iters: usize, repeats: usize, default: T, unique: T, op: F)
203 where T: Send + Sync + Eq + Clone + 'static,
204 F: Fn(&AtomicCell<T>, T) -> T + Send + Sync + 'static {
205 let barriers = Arc::new((Barrier::new(threads + 1), Barrier::new(threads + 1), Barrier::new(threads + 1)));
206 let swap = Arc::new(AtomicCell::new(default.clone()));
207 let op = Arc::new(op);
208 let handles = (0..threads).map(|_| {
209 let b = barriers.clone();
210 let default = default.clone();
211 let unique = unique.clone();
212 let swap = swap.clone();
213 let op = op.clone();
214 spawn(move || {
215 let mut it = Vec::with_capacity(iters);
216 for _ in 0..iters {
217 let mut v = Vec::with_capacity(repeats + 1);
218 b.0.wait();
219 b.1.wait();
220 for _ in 0..repeats {
221 v.push(op(&swap, default.clone()));
222 }
223 b.2.wait();
224 v.push(op(&swap, default.clone()));
225 it.push(v.into_iter().find(|v| v == &unique).is_some());
226 }
227 it
228 })
229 }).collect::<Vec<_>>();
230
231 let mut defs = Vec::with_capacity(iters);
232 for _ in 0..iters {
233 barriers.0.wait();
234 op(&swap, default.clone());
235 barriers.1.wait();
236 defs.push(op(&swap, unique.clone()));
237 barriers.2.wait();
238 }
239 let results = handles.into_iter().map(|v| v.join().unwrap()).collect::<Vec<_>>();
240 assert!(defs.into_iter().all(|v| v == default));
241 let len = results.iter().map(|v| v.len()).min().unwrap();
242 assert_eq!(len, iters);
243 for i in 0..iters {
244 let count = results.iter().filter(|v| v[i]).count();
245 assert_eq!(count, 1); }
247 }
248
249 #[derive(Clone, Eq, PartialEq, Hash)]
250 struct TestData {
251 d0: [u64; 32],
252 d1: [u64; 32],
253 d2: [u64; 32],
254 d3: [u64; 32],
255 }
256
257 impl TestData {
259 pub fn new(val: u64) -> Self {
260 let (mut d1, mut d2, mut d3) = ([0; 32], [0; 32], [0; 32]);
261 let mut h = DefaultHasher::default();
262 for a in d1.iter_mut() {
263 h.write_u64(val);
264 *a = h.finish();
265 }
266 for a in d2.iter_mut() {
267 h.write_u64(val);
268 *a = h.finish();
269 }
270 for a in d3.iter_mut() {
271 h.write_u64(val);
272 *a = h.finish();
273 }
274 Self {
275 d0: [val; 32],
276 d1,
277 d2,
278 d3,
279 }
280 }
281 }
282
283 fn swap_func<T>() -> impl Fn(&AtomicCell<T>, T) -> T { |s, o| s.swap(o) }
284
285 fn apply_func<T: Copy>() -> impl Fn(&AtomicCell<T>, T) -> T {
286 |s, o| {
287 s.apply(move |val| {
288 replace(val, o)
289 })
290 }
291 }
292
293 #[test]
294 fn test_basic() {
295 let swap = AtomicCell::new(1);
296 assert_eq!(swap.try_swap(2), Ok(1));
297 assert_eq!(swap.swap(3), 2);
298 assert_eq!(swap.swap(12345), 3);
299 swap.try_apply(|val| {
300 assert_eq!(*val, 12345);
301 *val = 10;
302 }).ok().unwrap();
303 assert_eq!(swap.swap(0), 10);
304 }
305
306 #[test]
307 fn test_swap_single() {
308 test_swap_single_case(8, 1000, 1000, 11, 22, swap_func());
309 test_swap_single_case(8, 1000, 100, TestData::new(1), TestData::new(2), swap_func());
310 }
311
312 #[test]
313 fn test_apply_single() {
314 test_swap_single_case(8, 1000, 1000, 11, 22, apply_func());
315 }
316
317 #[test]
318 fn test_swap_many() {
319 test_swap_many_case(8, 10000, |v| NonZeroU32::new(v as u32).unwrap(), swap_func());
320 test_swap_many_case(8, 5000, |v| TestData::new(v), swap_func());
321 }
322
323 #[test]
324 fn test_apply_many() {
325 test_swap_many_case(8, 10000, |v| NonZeroU32::new(v as u32).unwrap(), apply_func());
326 }
327}