use core::{cmp::Ordering, iter, marker::PhantomData, ptr::NonNull};
use arrayvec::ArrayVec;
use crate::node::{
Node,
visitor::{Step, Visitor, VisitorMut},
};
pub(crate) struct OrdIndexMutVisitor<'a, T, Q: ?Sized, F: Fn(&T, &Q) -> Ordering, const N: usize> {
current: NonNull<Node<T, N>>,
level: usize,
found: bool,
target: &'a Q,
cmp: F,
precursors: ArrayVec<NonNull<Node<T, N>>, N>,
rank: usize,
precursor_distances: ArrayVec<usize, N>,
_marker: PhantomData<*mut Node<T, N>>,
}
impl<'a, T, Q: ?Sized, F: Fn(&T, &Q) -> Ordering, const N: usize>
OrdIndexMutVisitor<'a, T, Q, F, N>
{
pub(crate) fn new(head: NonNull<Node<T, N>>, target: &'a Q, cmp: F) -> Self {
let max_levels = unsafe { head.as_ref() }.level();
let current = head;
Self {
current,
level: max_levels,
found: false,
target,
cmp,
precursors: iter::repeat_n(current, max_levels).collect(),
rank: 0,
precursor_distances: iter::repeat_n(0_usize, max_levels).collect(),
_marker: PhantomData,
}
}
pub(crate) fn precursor_distances(&self) -> &[usize] {
&self.precursor_distances
}
pub(crate) fn rank(&self) -> usize {
self.rank.saturating_sub(1)
}
pub(crate) fn current_rank_internal(&self) -> usize {
self.rank
}
#[expect(clippy::type_complexity, reason = "internal code")]
pub(crate) fn into_parts(
self,
) -> (
NonNull<Node<T, N>>,
bool,
ArrayVec<NonNull<Node<T, N>>, N>,
ArrayVec<usize, N>,
) {
(
self.current,
self.found,
self.precursors,
self.precursor_distances,
)
}
}
impl<T, Q: ?Sized, F: Fn(&T, &Q) -> Ordering, const N: usize> Visitor
for OrdIndexMutVisitor<'_, T, Q, F, N>
{
type NodeRef = NonNull<Node<T, N>>;
fn current(&self) -> Self::NodeRef {
self.current
}
fn level(&self) -> usize {
self.level
}
fn found(&self) -> bool {
self.found
}
#[expect(
clippy::indexing_slicing,
reason = "`level` comes from `(0..self.level).rev()` where `self.level` is \
initialised to `max_levels = precursors.len()` and only decreases, \
so `level < max_levels == precursors.len()` is always true"
)]
fn step(&mut self) -> Step<Self::NodeRef> {
if self.found {
return Step::FoundTarget;
}
{
let current_ref: &Node<T, N> = unsafe { self.current.as_ref() };
let links = current_ref.links();
for level in (0..self.level).rev() {
let maybe_link = links.get(level).and_then(|l| l.as_ref());
if let Some(link) = maybe_link {
let node_ptr = link.node();
let ord = unsafe { node_ptr.as_ref() }
.value()
.map_or(Ordering::Less, |v| (self.cmp)(v, self.target));
if ord == Ordering::Less {
self.current = node_ptr;
self.rank = self.rank.saturating_add(link.distance().get());
self.level = level.saturating_add(1);
return Step::Advanced(self.current);
}
}
self.precursors[level] = self.current;
self.precursor_distances[level] = self.rank;
}
}
self.level = 0;
if let Some(next_nn) = unsafe { self.current.as_ref() }.next() {
let ord = unsafe { next_nn.as_ref() }
.value()
.map_or(Ordering::Less, |v| (self.cmp)(v, self.target));
if ord == Ordering::Greater {
return Step::Exhausted;
}
self.current = next_nn;
self.rank = self.rank.saturating_add(1);
self.found = ord == Ordering::Equal;
return Step::Advanced(self.current);
}
Step::Exhausted
}
}
impl<T, Q: ?Sized, F: Fn(&T, &Q) -> Ordering, const N: usize> VisitorMut
for OrdIndexMutVisitor<'_, T, Q, F, N>
{
type NodeMut = NonNull<Node<T, N>>;
type Precursor = NonNull<Node<T, N>>;
fn current_mut(&mut self) -> Self::NodeMut {
self.current
}
fn precursors(&self) -> &[Self::Precursor] {
&self.precursors
}
}
#[expect(
clippy::undocumented_unsafe_blocks,
reason = "test code, covered by miri, so safety guarantees can be relaxed"
)]
#[cfg(test)]
mod tests {
use core::ptr::NonNull;
use anyhow::Result;
use pretty_assertions::assert_eq;
use rstest::rstest;
use super::OrdIndexMutVisitor;
use crate::node::{
Node,
tests::{MAX_LEVELS, skiplist},
visitor::{Step, Visitor, VisitorMut},
};
#[rstest]
fn find_existing_value(skiplist: Result<NonNull<Node<u8, MAX_LEVELS>>>) -> Result<()> {
let head = skiplist?;
let mut visitor = OrdIndexMutVisitor::new(head, &30_u8, Ord::cmp);
let found = visitor.traverse();
assert!(visitor.found());
let value = found.map(|ptr| unsafe { ptr.as_ref() }.value().copied());
assert_eq!(value, Some(Some(30)));
drop(visitor);
unsafe { drop(Box::from_raw(head.as_ptr())) };
Ok(())
}
#[rstest]
fn find_first_value(skiplist: Result<NonNull<Node<u8, MAX_LEVELS>>>) -> Result<()> {
let head = skiplist?;
let mut visitor = OrdIndexMutVisitor::new(head, &10_u8, Ord::cmp);
let found = visitor.traverse();
assert!(visitor.found());
let value = found.map(|ptr| unsafe { ptr.as_ref() }.value().copied());
assert_eq!(value, Some(Some(10)));
drop(visitor);
unsafe { drop(Box::from_raw(head.as_ptr())) };
Ok(())
}
#[rstest]
fn find_last_value(skiplist: Result<NonNull<Node<u8, MAX_LEVELS>>>) -> Result<()> {
let head = skiplist?;
let mut visitor = OrdIndexMutVisitor::new(head, &40_u8, Ord::cmp);
let found = visitor.traverse();
assert!(visitor.found());
let value = found.map(|ptr| unsafe { ptr.as_ref() }.value().copied());
assert_eq!(value, Some(Some(40)));
drop(visitor);
unsafe { drop(Box::from_raw(head.as_ptr())) };
Ok(())
}
#[rstest]
fn value_not_found(skiplist: Result<NonNull<Node<u8, MAX_LEVELS>>>) -> Result<()> {
let head = skiplist?;
let mut visitor = OrdIndexMutVisitor::new(head, &25_u8, Ord::cmp);
let found = visitor.traverse();
assert!(!visitor.found());
assert!(found.is_none());
drop(visitor);
unsafe { drop(Box::from_raw(head.as_ptr())) };
Ok(())
}
#[rstest]
fn value_beyond_list(skiplist: Result<NonNull<Node<u8, MAX_LEVELS>>>) -> Result<()> {
let head = skiplist?;
let mut visitor = OrdIndexMutVisitor::new(head, &99_u8, Ord::cmp);
let found = visitor.traverse();
assert!(!visitor.found());
assert!(found.is_none());
drop(visitor);
unsafe { drop(Box::from_raw(head.as_ptr())) };
Ok(())
}
#[rstest]
fn rank_first_element(skiplist: Result<NonNull<Node<u8, MAX_LEVELS>>>) -> Result<()> {
let head = skiplist?;
let mut visitor = OrdIndexMutVisitor::new(head, &10_u8, Ord::cmp);
visitor.traverse();
assert!(visitor.found());
assert_eq!(visitor.rank(), 0);
drop(visitor);
unsafe { drop(Box::from_raw(head.as_ptr())) };
Ok(())
}
#[rstest]
fn rank_second_element(skiplist: Result<NonNull<Node<u8, MAX_LEVELS>>>) -> Result<()> {
let head = skiplist?;
let mut visitor = OrdIndexMutVisitor::new(head, &20_u8, Ord::cmp);
visitor.traverse();
assert!(visitor.found());
assert_eq!(visitor.rank(), 1);
drop(visitor);
unsafe { drop(Box::from_raw(head.as_ptr())) };
Ok(())
}
#[rstest]
fn rank_third_element(skiplist: Result<NonNull<Node<u8, MAX_LEVELS>>>) -> Result<()> {
let head = skiplist?;
let mut visitor = OrdIndexMutVisitor::new(head, &30_u8, Ord::cmp);
visitor.traverse();
assert!(visitor.found());
assert_eq!(visitor.rank(), 2);
drop(visitor);
unsafe { drop(Box::from_raw(head.as_ptr())) };
Ok(())
}
#[rstest]
fn rank_last_element(skiplist: Result<NonNull<Node<u8, MAX_LEVELS>>>) -> Result<()> {
let head = skiplist?;
let mut visitor = OrdIndexMutVisitor::new(head, &40_u8, Ord::cmp);
visitor.traverse();
assert!(visitor.found());
assert_eq!(visitor.rank(), 3);
drop(visitor);
unsafe { drop(Box::from_raw(head.as_ptr())) };
Ok(())
}
#[rstest]
fn precursors_are_before_target(skiplist: Result<NonNull<Node<u8, MAX_LEVELS>>>) -> Result<()> {
let head = skiplist?;
let mut visitor = OrdIndexMutVisitor::new(head, &30_u8, Ord::cmp);
while let Step::Advanced(_) = visitor.step() {}
for &ptr in visitor.precursors() {
let value = unsafe { ptr.as_ref() }.value().copied();
assert!(
value.is_none_or(|v| v < 30),
"precursor value {value:?} should be < 30"
);
}
drop(visitor);
unsafe { drop(Box::from_raw(head.as_ptr())) };
Ok(())
}
#[rstest]
fn precursor_distances_at_most_found_rank(
skiplist: Result<NonNull<Node<u8, MAX_LEVELS>>>,
) -> Result<()> {
let head = skiplist?;
let mut visitor = OrdIndexMutVisitor::new(head, &30_u8, Ord::cmp);
visitor.traverse();
assert!(visitor.found());
let internal_rank = visitor.rank().saturating_add(1);
for &dist in visitor.precursor_distances() {
assert!(
dist <= internal_rank,
"precursor_distance {dist} should be <= found internal rank {internal_rank}"
);
}
drop(visitor);
unsafe { drop(Box::from_raw(head.as_ptr())) };
Ok(())
}
#[rstest]
fn exhausted_when_target_out_of_range(
skiplist: Result<NonNull<Node<u8, MAX_LEVELS>>>,
) -> Result<()> {
let head = skiplist?;
let mut visitor = OrdIndexMutVisitor::new(head, &99_u8, Ord::cmp);
loop {
let s = visitor.step();
match s {
Step::Advanced(_) => (),
Step::Exhausted => {
assert!(!visitor.found());
break;
}
Step::FoundTarget => panic!("should not find target 99"),
}
}
drop(visitor);
unsafe { drop(Box::from_raw(head.as_ptr())) };
Ok(())
}
#[rstest]
fn current_mut_matches_current(skiplist: Result<NonNull<Node<u8, MAX_LEVELS>>>) -> Result<()> {
let head = skiplist?;
let mut visitor = OrdIndexMutVisitor::new(head, &20_u8, Ord::cmp);
visitor.traverse();
assert_eq!(visitor.current(), visitor.current_mut());
drop(visitor);
unsafe { drop(Box::from_raw(head.as_ptr())) };
Ok(())
}
#[rstest]
fn precursors_length(skiplist: Result<NonNull<Node<u8, MAX_LEVELS>>>) -> Result<()> {
let head = skiplist?;
let max_levels = unsafe { head.as_ref() }.level();
let mut visitor = OrdIndexMutVisitor::new(head, &20_u8, Ord::cmp);
visitor.traverse();
assert_eq!(visitor.precursors().len(), max_levels);
assert_eq!(visitor.precursor_distances().len(), max_levels);
drop(visitor);
unsafe { drop(Box::from_raw(head.as_ptr())) };
Ok(())
}
}