use std::cell::UnsafeCell;
use std::cmp::Ordering as CmpOrdering;
use std::ops::RangeBounds;
use std::ptr;
use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
const MAX_HEIGHT: usize = 12;
const ARENA_INITIAL_BLOCK: usize = 4096;
const ARENA_MAX_BLOCK: usize = 1 << 20;
struct Arena {
blocks: UnsafeCell<Vec<Vec<u8>>>,
current_offset: UnsafeCell<usize>,
bytes_allocated: AtomicUsize,
}
impl Arena {
fn new() -> Self {
Self {
blocks: UnsafeCell::new(Vec::new()),
current_offset: UnsafeCell::new(0),
bytes_allocated: AtomicUsize::new(0),
}
}
unsafe fn alloc(&self, size: usize, align: usize) -> *mut u8 {
unsafe {
let blocks = &mut *self.blocks.get();
let offset = &mut *self.current_offset.get();
if let Some(block) = blocks.last() {
let base = block.as_ptr() as usize;
let aligned = (base + *offset + align - 1) & !(align - 1);
let new_offset = aligned - base + size;
if new_offset <= block.capacity() {
let actual = new_offset - *offset; *offset = new_offset;
self.bytes_allocated.fetch_add(actual, Ordering::Relaxed);
return aligned as *mut u8;
}
}
let prev_cap = blocks.last().map_or(0, Vec::capacity);
let block_size = if prev_cap == 0 {
ARENA_INITIAL_BLOCK
} else {
(prev_cap * 2).min(ARENA_MAX_BLOCK)
}
.max(size + align);
let block = Vec::<u8>::with_capacity(block_size);
let base = block.as_ptr() as usize;
let aligned = (base + align - 1) & !(align - 1);
let actual = aligned - base + size; *offset = actual;
self.bytes_allocated.fetch_add(actual, Ordering::Relaxed);
blocks.push(block);
aligned as *mut u8
}
}
unsafe fn alloc_node<K, V>(&self) -> *mut Node<K, V> {
unsafe {
self.alloc(
std::mem::size_of::<Node<K, V>>(),
std::mem::align_of::<Node<K, V>>(),
) as *mut Node<K, V>
}
}
}
#[repr(C)]
struct Node<K, V> {
key: K,
value: V,
height: u8,
next: [AtomicPtr<Node<K, V>>; MAX_HEIGHT],
prev0: AtomicPtr<Node<K, V>>,
}
pub struct ConcurrentSkipList<K: Ord + Clone, V: Clone> {
head: [AtomicPtr<Node<K, V>>; MAX_HEIGHT],
tail: AtomicPtr<Node<K, V>>,
len: AtomicUsize,
max_height: AtomicUsize,
all_nodes: UnsafeCell<Vec<*mut Node<K, V>>>,
arena: Arena,
}
unsafe impl<K: Ord + Clone + Send, V: Clone + Send> Send for ConcurrentSkipList<K, V> {}
unsafe impl<K: Ord + Clone + Send + Sync, V: Clone + Send + Sync> Sync
for ConcurrentSkipList<K, V>
{
}
impl<K: Ord + Clone, V: Clone> ConcurrentSkipList<K, V> {
pub fn new() -> Self {
Self {
head: std::array::from_fn(|_| AtomicPtr::new(ptr::null_mut())),
tail: AtomicPtr::new(ptr::null_mut()),
len: AtomicUsize::new(0),
max_height: AtomicUsize::new(1),
all_nodes: UnsafeCell::new(Vec::new()),
arena: Arena::new(),
}
}
pub fn insert(&self, key: K, value: V) {
let height = random_height();
let cur_max = self.max_height.load(Ordering::Relaxed);
if height > cur_max {
self.max_height.store(height, Ordering::Relaxed);
}
let mut prev: [*mut Node<K, V>; MAX_HEIGHT] = [ptr::null_mut(); MAX_HEIGHT];
let max_h = height.max(cur_max);
for level in (0..max_h).rev() {
let mut current: *mut Node<K, V> = if level + 1 < max_h && !prev[level + 1].is_null() {
prev[level + 1]
} else {
ptr::null_mut()
};
if current.is_null() {
let mut next = self.head[level].load(Ordering::Acquire);
while !next.is_null() {
let node = unsafe { &*next };
if node.key >= key {
break;
}
current = next;
next = node.next[level].load(Ordering::Acquire);
}
prev[level] = current;
} else {
let mut next = unsafe { &*current }.next[level].load(Ordering::Acquire);
while !next.is_null() {
let node = unsafe { &*next };
if node.key >= key {
break;
}
current = next;
next = node.next[level].load(Ordering::Acquire);
}
prev[level] = current;
}
}
let new_node: *mut Node<K, V> = unsafe { self.arena.alloc_node() };
unsafe {
ptr::write(
new_node,
Node {
key,
value,
height: height as u8,
next: std::array::from_fn(|_| AtomicPtr::new(ptr::null_mut())),
prev0: AtomicPtr::new(ptr::null_mut()),
},
);
}
unsafe {
(*self.all_nodes.get()).push(new_node);
}
let new_ref = unsafe { &*new_node };
for (level, &prev_node) in prev[..height].iter().enumerate() {
if prev_node.is_null() {
let old_head = self.head[level].load(Ordering::Relaxed);
new_ref.next[level].store(old_head, Ordering::Relaxed);
if level == 0 {
new_ref.prev0.store(ptr::null_mut(), Ordering::Relaxed);
if !old_head.is_null() {
unsafe { &*old_head }
.prev0
.store(new_node, Ordering::Release);
}
}
self.head[level].store(new_node, Ordering::Release);
} else {
let p = unsafe { &*prev_node };
let old_next = p.next[level].load(Ordering::Relaxed);
new_ref.next[level].store(old_next, Ordering::Relaxed);
if level == 0 {
new_ref.prev0.store(prev_node, Ordering::Relaxed);
if !old_next.is_null() {
unsafe { &*old_next }
.prev0
.store(new_node, Ordering::Release);
}
}
p.next[level].store(new_node, Ordering::Release);
}
}
if new_ref.next[0].load(Ordering::Relaxed).is_null() {
self.tail.store(new_node, Ordering::Release);
}
self.len.fetch_add(1, Ordering::Relaxed);
}
pub fn get<Q>(&self, key: &Q) -> Option<V>
where
K: std::borrow::Borrow<Q>,
Q: Ord + ?Sized,
{
let max_h = self.max_height.load(Ordering::Acquire);
let mut current: *const Node<K, V> = ptr::null();
for level in (0..max_h).rev() {
let mut next = if current.is_null() {
self.head[level].load(Ordering::Acquire)
} else {
unsafe { &*current }.next[level].load(Ordering::Acquire)
};
while !next.is_null() {
let n = unsafe { &*next };
match n.key.borrow().cmp(key) {
CmpOrdering::Less => {
current = next;
next = n.next[level].load(Ordering::Acquire);
}
CmpOrdering::Equal => return Some(n.value.clone()),
CmpOrdering::Greater => break,
}
}
}
None
}
pub fn lower_bound(&self, target: &K) -> Option<(K, V)> {
let max_h = self.max_height.load(Ordering::Acquire);
let mut current: *const Node<K, V> = ptr::null();
let mut candidate: *const Node<K, V> = ptr::null();
for level in (0..max_h).rev() {
let mut next = if current.is_null() {
self.head[level].load(Ordering::Acquire)
} else {
unsafe { &*current }.next[level].load(Ordering::Acquire)
};
while !next.is_null() {
let n = unsafe { &*next };
if n.key < *target {
current = next;
next = n.next[level].load(Ordering::Acquire);
} else {
candidate = next;
break;
}
}
}
let start = if current.is_null() {
self.head[0].load(Ordering::Acquire)
} else {
unsafe { &*current }.next[0].load(Ordering::Acquire)
};
let mut ptr = start;
while !ptr.is_null() {
let n = unsafe { &*ptr };
if n.key >= *target {
return Some((n.key.clone(), n.value.clone()));
}
ptr = n.next[0].load(Ordering::Acquire);
}
if !candidate.is_null() {
let n = unsafe { &*candidate };
return Some((n.key.clone(), n.value.clone()));
}
None
}
pub fn range_from(&self, target: &K) -> Vec<(K, V)> {
let max_h = self.max_height.load(Ordering::Acquire);
let mut current: *const Node<K, V> = ptr::null();
for level in (0..max_h).rev() {
let mut next = if current.is_null() {
self.head[level].load(Ordering::Acquire)
} else {
unsafe { &*current }.next[level].load(Ordering::Acquire)
};
while !next.is_null() {
let n = unsafe { &*next };
if n.key < *target {
current = next;
next = n.next[level].load(Ordering::Acquire);
} else {
break;
}
}
}
let start = if current.is_null() {
self.head[0].load(Ordering::Acquire)
} else {
unsafe { &*current }.next[0].load(Ordering::Acquire)
};
let mut result = Vec::new();
let mut ptr = start;
while !ptr.is_null() {
let n = unsafe { &*ptr };
if n.key >= *target {
result.push((n.key.clone(), n.value.clone()));
}
ptr = n.next[0].load(Ordering::Acquire);
}
result
}
pub fn iter(&self) -> SkipListIter<K, V> {
let entries = self.collect_all();
let len = entries.len();
SkipListIter {
entries,
front: 0,
back_exclusive: len,
}
}
pub fn range<R: RangeBounds<K>>(&self, bounds: R) -> SkipListIter<K, V> {
let all = self.collect_all();
let entries: Vec<(K, V)> = all
.into_iter()
.filter(|(k, _)| bounds.contains(k))
.collect();
let len = entries.len();
SkipListIter {
entries,
front: 0,
back_exclusive: len,
}
}
pub unsafe fn node_kv(&self, ptr: *const ()) -> (&K, &V) {
unsafe {
let node = &*(ptr as *const Node<K, V>);
(&node.key, &node.value)
}
}
pub unsafe fn node_next0(&self, ptr: *const ()) -> *const () {
unsafe {
let node = &*(ptr as *const Node<K, V>);
node.next[0].load(Ordering::Acquire) as *const ()
}
}
pub unsafe fn node_prev0(&self, ptr: *const ()) -> *const () {
unsafe {
let node = &*(ptr as *const Node<K, V>);
node.prev0.load(Ordering::Acquire) as *const ()
}
}
pub fn head_ptr(&self) -> *const () {
self.head[0].load(Ordering::Acquire) as *const ()
}
pub fn seek_ge_raw(&self, target: &K) -> *const () {
let max_h = self.max_height.load(Ordering::Acquire);
let mut current: *const Node<K, V> = ptr::null();
for level in (0..max_h).rev() {
let mut next = if current.is_null() {
self.head[level].load(Ordering::Acquire)
} else {
unsafe { &*current }.next[level].load(Ordering::Acquire)
};
while !next.is_null() {
let n = unsafe { &*next };
if n.key < *target {
current = next;
next = n.next[level].load(Ordering::Acquire);
} else {
break;
}
}
}
let start = if current.is_null() {
self.head[0].load(Ordering::Acquire)
} else {
unsafe { &*current }.next[0].load(Ordering::Acquire)
};
let mut ptr = start;
while !ptr.is_null() {
let n = unsafe { &*ptr };
if n.key >= *target {
return ptr as *const ();
}
ptr = n.next[0].load(Ordering::Acquire);
}
ptr::null()
}
pub fn seek_lt_raw(&self, target: &K) -> *const () {
let max_h = self.max_height.load(Ordering::Acquire);
if max_h == 0 {
return ptr::null();
}
let mut current: *const Node<K, V> = ptr::null();
for level in (0..max_h).rev() {
let mut next = if current.is_null() {
self.head[level].load(Ordering::Acquire)
} else {
unsafe { &*current }.next[level].load(Ordering::Acquire)
};
while !next.is_null() {
let n = unsafe { &*next };
if n.key < *target {
current = next;
next = n.next[level].load(Ordering::Acquire);
} else {
break;
}
}
}
if current.is_null() {
ptr::null()
} else {
current as *const ()
}
}
pub fn seek_le_raw(&self, target: &K) -> *const () {
let ge_ptr = self.seek_ge_raw(target);
if !ge_ptr.is_null() {
let (k, _) = unsafe { self.node_kv(ge_ptr) };
if k == target {
return ge_ptr;
}
}
self.seek_lt_raw(target)
}
pub fn tail_ptr(&self) -> *const () {
self.tail.load(Ordering::Acquire) as *const ()
}
pub fn len(&self) -> usize {
self.len.load(Ordering::Relaxed)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
fn collect_all(&self) -> Vec<(K, V)> {
let mut result = Vec::new();
let mut ptr = self.head[0].load(Ordering::Acquire);
while !ptr.is_null() {
let node = unsafe { &*ptr };
result.push((node.key.clone(), node.value.clone()));
ptr = node.next[0].load(Ordering::Acquire);
}
result
}
}
impl<K: Ord + Clone, V: Clone> Default for ConcurrentSkipList<K, V> {
fn default() -> Self {
Self::new()
}
}
impl<K: Ord + Clone, V: Clone> Drop for ConcurrentSkipList<K, V> {
fn drop(&mut self) {
let nodes = self.all_nodes.get_mut();
for &node_ptr in nodes.iter() {
unsafe {
ptr::drop_in_place(node_ptr);
}
}
}
}
fn random_height() -> usize {
let mut h = 1;
while h < MAX_HEIGHT && cheap_random_bool() {
h += 1;
}
h
}
fn cheap_random_bool() -> bool {
thread_local! {
static STATE: std::cell::Cell<u64> = std::cell::Cell::new(
{
let x = 0u8;
let addr = &x as *const u8 as u64;
addr ^ 0x517cc1b727220a95
}
);
}
STATE.with(|s| {
let mut x = s.get();
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
s.set(x);
(x & 3) == 0 })
}
pub struct SkipListIter<K, V> {
entries: Vec<(K, V)>,
front: usize,
back_exclusive: usize,
}
impl<K: Clone, V: Clone> Iterator for SkipListIter<K, V> {
type Item = (K, V);
fn next(&mut self) -> Option<Self::Item> {
if self.front < self.back_exclusive {
let item = self.entries[self.front].clone();
self.front += 1;
Some(item)
} else {
None
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.back_exclusive.saturating_sub(self.front);
(remaining, Some(remaining))
}
}
impl<K: Clone, V: Clone> DoubleEndedIterator for SkipListIter<K, V> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.back_exclusive > self.front {
self.back_exclusive -= 1;
Some(self.entries[self.back_exclusive].clone())
} else {
None
}
}
}
impl<K: Clone, V: Clone> ExactSizeIterator for SkipListIter<K, V> {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_insert_and_get() {
let sl = ConcurrentSkipList::new();
sl.insert(10, "ten");
sl.insert(5, "five");
sl.insert(20, "twenty");
assert_eq!(sl.get(&10), Some("ten"));
assert_eq!(sl.get(&5), Some("five"));
assert_eq!(sl.get(&20), Some("twenty"));
assert_eq!(sl.get(&1), None);
assert_eq!(sl.len(), 3);
}
#[test]
fn test_forward_iter() {
let sl = ConcurrentSkipList::new();
for i in (0..10).rev() {
sl.insert(i, i * 10);
}
let items: Vec<_> = sl.iter().collect();
assert_eq!(items.len(), 10);
for (i, (k, v)) in items.iter().enumerate() {
assert_eq!(*k, i as i32);
assert_eq!(*v, (i as i32) * 10);
}
}
#[test]
fn test_reverse_iter() {
let sl = ConcurrentSkipList::new();
for i in 0..5 {
sl.insert(i, i);
}
let items: Vec<_> = sl.iter().rev().collect();
assert_eq!(items, vec![(4, 4), (3, 3), (2, 2), (1, 1), (0, 0)]);
}
#[test]
fn test_bidirectional_iter() {
let sl = ConcurrentSkipList::new();
for i in 0..6 {
sl.insert(i, i);
}
let mut it = sl.iter();
assert_eq!(it.next(), Some((0, 0)));
assert_eq!(it.next_back(), Some((5, 5)));
assert_eq!(it.next(), Some((1, 1)));
assert_eq!(it.next_back(), Some((4, 4)));
assert_eq!(it.next(), Some((2, 2)));
assert_eq!(it.next_back(), Some((3, 3)));
assert_eq!(it.next(), None);
assert_eq!(it.next_back(), None);
}
#[test]
fn test_range() {
let sl = ConcurrentSkipList::new();
for i in 0..10 {
sl.insert(i, i);
}
let items: Vec<_> = sl.range(3..7).collect();
assert_eq!(items, vec![(3, 3), (4, 4), (5, 5), (6, 6)]);
let items: Vec<_> = sl.range(3..=7).collect();
assert_eq!(items, vec![(3, 3), (4, 4), (5, 5), (6, 6), (7, 7)]);
let items: Vec<_> = sl.range(..3).collect();
assert_eq!(items, vec![(0, 0), (1, 1), (2, 2)]);
}
#[test]
fn test_empty() {
let sl: ConcurrentSkipList<i32, i32> = ConcurrentSkipList::new();
assert_eq!(sl.len(), 0);
assert!(sl.is_empty());
assert_eq!(sl.get(&0), None);
assert_eq!(sl.iter().next(), None);
assert_eq!(sl.iter().next_back(), None);
}
#[test]
fn test_single_entry() {
let sl = ConcurrentSkipList::new();
sl.insert(42, "answer");
assert_eq!(sl.len(), 1);
assert_eq!(sl.get(&42), Some("answer"));
let mut it = sl.iter();
assert_eq!(it.next(), Some((42, "answer")));
assert_eq!(it.next(), None);
let mut it = sl.iter();
assert_eq!(it.next_back(), Some((42, "answer")));
assert_eq!(it.next_back(), None);
}
#[test]
fn test_large_dataset() {
let sl = ConcurrentSkipList::new();
for i in (0..10_000).rev() {
sl.insert(i, i);
}
assert_eq!(sl.len(), 10_000);
let items: Vec<_> = sl.iter().collect();
assert_eq!(items.len(), 10_000);
for (i, item) in items.iter().enumerate() {
assert_eq!(*item, (i as i32, i as i32));
}
let rev: Vec<_> = sl.iter().rev().collect();
assert_eq!(rev.len(), 10_000);
assert_eq!(rev[0], (9999, 9999));
assert_eq!(rev[9999], (0, 0));
}
#[test]
fn test_concurrent_reads() {
use std::sync::Arc;
use std::thread;
let sl = Arc::new(ConcurrentSkipList::new());
for i in 0..1000 {
sl.insert(i, i * 2);
}
let mut handles = Vec::new();
for _ in 0..8 {
let sl = Arc::clone(&sl);
handles.push(thread::spawn(move || {
for i in 0..1000 {
assert_eq!(sl.get(&i), Some(i * 2));
}
let items: Vec<_> = sl.iter().collect();
assert_eq!(items.len(), 1000);
}));
}
for h in handles {
h.join().unwrap();
}
}
#[test]
fn test_vec_u8_keys() {
let sl = ConcurrentSkipList::new();
sl.insert(b"banana".to_vec(), b"yellow".to_vec());
sl.insert(b"apple".to_vec(), b"red".to_vec());
sl.insert(b"cherry".to_vec(), b"dark_red".to_vec());
let items: Vec<_> = sl.iter().collect();
assert_eq!(items.len(), 3);
assert_eq!(items[0].0, b"apple");
assert_eq!(items[1].0, b"banana");
assert_eq!(items[2].0, b"cherry");
assert_eq!(sl.get(b"banana".as_slice()), Some(b"yellow".to_vec()));
}
#[test]
fn test_seek_lt_raw() {
let sl = ConcurrentSkipList::new();
for i in [10, 20, 30, 40, 50] {
sl.insert(i, i * 10);
}
let ptr = sl.seek_lt_raw(&30);
assert!(!ptr.is_null());
let (k, v) = unsafe { sl.node_kv(ptr) };
assert_eq!(*k, 20);
assert_eq!(*v, 200);
let ptr = sl.seek_lt_raw(&10);
assert!(ptr.is_null());
let ptr = sl.seek_lt_raw(&11);
assert!(!ptr.is_null());
let (k, _) = unsafe { sl.node_kv(ptr) };
assert_eq!(*k, 10);
let ptr = sl.seek_lt_raw(&51);
assert!(!ptr.is_null());
let (k, _) = unsafe { sl.node_kv(ptr) };
assert_eq!(*k, 50);
let ptr = sl.seek_lt_raw(&0);
assert!(ptr.is_null());
}
#[test]
fn test_seek_le_raw() {
let sl = ConcurrentSkipList::new();
for i in [10, 20, 30, 40, 50] {
sl.insert(i, i);
}
let ptr = sl.seek_le_raw(&30);
assert!(!ptr.is_null());
let (k, _) = unsafe { sl.node_kv(ptr) };
assert_eq!(*k, 30);
let ptr = sl.seek_le_raw(&25);
assert!(!ptr.is_null());
let (k, _) = unsafe { sl.node_kv(ptr) };
assert_eq!(*k, 20);
let ptr = sl.seek_le_raw(&9);
assert!(ptr.is_null());
}
#[test]
fn test_tail_ptr() {
let sl = ConcurrentSkipList::new();
assert!(sl.tail_ptr().is_null());
sl.insert(10, 100);
sl.insert(5, 50);
sl.insert(20, 200);
let ptr = sl.tail_ptr();
assert!(!ptr.is_null());
let (k, v) = unsafe { sl.node_kv(ptr) };
assert_eq!(*k, 20);
assert_eq!(*v, 200);
}
#[test]
fn test_seek_lt_raw_large() {
let sl = ConcurrentSkipList::new();
for i in (0..1000).rev() {
sl.insert(i * 2, i); }
let ptr = sl.seek_lt_raw(&500);
assert!(!ptr.is_null());
let (k, _) = unsafe { sl.node_kv(ptr) };
assert_eq!(*k, 498);
let ptr = sl.seek_lt_raw(&1);
assert!(!ptr.is_null());
let (k, _) = unsafe { sl.node_kv(ptr) };
assert_eq!(*k, 0);
let ptr = sl.seek_lt_raw(&1999);
assert!(!ptr.is_null());
let (k, _) = unsafe { sl.node_kv(ptr) };
assert_eq!(*k, 1998);
}
}