use std::num::NonZeroUsize;
use diskann_utils::views::{DenseData, MutDenseData};
use std::ops::{Index, IndexMut};
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ChunkOffsetsBase<T>
where
T: DenseData<Elem = usize>,
{
dim: NonZeroUsize,
offsets: T,
}
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum ChunkOffsetError {
#[error("offsets must have a length of at least 2, found {0}")]
LengthNotAtLeastTwo(usize),
#[error("offsets must begin at 0, not {0}")]
DoesNotBeginWithZero(usize),
#[error(
"offsets must be strictly increasing, \
instead entry {start_val} at position {start} is followed by {next_val}"
)]
NonMonotonic {
start_val: usize,
start: usize,
next_val: usize,
},
}
impl<T> ChunkOffsetsBase<T>
where
T: DenseData<Elem = usize>,
{
pub fn new(offsets: T) -> Result<Self, ChunkOffsetError> {
let slice = offsets.as_slice();
let len = slice.len();
if len < 2 {
return Err(ChunkOffsetError::LengthNotAtLeastTwo(len));
}
let start = slice[0];
if start != 0 {
return Err(ChunkOffsetError::DoesNotBeginWithZero(start));
}
let mut last: NonZeroUsize = match NonZeroUsize::new(slice[1]) {
Some(x) => Ok(x),
None => Err(ChunkOffsetError::NonMonotonic {
start_val: start,
start: 0,
next_val: 0,
}),
}?;
for i in 2..slice.len() {
let start_val = slice[i - 1];
let next_val = NonZeroUsize::new(slice[i]);
last = match next_val {
Some(next_val) => {
if start_val >= next_val.get() {
Err(ChunkOffsetError::NonMonotonic {
start_val,
start: i - 1,
next_val: next_val.get(),
})
} else {
Ok(next_val)
}
}
None => Err(ChunkOffsetError::NonMonotonic {
start_val,
start: i - 1,
next_val: 0,
}),
}?;
}
Ok(Self { dim: last, offsets })
}
pub fn len(&self) -> usize {
debug_assert!(self.offsets.as_slice().len() >= 2);
self.offsets.as_slice().len() - 1
}
pub fn is_empty(&self) -> bool {
false
}
pub fn dim(&self) -> usize {
self.dim.get()
}
pub fn dim_nonzero(&self) -> NonZeroUsize {
self.dim
}
pub fn at(&self, i: usize) -> core::ops::Range<usize> {
assert!(
i < self.len(),
"index {i} must be less than len {}",
self.len()
);
let slice = self.offsets.as_slice();
slice[i]..slice[i + 1]
}
pub fn as_view(&self) -> ChunkOffsetsView<'_> {
ChunkOffsetsBase {
dim: self.dim,
offsets: self.offsets.as_slice(),
}
}
pub fn to_owned(&self) -> ChunkOffsets {
ChunkOffsetsBase {
dim: self.dim,
offsets: self.offsets.as_slice().into(),
}
}
pub fn as_slice(&self) -> &[usize] {
self.offsets.as_slice()
}
}
pub type ChunkOffsetsView<'a> = ChunkOffsetsBase<&'a [usize]>;
pub type ChunkOffsets = ChunkOffsetsBase<Box<[usize]>>;
impl<'a> From<ChunkOffsetsView<'a>> for &'a [usize] {
fn from(view: ChunkOffsetsView<'a>) -> Self {
view.offsets
}
}
#[derive(Debug, Clone, Copy)]
pub struct ChunkViewImpl<'a, T>
where
T: DenseData,
{
data: T,
offsets: ChunkOffsetsView<'a>,
}
#[derive(Error, Debug)]
#[non_exhaustive]
#[error(
"error in chunk view construction, got a slice of length {got} but \
the provided chunking schema expects a length of {should}"
)]
pub struct ChunkViewError {
got: usize,
should: usize,
}
impl<'a, T> ChunkViewImpl<'a, T>
where
T: DenseData,
{
pub fn new<U>(data: U, offsets: ChunkOffsetsView<'a>) -> Result<Self, ChunkViewError>
where
T: From<U>,
{
let data: T = data.into();
let got = data.as_slice().len();
let should = offsets.dim();
if got != should {
Err(ChunkViewError { got, should })
} else {
Ok(Self { data, offsets })
}
}
pub fn len(&self) -> usize {
self.offsets.len()
}
pub fn is_empty(&self) -> bool {
self.offsets.is_empty()
}
}
impl<T> Index<usize> for ChunkViewImpl<'_, T>
where
T: DenseData,
{
type Output = [T::Elem];
fn index(&self, i: usize) -> &Self::Output {
&(self.data.as_slice())[self.offsets.at(i)]
}
}
impl<T> IndexMut<usize> for ChunkViewImpl<'_, T>
where
T: MutDenseData,
{
fn index_mut(&mut self, i: usize) -> &mut Self::Output {
&mut (self.data.as_mut_slice())[self.offsets.at(i)]
}
}
pub type ChunkView<'a, T> = ChunkViewImpl<'a, &'a [T]>;
pub type MutChunkView<'a, T> = ChunkViewImpl<'a, &'a mut [T]>;
#[cfg(test)]
mod tests {
use super::*;
use diskann_utils::lazy_format;
fn is_copyable<T: Copy>(_x: T) -> bool {
true
}
#[test]
fn chunk_offset_happy_path() {
let offsets_raw: Vec<usize> = vec![0, 1, 3, 6, 10, 12, 13, 14];
let offsets = ChunkOffsetsView::new(offsets_raw.as_slice()).unwrap();
assert_eq!(offsets.len(), offsets_raw.len() - 1);
assert_eq!(offsets.dim(), *offsets_raw.last().unwrap());
assert!(!offsets.is_empty());
assert_eq!(offsets.at(0), 0..1);
assert_eq!(offsets.at(1), 1..3);
assert_eq!(offsets.at(2), 3..6);
assert_eq!(offsets.at(3), 6..10);
assert_eq!(offsets.at(4), 10..12);
assert_eq!(offsets.at(5), 12..13);
assert_eq!(offsets.at(6), 13..14);
assert!(is_copyable(offsets));
assert_eq!(offsets.as_slice(), offsets_raw.as_slice());
let offsets_owned = offsets.to_owned();
assert_eq!(offsets_owned.as_slice(), offsets_raw.as_slice());
assert_ne!(
offsets_owned.as_slice().as_ptr(),
offsets_raw.as_slice().as_ptr()
);
assert_eq!(offsets_owned.dim, offsets.dim);
let offsets_view = offsets_owned.as_view();
assert_eq!(offsets_view, offsets);
assert_eq!(
offsets_view.as_slice().as_ptr(),
offsets_owned.as_slice().as_ptr()
);
}
#[test]
#[should_panic(expected = "index 5 must be less than len 3")]
fn chunk_offset_indexing_panic() {
let offsets = ChunkOffsets::new(Box::new([0, 1, 2, 3])).unwrap();
let _ = offsets.at(5);
}
#[test]
fn chunk_offset_construction_errors() {
let offsets = ChunkOffsets::new(Box::new([]));
assert_eq!(
offsets.unwrap_err().to_string(),
"offsets must have a length of at least 2, found 0"
);
let offsets = ChunkOffsets::new(Box::new([0]));
assert_eq!(
offsets.unwrap_err().to_string(),
"offsets must have a length of at least 2, found 1"
);
let offsets = ChunkOffsets::new(Box::new([10, 11, 12, 13]));
assert_eq!(
offsets.unwrap_err().to_string(),
"offsets must begin at 0, not 10"
);
let offsets = ChunkOffsets::new(Box::new([0, 10, 20, 30, 30, 40, 41]));
assert_eq!(
offsets.unwrap_err().to_string(),
"offsets must be strictly increasing, instead entry 30 at position 3 \
is followed by 30"
);
let offsets = ChunkOffsets::new(Box::new([0, 10, 9, 10, 20]));
assert_eq!(
offsets.unwrap_err().to_string(),
"offsets must be strictly increasing, instead entry 10 at position 1 \
is followed by 9"
);
let offsets = ChunkOffsets::new(Box::new([0, 10, 11, 12, 0]));
assert_eq!(
offsets.unwrap_err().to_string(),
"offsets must be strictly increasing, instead entry 12 at position 3 \
is followed by 0"
);
let offsets = ChunkOffsets::new(Box::new([0, 0, 11, 12, 20]));
assert_eq!(
offsets.unwrap_err().to_string(),
"offsets must be strictly increasing, instead entry 0 at position 0 \
is followed by 0"
);
}
fn check_chunk_view<T>(
view: &ChunkViewImpl<'_, T>,
data: &[i32],
offsets: &[usize],
context: &dyn std::fmt::Display,
) where
T: DenseData<Elem = i32>,
{
assert_eq!(view.len(), offsets.len() - 1, "{}", context);
for i in 0..view.len() {
let context = lazy_format!("start = {}, {}", i, context);
let start = offsets[i];
let stop = offsets[i + 1];
let expected = &data[start..stop];
let retrieved = &view[i];
assert_eq!(retrieved, expected, "{}", context);
}
}
#[test]
fn test_immutable_chunkview() {
let data: Vec<i32> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let offsets: Vec<usize> = vec![0, 3, 5, 9, 10];
let chunks = ChunkOffsetsView::new(offsets.as_slice()).unwrap();
let chunk_view = ChunkView::new(data.as_slice(), chunks).unwrap();
assert_eq!(chunk_view.len(), offsets.len() - 1);
assert_eq!(chunk_view.len(), chunks.len());
assert!(is_copyable(chunk_view));
let context = lazy_format!("chunkview happy path");
check_chunk_view(&chunk_view, &data, &offsets, &context);
}
#[test]
fn test_chunkview_construction_error() {
let data: Vec<i32> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let offsets: Vec<usize> = vec![0, 3, 5, 9];
let chunks = ChunkOffsetsView::new(offsets.as_slice()).unwrap();
let chunk_view = ChunkView::new(data.as_slice(), chunks);
assert!(chunk_view.is_err());
assert_eq!(
chunk_view.unwrap_err().to_string(),
"error in chunk view construction, got a slice of length 10 but \
the provided chunking schema expects a length of 9"
);
}
#[test]
fn test_mutable_chunkview() {
let mut data: Vec<i32> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let offsets: Vec<usize> = vec![0, 3, 5, 9, 10];
let data_clone = data.clone();
let chunks = ChunkOffsetsView::new(offsets.as_slice()).unwrap();
let mut chunk_view = MutChunkView::new(data.as_mut_slice(), chunks).unwrap();
assert_eq!(chunk_view.len(), offsets.len() - 1);
assert_eq!(chunk_view.len(), chunks.len());
let context = lazy_format!("mutchunkview happy path");
check_chunk_view(&chunk_view, &data_clone, &offsets, &context);
for i in 0..chunk_view.len() {
let i_i32: i32 = i.try_into().unwrap();
chunk_view[i].iter_mut().for_each(|d| *d = i_i32);
}
assert_eq!(data[0], 0);
assert_eq!(data[1], 0);
assert_eq!(data[2], 0);
assert_eq!(data[3], 1);
assert_eq!(data[4], 1);
assert_eq!(data[5], 2);
assert_eq!(data[6], 2);
assert_eq!(data[7], 2);
assert_eq!(data[8], 2);
assert_eq!(data[9], 3);
}
}