use std::{
ops::{Deref, DerefMut},
ptr::NonNull,
};
use diskann_utils::{Reborrow, ReborrowMut};
use thiserror::Error;
use crate::{
alloc::{AllocatorCore, AllocatorError, Poly},
num::PowerOfTwo,
ownership::{Mut, Owned, Ref},
};
#[derive(Debug, Clone, Copy)]
pub struct Slice<T, M> {
slice: T,
meta: M,
}
const fn canonical_align<T, M>() -> PowerOfTwo {
let m_align = PowerOfTwo::alignment_of::<M>();
let t_align = PowerOfTwo::alignment_of::<T>();
if m_align.raw() > t_align.raw() {
m_align
} else {
t_align
}
}
const fn canonical_metadata_bytes<T, M>() -> usize {
let m_size = std::mem::size_of::<M>();
if m_size == 0 {
0
} else {
m_size.next_multiple_of(std::mem::align_of::<T>())
}
}
const fn canonical_bytes<T, M>(count: usize) -> usize {
canonical_metadata_bytes::<T, M>() + std::mem::size_of::<T>() * count
}
impl<T, M> Slice<T, M> {
pub fn new<U>(slice: T, meta: U) -> Self
where
U: Into<M>,
{
Self {
slice,
meta: meta.into(),
}
}
pub fn meta(&self) -> &M::Target
where
M: Deref,
{
&self.meta
}
pub fn meta_mut(&mut self) -> &mut M::Target
where
M: DerefMut,
{
&mut self.meta
}
}
impl<T, M, U, V> Slice<T, M>
where
T: Deref<Target = [U]>,
M: Deref<Target = V>,
{
pub fn len(&self) -> usize {
self.slice.len()
}
pub fn is_empty(&self) -> bool {
self.slice.is_empty()
}
pub fn vector(&self) -> &[U] {
&self.slice
}
pub fn vector_mut(&mut self) -> &mut [U]
where
T: DerefMut,
{
&mut self.slice
}
pub const fn canonical_align() -> PowerOfTwo {
canonical_align::<U, V>()
}
pub const fn canonical_bytes(count: usize) -> usize {
canonical_bytes::<U, V>(count)
}
}
impl<T, A, M> Slice<Poly<[T], A>, Owned<M>>
where
A: AllocatorCore,
T: Default,
M: Default,
{
pub fn new_in(len: usize, allocator: A) -> Result<Self, AllocatorError> {
Ok(Self {
slice: Poly::from_iter((0..len).map(|_| T::default()), allocator)?,
meta: Owned::default(),
})
}
}
pub type SliceRef<'a, T, M> = Slice<&'a [T], Ref<'a, M>>;
pub type SliceMut<'a, T, M> = Slice<&'a mut [T], Mut<'a, M>>;
pub type PolySlice<T, M, A> = Slice<Poly<[T], A>, Owned<M>>;
impl<'a, T, A, M> Reborrow<'a> for Slice<Poly<[T], A>, Owned<M>>
where
A: AllocatorCore,
M: 'static,
{
type Target = SliceRef<'a, T, M>;
fn reborrow(&'a self) -> Self::Target {
Slice {
slice: &*self.slice,
meta: Ref::from(&self.meta.0),
}
}
}
impl<'a, T, A, M> ReborrowMut<'a> for Slice<Poly<[T], A>, Owned<M>>
where
A: AllocatorCore,
M: 'static,
{
type Target = SliceMut<'a, T, M>;
fn reborrow_mut(&'a mut self) -> Self::Target {
Slice {
slice: &mut *self.slice,
meta: Mut::from(&mut self.meta.0),
}
}
}
#[derive(Debug, Error, PartialEq, Clone, Copy)]
pub enum NotCanonical {
#[error("expected a slice length of {0} bytes but instead got {1} bytes")]
WrongLength(usize, usize),
#[error("expected a base pointer alignment of at least {0}")]
NotAligned(usize),
}
impl<'a, T, M> SliceRef<'a, T, M>
where
T: bytemuck::Pod,
M: bytemuck::Pod,
{
pub fn from_canonical(data: &'a [u8], dim: usize) -> Result<Self, NotCanonical> {
let expected_align = Self::canonical_align().raw();
let expected_len = Self::canonical_bytes(dim);
if !(data.as_ptr() as usize).is_multiple_of(expected_align) {
Err(NotCanonical::NotAligned(expected_align))
} else if data.len() != expected_len {
Err(NotCanonical::WrongLength(expected_len, data.len()))
} else {
Ok(unsafe { Self::from_canonical_unchecked(data, dim) })
}
}
pub unsafe fn from_canonical_unchecked(data: &'a [u8], dim: usize) -> Self {
debug_assert_eq!(data.len(), Self::canonical_bytes(dim));
let offset = canonical_metadata_bytes::<T, M>();
let slice =
unsafe { std::slice::from_raw_parts(data.as_ptr().add(offset).cast::<T>(), dim) };
let meta =
unsafe { Ref::new(NonNull::new_unchecked(data.as_ptr().cast_mut()).cast::<M>()) };
Self { slice, meta }
}
}
impl<'a, T, M> SliceMut<'a, T, M>
where
T: bytemuck::Pod,
M: bytemuck::Pod,
{
pub fn from_canonical_mut(data: &'a mut [u8], dim: usize) -> Result<Self, NotCanonical> {
let expected_align = Self::canonical_align().raw();
let expected_len = Self::canonical_bytes(dim);
if !(data.as_ptr() as usize).is_multiple_of(expected_align) {
return Err(NotCanonical::NotAligned(expected_align));
} else if data.len() != expected_len {
return Err(NotCanonical::WrongLength(expected_len, data.len()));
}
let offset = canonical_metadata_bytes::<T, M>();
let (meta, slice) = unsafe { data.split_at_mut_unchecked(offset) };
let slice = unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr().cast::<T>(), dim) };
let meta = unsafe { Mut::new(NonNull::new_unchecked(meta.as_mut_ptr()).cast::<M>()) };
Ok(Self { slice, meta })
}
}
#[cfg(test)]
mod tests {
use std::fmt::Debug;
use rand::{
SeedableRng,
distr::{Distribution, Uniform},
rngs::StdRng,
};
use super::*;
use crate::{
alloc::{AlignedAllocator, GlobalAllocator},
num::PowerOfTwo,
};
#[derive(Default, Debug, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
struct Metadata {
a: u32,
b: u32,
}
impl Metadata {
fn new(a: u32, b: u32) -> Metadata {
Self { a, b }
}
}
#[test]
fn test_vector() {
let len = 20;
let mut base = PolySlice::<f32, Metadata, _>::new_in(len, GlobalAllocator).unwrap();
assert_eq!(base.len(), len);
assert_eq!(*base.meta(), Metadata::default());
assert!(!base.is_empty());
{
*base.meta_mut() = Metadata::new(1, 2);
let v = base.vector_mut();
assert_eq!(v.len(), len);
v.iter_mut().enumerate().for_each(|(i, v)| *v = i as f32);
}
{
let expected_metadata = Metadata::new(1, 2);
assert_eq!(*base.meta(), expected_metadata);
assert_eq!(base.len(), len);
let v = base.vector();
v.iter().enumerate().for_each(|(i, v)| {
assert_eq!(*v, i as f32);
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
struct Zst;
#[expect(clippy::infallible_try_from)]
impl TryFrom<usize> for Zst {
type Error = std::convert::Infallible;
fn try_from(_: usize) -> Result<Self, Self::Error> {
Ok(Self)
}
}
#[derive(Debug, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C, align(16))]
struct ZstAligned;
#[expect(clippy::infallible_try_from)]
impl TryFrom<usize> for ZstAligned {
type Error = std::convert::Infallible;
fn try_from(_: usize) -> Result<Self, Self::Error> {
Ok(Self)
}
}
fn check_canonicalization<T, M>(
dim: usize,
align: usize,
slope: usize,
offset: usize,
ntrials: usize,
rng: &mut StdRng,
) where
T: bytemuck::Pod + TryFrom<usize, Error: Debug> + Debug + PartialEq,
M: bytemuck::Pod + TryFrom<usize, Error: Debug> + Debug + PartialEq,
{
let bytes = SliceRef::<T, M>::canonical_bytes(dim);
assert_eq!(
bytes,
slope * dim + offset,
"computed bytes did not match the expected formula"
);
let expected_align = std::mem::align_of::<T>().max(std::mem::align_of::<M>());
assert_eq!(SliceRef::<T, M>::canonical_align().raw(), align);
assert_eq!(SliceRef::<T, M>::canonical_align().raw(), expected_align);
let mut buffer = Poly::broadcast(
0u8,
bytes + expected_align,
AlignedAllocator::new(PowerOfTwo::new(expected_align).unwrap()),
)
.unwrap();
let mut expected = vec![usize::default(); dim];
let dist = Uniform::new(0, 255).unwrap();
for _ in 0..ntrials {
let m: usize = dist.sample(rng);
expected.iter_mut().for_each(|i| *i = dist.sample(rng));
{
let mut v =
SliceMut::<T, M>::from_canonical_mut(&mut buffer[..bytes], dim).unwrap();
*v.meta_mut() = m.try_into().unwrap();
assert_eq!(v.vector().len(), dim);
assert_eq!(v.vector_mut().len(), dim);
std::iter::zip(v.vector_mut().iter_mut(), expected.iter_mut()).for_each(
|(v, e)| {
*v = (*e).try_into().unwrap();
},
);
}
{
let v = SliceRef::<T, M>::from_canonical(&buffer[..bytes], dim).unwrap();
assert_eq!(*v.meta(), m.try_into().unwrap());
assert_eq!(v.vector().len(), dim);
std::iter::zip(v.vector().iter(), expected.iter()).for_each(|(v, e)| {
assert_eq!(*v, (*e).try_into().unwrap());
});
}
}
{
for len in 0..bytes {
let err =
SliceMut::<T, M>::from_canonical_mut(&mut buffer[..len], dim).unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
let err = SliceRef::<T, M>::from_canonical(&buffer[..len], dim).unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
}
let err =
SliceMut::<T, M>::from_canonical_mut(&mut buffer[..bytes + 1], dim).unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
let err = SliceRef::<T, M>::from_canonical(&buffer[..bytes + 1], dim).unwrap_err();
assert!(matches!(err, NotCanonical::WrongLength(_, _)));
}
{
for offset in 1..expected_align {
let err =
SliceMut::<T, M>::from_canonical_mut(&mut buffer[offset..offset + bytes], dim)
.unwrap_err();
assert!(matches!(err, NotCanonical::NotAligned(_)));
let err = SliceRef::<T, M>::from_canonical(&buffer[offset..offset + bytes], dim)
.unwrap_err();
assert!(matches!(err, NotCanonical::NotAligned(_)));
}
}
}
cfg_if::cfg_if! {
if #[cfg(miri)] {
const MAX_DIM: usize = 10;
const TRIALS_PER_DIM: usize = 1;
} else {
const MAX_DIM: usize = 256;
const TRIALS_PER_DIM: usize = 20;
}
}
macro_rules! test_canonical {
($name:ident, $M:ty, $T:ty, $align:literal, $slope:literal, $offset:literal, $seed:literal) => {
#[test]
fn $name() {
let mut rng = StdRng::seed_from_u64($seed);
for dim in 0..MAX_DIM {
check_canonicalization::<$T, $M>(
dim,
$align,
$slope,
$offset,
TRIALS_PER_DIM,
&mut rng,
);
}
}
};
}
test_canonical!(canonical_u8_u32, u8, u32, 4, 4, 4, 0x60884b7a4ca28f49);
test_canonical!(canonical_u32_u8, u32, u8, 4, 1, 4, 0x874aa5d8f40ec5ef);
test_canonical!(canonical_u32_u32, u32, u32, 4, 4, 4, 0x516c550e7be19acc);
test_canonical!(canonical_zst_u32, Zst, u32, 4, 4, 0, 0x908682ebda7c0fb9);
test_canonical!(canonical_u32_zst, u32, Zst, 4, 0, 4, 0xf223385881819c1c);
test_canonical!(
canonical_zstaligned_u32,
ZstAligned,
u32,
16,
4,
0,
0x1811ee0fd078a173
);
test_canonical!(
canonical_u32_zstaligned,
u32,
ZstAligned,
16,
0,
16,
0x6c9a67b09c0b6c0f
);
}