use core::{iter, marker::PhantomData, ptr::NonNull};
use arrayvec::ArrayVec;
use crate::node::{
Node,
visitor::{Step, Visitor, VisitorMut},
};
pub(crate) struct IndexMutVisitor<T, const N: usize> {
current: NonNull<Node<T, N>>,
index: usize,
level: usize,
target: usize,
precursors: ArrayVec<NonNull<Node<T, N>>, N>,
precursor_distances: ArrayVec<usize, N>,
_marker: PhantomData<*mut Node<T, N>>,
}
impl<T, const N: usize> IndexMutVisitor<T, N> {
pub(crate) fn new(head: NonNull<Node<T, N>>, target: usize) -> Self {
let max_levels = unsafe { head.as_ref() }.level();
let current = head;
Self {
current,
index: 0,
level: max_levels,
target,
precursors: iter::repeat_n(current, max_levels).collect(),
precursor_distances: iter::repeat_n(0, max_levels).collect(),
_marker: PhantomData,
}
}
pub(crate) fn precursor_distances(&self) -> &[usize] {
&self.precursor_distances
}
#[expect(clippy::type_complexity, reason = "internal code")]
pub(crate) fn into_parts(
self,
) -> (
NonNull<Node<T, N>>,
ArrayVec<NonNull<Node<T, N>>, N>,
ArrayVec<usize, N>,
) {
(self.current, self.precursors, self.precursor_distances)
}
}
impl<T, const N: usize> Visitor for IndexMutVisitor<T, N> {
type NodeRef = NonNull<Node<T, N>>;
fn current(&self) -> Self::NodeRef {
self.current
}
fn level(&self) -> usize {
self.level
}
fn found(&self) -> bool {
self.index == self.target
}
#[expect(
clippy::indexing_slicing,
reason = "`level` comes from `(0..self.level).rev()` where `self.level` is \
initialised to `max_levels = precursors.len()` and only decreases, \
so `level < max_levels == precursors.len()` is always true"
)]
fn step(&mut self) -> Step<Self::NodeRef> {
if self.found() {
return Step::FoundTarget;
}
{
let current_ref: &Node<T, N> = unsafe { self.current.as_ref() };
let links = current_ref.links();
for level in (0..self.level).rev() {
let maybe_link = links.get(level).and_then(|l| l.as_ref());
if let Some(link) = maybe_link {
let next_index = self.index.saturating_add(link.distance().get());
if next_index < self.target {
self.current = link.node();
self.level = level.saturating_add(1);
self.index = next_index;
return Step::Advanced(self.current);
}
}
self.precursors[level] = self.current;
self.precursor_distances[level] = self.index;
}
}
self.level = 0;
let next_opt = unsafe { self.current.as_ref() }.next();
if let Some(next_nn) = next_opt {
self.current = next_nn;
self.index = self.index.saturating_add(1);
return Step::Advanced(self.current);
}
Step::Exhausted
}
}
impl<T, const N: usize> VisitorMut for IndexMutVisitor<T, N> {
type NodeMut = NonNull<Node<T, N>>;
type Precursor = NonNull<Node<T, N>>;
fn current_mut(&mut self) -> Self::NodeMut {
self.current
}
fn precursors(&self) -> &[Self::Precursor] {
&self.precursors
}
}
#[expect(
clippy::undocumented_unsafe_blocks,
reason = "test code, safety guarantees can be relaxed"
)]
#[cfg(test)]
mod tests {
use core::ptr::NonNull;
use anyhow::Result;
use pretty_assertions::assert_eq;
use rstest::rstest;
use super::IndexMutVisitor;
use crate::node::{
Node,
tests::{MAX_LEVELS, skiplist},
visitor::{Step, Visitor, VisitorMut},
};
#[rstest]
fn find_index_2(skiplist: Result<NonNull<Node<u8, MAX_LEVELS>>>) -> Result<()> {
let head = skiplist?;
let mut visitor = IndexMutVisitor::new(head, 2);
let found = visitor.traverse();
assert!(visitor.found());
let value = found.map(|ptr| unsafe { ptr.as_ref() }.value().copied());
assert_eq!(value, Some(Some(20)));
unsafe { drop(Box::from_raw(head.as_ptr())) };
Ok(())
}
#[rstest]
fn find_index_not_found(skiplist: Result<NonNull<Node<u8, MAX_LEVELS>>>) -> Result<()> {
let head = skiplist?;
let mut visitor = IndexMutVisitor::new(head, 5);
let found = visitor.traverse();
assert!(!visitor.found());
assert!(found.is_none());
unsafe { drop(Box::from_raw(head.as_ptr())) };
Ok(())
}
#[rstest]
fn precursors_are_before_target(skiplist: Result<NonNull<Node<u8, MAX_LEVELS>>>) -> Result<()> {
let head = skiplist?;
let mut visitor = IndexMutVisitor::new(head, 3);
while let Step::Advanced(_) = visitor.step() {}
for &dist in visitor.precursor_distances() {
assert!(dist < 3, "precursor distance {dist} should be < 3");
}
unsafe { drop(Box::from_raw(head.as_ptr())) };
Ok(())
}
#[rstest]
fn exhausted_when_target_out_of_range(
skiplist: Result<NonNull<Node<u8, MAX_LEVELS>>>,
) -> Result<()> {
let head = skiplist?;
let mut visitor = IndexMutVisitor::new(head, 99);
loop {
let s = visitor.step();
match s {
Step::Advanced(_) => {}
Step::Exhausted => {
assert!(!visitor.found());
break;
}
Step::FoundTarget => panic!("should not find target 99"),
}
}
unsafe { drop(Box::from_raw(head.as_ptr())) };
Ok(())
}
#[rstest]
fn current_mut_matches_current(skiplist: Result<NonNull<Node<u8, MAX_LEVELS>>>) -> Result<()> {
let head = skiplist?;
let mut visitor = IndexMutVisitor::new(head, 2);
visitor.traverse();
assert_eq!(visitor.current(), visitor.current_mut());
unsafe { drop(Box::from_raw(head.as_ptr())) };
Ok(())
}
#[rstest]
fn precursors_length(skiplist: Result<NonNull<Node<u8, MAX_LEVELS>>>) -> Result<()> {
let head = skiplist?;
let max_levels = unsafe { head.as_ref() }.level();
let mut visitor = IndexMutVisitor::new(head, 2);
visitor.traverse();
assert_eq!(visitor.precursors().len(), max_levels);
assert_eq!(visitor.precursor_distances().len(), max_levels);
unsafe { drop(Box::from_raw(head.as_ptr())) };
Ok(())
}
}