use serde::{Deserialize, Serialize};
pub const SIMD_BLOCK_SIZE: usize = 32;
pub const SIMD_LANES: usize = 32;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
#[repr(C, align(32))]
pub struct SimdBlock {
pub data: [i8; SIMD_LANES],
}
impl SimdBlock {
#[inline]
pub fn zeros() -> Self {
Self {
data: [0i8; SIMD_LANES],
}
}
#[inline]
pub fn from_slice(slice: &[i8]) -> Self {
assert_eq!(
slice.len(),
SIMD_LANES,
"Slice must have exactly {} elements",
SIMD_LANES
);
let mut data = [0i8; SIMD_LANES];
data.copy_from_slice(slice);
Self { data }
}
#[inline]
pub fn len(&self) -> usize {
SIMD_LANES
}
#[inline]
pub fn is_empty(&self) -> bool {
false
}
#[inline]
pub fn as_slice(&self) -> &[i8] {
&self.data
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [i8] {
&mut self.data
}
}
impl Default for SimdBlock {
fn default() -> Self {
Self::zeros()
}
}
impl AsRef<[i8]> for SimdBlock {
fn as_ref(&self) -> &[i8] {
&self.data
}
}
impl AsMut<[i8]> for SimdBlock {
fn as_mut(&mut self) -> &mut [i8] {
&mut self.data
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
#[repr(C, align(32))]
pub struct Int8QuantizedVectorMetadata {
pub scale: f32,
pub bias: f32,
pub sum: f32,
pub squared_sum: f32,
#[serde(skip)]
pub _padding: [u8; 16],
}
impl Int8QuantizedVectorMetadata {
#[inline]
pub fn new(scale: f32, bias: f32, sum: f32, squared_sum: f32) -> Self {
Self {
scale,
bias,
sum,
squared_sum,
_padding: [0; 16],
}
}
pub fn from_vector(vector: &[f32]) -> Self {
for (idx, &value) in vector.iter().enumerate() {
if !value.is_finite() {
panic!(
"Cannot quantize non-finite value ({}) at index {}. Vector must contain only finite f32 values.",
value, idx
);
}
}
let (min, max, sum, squared_sum) = vector.iter().fold(
(f32::INFINITY, f32::NEG_INFINITY, 0.0f32, 0.0f32),
|(min, max, sum, sq_sum), &v| (min.min(v), max.max(v), sum + v, sq_sum + v * v),
);
let scale = 254.0 / (max - min).max(1e-9);
let bias = -min * scale - 127.0;
Self::new(scale, bias, sum, squared_sum)
}
#[inline]
pub fn quantize(&self, value: f32) -> i8 {
let scaled = value * self.scale + self.bias;
scaled.round().clamp(-128.0, 127.0) as i8
}
#[inline]
pub fn dequantize(&self, quantized: i8) -> f32 {
(quantized as f32 - self.bias) / self.scale
}
#[inline]
pub fn norm(&self) -> f32 {
self.squared_sum.sqrt()
}
#[inline]
pub fn norm_squared(&self) -> f32 {
self.squared_sum
}
}
impl Default for Int8QuantizedVectorMetadata {
fn default() -> Self {
Self {
scale: 1.0,
bias: 0.0,
sum: 0.0,
squared_sum: 0.0,
_padding: [0; 16],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Int8QuantizedVector {
pub blocks: Vec<SimdBlock>,
pub metadata: Int8QuantizedVectorMetadata,
pub dimension: usize,
}
impl Int8QuantizedVector {
pub fn new(data: Vec<i8>, metadata: Int8QuantizedVectorMetadata, dimension: usize) -> Self {
let mut padded_data = data;
let remainder = padded_data.len() % SIMD_LANES;
if remainder != 0 {
padded_data.extend(std::iter::repeat(0i8).take(SIMD_LANES - remainder));
}
let blocks: Vec<SimdBlock> = padded_data
.chunks(SIMD_LANES)
.map(SimdBlock::from_slice)
.collect();
Self {
blocks,
metadata,
dimension,
}
}
#[inline]
pub fn from_blocks(
blocks: Vec<SimdBlock>,
metadata: Int8QuantizedVectorMetadata,
dimension: usize,
) -> Self {
Self {
blocks,
metadata,
dimension,
}
}
#[inline]
pub fn as_slice(&self) -> &[i8] {
unsafe {
std::slice::from_raw_parts(
self.blocks.as_ptr() as *const i8,
self.blocks.len() * SIMD_LANES,
)
}
}
#[inline]
pub fn num_blocks(&self) -> usize {
self.blocks.len()
}
#[inline]
pub fn len(&self) -> usize {
self.dimension
}
#[inline]
pub fn is_empty(&self) -> bool {
self.dimension == 0
}
pub fn memory_bytes(&self) -> usize {
std::mem::size_of::<Self>() + self.blocks.len() * std::mem::size_of::<SimdBlock>()
}
pub fn to_f32(&self) -> Vec<f32> {
self.as_slice()
.iter()
.take(self.dimension)
.map(|&q| self.metadata.dequantize(q))
.collect()
}
}
#[inline]
pub fn blocks_for_dimension(dimension: usize) -> usize {
dimension.div_ceil(SIMD_LANES)
}
#[inline]
pub fn padded_dimension(dimension: usize) -> usize {
blocks_for_dimension(dimension) * SIMD_LANES
}
#[cfg(test)]
mod tests {
use super::*;
use crate::search::quantization::quantization::{Dequantize, Quantize};
#[test]
fn test_simd_block_alignment() {
assert_eq!(
std::mem::align_of::<SimdBlock>(),
32,
"SimdBlock must have 32-byte alignment"
);
assert_eq!(
std::mem::size_of::<SimdBlock>(),
32,
"SimdBlock must be 32 bytes"
);
}
#[test]
fn test_simd_block_creation() {
let block = SimdBlock::zeros();
assert_eq!(block.len(), 32);
assert!(!block.is_empty());
assert!(block.data.iter().all(|&x| x == 0));
let data: Vec<i8> = (0..32).map(|i| i as i8).collect();
let block2 = SimdBlock::from_slice(&data);
assert_eq!(block2.data[0], 0);
assert_eq!(block2.data[31], 31);
}
#[test]
fn test_int8_quantized_vector_metadata_alignment() {
assert_eq!(
std::mem::align_of::<Int8QuantizedVectorMetadata>(),
32,
"Int8QuantizedVectorMetadata must have 32-byte alignment"
);
assert_eq!(
std::mem::size_of::<Int8QuantizedVectorMetadata>(),
32,
"Int8QuantizedVectorMetadata must be 32 bytes"
);
}
#[test]
fn test_metadata_from_vector() {
let vector = vec![0.0f32, 0.5, 1.0];
let metadata = Int8QuantizedVectorMetadata::from_vector(&vector);
assert!((metadata.scale - 254.0).abs() < 1e-6);
assert!((metadata.bias - (-127.0)).abs() < 1e-6);
assert!((metadata.sum - 1.5).abs() < 1e-6);
assert!((metadata.squared_sum - 1.25).abs() < 1e-6);
}
#[test]
fn test_int8_quantized_vector_creation() {
let data: Vec<i8> = (0..64).map(|i| (i % 256) as i8).collect();
let metadata = Int8QuantizedVectorMetadata::default();
let qv = Int8QuantizedVector::new(data, metadata, 64);
assert_eq!(qv.len(), 64);
assert_eq!(qv.num_blocks(), 2);
}
#[test]
#[should_panic(expected = "non-finite value (NaN)")]
fn test_rejects_nan_values() {
let vector = vec![0.1f32, f32::NAN, 0.3];
let _metadata = Int8QuantizedVectorMetadata::from_vector(&vector);
}
#[test]
#[should_panic(expected = "non-finite value (inf)")]
fn test_rejects_positive_infinity() {
let vector = vec![0.1f32, f32::INFINITY, 0.3];
let _metadata = Int8QuantizedVectorMetadata::from_vector(&vector);
}
#[test]
#[should_panic(expected = "non-finite value (-inf)")]
fn test_rejects_negative_infinity() {
let vector = vec![0.1f32, f32::NEG_INFINITY, 0.3];
let _metadata = Int8QuantizedVectorMetadata::from_vector(&vector);
}
#[test]
fn test_accepts_finite_values() {
let vector = vec![
f32::MAX,
f32::MIN,
0.0,
-0.0,
1e-38, -1e-38, 1e38, -1e38, ];
let _metadata = Int8QuantizedVectorMetadata::from_vector(&vector);
}
#[test]
fn test_as_slice_safety() {
let data: Vec<i8> = (0..64).map(|i| i as i8).collect();
let metadata = Int8QuantizedVectorMetadata::default();
let qv = Int8QuantizedVector::new(data.clone(), metadata, 64);
let slice = qv.as_slice();
assert_eq!(slice.len(), 64);
assert_eq!(&slice[..64], &data[..64]);
let slice2 = qv.as_slice();
assert_eq!(slice, slice2);
let first_val = slice[0];
assert_eq!(first_val, data[0]);
}
#[test]
fn test_empty_vector_handling() {
let empty: Vec<f32> = vec![];
let qv = empty.quantize();
assert_eq!(qv.len(), 0);
assert_eq!(qv.as_slice().len(), 0);
assert_eq!(qv.num_blocks(), 0);
}
#[test]
fn test_large_dimension_vectors() {
let dimensions = vec![1000, 2048, 4096, 8192, 10000];
for dim in dimensions {
let vector: Vec<f32> = (0..dim).map(|i| (i % 100) as f32 * 0.01).collect();
let quantized = vector.quantize();
let dequantized = quantized.dequantize();
assert_eq!(dequantized.len(), dim, "Dimension mismatch for {}", dim);
let mse: f32 = vector
.iter()
.zip(dequantized.iter())
.map(|(o, d)| (o - d).powi(2))
.sum::<f32>()
/ dim as f32;
let rmse = mse.sqrt();
assert!(rmse < 0.01, "RMSE too high for dimension {}: {}", dim, rmse);
}
}
#[test]
fn test_very_small_dimension_vectors() {
for dim in 1..=32 {
let vector: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.1 - 0.5).collect();
let quantized = vector.quantize();
let dequantized = quantized.dequantize();
assert_eq!(dequantized.len(), dim);
}
}
#[test]
fn test_uniform_vector_quantization() {
let uniform = vec![0.5f32; 100];
let quantized = uniform.quantize();
let dequantized = quantized.dequantize();
for (i, &val) in dequantized.iter().enumerate() {
assert!(
(val - 0.5).abs() < 0.01,
"Uniform vector dequantization failed at index {}: got {}",
i,
val
);
}
}
#[test]
fn test_zero_vector_quantization() {
let zeros = vec![0.0f32; 128];
let quantized = zeros.quantize();
let dequantized = quantized.dequantize();
assert_eq!(dequantized.len(), 128);
for val in &dequantized {
assert!(
val.abs() < 0.01f32,
"Zero vector dequantized to non-zero: {}",
val
);
}
}
}