use std::alloc::{self, Layout};
use std::cell::UnsafeCell;
use std::marker::PhantomData;
use std::mem::{align_of, size_of};
use std::ptr::{self, NonNull};
use std::sync::Arc;
use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
const MAX_HEIGHT: usize = 16;
const BRANCHING: u64 = 4;
const CHUNK_SIZE: usize = 1 << 20;
#[repr(C)]
struct Node<K> {
key: K,
}
#[inline]
fn tower_offset<K>() -> usize {
let align = align_of::<AtomicPtr<Node<K>>>();
(size_of::<Node<K>>() + align - 1) & !(align - 1)
}
#[inline]
fn node_layout<K>(height: usize) -> Layout {
let size = tower_offset::<K>() + height * size_of::<AtomicPtr<Node<K>>>();
let align = align_of::<Node<K>>().max(align_of::<AtomicPtr<Node<K>>>());
Layout::from_size_align(size, align).expect("valid node layout")
}
#[inline]
unsafe fn tower<K>(node: *const Node<K>) -> *const AtomicPtr<Node<K>> {
(node as *const u8).add(tower_offset::<K>()).cast()
}
struct Arena {
chunks: Vec<(NonNull<u8>, Layout)>,
cursor: *mut u8,
end: *mut u8,
}
impl Arena {
fn new() -> Self {
Self {
chunks: Vec::new(),
cursor: ptr::null_mut(),
end: ptr::null_mut(),
}
}
unsafe fn alloc(&mut self, layout: Layout) -> *mut u8 {
let align = layout.align();
let mut aligned = (self.cursor as usize).wrapping_add(align - 1) & !(align - 1);
if self.cursor.is_null() || aligned + layout.size() > self.end as usize {
self.grow(layout);
aligned = (self.cursor as usize + align - 1) & !(align - 1);
}
self.cursor = (aligned + layout.size()) as *mut u8;
aligned as *mut u8
}
#[cold]
unsafe fn grow(&mut self, layout: Layout) {
let align = layout.align().max(64);
let size = CHUNK_SIZE.max(layout.size().next_power_of_two());
let chunk_layout = Layout::from_size_align(size, align).expect("valid chunk layout");
let ptr = alloc::alloc(chunk_layout);
if ptr.is_null() {
alloc::handle_alloc_error(chunk_layout);
}
self.chunks
.push((NonNull::new_unchecked(ptr), chunk_layout));
self.cursor = ptr;
self.end = ptr.add(size);
}
}
impl Drop for Arena {
fn drop(&mut self) {
for (ptr, layout) in &self.chunks {
unsafe { alloc::dealloc(ptr.as_ptr(), *layout) };
}
}
}
struct SkipListCore<K> {
head: Box<[AtomicPtr<Node<K>>]>,
arena: UnsafeCell<Arena>,
height: AtomicUsize,
len: AtomicUsize,
}
unsafe impl<K: Send + Sync> Send for SkipListCore<K> {}
unsafe impl<K: Send + Sync> Sync for SkipListCore<K> {}
impl<K> SkipListCore<K> {
fn new() -> Self {
let head = (0..MAX_HEIGHT)
.map(|_| AtomicPtr::new(ptr::null_mut()))
.collect::<Vec<_>>()
.into_boxed_slice();
Self {
head,
arena: UnsafeCell::new(Arena::new()),
height: AtomicUsize::new(1),
len: AtomicUsize::new(0),
}
}
#[inline]
fn next_slot(&self, node: *const Node<K>, level: usize) -> &AtomicPtr<Node<K>> {
if node.is_null() {
&self.head[level]
} else {
unsafe { &*tower(node).add(level) }
}
}
#[inline]
fn len(&self) -> usize {
self.len.load(Ordering::Acquire)
}
}
impl<K> Drop for SkipListCore<K> {
fn drop(&mut self) {
let mut node = self.head[0].load(Ordering::Relaxed);
while !node.is_null() {
let next = unsafe { (*tower(node)).load(Ordering::Relaxed) };
unsafe { ptr::drop_in_place(ptr::addr_of_mut!((*node).key)) };
node = next;
}
}
}
pub fn new_skiplist<K: Ord + Send + Sync>() -> (SkipListWriter<K>, SkipListReader<K>) {
let core = Arc::new(SkipListCore::new());
let writer = SkipListWriter {
core: core.clone(),
rng: 0x9E3779B97F4A7C15,
};
let reader = SkipListReader { core };
(writer, reader)
}
pub struct SkipListWriter<K> {
core: Arc<SkipListCore<K>>,
rng: u64,
}
impl<K: Ord> SkipListWriter<K> {
fn random_height(&mut self) -> usize {
let mut x = self.rng;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.rng = x;
let mut height = 1;
while height < MAX_HEIGHT && x.is_multiple_of(BRANCHING) {
height += 1;
x /= BRANCHING;
}
height
}
pub fn insert(&mut self, key: K) {
let cur_height = self.core.height.load(Ordering::Relaxed);
let mut preds: [*const Node<K>; MAX_HEIGHT] = [ptr::null(); MAX_HEIGHT];
let mut pred: *const Node<K> = ptr::null();
for level in (0..MAX_HEIGHT).rev() {
if level < cur_height {
loop {
let next = self.core.next_slot(pred, level).load(Ordering::Acquire);
if !next.is_null() && unsafe { (*next).key < key } {
pred = next;
} else {
break;
}
}
}
preds[level] = pred;
}
let height = self.random_height();
let layout = node_layout::<K>(height);
let node = unsafe { (*self.core.arena.get()).alloc(layout) } as *mut Node<K>;
unsafe {
ptr::write(ptr::addr_of_mut!((*node).key), key);
let tower = tower::<K>(node) as *mut AtomicPtr<Node<K>>;
for (level, pred) in preds.iter().enumerate().take(height) {
let succ = self.core.next_slot(*pred, level).load(Ordering::Acquire);
ptr::write(tower.add(level), AtomicPtr::new(succ));
}
}
if height > cur_height {
self.core.height.store(height, Ordering::Release);
}
for (level, pred) in preds.iter().enumerate().take(height) {
self.core
.next_slot(*pred, level)
.store(node, Ordering::Release);
}
self.core.len.fetch_add(1, Ordering::Release);
}
}
#[derive(Clone)]
pub struct SkipListReader<K> {
core: Arc<SkipListCore<K>>,
}
impl<K: Ord> SkipListReader<K> {
pub fn upper_bound_with<R>(&self, target: &K, f: impl FnOnce(&K) -> R) -> Option<R> {
let node = self.find_le(target);
if node.is_null() {
None
} else {
Some(f(unsafe { &(*node).key }))
}
}
fn find_le(&self, target: &K) -> *const Node<K> {
let height = self.core.height.load(Ordering::Acquire);
let mut pred: *const Node<K> = ptr::null();
for level in (0..height).rev() {
loop {
let next = self.core.next_slot(pred, level).load(Ordering::Acquire);
if !next.is_null() && unsafe { (*next).key <= *target } {
pred = next;
} else {
break;
}
}
}
pred
}
fn lower_bound(&self, start: &K) -> *const Node<K> {
let height = self.core.height.load(Ordering::Acquire);
let mut pred: *const Node<K> = ptr::null();
for level in (0..height).rev() {
loop {
let next = self.core.next_slot(pred, level).load(Ordering::Acquire);
if !next.is_null() && unsafe { (*next).key < *start } {
pred = next;
} else {
break;
}
}
}
self.core.next_slot(pred, 0).load(Ordering::Acquire)
}
pub fn iter(&self) -> Iter<'_, K> {
Iter {
node: self.core.head[0].load(Ordering::Acquire),
_marker: PhantomData,
}
}
pub fn range_from(&self, start: &K) -> Iter<'_, K> {
Iter {
node: self.lower_bound(start),
_marker: PhantomData,
}
}
pub fn front_with<R>(&self, f: impl FnOnce(&K) -> R) -> Option<R> {
let node = self.core.head[0].load(Ordering::Acquire);
if node.is_null() {
None
} else {
Some(f(unsafe { &(*node).key }))
}
}
pub fn len(&self) -> usize {
self.core.len()
}
}
pub struct Iter<'a, K> {
node: *const Node<K>,
_marker: PhantomData<&'a SkipListReader<K>>,
}
impl<'a, K> Iterator for Iter<'a, K> {
type Item = &'a K;
fn next(&mut self) -> Option<&'a K> {
if self.node.is_null() {
return None;
}
let node = unsafe { &*self.node };
self.node = unsafe { (*tower(self.node)).load(Ordering::Acquire) };
Some(&node.key)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicBool;
use std::thread;
fn collect(reader: &SkipListReader<i64>) -> Vec<i64> {
reader.iter().copied().collect()
}
#[test]
fn test_insert_keeps_sorted_order() {
let (mut w, r) = new_skiplist::<i64>();
for k in [5, 1, 9, 3, 7, 2, 8, 4, 6, 0] {
w.insert(k);
}
assert_eq!(collect(&r), vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
assert_eq!(r.len(), 10);
}
#[test]
fn test_empty() {
let (_w, r) = new_skiplist::<i64>();
assert_eq!(r.len(), 0);
assert_eq!(collect(&r), Vec::<i64>::new());
assert_eq!(r.upper_bound_with(&5, |k| *k), None);
assert_eq!(r.front_with(|k| *k), None);
assert_eq!(r.range_from(&0).count(), 0);
}
#[test]
fn test_upper_bound_le() {
let (mut w, r) = new_skiplist::<i64>();
for k in [10, 20, 30, 40] {
w.insert(k);
}
assert_eq!(r.upper_bound_with(&20, |k| *k), Some(20));
assert_eq!(r.upper_bound_with(&40, |k| *k), Some(40));
assert_eq!(r.upper_bound_with(&25, |k| *k), Some(20));
assert_eq!(r.upper_bound_with(&39, |k| *k), Some(30));
assert_eq!(r.upper_bound_with(&999, |k| *k), Some(40));
assert_eq!(r.upper_bound_with(&5, |k| *k), None);
}
#[test]
fn test_front_and_range_from() {
let (mut w, r) = new_skiplist::<i64>();
for k in [3, 1, 4, 1_000, 2] {
w.insert(k);
}
assert_eq!(r.front_with(|k| *k), Some(1));
assert_eq!(
r.range_from(&3).copied().collect::<Vec<_>>(),
vec![3, 4, 1_000]
);
assert_eq!(
r.range_from(&0).copied().collect::<Vec<_>>(),
vec![1, 2, 3, 4, 1_000]
);
assert_eq!(r.range_from(&2_000).count(), 0);
assert_eq!(r.range_from(&5).copied().collect::<Vec<_>>(), vec![1_000]);
}
#[test]
fn test_composite_key_dup_values() {
let (mut w, r) = new_skiplist::<(i64, u64)>();
for key in [(7, 0), (3, 1), (7, 2), (3, 0), (7, 1)] {
w.insert(key);
}
let all: Vec<_> = r.iter().copied().collect();
assert_eq!(all, vec![(3, 0), (3, 1), (7, 0), (7, 1), (7, 2)]);
assert_eq!(r.upper_bound_with(&(7, 1), |k| *k), Some((7, 1)));
assert_eq!(r.upper_bound_with(&(3, 5), |k| *k), Some((3, 1)));
}
#[test]
fn test_string_keys_drop() {
let (mut w, r) = new_skiplist::<String>();
for s in ["delta", "alpha", "charlie", "bravo"] {
w.insert(s.to_string());
}
assert_eq!(
r.iter().cloned().collect::<Vec<_>>(),
vec!["alpha", "bravo", "charlie", "delta"]
);
assert_eq!(
r.upper_bound_with(&"c".to_string(), |k| k.clone()),
Some("bravo".to_string())
);
drop(w);
drop(r); }
#[test]
fn test_many_inserts_force_chunk_growth() {
let (mut w, r) = new_skiplist::<i64>();
const N: i64 = 200_000;
for k in (0..N).rev() {
w.insert(k);
}
assert_eq!(r.len(), N as usize);
assert!(r.iter().copied().eq(0..N));
assert_eq!(r.upper_bound_with(&(N - 1), |k| *k), Some(N - 1));
}
#[test]
fn test_concurrent_single_writer_many_readers() {
const N: i64 = 50_000;
let (mut w, r) = new_skiplist::<i64>();
let done = Arc::new(AtomicBool::new(false));
let readers: Vec<_> = (0..4)
.map(|_| {
let r = r.clone();
let done = done.clone();
thread::spawn(move || {
let mut max_seen = -1;
while !done.load(Ordering::Acquire) {
if let Some(top) = r.upper_bound_with(&N, |k| *k) {
assert!((0..N).contains(&top));
assert!(top >= max_seen, "visibility went backwards");
max_seen = top;
assert!(r.upper_bound_with(&top, |k| *k) == Some(top));
}
}
max_seen
})
})
.collect();
for k in 0..N {
w.insert(k);
}
done.store(true, Ordering::Release);
for h in readers {
h.join().unwrap();
}
assert_eq!(r.len(), N as usize);
assert_eq!(collect(&r), (0..N).collect::<Vec<_>>());
}
}