use crate::basic::clocks::{check_clock, me};
use shuttle::sync::atomic::*;
use shuttle::{check_dfs, future, thread};
use std::collections::HashSet;
use std::sync::Arc;
use test_log::test;
macro_rules! int_tests {
($name:ident, $ty:ident) => {
mod $name {
use super::*;
use test_log::test;
#[test]
fn store_store_reordering() {
let observed_states = Arc::new(std::sync::Mutex::new(HashSet::new()));
let observed_states_clone = observed_states.clone();
check_dfs(
move || {
let x = Arc::new($ty::new(0));
let y = Arc::new($ty::new(0));
let thd1 = {
let x = x.clone();
let y = y.clone();
thread::spawn(move || {
x.store(1, Ordering::SeqCst);
y.store(1, Ordering::SeqCst);
assert_eq!(me(), 1usize);
check_clock(|i, c| (c > 0) == (i == 0 || i == 1));
})
};
let thd2 = {
let x = x.clone();
let y = y.clone();
thread::spawn(move || {
let y = y.load(Ordering::SeqCst);
let x = x.load(Ordering::SeqCst);
assert_eq!(me(), 2usize);
check_clock(|i, c| {
(c > 0)
== match (x, y) {
(0, 0) => i == 0 || i == 2,
_ => i == 0 || i == 2 || i == 1,
}
});
(x, y)
})
};
let thd1 = thd1.join().unwrap();
let thd2 = thd2.join().unwrap();
observed_states.lock().unwrap().insert((thd1, thd2));
},
None,
);
let observed_states = Arc::try_unwrap(observed_states_clone)
.unwrap()
.into_inner()
.unwrap();
assert!(observed_states.contains(&((), (0, 0))));
assert!(observed_states.contains(&((), (1, 0))));
assert!(observed_states.contains(&((), (1, 1))));
assert!(!observed_states.contains(&((), (0, 1))));
}
#[test]
fn store_load_reordering() {
let observed_states = Arc::new(std::sync::Mutex::new(HashSet::new()));
let observed_states_clone = observed_states.clone();
check_dfs(
move || {
let x = Arc::new($ty::new(0));
let y = Arc::new($ty::new(0));
let thd1 = {
let x = x.clone();
let y = y.clone();
thread::spawn(move || {
let x = x.load(Ordering::SeqCst);
y.store(1, Ordering::SeqCst);
assert_eq!(me(), 1usize);
check_clock(|i, c| {
(c > 0)
== match x {
0 => i == 0 || i == 1,
_ => i == 0 || i == 1 || i == 2,
}
});
(x,)
})
};
let thd2 = {
let x = x.clone();
let y = y.clone();
thread::spawn(move || {
let y = y.load(Ordering::SeqCst);
x.store(1, Ordering::SeqCst);
assert_eq!(me(), 2usize);
check_clock(|i, c| {
(c > 0)
== match y {
0 => i == 0 || i == 2,
_ => i == 0 || i == 2 || i == 1,
}
});
(y,)
})
};
let thd1 = thd1.join().unwrap();
let thd2 = thd2.join().unwrap();
observed_states.lock().unwrap().insert((thd1, thd2));
},
None,
);
let observed_states = Arc::try_unwrap(observed_states_clone)
.unwrap()
.into_inner()
.unwrap();
assert!(observed_states.contains(&((0,), (0,))));
assert!(observed_states.contains(&((0,), (1,))));
assert!(observed_states.contains(&((1,), (0,))));
assert!(!observed_states.contains(&((1,), (1,))));
}
#[test]
fn load_store_reordering() {
let observed_states = Arc::new(std::sync::Mutex::new(HashSet::new()));
let observed_states_clone = observed_states.clone();
check_dfs(
move || {
let x = Arc::new($ty::new(0));
let y = Arc::new($ty::new(0));
let thd1 = {
let x = x.clone();
let y = y.clone();
thread::spawn(move || {
x.store(1, Ordering::SeqCst);
let y = y.load(Ordering::SeqCst);
assert_eq!(me(), 1usize);
check_clock(|i, c| {
(c > 0)
== match y {
0 => i == 0 || i == 1,
_ => i == 0 || i == 1 || i == 2,
}
});
(y,)
})
};
let thd2 = {
let x = x.clone();
let y = y.clone();
thread::spawn(move || {
y.store(1, Ordering::SeqCst);
let x = x.load(Ordering::SeqCst);
assert_eq!(me(), 2usize);
check_clock(|i, c| {
(c > 0)
== match x {
0 => i == 0 || i == 2,
_ => i == 0 || i == 2 || i == 1,
}
});
(x,)
})
};
let thd1 = thd1.join().unwrap();
let thd2 = thd2.join().unwrap();
observed_states.lock().unwrap().insert((thd1, thd2));
},
None,
);
let observed_states = Arc::try_unwrap(observed_states_clone)
.unwrap()
.into_inner()
.unwrap();
assert!(!observed_states.contains(&((0,), (0,)))); assert!(observed_states.contains(&((0,), (1,))));
assert!(observed_states.contains(&((1,), (0,))));
assert!(observed_states.contains(&((1,), (1,))));
}
#[test]
fn fetch_update() {
let observed_values = Arc::new(std::sync::Mutex::new(HashSet::new()));
let observed_values_clone = observed_values.clone();
check_dfs(
move || {
let x = Arc::new($ty::new(0));
let thd1 = {
let x = x.clone();
thread::spawn(move || {
x.store(1, Ordering::SeqCst);
assert_eq!(me(), 1);
})
};
let thd2 = {
let x = x.clone();
thread::spawn(move || {
x.store(2, Ordering::SeqCst);
assert_eq!(me(), 2);
})
};
let val = x
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |x| Some(x + 5))
.unwrap();
assert_eq!(me(), 0);
check_clock(|i, c| match val {
0 => (i == 0) == (c > 0), 1 => !(i == 1) || (c > 0), 2 => !(i == 2) || (c > 0), _ => unreachable!(),
});
thd1.join().unwrap();
thd2.join().unwrap();
let x = Arc::try_unwrap(x).unwrap().into_inner();
observed_values_clone.lock().unwrap().insert(x);
},
None,
);
let observed_values = Arc::try_unwrap(observed_values).unwrap().into_inner().unwrap();
assert_eq!(observed_values.len(), 4);
assert!(observed_values.contains(&1));
assert!(observed_values.contains(&2));
assert!(observed_values.contains(&6));
assert!(observed_values.contains(&7));
}
#[test]
fn compare_exchange() {
let observed_values = Arc::new(std::sync::Mutex::new(HashSet::new()));
let observed_values_clone = observed_values.clone();
check_dfs(
move || {
let x = Arc::new($ty::new(0));
let thd1 = {
let x = x.clone();
thread::spawn(move || {
x.store(1, Ordering::SeqCst);
assert_eq!(me(), 1);
})
};
let result = x.compare_exchange(1, 5, Ordering::SeqCst, Ordering::SeqCst);
check_clock(|i, c| {
(c > 0)
== match result {
Ok(_) => i == 0 || i == 1, Err(_) => i == 0, }
});
thd1.join().unwrap();
observed_values_clone.lock().unwrap().insert(result);
},
None,
);
let observed_values = Arc::try_unwrap(observed_values).unwrap().into_inner().unwrap();
assert_eq!(observed_values.len(), 2);
assert!(observed_values.contains(&Ok(1)));
assert!(observed_values.contains(&Err(0)));
}
#[test]
fn compare_exchange_weak() {
let observed_values = Arc::new(std::sync::Mutex::new(HashSet::new()));
let observed_values_clone = observed_values.clone();
check_dfs(
move || {
let x = Arc::new($ty::new(0));
let thd1 = {
let x = x.clone();
thread::spawn(move || {
x.store(1, Ordering::SeqCst);
assert_eq!(me(), 1);
})
};
let result = x.compare_exchange_weak(1, 5, Ordering::SeqCst, Ordering::SeqCst);
check_clock(|i, c| {
(c > 0)
== match result {
Ok(_) => i == 0 || i == 1, Err(_) => i == 0, }
});
thd1.join().unwrap();
observed_values_clone.lock().unwrap().insert(result);
},
None,
);
let observed_values = Arc::try_unwrap(observed_values).unwrap().into_inner().unwrap();
assert_eq!(observed_values.len(), 2);
assert!(observed_values.contains(&Ok(1)));
assert!(observed_values.contains(&Err(0)));
}
}
};
}
int_tests!(int_i8, AtomicI8);
int_tests!(int_i16, AtomicI16);
int_tests!(int_i32, AtomicI32);
int_tests!(int_i64, AtomicI64);
int_tests!(int_isize, AtomicIsize);
int_tests!(int_u8, AtomicU8);
int_tests!(int_u16, AtomicU16);
int_tests!(int_u32, AtomicU32);
int_tests!(int_u64, AtomicU64);
int_tests!(int_usize, AtomicUsize);
mod bool {
use super::*;
use test_log::test;
#[test]
fn fetch_update() {
let observed_values = Arc::new(std::sync::Mutex::new(HashSet::new()));
let observed_values_clone = observed_values.clone();
check_dfs(
move || {
let x = Arc::new(AtomicBool::new(false));
let thd1 = {
let x = x.clone();
thread::spawn(move || {
x.store(true, Ordering::SeqCst);
assert_eq!(me(), 1);
})
};
let val = x
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |x| Some(!x))
.unwrap();
assert_eq!(me(), 0);
check_clock(|i, c| {
(c > 0)
== match val {
false => i == 0, true => i == 0 || i == 1, }
});
thd1.join().unwrap();
let x = Arc::try_unwrap(x).unwrap().into_inner();
observed_values_clone.lock().unwrap().insert(x);
},
None,
);
let observed_values = Arc::try_unwrap(observed_values).unwrap().into_inner().unwrap();
assert_eq!(observed_values.len(), 2);
}
#[test]
fn compare_exchange() {
let observed_values = Arc::new(std::sync::Mutex::new(HashSet::new()));
let observed_values_clone = observed_values.clone();
check_dfs(
move || {
let x = Arc::new(AtomicBool::new(false));
let thd1 = {
let x = x.clone();
thread::spawn(move || {
x.store(true, Ordering::SeqCst);
assert_eq!(me(), 1);
})
};
let result = x.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst);
check_clock(|i, c| {
(c > 0)
== match result {
Ok(_) => i == 0, Err(_) => i == 0 || i == 1, }
});
thd1.join().unwrap();
observed_values_clone.lock().unwrap().insert(result);
},
None,
);
let observed_values = Arc::try_unwrap(observed_values).unwrap().into_inner().unwrap();
assert_eq!(observed_values.len(), 2);
assert!(observed_values.contains(&Ok(false)));
assert!(observed_values.contains(&Err(true)));
}
}
mod ptr {
use super::*;
use test_log::test;
#[test]
fn fetch_update() {
let observed_values = Arc::new(std::sync::Mutex::new(HashSet::new()));
let observed_values_clone = observed_values.clone();
check_dfs(
move || {
let mut x = 0usize;
let x_ptr = &mut x as *mut _;
let ptr = Arc::new(AtomicPtr::new(x_ptr));
let thd1 = {
let ptr = ptr.clone();
thread::spawn(move || {
let mut y = 0usize;
ptr.store(&mut y as *mut _, Ordering::SeqCst);
assert_eq!(me(), 1);
})
};
let val = ptr
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |_| Some(x_ptr))
.unwrap();
check_clock(|i, c| (c > 0) == if val == x_ptr { i == 0 } else { i == 0 || i == 1 });
thd1.join().unwrap();
let ptr = Arc::try_unwrap(ptr).unwrap().into_inner();
observed_values_clone.lock().unwrap().insert(ptr == x_ptr);
},
None,
);
let observed_values = Arc::try_unwrap(observed_values).unwrap().into_inner().unwrap();
assert_eq!(observed_values.len(), 2);
}
#[test]
fn compare_exchange() {
let observed_values = Arc::new(std::sync::Mutex::new(HashSet::new()));
let observed_values_clone = observed_values.clone();
check_dfs(
move || {
let mut x = 0usize;
let x_ptr = &mut x as *mut _;
let mut x2 = 0usize;
let x2_ptr = &mut x2 as *mut _;
let ptr = Arc::new(AtomicPtr::new(x_ptr));
let thd1 = {
let ptr = ptr.clone();
thread::spawn(move || {
let mut y = 0usize;
ptr.store(&mut y as *mut _, Ordering::SeqCst);
assert_eq!(me(), 1);
})
};
let result = ptr.compare_exchange(x_ptr, x2_ptr, Ordering::SeqCst, Ordering::SeqCst);
check_clock(|i, c| {
(c > 0)
== match result {
Ok(_) => i == 0, Err(_) => i == 0 || i == 1, }
});
thd1.join().unwrap();
observed_values_clone.lock().unwrap().insert(result.is_ok());
},
None,
);
let observed_values = Arc::try_unwrap(observed_values).unwrap().into_inner().unwrap();
assert_eq!(observed_values.len(), 2);
}
}
#[test]
fn relaxed_orderings_work() {
check_dfs(
move || {
let flag = Arc::new(AtomicBool::new(false));
let data = Arc::new(AtomicU64::new(0));
{
let flag = Arc::clone(&flag);
let data = Arc::clone(&data);
thread::spawn(move || {
data.store(42, Ordering::Relaxed);
flag.store(true, Ordering::Relaxed);
});
}
if flag.load(Ordering::Relaxed) {
assert_eq!(data.load(Ordering::Relaxed), 42);
}
},
None,
);
}
#[test]
fn atomics_futures() {
check_dfs(
move || {
let flag = Arc::new(AtomicUsize::new(0));
let future = {
let flag = Arc::clone(&flag);
future::spawn(async move { flag.fetch_add(1, Ordering::SeqCst) })
};
let old = future::block_on(future).unwrap();
assert_eq!(old, 0);
assert_eq!(flag.load(Ordering::SeqCst), 1);
},
None,
)
}