use std::{
alloc::{Layout, alloc, dealloc},
fmt::Display,
ptr,
sync::atomic::{AtomicUsize, Ordering},
};
use crossbeam_utils::CachePadded;
use seqlock::SeqLock;
use crate::GOLDEN_RATIO;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum Error {
OutOfBounds(usize),
Uninitialized,
AlreadyInitialized,
Resizing,
Resized,
}
impl Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::OutOfBounds(index) => write!(f, "Index out of bounds: {}", index),
Error::Uninitialized => write!(f, "Slot is uninitialized"),
Error::AlreadyInitialized => write!(f, "Slot is already initialized"),
Error::Resizing => write!(
f,
"Array is currently resizing which will invalidate the requested capacity snapshot"
),
Error::Resized => write!(
f,
"Array has been resized since the requested capacity snapshot"
),
}
}
}
pub type Result<T> = core::result::Result<T, Error>;
pub struct SeqArray<T: Copy + Clone + Send + Sync + 'static> {
state: CachePadded<SeqLock<State<T>>>,
len: CachePadded<AtomicUsize>,
}
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
enum State<T: Copy + Clone + Send + Sync + 'static> {
Resizing {
old_capacity: usize,
new_capacity: usize,
old_ptr: *mut SeqLock<Option<T>>,
new_ptr: *mut SeqLock<Option<T>>,
},
Cloning {
capacity: usize,
ptr: *mut SeqLock<Option<T>>,
},
Available {
capacity: usize,
ptr: *mut SeqLock<Option<T>>,
},
}
unsafe impl<T: Copy + Clone + Send + Sync + 'static> Send for State<T> {}
unsafe impl<T: Copy + Clone + Send + Sync + 'static> Sync for State<T> {}
impl<T: Copy + Clone + Send + Sync + 'static> State<T> {
#[inline(always)]
const fn capacity(&self) -> usize {
match self {
State::Resizing { new_capacity, .. } => *new_capacity,
State::Cloning { capacity, .. } => *capacity,
State::Available { capacity, .. } => *capacity,
}
}
#[inline(always)]
const fn ptr(&self) -> *mut SeqLock<Option<T>> {
match self {
State::Resizing { new_ptr, .. } => *new_ptr,
State::Cloning { ptr, .. } => *ptr,
State::Available { ptr, .. } => *ptr,
}
}
}
impl<T: Copy + Clone + Send + Sync + 'static> SeqArray<T> {
#[inline(always)]
pub fn new() -> Self {
SeqArray::with_capacity(10)
}
#[inline(always)]
pub fn with_capacity(cap: usize) -> Self {
let layout = Layout::array::<SeqLock<Option<T>>>(cap).unwrap();
let ptr = unsafe { alloc(layout) as *mut SeqLock<Option<T>> };
for i in 0..cap {
unsafe {
ptr.add(i).write(SeqLock::new(None));
}
}
SeqArray {
state: CachePadded::new(SeqLock::new(State::Available { capacity: cap, ptr })),
len: CachePadded::new(AtomicUsize::new(0)),
}
}
#[inline(always)]
pub fn capacity(&self) -> usize {
self.state.read().capacity()
}
#[inline(always)]
pub fn len(&self) -> usize {
self.len.load(Ordering::SeqCst)
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline(always)]
pub fn resizing(&self) -> bool {
matches!(self.state.read(), State::Resizing { .. })
}
#[inline(always)]
pub fn cloning(&self) -> bool {
matches!(self.state.read(), State::Cloning { .. })
}
#[inline(always)]
pub fn get_without_resize(&self, snapshot_cap: usize, index: usize) -> Result<T> {
let (cap, ptr) = match self.state.read() {
State::Resizing { .. } => return Err(Error::Resizing),
State::Cloning { capacity, ptr } => (capacity, ptr),
State::Available { capacity, ptr } => (capacity, ptr),
};
if ptr.is_null() {
return Err(Error::Resizing);
}
if snapshot_cap != cap {
return Err(Error::Resized);
}
if index >= cap {
return Err(Error::OutOfBounds(index));
}
let slot = unsafe { &*ptr.add(index) };
match slot.read() {
Some(value) => Ok(value),
None => Err(Error::Uninitialized),
}
}
#[inline(always)]
pub fn get(&self, index: usize) -> Result<T> {
loop {
let (cap, ptr) = match self.state.read() {
State::Resizing { .. } => {
std::thread::yield_now();
continue;
}
State::Cloning { capacity, ptr } => (capacity, ptr),
State::Available { capacity, ptr } => (capacity, ptr),
};
if index >= cap {
return Err(Error::OutOfBounds(index));
}
if ptr.is_null() {
std::thread::yield_now();
continue;
}
let slot = unsafe { &*ptr.add(index) };
match slot.read() {
Some(val) => return Ok(val),
None => return Err(Error::Uninitialized),
}
}
}
#[inline(always)]
pub fn unset(&self, index: usize) -> Result<T> {
loop {
let (cap, ptr) = match self.state.read() {
State::Resizing { .. } => {
std::thread::yield_now();
continue;
}
State::Cloning { capacity, ptr } => (capacity, ptr),
State::Available { capacity, ptr } => (capacity, ptr),
};
if index >= cap {
return Err(Error::OutOfBounds(index));
}
if ptr.is_null() {
std::thread::yield_now();
continue;
}
let slot = unsafe { &*ptr.add(index) };
match self.state.read() {
State::Resizing { .. } | State::Cloning { .. } => {
std::thread::yield_now();
continue;
}
_ => (),
};
let mut guard = slot.lock_write();
let val = guard.take();
match val {
Some(val) => {
self.len.fetch_sub(1, Ordering::SeqCst);
return Ok(val);
}
None => return Err(Error::Uninitialized),
}
}
}
#[inline(always)]
pub fn unset_without_resize(&self, snapshot_cap: usize, index: usize) -> Result<T> {
loop {
let (cap, ptr) = match self.state.read() {
State::Resizing { .. } => {
return Err(Error::Resizing);
}
State::Cloning { .. } => {
std::thread::yield_now();
continue;
}
State::Available { capacity, ptr } => (capacity, ptr),
};
if index >= cap {
return Err(Error::OutOfBounds(index));
}
if cap != snapshot_cap {
return Err(Error::Resized);
}
if ptr.is_null() {
return Err(Error::Resizing);
}
let slot = unsafe { &*ptr.add(index) };
match self.state.read() {
State::Resizing { .. } => {
return Err(Error::Resizing);
}
State::Cloning { .. } => {
std::thread::yield_now();
continue;
}
State::Available { .. } => {
let mut guard = slot.lock_write();
match guard.take() {
None => return Err(Error::Uninitialized),
Some(val) => {
self.len.fetch_sub(1, Ordering::SeqCst);
return Ok(val);
}
}
}
}
}
}
#[inline(always)]
pub fn set(&self, index: usize, value: T) -> Result<Option<T>> {
loop {
let (cap, ptr) = match self.state.read() {
State::Resizing { .. } => {
std::thread::yield_now();
continue;
}
State::Cloning { capacity, ptr } => (capacity, ptr),
State::Available { capacity, ptr } => (capacity, ptr),
};
if index >= cap {
return Err(Error::OutOfBounds(index));
}
if ptr.is_null() {
std::thread::yield_now();
continue;
}
let slot = unsafe { &*ptr.add(index) };
match self.state.read() {
State::Resizing { .. } | State::Cloning { .. } => {
std::thread::yield_now();
continue;
}
State::Available { .. } => {
let mut guard = slot.lock_write();
if guard.is_none() {
self.len.fetch_add(1, Ordering::SeqCst);
}
return Ok(guard.replace(value));
}
}
}
}
#[inline(always)]
pub fn set_without_resize(
&self,
snapshot_cap: usize,
index: usize,
value: T,
) -> Result<Option<T>> {
loop {
let (cap, ptr) = match self.state.read() {
State::Resizing { .. } => {
return Err(Error::Resizing);
}
State::Cloning { .. } => {
std::thread::yield_now();
continue;
}
State::Available { capacity, ptr } => (capacity, ptr),
};
if index >= cap {
return Err(Error::OutOfBounds(index));
}
if cap != snapshot_cap {
return Err(Error::Resized);
}
if ptr.is_null() {
return Err(Error::Resizing);
}
let slot = unsafe { &*ptr.add(index) };
match self.state.read() {
State::Resizing { .. } => {
return Err(Error::Resizing);
}
State::Cloning { .. } => {
std::thread::yield_now();
continue;
}
State::Available { .. } => {
let mut guard = slot.lock_write();
if guard.is_none() {
self.len.fetch_add(1, Ordering::SeqCst);
}
return Ok(guard.replace(value));
}
}
}
}
#[inline(always)]
pub fn cas_set(&self, index: usize, value: T) -> Result<()> {
loop {
let (cap, ptr) = match self.state.read() {
State::Resizing { .. } => {
std::thread::yield_now();
continue;
}
State::Cloning { .. } => {
std::thread::yield_now();
continue;
}
State::Available { capacity, ptr } => (capacity, ptr),
};
if index >= cap {
return Err(Error::OutOfBounds(index));
}
if ptr.is_null() {
std::thread::yield_now();
continue;
}
let slot = unsafe { &*ptr.add(index) };
match self.state.read() {
State::Resizing { .. } | State::Cloning { .. } => {
std::thread::yield_now();
continue;
}
State::Available { .. } => {
let mut guard = slot.lock_write();
if guard.is_some() {
return Err(Error::AlreadyInitialized);
} else {
self.len.fetch_add(1, Ordering::SeqCst);
guard.replace(value);
return Ok(());
}
}
}
}
}
#[inline(always)]
pub fn cas_set_without_resize(
&self,
snapshot_cap: usize,
index: usize,
value: T,
) -> Result<()> {
loop {
let (cap, ptr) = match self.state.read() {
State::Resizing { .. } => {
return Err(Error::Resizing);
}
State::Cloning { .. } => {
continue;
}
State::Available { capacity, ptr } => (capacity, ptr),
};
if index >= cap {
return Err(Error::OutOfBounds(index));
}
if cap != snapshot_cap {
return Err(Error::Resized);
}
if ptr.is_null() {
return Err(Error::Resizing);
}
let slot = unsafe { &*ptr.add(index) };
match self.state.read() {
State::Resizing { .. } => return Err(Error::Resizing),
State::Cloning { .. } => {
std::thread::yield_now();
continue;
}
State::Available { .. } => {
let mut guard = slot.lock_write();
if guard.is_some() {
return Err(Error::AlreadyInitialized);
} else {
guard.replace(value);
self.len.fetch_add(1, Ordering::SeqCst);
return Ok(());
}
}
}
}
}
#[inline(always)]
pub fn reserve(&self, new_cap: usize) {
let cap = self.capacity();
if new_cap > cap {
self.resize(new_cap);
}
}
#[inline(always)]
pub fn scale_up(&self) {
let cap = self.capacity();
self.resize((cap as f64 * GOLDEN_RATIO).ceil() as usize);
}
#[inline(always)]
fn resize(&self, new_cap: usize) {
loop {
match self.state.read() {
State::Resizing { new_capacity, .. } => {
if new_cap <= new_capacity {
return;
} else {
std::thread::yield_now();
continue;
}
}
State::Cloning { .. } => {
std::thread::yield_now();
return;
}
State::Available { .. } => (),
}
let mut state_guard = self.state.lock_write();
let (old_cap, old_ptr) = match *state_guard {
State::Resizing { new_capacity, .. } => {
if new_cap <= new_capacity {
return;
} else {
std::thread::yield_now();
continue;
}
}
State::Cloning { .. } => {
std::thread::yield_now();
return;
}
State::Available { capacity, ptr } => (capacity, ptr),
};
if old_cap >= new_cap {
return;
}
let layout = Layout::array::<SeqLock<Option<T>>>(new_cap).unwrap();
let new_ptr = unsafe { alloc(layout) as *mut SeqLock<Option<T>> };
*state_guard = State::Resizing {
old_capacity: old_cap,
new_capacity: new_cap,
old_ptr,
new_ptr,
};
for i in 0..old_cap {
let old_slot = unsafe { &*old_ptr.add(i) }.read();
unsafe {
new_ptr.add(i).write(SeqLock::new(old_slot));
}
}
for i in old_cap..new_cap {
unsafe {
new_ptr.add(i).write(SeqLock::new(None));
}
}
*state_guard = State::Available {
capacity: new_cap,
ptr: new_ptr,
};
}
}
}
impl<T: Copy + Clone + Send + Sync + 'static> Clone for SeqArray<T> {
fn clone(&self) -> Self {
loop {
match self.state.read() {
State::Resizing { .. } => {
std::thread::yield_now();
continue;
}
State::Cloning { .. } => {
std::thread::yield_now();
continue;
}
State::Available { .. } => (),
}
let mut state_guard = self.state.lock_write();
let (old_cap, old_ptr) = match *state_guard {
State::Resizing { .. } => {
std::thread::yield_now();
continue;
}
State::Cloning { .. } => {
std::thread::yield_now();
continue;
}
State::Available { capacity, ptr } => (capacity, ptr),
};
let layout = Layout::array::<SeqLock<Option<T>>>(old_cap).unwrap();
let new_ptr = unsafe { alloc(layout) as *mut SeqLock<Option<T>> };
*state_guard = State::Cloning {
capacity: old_cap,
ptr: new_ptr,
};
unsafe {
ptr::copy_nonoverlapping(old_ptr, new_ptr, old_cap);
}
let len = self.len.load(Ordering::SeqCst);
*state_guard = State::Available {
capacity: old_cap,
ptr: old_ptr,
};
return SeqArray {
state: CachePadded::new(SeqLock::new(State::Available {
capacity: old_cap,
ptr: new_ptr,
})),
len: CachePadded::new(AtomicUsize::new(len)),
};
}
}
}
impl<T: Copy + Clone + Send + Sync + 'static> Default for SeqArray<T> {
fn default() -> Self {
SeqArray::with_capacity(10)
}
}
impl<T: Copy + Clone + Send + Sync + 'static> Drop for SeqArray<T> {
fn drop(&mut self) {
let guard = self.state.lock_write();
let cap = guard.capacity();
let ptr = guard.ptr();
if !ptr.is_null() {
let layout = Layout::array::<SeqLock<Option<T>>>(cap).unwrap();
unsafe {
for i in 0..cap {
ptr::drop_in_place(ptr.add(i));
}
dealloc(ptr as *mut u8, layout);
}
}
}
}
impl<T: Copy + Clone + Send + Sync + 'static> IntoIterator for &SeqArray<T> {
type Item = T;
type IntoIter = std::vec::IntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
let clone = self.clone();
(0..clone.capacity())
.filter_map(|i| clone.get(i).ok())
.collect::<Vec<_>>()
.into_iter()
}
}
impl<T: Copy + Clone + Send + Sync + 'static + PartialEq> PartialEq for SeqArray<T> {
fn eq(&self, other: &Self) -> bool {
use std::thread;
let handle_a = thread::spawn({
let this = self.clone();
move || this
});
let handle_b = thread::spawn({
let that = other.clone();
move || that
});
let clone_a = handle_a.join().unwrap();
let clone_b = handle_b.join().unwrap();
if clone_a.capacity() != clone_b.capacity() {
return false;
}
if clone_a.capacity() == 0 {
return true; }
let cap = clone_a.capacity();
(0..cap)
.map(|i| clone_a.get(i))
.zip((0..cap).map(|i| clone_b.get(i)))
.all(|(a, b)| a == b)
}
}
impl<T: Copy + Clone + Send + Sync + 'static + Eq> Eq for SeqArray<T> {}
#[cfg(test)]
use rayon::prelude::*;
#[test]
fn concurrent_set_unique() {
use std::collections::HashSet;
let vec: SeqArray<usize> = SeqArray::with_capacity(1000000);
let n = 1000000;
(0..n).into_par_iter().for_each(|i| {
vec.set(i, i).unwrap();
});
let mut results: Vec<_> = (0..vec.capacity()).map(|i| vec.get(i)).collect();
let mut seen = HashSet::new();
for val in results.drain(..) {
let v = val.expect("All slots should be filled");
assert!(seen.insert(v), "Duplicate value {v} found");
}
assert_eq!(seen.len(), n);
}
#[test]
fn concurrent_set_and_get_double() {
let vec: SeqArray<usize> = SeqArray::with_capacity(1000000);
let n = 1000000;
(0..n).into_par_iter().for_each(|i| {
vec.set(i, i * 2).unwrap();
});
let results: Vec<_> = (0..vec.capacity())
.into_par_iter()
.map(|i| vec.get(i))
.collect();
for (i, val) in results.into_iter().enumerate() {
assert_eq!(val, Ok(i * 2));
}
}
#[test]
fn concurrent_set_with_reserve() {
let vec: SeqArray<usize> = SeqArray::with_capacity(10);
let n = 1000000;
vec.reserve(n); (0..n).into_par_iter().for_each(|i| {
vec.set(i, i).unwrap();
});
let mut results: Vec<_> = (0..vec.capacity()).map(|i| vec.get(i).unwrap()).collect();
results.sort_unstable();
for (expected, actual) in (0..n).zip(results) {
assert_eq!(expected, actual, "Value mismatch at index {expected}");
}
}
#[test]
fn concurrent_contention_on_same_indexes() {
let vec: SeqArray<usize> = SeqArray::with_capacity(1000);
let n = 1000;
vec.reserve(n); (0..n)
.flat_map(|i| (0..10).map(move |j| (i, j)))
.par_bridge()
.for_each(|(i, j)| {
vec.set(i, j).unwrap();
});
let results: Vec<_> = (0..n).map(|i| vec.get(i)).collect();
for val in results {
let v = val.expect("Slot should be filled");
assert!((0..10).contains(&v), "Value {v} not in expected range");
}
}
#[test]
fn cas_set_basic_unique() {
let arr: SeqArray<usize> = SeqArray::with_capacity(10);
for i in 0..10 {
arr.cas_set(i, i).unwrap();
assert_eq!(arr.get(i), Ok(i));
arr.cas_set(i, i + 100).unwrap_err();
assert_eq!(arr.get(i), Ok(i));
}
}
#[test]
fn cas_set_concurrent_unique() {
use rayon::prelude::*;
let arr: SeqArray<usize> = SeqArray::with_capacity(1000000);
(0..1000000).into_par_iter().for_each(|i| {
let _ = arr.cas_set(i, i);
});
for i in 0..1000000 {
assert_eq!(arr.get(i), Ok(i));
}
}
#[test]
fn cas_set_concurrent_contention_range() {
use rayon::prelude::*;
let arr: SeqArray<i32> = SeqArray::with_capacity(100);
(0..100)
.flat_map(|i| (0..10).map(move |j| (i, j)))
.par_bridge()
.for_each(|(i, j)| match arr.cas_set(i, j) {
Ok(_) => (),
Err(Error::AlreadyInitialized) => (),
Err(e) => panic!("Unexpected error: {:?}", e),
});
for i in 0..100 {
let v = arr.get(i).unwrap();
assert!((0..10).contains(&v), "Slot {i} has unexpected value {v}");
}
}
#[test]
fn full_table_probe_and_insert() {
let n = 16;
let arr: SeqArray<(usize, usize)> = SeqArray::with_capacity(n);
for i in 0..(n - 1) {
arr.cas_set(i, (i, i * 10)).unwrap();
}
let key = 123456;
let value = 9999;
let mut found = false;
for i in 0..n {
if arr.get(i).is_err() {
arr.cas_set(i, (key, value)).unwrap();
found = true;
break;
}
}
assert!(found, "Should have found an empty slot");
for i in 0..(n - 1) {
assert_eq!(arr.get(i), Ok((i, i * 10)));
}
let mut found = false;
for i in 0..n {
if let Ok((k, v)) = arr.get(i) {
if k == key {
assert_eq!(v, value);
found = true;
}
}
}
assert!(found, "Special key not found");
}
#[test]
fn test_partial_eq_with_parallel_cloning() {
let arr1 = SeqArray::with_capacity(10);
for i in 0..10 {
arr1.set(i, i * i).unwrap();
}
let arr2 = arr1.clone();
assert!(arr1 == arr2, "Cloned arrays should be equal");
let arr_large = SeqArray::with_capacity(20);
assert!(
arr1 != arr_large,
"Arrays with different capacities should not be equal"
);
}
#[test]
fn test_partial_eq_early_return() {
let arr1: SeqArray<i32> = SeqArray::with_capacity(10);
let arr2: SeqArray<i32> = SeqArray::with_capacity(20);
assert!(arr1 != arr2, "Early return on capacity mismatch");
}
#[test]
fn clone_during_concurrent_growth() {
use rayon::join;
let arr: SeqArray<usize> = SeqArray::with_capacity(4);
let n = 10_000;
join(
|| {
for i in 0..n {
while i >= arr.capacity() {
arr.scale_up();
}
arr.set(i, i).unwrap();
}
},
|| {
let mut cap = arr.capacity();
for _ in 0..100 {
let clone = arr.clone();
assert!(
clone.capacity() >= cap,
"Clone capacity should be at least as large as original"
);
cap = clone.capacity();
}
},
);
}
#[test]
fn into_iter_snapshot() {
let arr: SeqArray<usize> = SeqArray::with_capacity(8);
for i in 0..8 {
arr.set(i, i).unwrap();
}
let snapshot: Vec<_> = (&arr).into_iter().collect();
assert_eq!(snapshot, (0..8).collect::<Vec<_>>());
for i in 0..8 {
arr.set(i, i * 10).unwrap();
}
assert_eq!(snapshot, (0..8).collect::<Vec<_>>());
let new_snapshot: Vec<_> = (&arr).into_iter().collect();
assert_eq!(new_snapshot, (0..8).map(|i| i * 10).collect::<Vec<_>>());
}