use core::ptr::NonNull;
use crate::{
level_generator::LevelGenerator,
node::{
Node,
link::Link,
visitor::{IndexMutVisitor, Visitor},
},
skip_list::SkipList,
};
impl<T, G: LevelGenerator, const N: usize> SkipList<T, N, G> {
#[expect(
clippy::expect_used,
reason = "Link::new distances are computed to be ≥ 1; \
increment_distance overflow requires > usize::MAX nodes; \
all expects fire only on internal invariant violations, not user input"
)]
#[expect(
clippy::indexing_slicing,
reason = "l < height ≤ max_levels = new_raw.links.len(); \
pred_ptr was reached at level l during traversal so pred_ptr.links.len() > l; \
all accesses are bounded by max_levels = head.links.len()"
)]
#[expect(
clippy::multiple_unsafe_ops_per_block,
reason = "insertion and link wiring touch provably disjoint heap nodes; \
splitting across blocks would require unsafe-crossing raw-pointer variables"
)]
#[inline]
pub fn insert(&mut self, index: usize, value: T) {
assert!(
index <= self.len,
"insertion index (is {index}) should be <= len (is {})",
self.len
);
let height = self.generator.level();
let new_rank = index.saturating_add(1);
let (current, precursors, precursor_distances) = {
let mut visitor = IndexMutVisitor::new(self.head, new_rank);
visitor.traverse();
visitor.into_parts()
};
let new_node_nonnull: NonNull<Node<T, N>> = unsafe {
let insert_after_ptr = if index < self.len {
(*current.as_ptr())
.prev()
.expect("node at rank >= 1 always has a predecessor")
} else {
current
};
let new_raw: *mut Node<T, N> =
Node::insert_after(insert_after_ptr, Node::with_value(height, value)).as_ptr();
for (l, (pred_nn, pred_rank)) in precursors
.iter()
.copied()
.zip(precursor_distances.iter().copied())
.enumerate()
{
let pred_ptr = pred_nn.as_ptr();
if l < height {
let distance = new_rank.saturating_sub(pred_rank);
let old_link = (*pred_ptr).links_mut()[l].take();
(*pred_ptr).links_mut()[l] = Some(
Link::new(NonNull::new_unchecked(new_raw), distance)
.expect("distance >= 1"),
);
(*new_raw).links_mut()[l] = if let Some(old) = old_link {
let new_d = old
.distance()
.get()
.saturating_sub(distance)
.saturating_add(1);
Some(Link::new(old.node(), new_d).expect("new_d >= 1"))
} else {
None
};
} else if let Some(link) = (*pred_ptr).links_mut()[l].as_mut() {
link.increment_distance()
.expect("distance overflow requires > usize::MAX nodes");
}
}
NonNull::new_unchecked(new_raw)
};
if index == self.len {
self.tail = Some(new_node_nonnull);
}
self.len = self.len.saturating_add(1);
}
#[expect(
clippy::expect_used,
reason = "precursors[0].links[0] exists because index < len guarantees a node at target_rank; \
take_value is Some for any body/tail node; \
Link::new distance is computed to be >= 1; \
all expects fire only on internal invariant violations, not user input"
)]
#[expect(
clippy::indexing_slicing,
reason = "l < target_height <= max_levels = target.links.len(); \
precursors[l] was reached at level l so precursors[l].links.len() > l; \
all accesses are bounded by max_levels = head.links.len()"
)]
#[expect(
clippy::multiple_unsafe_ops_per_block,
reason = "link rewiring and node pop touch provably disjoint heap nodes; \
splitting across blocks would require unsafe-crossing raw-pointer variables"
)]
#[inline]
pub fn remove(&mut self, index: usize) -> T {
assert!(
index < self.len,
"removal index (is {index}) should be < len (is {})",
self.len
);
let target_rank = index.saturating_add(1);
let (target_node, precursors, precursor_distances) = {
let mut visitor = IndexMutVisitor::new(self.head, target_rank);
visitor.traverse();
visitor.into_parts()
};
let (value, new_tail) = unsafe {
let target_ptr: *mut Node<T, N> = target_node.as_ptr();
let target_height = (*target_ptr).level();
for (l, (pred_nn, pred_rank)) in precursors
.iter()
.copied()
.zip(precursor_distances.iter().copied())
.enumerate()
{
let pred_ptr = pred_nn.as_ptr();
if l < target_height {
let pred_to_target = target_rank.saturating_sub(pred_rank);
let old_link = (*target_ptr).links_mut()[l].take();
(*pred_ptr).links_mut()[l] = if let Some(target_link) = old_link {
let new_dist = pred_to_target
.saturating_add(target_link.distance().get())
.saturating_sub(1);
Some(Link::new(target_link.node(), new_dist).expect("new_dist >= 1"))
} else {
None
};
} else if let Some(link) = (*pred_ptr).links_mut()[l].as_mut() {
link.decrement_distance()
.expect("skip list invariant: distance >= 2 before decrement");
}
}
let new_tail = (*target_ptr).prev();
let mut popped = (*target_ptr).pop();
(
popped.take_value().expect("target node always has a value"),
new_tail,
)
};
if index.saturating_add(1) == self.len {
self.tail = if self.len == 1 { None } else { new_tail };
}
self.len = self.len.saturating_sub(1);
value
}
#[expect(
clippy::expect_used,
reason = "the expect calls fire only on internal invariant violations; \
a < len and b < len are asserted at the top of this function, \
guaranteeing that the target nodes and their values exist"
)]
#[expect(
clippy::multiple_unsafe_ops_per_block,
reason = "two value_mut accesses and a ptr::swap on provably distinct heap nodes; \
splitting across blocks would require unsafe-crossing raw-pointer variables"
)]
#[inline]
pub fn swap(&mut self, a: usize, b: usize) {
assert!(
a < self.len,
"swap index a (is {a}) should be < len (is {})",
self.len
);
assert!(
b < self.len,
"swap index b (is {b}) should be < len (is {})",
self.len
);
if a == b {
return;
}
let ptr_a = self.node_ptr_at(a);
let ptr_b = self.node_ptr_at(b);
unsafe {
let val_a: *mut T = (*ptr_a.as_ptr()).value_mut().expect("node a has a value");
let val_b: *mut T = (*ptr_b.as_ptr()).value_mut().expect("node b has a value");
core::ptr::swap(val_a, val_b);
}
}
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use super::super::SkipList;
#[test]
fn insert_into_empty() {
let mut list = SkipList::<i32>::with_capacity(1);
list.insert(0, 42);
assert_eq!(list.len(), 1);
assert!(!list.is_empty());
assert_eq!(
list.head_ref().next_as_ref().and_then(|n| n.value()),
Some(&42)
);
assert!(
list.head_ref()
.next_as_ref()
.and_then(|n| n.next_as_ref())
.is_none()
);
}
#[test]
fn insert_at_front() {
let mut list = SkipList::<i32>::with_capacity(1);
list.push_back(2);
list.push_back(3);
list.insert(0, 1);
assert_eq!(list.len(), 3);
let n1 = list.head_ref().next_as_ref().expect("n1");
assert_eq!(n1.value(), Some(&1));
let n2 = n1.next_as_ref().expect("n2");
assert_eq!(n2.value(), Some(&2));
let n3 = n2.next_as_ref().expect("n3");
assert_eq!(n3.value(), Some(&3));
assert!(n3.next_as_ref().is_none());
}
#[test]
fn insert_at_back() {
let mut list = SkipList::<i32>::with_capacity(1);
list.push_back(1);
list.push_back(2);
list.insert(2, 3);
assert_eq!(list.len(), 3);
let n1 = list.head_ref().next_as_ref().expect("n1");
assert_eq!(n1.value(), Some(&1));
let n2 = n1.next_as_ref().expect("n2");
assert_eq!(n2.value(), Some(&2));
let n3 = n2.next_as_ref().expect("n3");
assert_eq!(n3.value(), Some(&3));
assert!(n3.next_as_ref().is_none());
}
#[test]
fn insert_in_middle() {
let mut list = SkipList::<i32>::with_capacity(1);
list.push_back(1);
list.push_back(3);
list.insert(1, 2);
assert_eq!(list.len(), 3);
let n1 = list.head_ref().next_as_ref().expect("n1");
assert_eq!(n1.value(), Some(&1));
let n2 = n1.next_as_ref().expect("n2");
assert_eq!(n2.value(), Some(&2));
let n3 = n2.next_as_ref().expect("n3");
assert_eq!(n3.value(), Some(&3));
assert!(n3.next_as_ref().is_none());
}
#[test]
fn insert_len_increments() {
let mut list = SkipList::<usize>::new();
for i in 0..50_usize {
list.insert(0, i);
assert_eq!(list.len(), i.saturating_add(1));
}
}
#[test]
#[should_panic(expected = "insertion index (is 5) should be <= len (is 3)")]
fn insert_out_of_bounds() {
let mut list = SkipList::<i32>::with_capacity(1);
list.push_back(1);
list.push_back(2);
list.push_back(3);
list.insert(5, 99);
}
#[test]
fn insert_interleaved_with_pop() {
let mut list = SkipList::<i32>::with_capacity(1);
list.push_back(1); list.push_back(3); list.insert(1, 2); assert_eq!(list.pop_front(), Some(1)); list.insert(0, 0); assert_eq!(list.pop_back(), Some(3)); list.insert(2, 4); assert_eq!(list.len(), 3);
let n1 = list.head_ref().next_as_ref().expect("n1");
assert_eq!(n1.value(), Some(&0));
let n2 = n1.next_as_ref().expect("n2");
assert_eq!(n2.value(), Some(&2));
let n3 = n2.next_as_ref().expect("n3");
assert_eq!(n3.value(), Some(&4));
assert!(n3.next_as_ref().is_none());
}
#[test]
fn insert_multiple_positions() {
let mut list = SkipList::<i32>::with_capacity(1);
list.insert(0, 2); list.insert(0, 0); list.insert(1, 1); list.insert(3, 4); list.insert(3, 3); assert_eq!(list.len(), 5);
let mut node = list.head_ref().next_as_ref().expect("first");
for expected in 0..5_i32 {
assert_eq!(node.value(), Some(&expected));
if expected < 4 {
node = node.next_as_ref().expect("next");
}
}
}
#[test]
fn remove_only_element() {
let mut list = SkipList::<i32>::with_capacity(1);
list.push_back(42);
assert_eq!(list.remove(0), 42);
assert_eq!(list.len(), 0);
assert!(list.is_empty());
assert!(list.head_ref().next_as_ref().is_none());
}
#[test]
fn remove_at_front() {
let mut list = SkipList::<i32>::with_capacity(1);
list.push_back(1);
list.push_back(2);
list.push_back(3);
assert_eq!(list.remove(0), 1);
assert_eq!(list.len(), 2);
let n1 = list.head_ref().next_as_ref().expect("n1");
assert_eq!(n1.value(), Some(&2));
let n2 = n1.next_as_ref().expect("n2");
assert_eq!(n2.value(), Some(&3));
assert!(n2.next_as_ref().is_none());
}
#[test]
fn remove_at_back() {
let mut list = SkipList::<i32>::with_capacity(1);
list.push_back(1);
list.push_back(2);
list.push_back(3);
assert_eq!(list.remove(2), 3);
assert_eq!(list.len(), 2);
let n1 = list.head_ref().next_as_ref().expect("n1");
assert_eq!(n1.value(), Some(&1));
let n2 = n1.next_as_ref().expect("n2");
assert_eq!(n2.value(), Some(&2));
assert!(n2.next_as_ref().is_none());
}
#[test]
fn remove_in_middle() {
let mut list = SkipList::<i32>::with_capacity(1);
list.push_back(1);
list.push_back(2);
list.push_back(3);
assert_eq!(list.remove(1), 2);
assert_eq!(list.len(), 2);
let n1 = list.head_ref().next_as_ref().expect("n1");
assert_eq!(n1.value(), Some(&1));
let n2 = n1.next_as_ref().expect("n2");
assert_eq!(n2.value(), Some(&3));
assert!(n2.next_as_ref().is_none());
}
#[test]
fn remove_len_decrements() {
let mut list = SkipList::<usize>::new();
for i in 0..50_usize {
list.push_back(i);
}
for i in (0..50_usize).rev() {
assert_eq!(list.len(), i.saturating_add(1));
list.remove(0);
}
assert_eq!(list.len(), 0);
}
#[test]
#[should_panic(expected = "removal index (is 3) should be < len (is 3)")]
fn remove_out_of_bounds() {
let mut list = SkipList::<i32>::with_capacity(1);
list.push_back(1);
list.push_back(2);
list.push_back(3);
list.remove(3);
}
#[test]
#[should_panic(expected = "removal index (is 0) should be < len (is 0)")]
fn remove_from_empty() {
let mut list = SkipList::<i32>::new();
list.remove(0);
}
#[test]
fn remove_interleaved_with_insert() {
let mut list = SkipList::<i32>::with_capacity(1);
list.push_back(1); list.push_back(2); list.push_back(3); assert_eq!(list.remove(1), 2); list.insert(1, 4); assert_eq!(list.remove(0), 1); list.push_back(5); assert_eq!(list.remove(2), 5); assert_eq!(list.len(), 2);
let n1 = list.head_ref().next_as_ref().expect("n1");
assert_eq!(n1.value(), Some(&4));
let n2 = n1.next_as_ref().expect("n2");
assert_eq!(n2.value(), Some(&3));
assert!(n2.next_as_ref().is_none());
}
#[test]
fn remove_all_elements() {
let mut list = SkipList::<i32>::with_capacity(1);
for i in 1..=5_i32 {
list.push_back(i);
}
assert_eq!(list.remove(2), 3); assert_eq!(list.remove(0), 1); assert_eq!(list.remove(2), 5); assert_eq!(list.remove(1), 4); assert_eq!(list.remove(0), 2); assert_eq!(list.len(), 0);
assert!(list.is_empty());
}
#[test]
#[expect(
clippy::integer_division,
clippy::integer_division_remainder_used,
reason = "removing from middle"
)]
fn remove_all_in_order_from_middle() {
let n = 50_usize;
let mut list = SkipList::<usize>::new();
for i in 0..n {
list.push_back(i);
}
while !list.is_empty() {
list.remove(list.len() / 2);
}
assert_eq!(list.len(), 0);
assert!(list.head_ref().next_as_ref().is_none());
}
#[test]
fn swap_basic() {
let mut list = SkipList::<i32>::new();
list.push_back(1);
list.push_back(2);
list.push_back(3);
list.swap(0, 2);
assert_eq!(list.get(0), Some(&3));
assert_eq!(list.get(1), Some(&2));
assert_eq!(list.get(2), Some(&1));
}
#[test]
fn swap_same_index() {
let mut list = SkipList::<i32>::new();
list.push_back(1);
list.push_back(2);
list.push_back(3);
list.swap(1, 1);
assert_eq!(list.get(0), Some(&1));
assert_eq!(list.get(1), Some(&2));
assert_eq!(list.get(2), Some(&3));
}
#[test]
fn swap_adjacent() {
let mut list = SkipList::<i32>::new();
list.push_back(1);
list.push_back(2);
list.push_back(3);
list.swap(1, 2);
assert_eq!(list.get(0), Some(&1));
assert_eq!(list.get(1), Some(&3));
assert_eq!(list.get(2), Some(&2));
}
#[test]
fn swap_two_elements() {
let mut list = SkipList::<i32>::new();
list.push_back(10);
list.push_back(20);
list.swap(0, 1);
assert_eq!(list.get(0), Some(&20));
assert_eq!(list.get(1), Some(&10));
}
#[test]
fn swap_front_back() {
let mut list = SkipList::<i32>::new();
for i in 1..=5_i32 {
list.push_back(i);
}
list.swap(0, 4);
assert_eq!(list.get(0), Some(&5));
assert_eq!(list.get(1), Some(&2));
assert_eq!(list.get(2), Some(&3));
assert_eq!(list.get(3), Some(&4));
assert_eq!(list.get(4), Some(&1));
}
#[test]
fn swap_preserves_len() {
let mut list = SkipList::<i32>::new();
list.push_back(1);
list.push_back(2);
list.push_back(3);
list.swap(0, 2);
assert_eq!(list.len(), 3);
}
#[test]
#[should_panic(expected = "swap index a (is 3) should be < len (is 3)")]
fn swap_out_of_bounds_a() {
let mut list = SkipList::<i32>::new();
list.push_back(1);
list.push_back(2);
list.push_back(3);
list.swap(3, 0);
}
#[test]
#[should_panic(expected = "swap index b (is 3) should be < len (is 3)")]
fn swap_out_of_bounds_b() {
let mut list = SkipList::<i32>::new();
list.push_back(1);
list.push_back(2);
list.push_back(3);
list.swap(0, 3);
}
#[test]
#[should_panic(expected = "swap index a (is 0) should be < len (is 0)")]
fn swap_empty() {
let mut list = SkipList::<i32>::new();
list.swap(0, 0);
}
#[test]
fn swap_large() {
let n: usize = 100;
let mut list = SkipList::<usize>::new();
for i in 0..n {
list.push_back(i);
}
#[expect(
clippy::integer_division,
clippy::integer_division_remainder_used,
reason = "swapping across midpoint"
)]
for i in 0..(n / 2) {
list.swap(i, n - 1 - i);
}
for i in 0..n {
assert_eq!(list.get(i), Some(&(n - 1 - i)));
}
assert_eq!(list.len(), n);
}
}