use std::cell::UnsafeCell;
use std::mem::MaybeUninit;
use std::hint::spin_loop;
use std::sync::atomic::{AtomicU64, Ordering};
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
const LOCK_BIT: u64 = 1 << 0;
const INIT_BIT: u64 = 1 << 1;
const VERSION_INC: u64 = 1 << 2;
#[repr(align(64))]
pub struct Unit<E> {
version: AtomicU64,
value: UnsafeCell<MaybeUninit<E>>,
}
unsafe impl<E: Send> Send for Unit<E> {}
unsafe impl<E: Send + Sync> Sync for Unit<E> {}
#[repr(align(64))]
pub struct AtomicArray<E> {
units: Box<[Unit<E>]>,
capacity: usize,
}
unsafe impl<E: Send> Send for AtomicArray<E> {}
unsafe impl<E: Send + Sync> Sync for AtomicArray<E> {}
pub trait DataType<E> {
type Read;
fn read(unit: &Unit<E>) -> Self::Read;
}
pub struct CloneType;
pub struct CopyType;
impl<E: Clone> DataType<E> for CloneType {
type Read = E;
fn read(unit: &Unit<E>) -> Self::Read {
unsafe { (*unit.value.get()).assume_init_ref().clone() }
}
}
impl<E: Copy> DataType<E> for CopyType {
type Read = E;
fn read(unit: &Unit<E>) -> Self::Read {
unsafe { (*unit.value.get()).assume_init() }
}
}
impl<E> AtomicArray<E> {
#[must_use = "New instances of AtomicArray must serve a purpose!"]
pub fn new(capacity: usize) -> Self {
assert!(capacity > 0, "Capacity must be greater than 0!");
let mut units = Vec::with_capacity(capacity);
for _ in 0..capacity {
units.push(Unit {
version: AtomicU64::new(0),
value: UnsafeCell::new(MaybeUninit::uninit()),
});
}
Self { units: units.into_boxed_slice(), capacity }
}
#[inline(always)]
pub fn capacity(&self) -> usize {
self.capacity
}
#[inline(always)]
fn unit(&self, index: usize) -> &Unit<E> {
assert!(index < self.units.len());
&self.units[index]
}
#[inline(always)]
fn reserve_and<F>(&self, index: usize, closure: F)
where
F: FnOnce(&Unit<E>)
{
let unit = self.unit(index);
let mut v;
loop {
v = unit.version.load(Ordering::Acquire);
if (v & LOCK_BIT) != 0 {
spin_loop();
continue;
}
if unit.version.compare_exchange(v, v | LOCK_BIT, Ordering::AcqRel, Ordering::Acquire)
.is_ok() {
break;
}
}
let result = catch_unwind(AssertUnwindSafe(|| {
closure(unit);
}));
let new_v = (unit.version.load(Ordering::Relaxed) & !LOCK_BIT) + VERSION_INC;
unit.version.store(
new_v,
Ordering::Release
);
if let Err(panic) = result {
resume_unwind(panic);
}
}
pub fn store(&self, index: usize, value: E) {
assert!(index < self.capacity, "Index must be less than capacity at all times!");
self.reserve_and(index, |unit| {
unsafe {
(*unit.value.get()) = MaybeUninit::new(value);
}
let v = unit.version.load(Ordering::Relaxed);
unit.version.store(v | INIT_BIT, Ordering::Relaxed);
});
}
pub fn load<D: DataType<E>>(&self, index: usize) -> D::Read {
assert!(index < self.capacity, "Index must be less than capacity at all times!");
let unit = self.unit(index);
loop {
let v1 = unit.version.load(Ordering::Acquire);
if (v1 & LOCK_BIT) != 0 {
spin_loop();
continue;
}
if (v1 & INIT_BIT) == 0 {
panic!("load is performed on uninitialized index: {}", index);
}
let value = D::read(unit);
let v2 = unit.version.load(Ordering::Acquire);
if v1 == v2 {
return value;
}
}
}
pub fn try_load<D: DataType<E>>(&self, index: usize) -> Option<D::Read> {
assert!(index < self.capacity, "Index must be less than capacity at all times!");
let unit = self.unit(index);
loop {
let v1 = unit.version.load(Ordering::Acquire);
if (v1 & LOCK_BIT) != 0 {
spin_loop();
continue;
}
if (v1 & INIT_BIT) == 0 {
return None;
}
let value = D::read(unit);
let v2 = unit.version.load(Ordering::Acquire);
if v1 == v2 {
return Some(value);
}
}
}
pub fn replace<D: DataType<E>>(&self, index: usize, value: E) -> D::Read {
assert!(index < self.capacity, "Index must be less than capacity at all times!");
let mut old = None;
self.reserve_and(index, |unit| {
let v = unit.version.load(Ordering::Relaxed);
if (v & INIT_BIT) == 0 {
panic!("replace is performed on uninitialized index: {}", index);
}
unsafe {
old = Some(D::read(unit));
(*unit.value.get()) = MaybeUninit::new(value);
}
unit.version.store(v | INIT_BIT, Ordering::Relaxed);
});
old.unwrap()
}
pub fn try_replace<D: DataType<E>>(&self, index: usize, value: E) -> Option<D::Read> {
assert!(index < self.capacity, "Index must be less than capacity at all times!");
let mut old = None;
self.reserve_and(index, |unit| {
let v = unit.version.load(Ordering::Relaxed);
if (v & INIT_BIT) == 0 {
return;
}
unsafe {
old = Some(D::read(unit));
(*unit.value.get()) = MaybeUninit::new(value);
}
unit.version.store(v | INIT_BIT, Ordering::Relaxed);
});
old
}
pub fn update<F, D>(&self, index: usize, closure: F)
where
D: DataType<E>,
F: FnOnce(D::Read) -> E,
{
assert!(index < self.capacity, "Index must be less than capacity at all times!");
self.reserve_and(index, |unit| {
let v = unit.version.load(Ordering::Relaxed);
if (v & INIT_BIT) == 0 {
panic!("update is performed on uninitialized index: {}", index);
}
unsafe {
let current = D::read(unit);
(*unit.value.get()) = MaybeUninit::new(closure(current));
}
unit.version.store(v | INIT_BIT, Ordering::Relaxed);
});
}
pub fn try_update<F, D>(&self, index: usize, closure: F) -> bool
where
D: DataType<E>,
F: FnOnce(D::Read) -> E,
{
assert!(index < self.capacity, "Index must be less than capacity at all times!");
let mut status = false;
self.reserve_and(index, |unit| {
let v = unit.version.load(Ordering::Relaxed);
if (v & INIT_BIT) == 0 {
return;
}
unsafe {
let current = D::read(unit);
(*unit.value.get()) = MaybeUninit::new(closure(current));
}
unit.version.store(v | INIT_BIT, Ordering::Relaxed);
status = true;
});
status
}
pub fn length(&self) -> usize {
self.units.iter().filter(|unit| {
unit.version.load(Ordering::Acquire) & INIT_BIT != 0
})
.count()
}
#[inline(always)]
pub fn contains_index(&self, index: usize) -> bool {
assert!(index < self.capacity, "Index must always be less than capacity!");
self.unit(index)
.version
.load(Ordering::Acquire) & INIT_BIT != 0
}
pub fn free(&self, index: usize) {
assert!(index < self.capacity, "Index must always be less than capacity!");
self.reserve_and(index, |unit| {
let v = unit.version.load(Ordering::Relaxed);
unit.version.store(v & !INIT_BIT, Ordering::Relaxed);
});
}
pub fn swap<D: DataType<E>>(&self, index: usize, value: E) -> D::Read { self.replace::<D>(index, value) }
pub fn load_or_store<F, D>(&self, index: usize, store: F) -> D::Read
where
F: FnOnce() -> E,
D: DataType<E>,
{
assert!(index < self.capacity, "Index must be less than capacity at all times!");
if let Some(value) = self.try_load::<D>(index) {
return value;
}
self.reserve_and(index, |unit| {
let v = unit.version.load(Ordering::Relaxed);
if (v & INIT_BIT) == 0 {
unsafe {
(*unit.value.get()) = MaybeUninit::new(store());
}
unit.version.store(v | INIT_BIT, Ordering::Relaxed);
}
});
self.load::<D>(index)
}
}
impl<E> Drop for AtomicArray<E> {
fn drop(&mut self) {
for unit in self.units.iter_mut() {
let v = unit.version.load(Ordering::Relaxed);
if (v & INIT_BIT) != 0 {
unsafe {
unit.value.get_mut().assume_init_drop();
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
use std::sync::atomic::{AtomicBool, Ordering};
#[test]
fn basic_store_load_copy() {
let arr = AtomicArray::new(4);
arr.store(0, 42u64);
let v = arr.load::<CopyType>(0);
assert_eq!(v, 42);
}
#[test]
fn basic_store_load_clone() {
let arr = AtomicArray::new(2);
arr.store(1, String::from("hello"));
let v = arr.load::<CloneType>(1);
assert_eq!(v, "hello");
}
#[test]
fn try_load_uninitialized() {
let arr: AtomicArray<u32> = AtomicArray::new(1);
assert!(arr.try_load::<CopyType>(0).is_none());
}
#[test]
fn replace_returns_old() {
let arr = AtomicArray::new(1);
arr.store(0, 10);
let old = arr.replace::<CopyType>(0, 20);
assert_eq!(old, 10);
assert_eq!(arr.load::<CopyType>(0), 20);
}
#[test]
fn update_in_place() {
let arr = AtomicArray::new(1);
arr.store(0, 5);
arr.update::<_, CopyType>(0, |v| v * 2);
assert_eq!(arr.load::<CopyType>(0), 10);
}
#[test]
fn concurrent_reads() {
let arr = Arc::new(AtomicArray::new(1));
arr.store(0, 123);
let mut handles = vec![];
for _ in 0..16 {
let a = arr.clone();
handles.push(thread::spawn(move || {
for _ in 0..100_000 {
assert_eq!(a.load::<CopyType>(0), 123);
}
}));
}
for h in handles {
h.join().unwrap();
}
}
#[test]
fn concurrent_read_write() {
let arr = Arc::new(AtomicArray::new(1));
arr.store(0, 0u64);
let stop = Arc::new(AtomicBool::new(false));
{
let a = arr.clone();
let s = stop.clone();
thread::spawn(move || {
let mut i = 0;
while !s.load(Ordering::Relaxed) {
a.store(0, i);
i += 1;
}
});
}
let mut readers = vec![];
for _ in 0..8 {
let a = arr.clone();
readers.push(thread::spawn(move || {
for _ in 0..100_000 {
let _ = a.load::<CopyType>(0);
}
}));
}
for r in readers {
r.join().unwrap();
}
stop.store(true, Ordering::Relaxed);
}
#[test]
fn stress_test() {
let arr = Arc::new(AtomicArray::new(8));
for i in 0..8 {
arr.store(i, i as u64);
}
let mut threads = vec![];
for t in 0..16 {
let a = arr.clone();
threads.push(thread::spawn(move || {
for i in 0..200_000 {
let idx = (i + t) % 8;
let _ = a.load::<CopyType>(idx);
a.update::<_, CopyType>(idx, |v| v + 1);
}
}));
}
for t in threads {
t.join().unwrap();
}
}
#[test]
fn version_monotonicity() {
let arr = AtomicArray::new(1);
arr.store(0, 0u64);
let unit = &arr.units[0];
let mut last = unit.version.load(Ordering::Relaxed);
for _ in 0..1000 {
arr.update::<_, CopyType>(0, |v| v + 1);
let now = unit.version.load(Ordering::Relaxed);
assert!(now > last);
last = now;
}
}
#[test]
fn drop_clone_safety() {
use std::sync::atomic::{AtomicUsize, Ordering};
static DROPS: AtomicUsize = AtomicUsize::new(0);
#[derive(Clone)]
struct DropCounter;
impl Drop for DropCounter {
fn drop(&mut self) {
DROPS.fetch_add(1, Ordering::Relaxed);
}
}
{
let arr = AtomicArray::new(1);
arr.store(0, DropCounter);
}
assert_eq!(DROPS.load(Ordering::Relaxed), 1);
}
#[test]
fn writer_eventually_makes_progress() {
let arr = Arc::new(AtomicArray::new(1));
arr.store(0, 0u64);
let done = Arc::new(AtomicBool::new(false));
let reader = {
let a = arr.clone();
let d = done.clone();
thread::spawn(move || {
while !d.load(Ordering::Relaxed) {
let _ = a.load::<CopyType>(0);
}
})
};
let writer = {
let a = arr.clone();
thread::spawn(move || {
for _ in 0..1000 {
a.update::<_, CopyType>(0, |v| v + 1);
}
})
};
writer.join().unwrap();
done.store(true, Ordering::Relaxed);
reader.join().unwrap();
}
#[test]
fn aba_like_pattern() {
let arr = AtomicArray::new(1);
arr.store(0, 10u64);
let first = arr.load::<CopyType>(0);
arr.store(0, 20);
arr.store(0, 10);
let second = arr.load::<CopyType>(0);
assert_eq!(first, second);
}
#[test]
fn update_panic_does_not_corrupt() {
use std::panic::{catch_unwind, AssertUnwindSafe};
let arr = AtomicArray::new(1);
arr.store(0, 5u64);
let result = catch_unwind(AssertUnwindSafe(|| {
arr.update::<_, CopyType>(0, |_| {
panic!("boom");
});
}));
assert!(result.is_err());
assert_eq!(arr.load::<CopyType>(0), 5);
}
#[test]
fn linearizability_sanity() {
let arr = Arc::new(AtomicArray::new(1));
arr.store(0, 0u64);
let a1 = arr.clone();
let t1 = thread::spawn(move || {
a1.update::<_, CopyType>(0, |v| v + 1);
});
let a2 = arr.clone();
let t2 = thread::spawn(move || {
a2.update::<_, CopyType>(0, |v| v + 1);
});
t1.join().unwrap();
t2.join().unwrap();
let v = arr.load::<CopyType>(0);
assert!(v == 1 || v == 2);
}
#[test]
fn drop_after_concurrent_access() {
use std::sync::Arc;
use std::thread;
let arr = Arc::new(AtomicArray::new(1));
arr.store(0, String::from("hello"));
let a = arr.clone();
let t = thread::spawn(move || {
for _ in 0..1000 {
let _ = a.load::<CloneType>(0);
}
});
t.join().unwrap();
drop(arr); }
#[test]
fn slots_are_independent() {
let arr = Arc::new(AtomicArray::new(2));
arr.store(0, 1u64);
arr.store(1, 100u64);
let a = arr.clone();
let t = std::thread::spawn(move || {
for _ in 0..1000 {
a.update::<_, CopyType>(0, |v| v + 1);
}
});
for _ in 0..1000 {
assert_eq!(arr.load::<CopyType>(1), 100);
}
t.join().unwrap();
}
}
#[cfg(all(feature = "loom_test", feature = "atomic-array"))]
mod loom_tests {
use super::*;
use loom::sync::Arc;
use loom::thread;
#[test]
fn loom_store_load() {
loom::model(|| {
let arr = Arc::new(AtomicArray::new(1));
let a1 = arr.clone();
let t1 = thread::spawn(move || {
a1.store(0, 1u64);
});
let a2 = arr.clone();
let t2 = thread::spawn(move || {
let _ = a2.try_load::<CopyType>(0);
});
t1.join().unwrap();
t2.join().unwrap();
});
}
#[test]
fn loom_no_torn_read() {
loom::model(|| {
let arr = loom::sync::Arc::new(AtomicArray::new(1));
let writer = {
let a = arr.clone();
loom::thread::spawn(move || {
for _ in 0..3 {
a.store(0, 1u64);
a.store(0, 2u64);
}
})
};
let reader = {
let a = arr.clone();
loom::thread::spawn(move || {
for _ in 0..3 {
if let Some(v) = a.try_load::<CopyType>(0) {
assert!(v == 1 || v == 2);
}
}
})
};
writer.join().unwrap();
reader.join().unwrap();
});
}
#[test]
fn loom_concurrent_updates() {
loom::model(|| {
let arr = loom::sync::Arc::new(AtomicArray::new(1));
arr.store(0, 0u64);
let t1 = {
let a = arr.clone();
loom::thread::spawn(move || {
a.update::<_, CopyType>(0, |v| v + 1);
})
};
let t2 = {
let a = arr.clone();
loom::thread::spawn(move || {
a.update::<_, CopyType>(0, |v| v + 1);
})
};
t1.join().unwrap();
t2.join().unwrap();
let v = arr.load::<CopyType>(0);
assert!(v == 1 || v == 2);
});
}
#[test]
fn loom_panic_does_not_lock_forever() {
loom::model(|| {
let arr = Arc::new(AtomicArray::new(1));
arr.store(0, 1u64);
let a = arr.clone();
let _ = thread::spawn(move || {
let _ = std::panic::catch_unwind(AssertUnwindSafe(|| {
a.update::<_, CopyType>(0, |_| panic!("boom"));
}));
}).join();
let _ = arr.load::<CopyType>(0);
});
}
}