use crate::{Result, StorageError};
use std::fs::File;
use std::io::{Read, Write};
use std::path::Path;
#[derive(Debug, Clone)]
pub struct SQ8Quantizer {
dimension: usize,
}
#[derive(Debug, Clone)]
pub struct QuantizedVector {
pub codes: Vec<u8>,
pub min: f32,
pub max: f32,
}
impl SQ8Quantizer {
pub fn new(dimension: usize) -> Self {
Self { dimension }
}
pub fn quantize(&self, vector: &[f32]) -> Result<QuantizedVector> {
if vector.len() != self.dimension {
return Err(StorageError::InvalidData(format!(
"Vector dimension mismatch: expected {}, got {}",
self.dimension,
vector.len()
)));
}
let mut min = f32::INFINITY;
let mut max = f32::NEG_INFINITY;
for &val in vector.iter() {
if val < min {
min = val;
}
if val > max {
max = val;
}
}
let range = max - min;
let codes = if range < 1e-8 {
vec![0u8; self.dimension]
} else {
let scale = 255.0 / range;
vector
.iter()
.map(|&val| {
let normalized = (val - min) * scale;
normalized.round().clamp(0.0, 255.0) as u8
})
.collect()
};
Ok(QuantizedVector { codes, min, max })
}
pub fn dequantize(&self, qvec: &QuantizedVector) -> Vec<f32> {
if qvec.codes.len() != self.dimension {
return vec![0.0; self.dimension];
}
let range = qvec.max - qvec.min;
if range < 1e-8 {
return vec![qvec.min; self.dimension];
}
let scale = range / 255.0;
qvec.codes
.iter()
.map(|&code| code as f32 * scale + qvec.min)
.collect()
}
pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
let mut file = File::create(path).map_err(StorageError::Io)?;
file.write_all(b"SQ8\0").map_err(StorageError::Io)?;
file.write_all(&self.dimension.to_le_bytes())
.map_err(StorageError::Io)?;
Ok(())
}
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
let mut file = File::open(path).map_err(StorageError::Io)?;
let mut magic = [0u8; 4];
file.read_exact(&mut magic).map_err(StorageError::Io)?;
if &magic != b"SQ8\0" {
return Err(StorageError::InvalidData(
"Invalid SQ8 file magic".to_string(),
));
}
let mut dim_bytes = [0u8; 8];
file.read_exact(&mut dim_bytes).map_err(StorageError::Io)?;
let dimension = usize::from_le_bytes(dim_bytes);
Ok(Self { dimension })
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn asymmetric_distance_cosine(
&self,
query: &[f32],
data: &QuantizedVector,
) -> f32 {
if query.len() != self.dimension || data.codes.len() != self.dimension {
return f32::MAX; }
let range = data.max - data.min;
if range < 1e-8 {
let constant_val = data.min;
let query_norm = Self::fast_norm(query);
if query_norm < 1e-8 {
return 0.0; }
let sum: f32 = query.iter().sum();
let dot = sum * constant_val;
let data_norm = (self.dimension as f32).sqrt() * constant_val.abs();
if data_norm < 1e-8 {
return 1.0; }
return 1.0 - (dot / (query_norm * data_norm));
}
let scale = range / 255.0;
let mut dot_product = 0.0f32;
let mut query_norm_sq = 0.0f32;
let mut data_norm_sq = 0.0f32;
for i in 0..self.dimension {
let q = query[i];
let d = data.codes[i] as f32 * scale + data.min;
dot_product += q * d;
query_norm_sq += q * q;
data_norm_sq += d * d;
}
let query_norm = query_norm_sq.sqrt();
let data_norm = data_norm_sq.sqrt();
if query_norm < 1e-8 || data_norm < 1e-8 {
return 1.0; }
let cosine_sim = dot_product / (query_norm * data_norm);
1.0 - cosine_sim.clamp(-1.0, 1.0)
}
#[inline]
fn fast_norm(vec: &[f32]) -> f32 {
let mut sum = 0.0f32;
for &val in vec {
sum += val * val;
}
sum.sqrt()
}
}
impl QuantizedVector {
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(self.codes.len() + 8);
bytes.extend_from_slice(&self.min.to_le_bytes());
bytes.extend_from_slice(&self.max.to_le_bytes());
bytes.extend_from_slice(&self.codes);
bytes
}
pub fn from_bytes(bytes: &[u8], dimension: usize) -> Result<Self> {
if bytes.len() != dimension + 8 {
return Err(StorageError::InvalidData(format!(
"Invalid quantized vector size: expected {}, got {}",
dimension + 8,
bytes.len()
)));
}
let min = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
let max = f32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
let codes = bytes[8..].to_vec();
Ok(Self { codes, min, max })
}
pub fn size(&self) -> usize {
self.codes.len() + 8 }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sq8_basic() {
let quantizer = SQ8Quantizer::new(4);
let vector = vec![1.0, 2.0, 3.0, 4.0];
let qvec = quantizer.quantize(&vector).unwrap();
assert_eq!(qvec.codes.len(), 4);
let reconstructed = quantizer.dequantize(&qvec);
assert_eq!(reconstructed.len(), 4);
for i in 0..4 {
let error = (vector[i] - reconstructed[i]).abs();
assert!(error < 0.02, "Error too large: {}", error);
}
}
#[test]
fn test_sq8_normalized() {
let quantizer = SQ8Quantizer::new(3);
let vector = vec![0.577, 0.577, 0.577];
let qvec = quantizer.quantize(&vector).unwrap();
let reconstructed = quantizer.dequantize(&qvec);
for i in 0..3 {
let error = (vector[i] - reconstructed[i]).abs();
assert!(error < 0.005, "Normalized vector error: {}", error);
}
}
#[test]
fn test_sq8_constant_vector() {
let quantizer = SQ8Quantizer::new(3);
let vector = vec![5.0, 5.0, 5.0];
let qvec = quantizer.quantize(&vector).unwrap();
let reconstructed = quantizer.dequantize(&qvec);
for i in 0..3 {
assert!((reconstructed[i] - 5.0).abs() < 0.01);
}
}
#[test]
fn test_sq8_serialization() {
let quantizer = SQ8Quantizer::new(4);
let vector = vec![1.0, 2.0, 3.0, 4.0];
let qvec = quantizer.quantize(&vector).unwrap();
let bytes = qvec.to_bytes();
let qvec2 = QuantizedVector::from_bytes(&bytes, 4).unwrap();
assert_eq!(qvec.codes, qvec2.codes);
assert_eq!(qvec.min, qvec2.min);
assert_eq!(qvec.max, qvec2.max);
}
#[test]
fn test_sq8_save_load() {
use std::env;
let quantizer = SQ8Quantizer::new(128);
let temp_path = env::temp_dir().join("sq8_test.bin");
quantizer.save(&temp_path).unwrap();
let loaded = SQ8Quantizer::load(&temp_path).unwrap();
assert_eq!(quantizer.dimension(), loaded.dimension());
std::fs::remove_file(temp_path).ok();
}
#[test]
fn test_compression_ratio() {
let quantizer = SQ8Quantizer::new(128);
let vector = vec![0.5; 128];
let qvec = quantizer.quantize(&vector).unwrap();
let original_size = 128 * 4; let compressed_size = qvec.size();
println!("Original: {} bytes", original_size);
println!("Compressed: {} bytes", compressed_size);
println!(
"Compression ratio: {:.2}x",
original_size as f32 / compressed_size as f32
);
assert!(compressed_size < original_size);
}
#[test]
fn test_asymmetric_distance() {
let quantizer = SQ8Quantizer::new(4);
let query = vec![1.0, 0.0, 0.0, 0.0];
let data1 = vec![0.9, 0.1, 0.0, 0.0]; let data2 = vec![0.0, 1.0, 0.0, 0.0];
let qdata1 = quantizer.quantize(&data1).unwrap();
let qdata2 = quantizer.quantize(&data2).unwrap();
let dist1 = quantizer.asymmetric_distance_cosine(&query, &qdata1);
let dist2 = quantizer.asymmetric_distance_cosine(&query, &qdata2);
assert!(dist1 < dist2, "Similar vectors should have smaller distance");
let data1_deq = quantizer.dequantize(&qdata1);
let traditional_dist1 = cosine_distance(&query, &data1_deq);
let error = (dist1 - traditional_dist1).abs();
assert!(error < 0.05, "Asymmetric distance error too large: {}", error);
println!("Asymmetric dist: {:.4}, Traditional dist: {:.4}, Error: {:.4}",
dist1, traditional_dist1, error);
}
#[test]
fn test_asymmetric_distance_normalized() {
let quantizer = SQ8Quantizer::new(128);
let query = vec![0.577; 128]; let data = vec![0.577; 128];
let qdata = quantizer.quantize(&data).unwrap();
let dist = quantizer.asymmetric_distance_cosine(&query, &qdata);
assert!(dist < 0.01, "Same vector distance too large: {}", dist);
}
#[test]
fn test_asymmetric_distance_orthogonal() {
let quantizer = SQ8Quantizer::new(4);
let query = vec![1.0, 0.0, 0.0, 0.0];
let data = vec![0.0, 1.0, 0.0, 0.0];
let qdata = quantizer.quantize(&data).unwrap();
let dist = quantizer.asymmetric_distance_cosine(&query, &qdata);
assert!((dist - 1.0).abs() < 0.1, "Orthogonal distance incorrect: {}", dist);
}
fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
let mut dot = 0.0;
let mut norm_a = 0.0;
let mut norm_b = 0.0;
for i in 0..a.len() {
dot += a[i] * b[i];
norm_a += a[i] * a[i];
norm_b += b[i] * b[i];
}
let norm_a = norm_a.sqrt();
let norm_b = norm_b.sqrt();
if norm_a < 1e-8 || norm_b < 1e-8 {
return 1.0;
}
1.0 - (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
}
}