use std::io;
use super::codec_helpers::{serialize_with_header, validate_and_split_header};
use super::QuantizationCodec;
#[derive(Debug, Clone)]
pub struct QuantizedVector {
pub data: Vec<u8>,
pub min: f32,
pub max: f32,
}
impl QuantizedVector {
#[must_use]
pub fn from_f32(vector: &[f32]) -> Self {
debug_assert!(!vector.is_empty(), "Cannot quantize empty vector");
let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let range = max - min;
let data = if range < f32::EPSILON {
vec![128u8; vector.len()]
} else {
let scale = 255.0 / range;
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
vector
.iter()
.map(|&v| {
let normalized = (v - min) * scale;
normalized.round().clamp(0.0, 255.0) as u8
})
.collect()
};
Self { data, min, max }
}
#[must_use]
pub fn to_f32(&self) -> Vec<f32> {
let range = self.max - self.min;
if range < f32::EPSILON {
vec![self.min; self.data.len()]
} else {
let scale = range / 255.0;
self.data
.iter()
.map(|&v| f32::from(v) * scale + self.min)
.collect()
}
}
#[must_use]
pub fn dimension(&self) -> usize {
self.data.len()
}
#[must_use]
pub fn memory_size(&self) -> usize {
self.data.len() + 8 }
}
const SQ8_HEADER_SIZE: usize = 8;
impl QuantizationCodec for QuantizedVector {
fn to_bytes(&self) -> Vec<u8> {
let mut header = [0u8; SQ8_HEADER_SIZE];
header[..4].copy_from_slice(&self.min.to_le_bytes());
header[4..].copy_from_slice(&self.max.to_le_bytes());
serialize_with_header(&header, &self.data)
}
fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
let (header, payload) =
validate_and_split_header(bytes, SQ8_HEADER_SIZE, "QuantizedVector")?;
let min = f32::from_le_bytes([header[0], header[1], header[2], header[3]]);
let max = f32::from_le_bytes([header[4], header[5], header[6], header[7]]);
let data = payload.to_vec();
Ok(Self { data, min, max })
}
}
struct DequantParams {
scale: f32,
offset: f32,
}
fn dequant_params(quantized: &QuantizedVector) -> Option<DequantParams> {
let range = quantized.max - quantized.min;
if range < f32::EPSILON {
return None;
}
Some(DequantParams {
scale: range / 255.0,
offset: quantized.min,
})
}
#[must_use]
pub fn dot_product_quantized(query: &[f32], quantized: &QuantizedVector) -> f32 {
debug_assert_eq!(
query.len(),
quantized.data.len(),
"Dimension mismatch in dot_product_quantized"
);
let Some(params) = dequant_params(quantized) else {
return query.iter().sum::<f32>() * quantized.min;
};
query
.iter()
.zip(quantized.data.iter())
.map(|(&q, &v)| q * (f32::from(v) * params.scale + params.offset))
.sum()
}
#[must_use]
pub fn euclidean_squared_quantized(query: &[f32], quantized: &QuantizedVector) -> f32 {
debug_assert_eq!(
query.len(),
quantized.data.len(),
"Dimension mismatch in euclidean_squared_quantized"
);
let Some(params) = dequant_params(quantized) else {
let value = quantized.min;
return query.iter().map(|&q| (q - value).powi(2)).sum();
};
query
.iter()
.zip(quantized.data.iter())
.map(|(&q, &v)| {
let dequantized = f32::from(v) * params.scale + params.offset;
(q - dequantized).powi(2)
})
.sum()
}
#[must_use]
pub fn cosine_similarity_quantized(query: &[f32], quantized: &QuantizedVector) -> f32 {
cosine_from_dot(dot_product_quantized(query, quantized), query, quantized)
}
fn cosine_from_dot(dot: f32, query: &[f32], quantized: &QuantizedVector) -> f32 {
use crate::simd_native;
let query_norm = simd_native::norm_native(query);
let quantized_norm = quantized_vector_norm(quantized);
if query_norm < f32::EPSILON || quantized_norm < f32::EPSILON {
return 0.0;
}
dot / (query_norm * quantized_norm)
}
#[inline]
fn quantized_vector_norm(quantized: &QuantizedVector) -> f32 {
let Some(params) = dequant_params(quantized) else {
let value = quantized.min;
#[allow(clippy::cast_precision_loss)]
return value.abs() * (quantized.data.len() as f32).sqrt();
};
let len = quantized.data.len();
let chunks = len / 4;
let remainder = len % 4;
let mut sum0: f32 = 0.0;
let mut sum1: f32 = 0.0;
let mut sum2: f32 = 0.0;
let mut sum3: f32 = 0.0;
for i in 0..chunks {
let base = i * 4;
let d0 = f32::from(quantized.data[base]) * params.scale + params.offset;
let d1 = f32::from(quantized.data[base + 1]) * params.scale + params.offset;
let d2 = f32::from(quantized.data[base + 2]) * params.scale + params.offset;
let d3 = f32::from(quantized.data[base + 3]) * params.scale + params.offset;
sum0 += d0 * d0;
sum1 += d1 * d1;
sum2 += d2 * d2;
sum3 += d3 * d3;
}
let base = chunks * 4;
for i in 0..remainder {
let d = f32::from(quantized.data[base + i]) * params.scale + params.offset;
sum0 += d * d;
}
(sum0 + sum1 + sum2 + sum3).sqrt()
}
#[must_use]
pub fn dot_product_quantized_simd(query: &[f32], quantized: &QuantizedVector) -> f32 {
debug_assert_eq!(
query.len(),
quantized.data.len(),
"Dimension mismatch in dot_product_quantized_simd"
);
let Some(params) = dequant_params(quantized) else {
return query.iter().sum::<f32>() * quantized.min;
};
dot_product_dequant_unrolled_8(query, &quantized.data, params.scale, params.offset)
}
#[inline]
fn dot_product_dequant_unrolled_8(query: &[f32], data: &[u8], scale: f32, offset: f32) -> f32 {
let len = query.len();
let chunks = len / 8;
let remainder = len % 8;
let mut sum = 0.0f32;
for i in 0..chunks {
let base = i * 8;
for j in 0..8 {
let dequant = f32::from(data[base + j]) * scale + offset;
sum += query[base + j] * dequant;
}
}
let base = chunks * 8;
for i in 0..remainder {
let dequant = f32::from(data[base + i]) * scale + offset;
sum += query[base + i] * dequant;
}
sum
}
#[must_use]
pub fn euclidean_squared_quantized_simd(query: &[f32], quantized: &QuantizedVector) -> f32 {
debug_assert_eq!(
query.len(),
quantized.data.len(),
"Dimension mismatch in euclidean_squared_quantized_simd"
);
let Some(params) = dequant_params(quantized) else {
let value = quantized.min;
return query.iter().map(|&q| (q - value).powi(2)).sum();
};
let len = query.len();
let chunks = len / 4;
let remainder = len % 4;
let mut sum = 0.0f32;
for i in 0..chunks {
let base = i * 4;
let d0 = f32::from(quantized.data[base]) * params.scale + params.offset;
let d1 = f32::from(quantized.data[base + 1]) * params.scale + params.offset;
let d2 = f32::from(quantized.data[base + 2]) * params.scale + params.offset;
let d3 = f32::from(quantized.data[base + 3]) * params.scale + params.offset;
let diff0 = query[base] - d0;
let diff1 = query[base + 1] - d1;
let diff2 = query[base + 2] - d2;
let diff3 = query[base + 3] - d3;
sum += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
}
let base = chunks * 4;
for i in 0..remainder {
let dequant = f32::from(quantized.data[base + i]) * params.scale + params.offset;
let diff = query[base + i] - dequant;
sum += diff * diff;
}
sum
}
#[must_use]
pub fn cosine_similarity_quantized_simd(query: &[f32], quantized: &QuantizedVector) -> f32 {
cosine_from_dot(
dot_product_quantized_simd(query, quantized),
query,
quantized,
)
}