use std::ptr::NonNull;
use diskann_utils::{Reborrow, ReborrowMut};
use thiserror::Error;
use crate::{
alloc::{AllocatorCore, AllocatorError, GlobalAllocator, Poly},
bits::{
AsMutPtr, AsPtr, BitSlice, BitSliceBase, Dense, MutBitSlice, MutSlicePtr,
PermutationStrategy, Representation, SlicePtr,
},
ownership::{CopyMut, CopyRef, Mut, Owned, Ref},
};
#[derive(Debug, Clone, Copy)]
pub struct VectorBase<const NBITS: usize, Repr, Ptr, T, Perm = Dense>
where
Ptr: AsPtr<Type = u8>,
Repr: Representation<NBITS>,
Perm: PermutationStrategy<NBITS>,
{
bits: BitSliceBase<NBITS, Repr, Ptr, Perm>,
meta: T,
}
impl<const NBITS: usize, Repr, Ptr, T, Perm> VectorBase<NBITS, Repr, Ptr, T, Perm>
where
Ptr: AsPtr<Type = u8>,
Repr: Representation<NBITS>,
Perm: PermutationStrategy<NBITS>,
{
pub fn slice_bytes(count: usize) -> usize {
BitSliceBase::<NBITS, Repr, Ptr, Perm>::bytes_for(count)
}
pub fn canonical_bytes(count: usize) -> usize
where
T: CopyRef,
T::Target: bytemuck::Pod,
{
Self::slice_bytes(count) + std::mem::size_of::<T::Target>()
}
pub fn new<M>(bits: BitSliceBase<NBITS, Repr, Ptr, Perm>, meta: M) -> Self
where
M: Into<T>,
{
Self {
bits,
meta: meta.into(),
}
}
pub fn len(&self) -> usize {
self.bits.len()
}
pub fn is_empty(&self) -> bool {
self.bits.is_empty()
}
pub fn meta(&self) -> T::Target
where
T: CopyRef,
{
self.meta.copy_ref()
}
pub fn vector(&self) -> BitSlice<'_, NBITS, Repr, Perm> {
self.bits.reborrow()
}
pub fn vector_mut(&mut self) -> MutBitSlice<'_, NBITS, Repr, Perm>
where
Ptr: AsMutPtr,
{
self.bits.reborrow_mut()
}
pub fn set_meta(&mut self, value: T::Target)
where
Ptr: AsMutPtr,
T: CopyMut,
{
self.meta.copy_mut(value)
}
}
impl<const NBITS: usize, Repr, Perm, T>
VectorBase<NBITS, Repr, Poly<[u8], GlobalAllocator>, Owned<T>, Perm>
where
Repr: Representation<NBITS>,
Perm: PermutationStrategy<NBITS>,
T: Default,
{
pub fn new_boxed(len: usize) -> Self {
Self {
bits: BitSliceBase::new_boxed(len),
meta: Owned::default(),
}
}
}
impl<const NBITS: usize, Repr, Perm, T, A> VectorBase<NBITS, Repr, Poly<[u8], A>, Owned<T>, Perm>
where
Repr: Representation<NBITS>,
Perm: PermutationStrategy<NBITS>,
T: Default,
A: AllocatorCore,
{
pub fn new_in(len: usize, allocator: A) -> Result<Self, AllocatorError> {
Ok(Self {
bits: BitSliceBase::new_in(len, allocator)?,
meta: Owned::default(),
})
}
}
pub type VectorRef<'a, const NBITS: usize, Repr, T, Perm = Dense> =
VectorBase<NBITS, Repr, SlicePtr<'a, u8>, Ref<'a, T>, Perm>;
pub type VectorMut<'a, const NBITS: usize, Repr, T, Perm = Dense> =
VectorBase<NBITS, Repr, MutSlicePtr<'a, u8>, Mut<'a, T>, Perm>;
pub type Vector<const NBITS: usize, Repr, T, Perm = Dense> =
VectorBase<NBITS, Repr, Poly<[u8], GlobalAllocator>, Owned<T>, Perm>;
pub type PolyVector<const NBITS: usize, Repr, T, Perm, A> =
VectorBase<NBITS, Repr, Poly<[u8], A>, Owned<T>, Perm>;
impl<'this, const NBITS: usize, Repr, Ptr, T, Perm> Reborrow<'this>
for VectorBase<NBITS, Repr, Ptr, T, Perm>
where
Ptr: AsPtr<Type = u8>,
Repr: Representation<NBITS>,
Perm: PermutationStrategy<NBITS>,
T: CopyRef + Reborrow<'this, Target = Ref<'this, <T as CopyRef>::Target>>,
{
type Target = VectorRef<'this, NBITS, Repr, <T as CopyRef>::Target, Perm>;
fn reborrow(&'this self) -> Self::Target {
Self::Target {
bits: self.bits.reborrow(),
meta: self.meta.reborrow(),
}
}
}
impl<'this, const NBITS: usize, Repr, Ptr, T, Perm> ReborrowMut<'this>
for VectorBase<NBITS, Repr, Ptr, T, Perm>
where
Ptr: AsMutPtr<Type = u8>,
Repr: Representation<NBITS>,
Perm: PermutationStrategy<NBITS>,
T: CopyMut + ReborrowMut<'this, Target = Mut<'this, <T as CopyRef>::Target>>,
{
type Target = VectorMut<'this, NBITS, Repr, <T as CopyRef>::Target, Perm>;
fn reborrow_mut(&'this mut self) -> Self::Target {
Self::Target {
bits: self.bits.reborrow_mut(),
meta: self.meta.reborrow_mut(),
}
}
}
#[derive(Debug, Error, PartialEq, Clone, Copy)]
pub enum NotCanonical {
#[error("expected a slice length of {0} bytes but instead got {1} bytes")]
WrongLength(usize, usize),
}
impl<'a, const NBITS: usize, Repr, T, Perm> VectorRef<'a, NBITS, Repr, T, Perm>
where
Repr: Representation<NBITS>,
Perm: PermutationStrategy<NBITS>,
T: bytemuck::Pod,
{
pub fn from_canonical_front(data: &'a [u8], dim: usize) -> Result<Self, NotCanonical> {
let expected = Self::canonical_bytes(dim);
if data.len() != expected {
Err(NotCanonical::WrongLength(expected, data.len()))
} else {
Ok(unsafe { Self::from_canonical_unchecked(data, dim) })
}
}
pub fn from_canonical_back(data: &'a [u8], dim: usize) -> Result<Self, NotCanonical> {
let expected = Self::canonical_bytes(dim);
if data.len() != expected {
Err(NotCanonical::WrongLength(expected, data.len()))
} else {
Ok(unsafe { Self::from_canonical_back_unchecked(data, dim) })
}
}
pub unsafe fn from_canonical_unchecked(data: &'a [u8], dim: usize) -> Self {
debug_assert_eq!(data.len(), Self::canonical_bytes(dim));
let bits =
unsafe { BitSlice::new_unchecked(data.get_unchecked(std::mem::size_of::<T>()..), dim) };
let meta =
unsafe { Ref::new(NonNull::new_unchecked(data.as_ptr().cast_mut()).cast::<T>()) };
Self { bits, meta }
}
pub unsafe fn from_canonical_back_unchecked(data: &'a [u8], dim: usize) -> Self {
debug_assert_eq!(data.len(), Self::canonical_bytes(dim));
let (data, meta) =
unsafe { data.split_at_unchecked(data.len() - std::mem::size_of::<T>()) };
let bits = unsafe { BitSlice::new_unchecked(data, dim) };
let meta =
unsafe { Ref::new(NonNull::new_unchecked(meta.as_ptr().cast_mut()).cast::<T>()) };
Self { bits, meta }
}
}
impl<'a, const NBITS: usize, Repr, T, Perm> VectorMut<'a, NBITS, Repr, T, Perm>
where
Repr: Representation<NBITS>,
Perm: PermutationStrategy<NBITS>,
T: bytemuck::Pod,
{
pub fn from_canonical_front_mut(data: &'a mut [u8], dim: usize) -> Result<Self, NotCanonical> {
let expected = Self::canonical_bytes(dim);
if data.len() != expected {
Err(NotCanonical::WrongLength(expected, data.len()))
} else {
Ok(unsafe { Self::from_canonical_front_mut_unchecked(data, dim) })
}
}
pub unsafe fn from_canonical_front_mut_unchecked(data: &'a mut [u8], dim: usize) -> Self {
debug_assert_eq!(data.len(), Self::canonical_bytes(dim));
let (front, back) = unsafe { data.split_at_mut_unchecked(std::mem::size_of::<T>()) };
let bits = unsafe { MutBitSlice::new_unchecked(back, dim) };
let meta = unsafe { Mut::new(NonNull::new_unchecked(front.as_mut_ptr()).cast::<T>()) };
Self { bits, meta }
}
pub fn from_canonical_back_mut(data: &'a mut [u8], dim: usize) -> Result<Self, NotCanonical> {
let len = data.len();
let expected = || Self::canonical_bytes(dim);
let (front, back) = match data.split_at_mut_checked(Self::slice_bytes(dim)) {
Some(v) => v,
None => {
return Err(NotCanonical::WrongLength(expected(), len));
}
};
if back.len() != std::mem::size_of::<T>() {
return Err(NotCanonical::WrongLength(expected(), len));
}
let bits = unsafe { MutBitSlice::new_unchecked(front, dim) };
let meta = unsafe { Mut::new(NonNull::new_unchecked(back.as_mut_ptr()).cast::<T>()) };
Ok(Self { bits, meta })
}
}
#[cfg(test)]
mod tests {
use diskann_utils::{Reborrow, ReborrowMut};
use rand::{
Rng, SeedableRng,
distr::{Distribution, StandardUniform, Uniform},
rngs::StdRng,
};
use super::*;
use crate::bits::{BoxedBitSlice, Representation, Unsigned};
#[derive(Default, Debug, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
struct Metadata {
a: u32,
b: u32,
}
impl Metadata {
fn new(a: u32, b: u32) -> Metadata {
Self { a, b }
}
}
#[test]
fn test_vector() {
let len = 20;
let mut base = Vector::<7, Unsigned, Metadata>::new_boxed(len);
assert_eq!(base.len(), len);
assert_eq!(base.meta(), Metadata::default());
assert!(!base.is_empty());
{
let mut rb = base.reborrow_mut();
assert_eq!(rb.len(), len);
rb.set_meta(Metadata::new(1, 2));
let mut v = rb.vector_mut();
assert_eq!(v.len(), len);
for i in 0..v.len() {
v.set(i, i as i64).unwrap();
}
}
let expected_metadata = Metadata::new(1, 2);
assert_eq!(base.meta(), expected_metadata);
assert_eq!(base.len(), len);
let v = base.vector();
for i in 0..v.len() {
assert_eq!(v.get(i).unwrap(), i as i64);
}
{
let rb = base.reborrow();
assert_eq!(rb.len(), len);
assert_eq!(rb.meta(), expected_metadata);
let v = rb.vector();
for i in 0..v.len() {
assert_eq!(v.get(i).unwrap(), i as i64);
}
}
}
#[test]
fn test_compensated_mut() {
let len = 30;
let mut v = BoxedBitSlice::<7, Unsigned>::new_boxed(len);
let mut m = Metadata::default();
let mut vector = VectorMut::new(v.reborrow_mut(), &mut m);
assert_eq!(vector.len(), len);
vector.set_meta(Metadata::new(200, 5));
for i in 0..vector.len() {
vector.vector_mut().set(i, i as i64).unwrap();
}
assert_eq!(m.a, 200);
assert_eq!(m.b, 5);
for i in 0..len {
assert_eq!(v.get(i).unwrap(), i as i64);
}
}
type TestVectorRef<'a, const NBITS: usize> = VectorRef<'a, NBITS, Unsigned, Metadata>;
type TestVectorMut<'a, const NBITS: usize> = VectorMut<'a, NBITS, Unsigned, Metadata>;
fn check_canonicalization<const NBITS: usize, R>(dim: usize, ntrials: usize, rng: &mut R)
where
Unsigned: Representation<NBITS>,
R: Rng,
{
let bytes = TestVectorRef::<NBITS>::canonical_bytes(dim);
assert_eq!(
bytes,
std::mem::size_of::<Metadata>() + BitSlice::<NBITS, Unsigned>::bytes_for(dim)
);
let mut buffer_front = vec![u8::default(); bytes + std::mem::size_of::<Metadata>() + 1];
let mut buffer_back = vec![u8::default(); bytes + std::mem::size_of::<Metadata>() + 1];
let mut expected = vec![i64::default(); dim];
let uniform = Uniform::try_from(Unsigned::domain_const::<NBITS>()).unwrap();
for _ in 0..ntrials {
let offset = Uniform::new(0, std::mem::size_of::<Metadata>())
.unwrap()
.sample(rng);
let a: u32 = StandardUniform.sample(rng);
let b: u32 = StandardUniform.sample(rng);
expected.iter_mut().for_each(|i| *i = uniform.sample(rng));
{
let set = |mut cv: TestVectorMut<NBITS>| {
cv.set_meta(Metadata::new(a, b));
let mut vector = cv.vector_mut();
for (i, e) in expected.iter().enumerate() {
vector.set(i, *e).unwrap();
}
};
let cv = TestVectorMut::<NBITS>::from_canonical_front_mut(
&mut buffer_front[offset..offset + bytes],
dim,
)
.unwrap();
set(cv);
let cv = TestVectorMut::<NBITS>::from_canonical_back_mut(
&mut buffer_back[offset..offset + bytes],
dim,
)
.unwrap();
set(cv);
}
{
let check = |cv: TestVectorRef<NBITS>| {
assert_eq!(cv.meta(), Metadata::new(a, b));
let vector = cv.vector();
for (i, e) in expected.iter().enumerate() {
assert_eq!(vector.get(i).unwrap(), *e);
}
};
let cv = TestVectorRef::<NBITS>::from_canonical_front(
&buffer_front[offset..offset + bytes],
dim,
)
.unwrap();
check(cv);
let cv = TestVectorRef::<NBITS>::from_canonical_back(
&buffer_back[offset..offset + bytes],
dim,
)
.unwrap();
check(cv);
}
}
{
let err = TestVectorMut::<NBITS>::from_canonical_front_mut(
&mut buffer_front[..bytes - 1],
dim,
)
.unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
let err =
TestVectorMut::<NBITS>::from_canonical_back_mut(&mut buffer_back[..bytes - 1], dim)
.unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
let err = TestVectorMut::<NBITS>::from_canonical_front_mut(&mut [], dim).unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
let err = TestVectorMut::<NBITS>::from_canonical_back_mut(&mut [], dim).unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
let err = TestVectorMut::<NBITS>::from_canonical_front_mut(
&mut buffer_front[..bytes + 1],
dim,
)
.unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
let err =
TestVectorMut::<NBITS>::from_canonical_back_mut(&mut buffer_back[..bytes + 1], dim)
.unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
}
{
let err = TestVectorRef::<NBITS>::from_canonical_front(&buffer_front[..bytes - 1], dim)
.unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
let err = TestVectorRef::<NBITS>::from_canonical_back(&buffer_back[..bytes - 1], dim)
.unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
let err = TestVectorRef::<NBITS>::from_canonical_front(&[], dim).unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
let err = TestVectorRef::<NBITS>::from_canonical_back(&[], dim).unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
let err = TestVectorRef::<NBITS>::from_canonical_front(&buffer_front[..bytes + 1], dim)
.unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
let err = TestVectorRef::<NBITS>::from_canonical_back(&buffer_back[..bytes + 1], dim)
.unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
}
}
fn check_canonicalization_zst<const NBITS: usize, R>(dim: usize, ntrials: usize, rng: &mut R)
where
Unsigned: Representation<NBITS>,
R: Rng,
{
let bytes = VectorRef::<NBITS, Unsigned, ()>::canonical_bytes(dim);
assert_eq!(bytes, BitSlice::<NBITS, Unsigned>::bytes_for(dim));
let max_offset = 10;
let mut buffer_front = vec![u8::default(); bytes + max_offset];
let mut buffer_back = vec![u8::default(); bytes + max_offset];
let mut expected = vec![i64::default(); dim];
let uniform = Uniform::try_from(Unsigned::domain_const::<NBITS>()).unwrap();
for _ in 0..ntrials {
let offset = Uniform::new(0, max_offset).unwrap().sample(rng);
expected.iter_mut().for_each(|i| *i = uniform.sample(rng));
{
let set = |mut cv: VectorMut<NBITS, Unsigned, ()>| {
cv.set_meta(());
let mut vector = cv.vector_mut();
for (i, e) in expected.iter().enumerate() {
vector.set(i, *e).unwrap();
}
};
let cv = VectorMut::<NBITS, Unsigned, ()>::from_canonical_front_mut(
&mut buffer_front[offset..offset + bytes],
dim,
)
.unwrap();
set(cv);
let cv = VectorMut::<NBITS, Unsigned, ()>::from_canonical_back_mut(
&mut buffer_back[offset..offset + bytes],
dim,
)
.unwrap();
set(cv);
}
{
let check = |cv: VectorRef<NBITS, Unsigned, ()>| {
let vector = cv.vector();
for (i, e) in expected.iter().enumerate() {
assert_eq!(vector.get(i).unwrap(), *e);
}
};
let cv = VectorRef::<NBITS, Unsigned, ()>::from_canonical_front(
&buffer_front[offset..offset + bytes],
dim,
)
.unwrap();
check(cv);
let cv = VectorRef::<NBITS, Unsigned, ()>::from_canonical_back(
&buffer_back[offset..offset + bytes],
dim,
)
.unwrap();
check(cv);
}
}
{
if dim >= 1 {
let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_front_mut(
&mut buffer_front[..bytes - 1],
dim,
)
.unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_back_mut(
&mut buffer_back[..bytes - 1],
dim,
)
.unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
}
if dim >= 1 {
let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_front_mut(&mut [], dim)
.unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_back_mut(&mut [], dim)
.unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
}
{
let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_front_mut(
&mut buffer_front[..bytes + 1],
dim,
)
.unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_back_mut(
&mut buffer_back[..bytes + 1],
dim,
)
.unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
}
}
{
if dim >= 1 {
let err = VectorRef::<NBITS, Unsigned, ()>::from_canonical_front(
&buffer_front[..bytes - 1],
dim,
)
.unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
let err = VectorRef::<NBITS, Unsigned, ()>::from_canonical_back(
&buffer_back[..bytes - 1],
dim,
)
.unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
}
let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_front_mut(
&mut buffer_front[..bytes + 1],
dim,
)
.unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_back_mut(
&mut buffer_back[..bytes + 1],
dim,
)
.unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
}
{
if dim >= 1 {
let err =
VectorRef::<NBITS, Unsigned, ()>::from_canonical_front(&[], dim).unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
let err =
VectorRef::<NBITS, Unsigned, ()>::from_canonical_back(&[], dim).unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
}
{
let err = VectorRef::<NBITS, Unsigned, ()>::from_canonical_front(
&buffer_front[..bytes + 1],
dim,
)
.unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
let err = VectorRef::<NBITS, Unsigned, ()>::from_canonical_back(
&buffer_back[..bytes + 1],
dim,
)
.unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
}
}
}
cfg_if::cfg_if! {
if #[cfg(miri)] {
const MAX_DIM: usize = 37;
const TRIALS_PER_DIM: usize = 1;
} else {
const MAX_DIM: usize = 256;
const TRIALS_PER_DIM: usize = 20;
}
}
macro_rules! test_canonical {
($name:ident, $nbits:literal, $seed:literal) => {
#[test]
fn $name() {
let mut rng = StdRng::seed_from_u64($seed);
for dim in 0..MAX_DIM {
check_canonicalization::<$nbits, _>(dim, TRIALS_PER_DIM, &mut rng);
check_canonicalization_zst::<$nbits, _>(dim, TRIALS_PER_DIM, &mut rng);
}
}
};
}
test_canonical!(canonical_8bit, 8, 0xe64518a00ee99e2f);
test_canonical!(canonical_7bit, 7, 0x3907123f8c38def2);
test_canonical!(canonical_6bit, 6, 0xeccaeb83965ff6a1);
test_canonical!(canonical_5bit, 5, 0x9691fe59e49bfb96);
test_canonical!(canonical_4bit, 4, 0xc4d3e9bc699a7e6f);
test_canonical!(canonical_3bit, 3, 0x8a01b2ccdca8fb2b);
test_canonical!(canonical_2bit, 2, 0x3a07429e8184b67f);
test_canonical!(canonical_1bit, 1, 0x93fddb26059c115c);
}