use crate::{
AsFunctor, CompressInto,
bits::{Binary, MutBitSlice, Representation},
distances::Hamming,
};
#[derive(Debug, Clone, Copy)]
pub struct BinaryQuantizer;
impl<T> CompressInto<&[T], MutBitSlice<'_, 1, Binary>> for BinaryQuantizer
where
T: PartialOrd + Default,
{
type Error = std::convert::Infallible;
type Output = ();
fn compress_into(
&self,
from: &[T],
mut into: MutBitSlice<'_, 1, Binary>,
) -> Result<(), Self::Error> {
assert_eq!(from.len(), into.len());
from.iter().enumerate().for_each(|(i, v)| {
let v: u8 = if v > &T::default() {
Binary::encode_unchecked(1)
} else {
Binary::encode_unchecked(-1)
};
unsafe { into.set_unchecked(i, v) };
});
Ok(())
}
}
impl AsFunctor<Hamming> for BinaryQuantizer {
fn as_functor(&self) -> Hamming {
Hamming
}
}
#[cfg(test)]
mod tests {
use diskann_utils::{ReborrowMut, views::Matrix};
use rand::{SeedableRng, rngs::StdRng, seq::SliceRandom};
use super::*;
use crate::bits::{Binary, BoxedBitSlice};
fn test_compression_impl(len: usize, rng: &mut StdRng) {
let mut domain = [-10, -1, 0, 1, 10];
let mut test_pattern = Matrix::<i32>::new(0, domain.len(), len);
for col in 0..len {
domain.shuffle(rng);
for row in 0..test_pattern.nrows() {
test_pattern[(row, col)] = domain[row];
}
}
let quantizer = BinaryQuantizer;
let mut binary = BoxedBitSlice::<1, Binary>::new_boxed(len);
for row in test_pattern.row_iter() {
quantizer.compress_into(row, binary.reborrow_mut()).unwrap();
for (i, r) in row.iter().enumerate() {
if *r > 0 {
assert_eq!(binary.get(i).unwrap(), 1);
} else {
assert_eq!(binary.get(i).unwrap(), -1);
}
}
}
}
#[test]
fn test_compression() {
let mut rng = StdRng::seed_from_u64(0x9673d0890bbb7231);
for len in 1..17 {
test_compression_impl(len, &mut rng);
}
}
}