use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum QuantizationError {
InvalidDimension {
dimension: usize,
},
DimensionMismatch {
expected: usize,
actual: usize,
},
ByteLengthMismatch {
expected: usize,
actual: usize,
},
}
impl fmt::Display for QuantizationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidDimension { dimension } => {
write!(
f,
"dimension must be divisible by 8, got {dimension}. \
Try using a dimension like 128, 384, 768, 1024, or 1536."
)
}
Self::DimensionMismatch { expected, actual } => {
write!(f, "dimension mismatch: expected {expected}, got {actual}")
}
Self::ByteLengthMismatch { expected, actual } => {
write!(f, "byte length mismatch: expected {expected}, got {actual}")
}
}
}
}
impl std::error::Error for QuantizationError {}
#[derive(Clone, PartialEq, Eq)]
pub struct BinaryVector {
data: Vec<u8>,
dimension: usize,
}
impl BinaryVector {
pub fn quantize(vector: &[f32]) -> Result<Self, QuantizationError> {
let dimension = vector.len();
if dimension == 0 || dimension % 8 != 0 {
return Err(QuantizationError::InvalidDimension { dimension });
}
let bytes = dimension / 8;
let mut data = vec![0u8; bytes];
for (i, &value) in vector.iter().enumerate() {
if value > 0.0 {
let byte_idx = i / 8;
let bit_idx = i % 8;
data[byte_idx] |= 1 << bit_idx;
}
}
Ok(Self { data, dimension })
}
pub fn from_bytes(data: Vec<u8>, dimension: usize) -> Result<Self, QuantizationError> {
if dimension == 0 || dimension % 8 != 0 {
return Err(QuantizationError::InvalidDimension { dimension });
}
let expected_bytes = dimension / 8;
if data.len() != expected_bytes {
return Err(QuantizationError::ByteLengthMismatch {
expected: expected_bytes,
actual: data.len(),
});
}
Ok(Self { data, dimension })
}
#[must_use]
#[inline]
pub fn dimension(&self) -> usize {
self.dimension
}
#[must_use]
#[inline]
pub fn bytes(&self) -> usize {
self.data.len()
}
#[must_use]
#[inline]
pub fn data(&self) -> &[u8] {
&self.data
}
pub fn hamming_distance(&self, other: &Self) -> Result<u32, QuantizationError> {
if self.dimension != other.dimension {
return Err(QuantizationError::DimensionMismatch {
expected: self.dimension,
actual: other.dimension,
});
}
Ok(crate::simd::popcount::simd_popcount_xor(
&self.data,
&other.data,
))
}
#[allow(clippy::cast_precision_loss)] pub fn similarity(&self, other: &Self) -> Result<f32, QuantizationError> {
let distance = self.hamming_distance(other)?;
Ok(1.0 - (distance as f32 / self.dimension as f32))
}
}
impl fmt::Debug for BinaryVector {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BinaryVector")
.field("dimension", &self.dimension)
.field("bytes", &self.data.len())
.finish()
}
}
#[cfg(test)]
#[allow(clippy::cast_precision_loss)]
mod tests {
use super::*;
#[test]
fn test_quantize_128d() {
let v = vec![1.0f32; 128];
let bv = BinaryVector::quantize(&v).unwrap();
assert_eq!(bv.dimension(), 128);
assert_eq!(bv.bytes(), 16);
assert!(bv.data().iter().all(|&b| b == 0xFF));
}
#[test]
fn test_quantize_384d() {
let v = vec![-1.0f32; 384];
let bv = BinaryVector::quantize(&v).unwrap();
assert_eq!(bv.dimension(), 384);
assert_eq!(bv.bytes(), 48);
assert!(bv.data().iter().all(|&b| b == 0x00));
}
#[test]
fn test_quantize_768d() {
let v: Vec<f32> = (0..768)
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
.collect();
let bv = BinaryVector::quantize(&v).unwrap();
assert_eq!(bv.dimension(), 768);
assert_eq!(bv.bytes(), 96);
assert!(bv.data().iter().all(|&b| b == 0x55));
}
#[test]
fn test_quantize_1024d() {
let v = vec![1.0f32; 1024];
let bv = BinaryVector::quantize(&v).unwrap();
assert_eq!(bv.dimension(), 1024);
assert_eq!(bv.bytes(), 128);
}
#[test]
fn test_quantize_1536d() {
let v = vec![1.0f32; 1536];
let bv = BinaryVector::quantize(&v).unwrap();
assert_eq!(bv.dimension(), 1536);
assert_eq!(bv.bytes(), 192);
}
#[test]
fn test_invalid_dimension_not_divisible_by_8() {
let v = vec![1.0f32; 100];
let result = BinaryVector::quantize(&v);
assert!(matches!(
result,
Err(QuantizationError::InvalidDimension { dimension: 100 })
));
}
#[test]
fn test_invalid_dimension_zero() {
let v: Vec<f32> = vec![];
let result = BinaryVector::quantize(&v);
assert!(matches!(
result,
Err(QuantizationError::InvalidDimension { dimension: 0 })
));
}
#[test]
fn test_from_bytes_valid() {
let data = vec![0xFF; 16];
let bv = BinaryVector::from_bytes(data, 128).unwrap();
assert_eq!(bv.dimension(), 128);
assert_eq!(bv.bytes(), 16);
}
#[test]
fn test_from_bytes_length_mismatch() {
let data = vec![0xFF; 16];
let result = BinaryVector::from_bytes(data, 256);
assert!(matches!(
result,
Err(QuantizationError::ByteLengthMismatch {
expected: 32,
actual: 16
})
));
}
#[test]
fn test_hamming_distance_zero() {
let v = vec![1.0f32; 128];
let bv = BinaryVector::quantize(&v).unwrap();
assert_eq!(bv.hamming_distance(&bv).unwrap(), 0);
}
#[test]
fn test_hamming_distance_max() {
let v1 = vec![1.0f32; 128];
let v2 = vec![-1.0f32; 128];
let bv1 = BinaryVector::quantize(&v1).unwrap();
let bv2 = BinaryVector::quantize(&v2).unwrap();
assert_eq!(bv1.hamming_distance(&bv2).unwrap(), 128);
}
#[test]
fn test_hamming_distance_symmetric() {
let v1: Vec<f32> = (0..128).map(|i| (i as f32).sin()).collect();
let v2: Vec<f32> = (0..128).map(|i| (i as f32).cos()).collect();
let bv1 = BinaryVector::quantize(&v1).unwrap();
let bv2 = BinaryVector::quantize(&v2).unwrap();
assert_eq!(
bv1.hamming_distance(&bv2).unwrap(),
bv2.hamming_distance(&bv1).unwrap()
);
}
#[test]
fn test_hamming_distance_dimension_mismatch() {
let v1 = vec![1.0f32; 128];
let v2 = vec![1.0f32; 256];
let bv1 = BinaryVector::quantize(&v1).unwrap();
let bv2 = BinaryVector::quantize(&v2).unwrap();
let result = bv1.hamming_distance(&bv2);
assert!(matches!(
result,
Err(QuantizationError::DimensionMismatch {
expected: 128,
actual: 256
})
));
}
#[test]
fn test_similarity_identical() {
let v = vec![1.0f32; 128];
let bv = BinaryVector::quantize(&v).unwrap();
let sim = bv.similarity(&bv).unwrap();
assert!((sim - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_similarity_opposite() {
let v1 = vec![1.0f32; 128];
let v2 = vec![-1.0f32; 128];
let bv1 = BinaryVector::quantize(&v1).unwrap();
let bv2 = BinaryVector::quantize(&v2).unwrap();
let sim = bv1.similarity(&bv2).unwrap();
assert!(sim.abs() < f32::EPSILON);
}
#[test]
fn test_similarity_half() {
let v1 = vec![1.0f32; 128];
let v2: Vec<f32> = (0..128).map(|i| if i < 64 { 1.0 } else { -1.0 }).collect();
let bv1 = BinaryVector::quantize(&v1).unwrap();
let bv2 = BinaryVector::quantize(&v2).unwrap();
let sim = bv1.similarity(&bv2).unwrap();
assert!((sim - 0.5).abs() < f32::EPSILON);
}
#[test]
fn test_edge_case_nan() {
let mut v = vec![1.0f32; 128];
v[0] = f32::NAN;
let bv = BinaryVector::quantize(&v).unwrap();
assert_eq!(bv.data()[0] & 0x01, 0);
}
#[test]
fn test_edge_case_infinity() {
let mut v = vec![0.0f32; 128];
v[0] = f32::INFINITY;
v[1] = f32::NEG_INFINITY;
let bv = BinaryVector::quantize(&v).unwrap();
assert_eq!(bv.data()[0] & 0x01, 1);
assert_eq!(bv.data()[0] & 0x02, 0);
}
#[test]
fn test_quantize_deterministic() {
let v: Vec<f32> = (0..128).map(|i| (i as f32).sin()).collect();
let bv1 = BinaryVector::quantize(&v).unwrap();
let bv2 = BinaryVector::quantize(&v).unwrap();
assert_eq!(bv1, bv2);
}
#[test]
fn test_debug_format() {
let v = vec![1.0f32; 128];
let bv = BinaryVector::quantize(&v).unwrap();
let debug = format!("{bv:?}");
assert!(debug.contains("BinaryVector"));
assert!(debug.contains("dimension: 128"));
assert!(debug.contains("bytes: 16"));
}
#[test]
fn test_error_display() {
let err = QuantizationError::InvalidDimension { dimension: 100 };
let msg = err.to_string();
assert!(msg.contains("100"));
assert!(msg.contains("divisible by 8"));
}
}