use diskann_quantization::{
alloc::{Allocator, AllocatorError, BumpAllocator, CompoundError, Poly},
num::PowerOfTwo,
poly,
};
trait DoTheThing {
fn do_the_thing(&self) -> String;
}
#[derive(Debug, Clone, Copy)]
enum TransformKind {
Hadamard,
Null,
}
enum Transform<A>
where
A: Allocator,
{
Hadamard { _signs: Poly<[u32], A> },
Null,
}
#[derive(Debug, Clone, Copy)]
struct DimKind {
dim: usize,
kind: TransformKind,
}
struct Quantizer<A>
where
A: Allocator,
{
centroid: Poly<[f32], A>,
_scale: f32,
transform: Transform<A>,
}
impl<A> Quantizer<A>
where
A: Allocator + Clone,
{
fn new(
dim_kind: DimKind,
allocator: A,
) -> Result<Poly<Self, A>, CompoundError<AllocatorError>> {
Poly::new_with(
|allocator| {
let centroid = Poly::from_iter((0..dim_kind.dim).map(|_| 0.0), allocator.clone())?;
let transform = match dim_kind.kind {
TransformKind::Hadamard => Transform::Hadamard {
_signs: Poly::from_iter((0..dim_kind.dim).map(|_| 0), allocator.clone())?,
},
TransformKind::Null => Transform::Null,
};
Ok(Self {
centroid,
_scale: 0.0,
transform,
})
},
allocator,
)
}
}
impl<A> DoTheThing for Quantizer<A>
where
A: Allocator,
{
fn do_the_thing(&self) -> String {
"foo".into()
}
}
#[test]
fn miri_q1_no_transform() {
let dim_kind = DimKind {
dim: 128,
kind: TransformKind::Null,
};
let allocator = BumpAllocator::new(4096, PowerOfTwo::new(4096).unwrap()).unwrap();
let object = Quantizer::new(dim_kind, allocator.clone()).unwrap();
let base = allocator.as_ptr();
assert_eq!(Poly::as_ptr(&object).cast::<u8>(), base);
assert_eq!(
object.centroid.as_ptr().cast::<u8>(),
base.wrapping_add(
std::mem::size_of::<Quantizer<BumpAllocator>>()
.next_multiple_of(std::mem::align_of::<f32>()),
)
);
assert!(matches!(object.transform, Transform::Null));
}
#[test]
fn miri_q1_transform() {
let dim_kind = DimKind {
dim: 128,
kind: TransformKind::Hadamard,
};
let allocator = BumpAllocator::new(4096, PowerOfTwo::new(4096).unwrap()).unwrap();
let object = Quantizer::new(dim_kind, allocator.clone()).unwrap();
let base = allocator.as_ptr();
assert_eq!(Poly::as_ptr(&object).cast::<u8>(), base);
assert_eq!(
object.centroid.as_ptr().cast::<u8>(),
base.wrapping_add(
std::mem::size_of::<Quantizer<BumpAllocator>>()
.next_multiple_of(std::mem::align_of::<f32>()),
)
);
}
#[test]
fn miri_trait_object_as_base() {
let dim_kind = DimKind {
dim: 128,
kind: TransformKind::Hadamard,
};
let allocator = BumpAllocator::new(4096, PowerOfTwo::new(4096).unwrap()).unwrap();
let poly: Poly<Poly<dyn DoTheThing, BumpAllocator>, BumpAllocator> = Poly::new_with(
|allocator| -> Result<_, std::alloc::LayoutError> {
let object = Quantizer::new(dim_kind, allocator).unwrap();
Ok(poly!(DoTheThing, object))
},
allocator.clone(),
)
.unwrap();
let x: &dyn DoTheThing = &**poly;
assert_eq!(x.do_the_thing(), "foo");
{
let base = allocator.as_ptr();
assert_eq!(Poly::as_ptr(&poly).cast::<u8>(), base);
}
{
let ptr = Poly::as_ptr(&poly);
let object = unsafe { &*ptr };
assert_eq!(object.do_the_thing(), "foo");
}
}