use std::cell::UnsafeCell;
use std::collections::LinkedList;
use std::mem::MaybeUninit;
use std::sync::{
atomic::{AtomicU32, Ordering},
Arc, RwLock,
};
const CHUNK_SIZE: usize = 128;
#[derive(Clone, Debug)]
pub struct StableList<T>(Arc<StableListInner<T>>);
impl<T> StableList<T> {
pub fn new() -> Self {
Self(Arc::new(StableListInner::new()))
}
#[cfg(test)]
pub fn iter(&self) -> StableListIterator<T> {
StableListIterator {
idx: 0,
end_idx: None,
chunk: std::ptr::null(),
list: self,
}
}
pub fn bounded_iter(&self, start: usize, end: Option<usize>) -> StableListIterator<T> {
StableListIterator {
idx: start,
end_idx: end,
chunk: std::ptr::null(),
list: self,
}
}
pub fn push(&self, item: T) {
self.0.push(item)
}
#[cfg(test)]
pub fn get(&self, idx: usize) -> Option<&T> {
self.0.get(idx)
}
pub fn len(&self) -> usize {
self.0.len()
}
}
#[derive(Debug)]
struct StableListInner<T> {
list_lock: RwLock<LinkedList<*const [UnsafeCell<MaybeUninit<T>>; CHUNK_SIZE]>>,
last_global_idx: AtomicU32,
}
unsafe impl<T> Send for StableListInner<T> where T: Send {}
unsafe impl<T> Sync for StableListInner<T> {}
impl<T> StableListInner<T> {
fn new() -> Self {
let list: LinkedList<*const _> = LinkedList::new();
StableListInner {
list_lock: RwLock::new(list),
last_global_idx: AtomicU32::new(0),
}
}
fn push(&self, item: T) {
let mut list = match self.list_lock.write() {
Ok(lock) => lock,
Err(_) => panic!("StableList's internal mutex has been poisoned"),
};
let global_idx = self.last_global_idx.load(Ordering::SeqCst) as usize;
if global_idx == u32::MAX as usize {
panic!("list is full, cannot index past 2^32");
}
if global_idx % CHUNK_SIZE == 0 {
#[allow(clippy::uninit_assumed_init)]
let block: [UnsafeCell<MaybeUninit<T>>; CHUNK_SIZE] =
unsafe { MaybeUninit::uninit().assume_init() };
list.push_back(Box::into_raw(Box::new(block)));
}
let last_block = list
.iter_mut()
.last()
.expect("no block in list even though we tried to add one");
unsafe { *(**last_block)[global_idx % CHUNK_SIZE].get() = MaybeUninit::new(item) };
self.last_global_idx.fetch_add(1, Ordering::SeqCst);
}
fn len(&self) -> usize {
self.last_global_idx.load(Ordering::SeqCst) as usize
}
unsafe fn get_chunk(
&self,
idx: usize,
) -> Option<*const [UnsafeCell<MaybeUninit<T>>; CHUNK_SIZE]> {
match self.list_lock.read() {
Ok(lock) => lock.iter().nth(idx).copied(),
Err(_) => panic!("StableList's internal mutex has been poisoned"),
}
}
#[cfg(test)]
fn get(&self, idx: usize) -> Option<&T> {
if idx < self.last_global_idx.load(Ordering::SeqCst) as usize {
let list = match self.list_lock.read() {
Ok(lock) => lock,
Err(_) => panic!("StableList's internal mutex has been poisoned"),
};
list.iter()
.nth(idx / CHUNK_SIZE)
.map(|ch| unsafe { unwrap_value(&(&**ch)[idx % CHUNK_SIZE]) })
} else {
None
}
}
}
unsafe fn unwrap_value<T>(cell: &UnsafeCell<MaybeUninit<T>>) -> &T {
&*cell.get().as_ref().unwrap().as_ptr().as_ref().unwrap()
}
#[derive(Debug)]
pub struct StableListIterator<'a, T> {
idx: usize,
end_idx: Option<usize>,
chunk: *const [UnsafeCell<MaybeUninit<T>>; CHUNK_SIZE],
list: &'a StableList<T>,
}
impl<'a, T> Iterator for StableListIterator<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
if self.chunk.is_null() {
if self.idx > self.list.len() {
return None;
}
match unsafe { self.list.0.get_chunk((self.idx) / CHUNK_SIZE) } {
Some(next_chunk) => self.chunk = next_chunk,
None => return None,
}
return Some(unsafe { unwrap_value(&(&*self.chunk)[self.idx % CHUNK_SIZE]) });
}
if let Some(end_idx) = self.end_idx {
if self.idx + 1 == end_idx {
return None;
}
}
if self.idx + 1 == self.list.len() {
return None;
}
if self.idx % CHUNK_SIZE + 1 == CHUNK_SIZE {
match unsafe { self.list.0.get_chunk(self.idx / CHUNK_SIZE + 1) } {
None => return None,
Some(chunk) => {
self.chunk = chunk;
}
}
}
self.idx += 1;
Some(unsafe { unwrap_value(&(&*self.chunk)[self.idx % CHUNK_SIZE]) })
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn push_and_check_single_item() {
let list = StableList::new();
assert_eq!(list.get(0), None);
list.push(1002);
assert_eq!(list.get(0), Some(&1002));
assert_eq!(list.iter().next(), Some(&1002));
}
#[test]
fn push_and_check_full_chunk() {
let list = StableList::new();
assert_eq!(list.get(0), None);
for i in 0..CHUNK_SIZE {
list.push(100 + i);
}
for i in 0..CHUNK_SIZE {
assert_eq!(list.get(i), Some(&(100 + i)));
}
}
#[test]
fn push_and_check_multiple_chunks() {
let list = StableList::new();
assert_eq!(list.get(0), None);
for i in 0..(CHUNK_SIZE * 2) {
list.push(100 + i);
}
for i in 0..(CHUNK_SIZE * 2) {
assert_eq!(list.get(i), Some(&(100 + i)));
}
}
#[test]
fn populate_and_iterate_simple() {
let list = StableList::new();
let iter = list.iter();
let arb_values = CHUNK_SIZE * 2 + 1;
for i in 0..(arb_values) {
list.push(i * 10);
}
assert_eq!(list.len(), arb_values);
let mut values_found = 0;
for (exp, val) in (0..).zip(iter) {
assert_eq!(exp * 10, *val);
values_found += 1;
}
assert_eq!(values_found, arb_values);
}
#[test]
fn iterator_resumption() {
let list = StableList::new();
let mut iter = list.iter();
assert_eq!(list.len(), 0);
assert_eq!(iter.next(), None);
list.push(1000);
assert_eq!(list.len(), 1);
assert_eq!(iter.next(), Some(&1000));
}
#[test]
fn multiple_iterators() {
let list = StableList::<i32>::new();
let mut iter_1 = list.iter();
let mut iter_2 = list.iter();
for i in 100..200 {
list.push(i);
let a = iter_1.next();
let b = iter_2.next();
assert_eq!(Some(&i), a);
assert_eq!(Some(&i), b);
}
}
#[test]
fn bounded_iterator_resumption() {
let list = StableList::new();
let mut iter = list.bounded_iter(0, Some(5));
assert_eq!(iter.next(), None);
for i in 1000..1010 {
list.push(i);
}
assert_eq!(list.len(), 10);
for i in 1000..1005 {
assert_eq!(iter.next(), Some(&i));
}
assert_eq!(iter.next(), None);
}
#[test]
fn bounded_iterator_non_zero_start() {
let list = StableList::new();
let mut iter = list.bounded_iter(5, None);
assert_eq!(iter.next(), None);
for i in 1000..1010 {
list.push(i);
}
for i in 1005..1010 {
assert_eq!(iter.next(), Some(&i));
}
assert_eq!(iter.next(), None);
}
#[test]
fn bounded_iterator_boundary_start() {
let list = StableList::new();
for i in 0..(CHUNK_SIZE * 2) {
list.push(i * 2);
}
let expected = ((CHUNK_SIZE - 1)..=(CHUNK_SIZE + 1))
.map(|v| v * 2)
.collect::<Vec<usize>>();
let lower_iter = list.bounded_iter(CHUNK_SIZE - 1, None);
let middle_iter = list.bounded_iter(CHUNK_SIZE, None);
let mut upper_iter = list.bounded_iter(CHUNK_SIZE + 1, None);
assert_eq!(
expected,
lower_iter.take(3).copied().collect::<Vec<usize>>()
);
assert_eq!(
expected[1..],
middle_iter.take(2).copied().collect::<Vec<usize>>()
);
assert_eq!(expected.get(2), upper_iter.next());
}
}