use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use common::DistanceMetric;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum QuantizationType {
SQ4,
#[default]
SQ8,
SQ16,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SQConfig {
pub quantization_type: QuantizationType,
pub dimensions: usize,
pub metric: DistanceMetric,
pub store_originals: bool,
}
impl Default for SQConfig {
fn default() -> Self {
Self {
quantization_type: QuantizationType::SQ8,
dimensions: 0,
metric: DistanceMetric::Cosine,
store_originals: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SQStats {
pub num_vectors: usize,
pub original_memory_bytes: usize,
pub quantized_memory_bytes: usize,
pub compression_ratio: f32,
pub quantization_type: QuantizationType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct DimensionParams {
min_val: f32,
max_val: f32,
scale: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SQIndex {
config: SQConfig,
dimension_params: Vec<DimensionParams>,
quantized_vectors: Vec<Vec<u8>>,
ids: Vec<String>,
original_vectors: Option<Vec<Vec<f32>>>,
id_to_index: HashMap<String, usize>,
trained: bool,
}
#[derive(Debug, Clone)]
pub struct SQSearchResult {
pub id: String,
pub score: f32,
pub quantized_score: f32,
}
impl SQIndex {
pub fn new(config: SQConfig) -> Self {
Self {
dimension_params: Vec::new(),
quantized_vectors: Vec::new(),
ids: Vec::new(),
original_vectors: if config.store_originals {
Some(Vec::new())
} else {
None
},
id_to_index: HashMap::new(),
trained: false,
config,
}
}
pub fn train(&mut self, vectors: &[Vec<f32>]) -> Result<(), String> {
if vectors.is_empty() {
return Err("Cannot train on empty vector set".to_string());
}
let dimensions = vectors[0].len();
if self.config.dimensions == 0 {
self.config.dimensions = dimensions;
} else if self.config.dimensions != dimensions {
return Err(format!(
"Dimension mismatch: expected {}, got {}",
self.config.dimensions, dimensions
));
}
let mut dimension_params = Vec::with_capacity(dimensions);
for dim in 0..dimensions {
let mut min_val = f32::MAX;
let mut max_val = f32::MIN;
for vector in vectors {
let val = vector[dim];
min_val = min_val.min(val);
max_val = max_val.max(val);
}
let range = (max_val - min_val).max(1e-10);
let scale = self.get_max_quantized_value() / range;
dimension_params.push(DimensionParams {
min_val,
max_val,
scale,
});
}
self.dimension_params = dimension_params;
self.trained = true;
Ok(())
}
fn get_max_quantized_value(&self) -> f32 {
match self.config.quantization_type {
QuantizationType::SQ4 => 15.0,
QuantizationType::SQ8 => 255.0,
QuantizationType::SQ16 => 65535.0,
}
}
fn quantize_vector(&self, vector: &[f32]) -> Vec<u8> {
match self.config.quantization_type {
QuantizationType::SQ8 => self.quantize_sq8(vector),
QuantizationType::SQ4 => self.quantize_sq4(vector),
QuantizationType::SQ16 => self.quantize_sq16(vector),
}
}
fn quantize_sq8(&self, vector: &[f32]) -> Vec<u8> {
vector
.iter()
.enumerate()
.map(|(i, &val)| {
let params = &self.dimension_params[i];
let normalized = (val - params.min_val) * params.scale;
normalized.clamp(0.0, 255.0) as u8
})
.collect()
}
fn quantize_sq4(&self, vector: &[f32]) -> Vec<u8> {
let mut result = Vec::with_capacity(vector.len().div_ceil(2));
for chunk in vector.chunks(2) {
let low = {
let params = &self.dimension_params[result.len() * 2];
let normalized = (chunk[0] - params.min_val) * params.scale;
(normalized.clamp(0.0, 15.0) as u8) & 0x0F
};
let high = if chunk.len() > 1 {
let params = &self.dimension_params[result.len() * 2 + 1];
let normalized = (chunk[1] - params.min_val) * params.scale;
((normalized.clamp(0.0, 15.0) as u8) & 0x0F) << 4
} else {
0
};
result.push(low | high);
}
result
}
fn quantize_sq16(&self, vector: &[f32]) -> Vec<u8> {
let mut result = Vec::with_capacity(vector.len() * 2);
for (i, &val) in vector.iter().enumerate() {
let params = &self.dimension_params[i];
let normalized = (val - params.min_val) * params.scale;
let quantized = normalized.clamp(0.0, 65535.0) as u16;
result.extend_from_slice(&quantized.to_le_bytes());
}
result
}
pub fn dequantize_vector(&self, quantized: &[u8]) -> Vec<f32> {
match self.config.quantization_type {
QuantizationType::SQ8 => self.dequantize_sq8(quantized),
QuantizationType::SQ4 => self.dequantize_sq4(quantized),
QuantizationType::SQ16 => self.dequantize_sq16(quantized),
}
}
fn dequantize_sq8(&self, quantized: &[u8]) -> Vec<f32> {
quantized
.iter()
.enumerate()
.map(|(i, &val)| {
let params = &self.dimension_params[i];
params.min_val + (val as f32 / params.scale)
})
.collect()
}
fn dequantize_sq4(&self, quantized: &[u8]) -> Vec<f32> {
let mut result = Vec::with_capacity(self.config.dimensions);
for (byte_idx, &byte) in quantized.iter().enumerate() {
let dim_idx = byte_idx * 2;
if dim_idx < self.config.dimensions {
let low = byte & 0x0F;
let params = &self.dimension_params[dim_idx];
result.push(params.min_val + (low as f32 / params.scale));
}
if dim_idx + 1 < self.config.dimensions {
let high = (byte >> 4) & 0x0F;
let params = &self.dimension_params[dim_idx + 1];
result.push(params.min_val + (high as f32 / params.scale));
}
}
result
}
fn dequantize_sq16(&self, quantized: &[u8]) -> Vec<f32> {
quantized
.chunks(2)
.enumerate()
.map(|(i, bytes)| {
let val = u16::from_le_bytes([bytes[0], bytes[1]]);
let params = &self.dimension_params[i];
params.min_val + (val as f32 / params.scale)
})
.collect()
}
pub fn add(&mut self, ids: &[String], vectors: &[Vec<f32>]) -> Result<(), String> {
if !self.trained {
self.train(vectors)?;
}
for (id, vector) in ids.iter().zip(vectors.iter()) {
if vector.len() != self.config.dimensions {
return Err(format!(
"Dimension mismatch for {}: expected {}, got {}",
id,
self.config.dimensions,
vector.len()
));
}
if let Some(&existing_idx) = self.id_to_index.get(id) {
self.quantized_vectors[existing_idx] = self.quantize_vector(vector);
if let Some(ref mut originals) = self.original_vectors {
originals[existing_idx] = vector.clone();
}
} else {
let idx = self.quantized_vectors.len();
self.quantized_vectors.push(self.quantize_vector(vector));
self.ids.push(id.clone());
self.id_to_index.insert(id.clone(), idx);
if let Some(ref mut originals) = self.original_vectors {
originals.push(vector.clone());
}
}
}
Ok(())
}
pub fn search(&self, query: &[f32], top_k: usize) -> Result<Vec<SQSearchResult>, String> {
if !self.trained {
return Err("Index not trained".to_string());
}
if query.len() != self.config.dimensions {
return Err(format!(
"Query dimension mismatch: expected {}, got {}",
self.config.dimensions,
query.len()
));
}
let quantized_query = self.quantize_vector(query);
let mut scores: Vec<(usize, f32)> = self
.quantized_vectors
.iter()
.enumerate()
.map(|(idx, qv)| {
let score = self.quantized_distance(&quantized_query, qv);
(idx, score)
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let results: Vec<SQSearchResult> = scores
.into_iter()
.take(top_k)
.map(|(idx, quantized_score)| {
let final_score = if let Some(ref originals) = self.original_vectors {
self.float_similarity(query, &originals[idx])
} else {
quantized_score
};
SQSearchResult {
id: self.ids[idx].clone(),
score: final_score,
quantized_score,
}
})
.collect();
Ok(results)
}
fn quantized_distance(&self, a: &[u8], b: &[u8]) -> f32 {
match self.config.quantization_type {
QuantizationType::SQ8 => self.sq8_distance(a, b),
QuantizationType::SQ4 => self.sq4_distance(a, b),
QuantizationType::SQ16 => self.sq16_distance(a, b),
}
}
fn sq8_distance(&self, a: &[u8], b: &[u8]) -> f32 {
match self.config.metric {
DistanceMetric::Cosine | DistanceMetric::DotProduct => {
let dot: i32 = a
.iter()
.zip(b.iter())
.map(|(&x, &y)| x as i32 * y as i32)
.sum();
let norm_a: i32 = a.iter().map(|&x| x as i32 * x as i32).sum();
let norm_b: i32 = b.iter().map(|&x| x as i32 * x as i32).sum();
let denom = ((norm_a as f32).sqrt() * (norm_b as f32).sqrt()).max(1e-10);
dot as f32 / denom
}
DistanceMetric::Euclidean => {
let dist_sq: i32 = a
.iter()
.zip(b.iter())
.map(|(&x, &y)| {
let diff = x as i32 - y as i32;
diff * diff
})
.sum();
-(dist_sq as f32).sqrt()
}
}
}
fn sq4_distance(&self, a: &[u8], b: &[u8]) -> f32 {
let a_unpacked = self.unpack_sq4(a);
let b_unpacked = self.unpack_sq4(b);
self.sq8_distance(&a_unpacked, &b_unpacked)
}
fn unpack_sq4(&self, packed: &[u8]) -> Vec<u8> {
let mut result = Vec::with_capacity(self.config.dimensions);
for &byte in packed {
result.push(byte & 0x0F);
if result.len() < self.config.dimensions {
result.push((byte >> 4) & 0x0F);
}
}
result
}
fn sq16_distance(&self, a: &[u8], b: &[u8]) -> f32 {
match self.config.metric {
DistanceMetric::Cosine | DistanceMetric::DotProduct => {
let mut dot: i64 = 0;
let mut norm_a: i64 = 0;
let mut norm_b: i64 = 0;
for i in (0..a.len()).step_by(2) {
let va = u16::from_le_bytes([a[i], a[i + 1]]) as i64;
let vb = u16::from_le_bytes([b[i], b[i + 1]]) as i64;
dot += va * vb;
norm_a += va * va;
norm_b += vb * vb;
}
let denom = ((norm_a as f64).sqrt() * (norm_b as f64).sqrt()).max(1e-10);
(dot as f64 / denom) as f32
}
DistanceMetric::Euclidean => {
let mut dist_sq: i64 = 0;
for i in (0..a.len()).step_by(2) {
let va = u16::from_le_bytes([a[i], a[i + 1]]) as i64;
let vb = u16::from_le_bytes([b[i], b[i + 1]]) as i64;
let diff = va - vb;
dist_sq += diff * diff;
}
-((dist_sq as f64).sqrt() as f32)
}
}
}
fn float_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
match self.config.metric {
DistanceMetric::Cosine => {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
dot / (norm_a * norm_b).max(1e-10)
}
DistanceMetric::DotProduct => a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(),
DistanceMetric::Euclidean => {
let dist_sq: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
-dist_sq.sqrt()
}
}
}
pub fn delete(&mut self, ids: &[String]) -> usize {
let mut deleted = 0;
for id in ids {
if let Some(idx) = self.id_to_index.remove(id) {
let last_idx = self.quantized_vectors.len() - 1;
if idx != last_idx {
self.quantized_vectors.swap(idx, last_idx);
self.ids.swap(idx, last_idx);
if let Some(ref mut originals) = self.original_vectors {
originals.swap(idx, last_idx);
}
self.id_to_index.insert(self.ids[idx].clone(), idx);
}
self.quantized_vectors.pop();
self.ids.pop();
if let Some(ref mut originals) = self.original_vectors {
originals.pop();
}
deleted += 1;
}
}
deleted
}
pub fn stats(&self) -> SQStats {
let bytes_per_quantized = match self.config.quantization_type {
QuantizationType::SQ4 => self.config.dimensions.div_ceil(2),
QuantizationType::SQ8 => self.config.dimensions,
QuantizationType::SQ16 => self.config.dimensions * 2,
};
let original_memory = self.quantized_vectors.len() * self.config.dimensions * 4;
let quantized_memory = self.quantized_vectors.len() * bytes_per_quantized;
SQStats {
num_vectors: self.quantized_vectors.len(),
original_memory_bytes: original_memory,
quantized_memory_bytes: quantized_memory,
compression_ratio: if quantized_memory > 0 {
original_memory as f32 / quantized_memory as f32
} else {
0.0
},
quantization_type: self.config.quantization_type,
}
}
pub fn len(&self) -> usize {
self.quantized_vectors.len()
}
pub fn is_empty(&self) -> bool {
self.quantized_vectors.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_vectors() -> Vec<Vec<f32>> {
vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
vec![0.5, 0.5, 0.0, 0.0],
vec![0.0, 0.5, 0.5, 0.0],
]
}
#[test]
fn test_sq8_basic() {
let config = SQConfig {
quantization_type: QuantizationType::SQ8,
dimensions: 4,
metric: DistanceMetric::Cosine,
store_originals: false,
};
let mut index = SQIndex::new(config);
let vectors = create_test_vectors();
let ids: Vec<String> = (0..vectors.len()).map(|i| format!("v{}", i)).collect();
index.add(&ids, &vectors).unwrap();
assert_eq!(index.len(), 5);
let results = index.search(&vectors[0], 3).unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].id, "v0"); }
#[test]
fn test_sq4_compression() {
let config = SQConfig {
quantization_type: QuantizationType::SQ4,
dimensions: 8,
metric: DistanceMetric::Cosine,
store_originals: false,
};
let mut index = SQIndex::new(config);
let vectors = vec![
vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
vec![0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1],
];
let ids = vec!["a".to_string(), "b".to_string()];
index.add(&ids, &vectors).unwrap();
let stats = index.stats();
assert!(stats.compression_ratio > 6.0);
}
#[test]
fn test_sq16_accuracy() {
let config = SQConfig {
quantization_type: QuantizationType::SQ16,
dimensions: 4,
metric: DistanceMetric::Cosine,
store_originals: true,
};
let mut index = SQIndex::new(config);
let vectors = create_test_vectors();
let ids: Vec<String> = (0..vectors.len()).map(|i| format!("v{}", i)).collect();
index.add(&ids, &vectors).unwrap();
let results = index.search(&vectors[0], 2).unwrap();
assert!(results[0].score > 0.99); }
#[test]
fn test_delete() {
let config = SQConfig {
quantization_type: QuantizationType::SQ8,
dimensions: 4,
metric: DistanceMetric::Cosine,
store_originals: false,
};
let mut index = SQIndex::new(config);
let vectors = create_test_vectors();
let ids: Vec<String> = (0..vectors.len()).map(|i| format!("v{}", i)).collect();
index.add(&ids, &vectors).unwrap();
assert_eq!(index.len(), 5);
let deleted = index.delete(&["v0".to_string(), "v2".to_string()]);
assert_eq!(deleted, 2);
assert_eq!(index.len(), 3);
}
#[test]
fn test_dequantize_roundtrip() {
let config = SQConfig {
quantization_type: QuantizationType::SQ8,
dimensions: 4,
metric: DistanceMetric::Cosine,
store_originals: false,
};
let mut index = SQIndex::new(config);
let vectors = vec![vec![0.1, 0.5, 0.3, 0.9]];
let _ids = vec!["test".to_string()];
index.train(&vectors).unwrap();
let quantized = index.quantize_vector(&vectors[0]);
let dequantized = index.dequantize_vector(&quantized);
for (orig, deq) in vectors[0].iter().zip(dequantized.iter()) {
assert!((orig - deq).abs() < 0.05, "Dequantized value too different");
}
}
#[test]
fn test_update_existing() {
let config = SQConfig {
quantization_type: QuantizationType::SQ8,
dimensions: 4,
metric: DistanceMetric::Cosine,
store_originals: false,
};
let mut index = SQIndex::new(config);
let vectors = vec![vec![1.0, 0.0, 0.0, 0.0]];
let ids = vec!["v1".to_string()];
index.add(&ids, &vectors).unwrap();
assert_eq!(index.len(), 1);
let new_vectors = vec![vec![0.0, 1.0, 0.0, 0.0]];
index.add(&ids, &new_vectors).unwrap();
assert_eq!(index.len(), 1);
let results = index.search(&[0.0, 1.0, 0.0, 0.0], 1).unwrap();
assert_eq!(results[0].id, "v1");
}
}