use crate::error::{Error, Result};
use nalgebra::DMatrix;
use rand::{rngs::StdRng, SeedableRng};
use rand_distr::{Distribution, StandardNormal};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TurboQuantConfig {
pub dimension: usize,
pub bit_width: usize,
pub seed: u64,
}
impl Default for TurboQuantConfig {
fn default() -> Self {
Self {
dimension: 384,
bit_width: 3,
seed: 42,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TurboQuantIndex {
config: TurboQuantConfig,
rotation_matrix: Vec<Vec<f32>>,
codebook: Vec<f32>,
quantized_vectors: HashMap<i64, Vec<u8>>,
vector_norms: HashMap<i64, f32>,
}
pub struct LinearScanIndex {
config: TurboQuantConfig,
vectors: HashMap<i64, Vec<f32>>,
}
impl LinearScanIndex {
pub fn new(config: TurboQuantConfig) -> Result<Self> {
Ok(Self {
config,
vectors: HashMap::new(),
})
}
pub fn add_vector(&mut self, entity_id: i64, vector: &[f32]) -> Result<()> {
if vector.len() != self.config.dimension {
return Err(Error::InvalidVectorDimension {
expected: self.config.dimension,
actual: vector.len(),
});
}
self.vectors.insert(entity_id, vector.to_vec());
Ok(())
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(i64, f32)>> {
if query.len() != self.config.dimension {
return Err(Error::InvalidVectorDimension {
expected: self.config.dimension,
actual: query.len(),
});
}
let query_norm: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
let mut results: Vec<(i64, f32)> = self
.vectors
.iter()
.map(|(&entity_id, vector)| {
let dot_product: f32 = query.iter().zip(vector.iter()).map(|(a, b)| a * b).sum();
let target_norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
let similarity = if query_norm > 0.0 && target_norm > 0.0 {
dot_product / (query_norm * target_norm)
} else {
0.0
};
(entity_id, similarity)
})
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(k);
Ok(results)
}
pub fn stats(&self) -> LinearScanStats {
LinearScanStats {
num_vectors: self.vectors.len(),
dimension: self.config.dimension,
bytes_per_vector: self.config.dimension * 4, }
}
pub fn clear(&mut self) {
self.vectors.clear();
}
pub fn len(&self) -> usize {
self.vectors.len()
}
pub fn is_empty(&self) -> bool {
self.vectors.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct LinearScanStats {
pub num_vectors: usize,
pub dimension: usize,
pub bytes_per_vector: usize,
}
impl TurboQuantIndex {
pub fn new(config: TurboQuantConfig) -> Result<Self> {
if config.bit_width < 1 || config.bit_width > 8 {
return Err(Error::InvalidInput(
"bit_width must be between 1 and 8".to_string(),
));
}
let mut rng = StdRng::seed_from_u64(config.seed);
let rotation_matrix = Self::generate_rotation_matrix(config.dimension, &mut rng);
let codebook = Self::compute_codebook(config.bit_width);
Ok(Self {
config,
rotation_matrix,
codebook,
quantized_vectors: HashMap::new(),
vector_norms: HashMap::new(),
})
}
fn generate_rotation_matrix(d: usize, rng: &mut StdRng) -> Vec<Vec<f32>> {
let data: Vec<f64> = (0..d * d).map(|_| StandardNormal.sample(rng)).collect();
let matrix = DMatrix::from_vec(d, d, data);
let qr = matrix.qr();
let q = qr.q();
(0..d)
.map(|i| (0..d).map(|j| q[(i, j)] as f32).collect())
.collect()
}
fn compute_codebook(bit_width: usize) -> Vec<f32> {
let k = 1usize << bit_width; let mut rng = StdRng::seed_from_u64(0xc0de_b007);
let num_samples = 50_000usize;
let std_dev = (1.0_f32 / 384_f32).sqrt();
let samples: Vec<f32> = (0..num_samples)
.map(|_| {
let n: f64 = StandardNormal.sample(&mut rng);
(n as f32 * std_dev).clamp(-1.0, 1.0)
})
.collect();
let mut centroids: Vec<f32> = (0..k)
.map(|i| {
if k == 1 {
0.0
} else {
-1.0 + 2.0 * i as f32 / (k - 1) as f32
}
})
.collect();
for _ in 0..100 {
let mut sums = vec![0.0f64; k];
let mut counts = vec![0usize; k];
for &x in &samples {
let nearest = centroids
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
(x - *a)
.abs()
.partial_cmp(&(x - *b).abs())
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap_or(0);
sums[nearest] += x as f64;
counts[nearest] += 1;
}
let prev = centroids.clone();
for i in 0..k {
if counts[i] > 0 {
centroids[i] = (sums[i] / counts[i] as f64) as f32;
}
}
let converged = centroids
.iter()
.zip(prev.iter())
.all(|(a, b)| (a - b).abs() < 1e-6);
if converged {
break;
}
}
centroids.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
centroids
}
pub fn add_vector(&mut self, entity_id: i64, vector: &[f32]) -> Result<()> {
if vector.len() != self.config.dimension {
return Err(Error::InvalidVectorDimension {
expected: self.config.dimension,
actual: vector.len(),
});
}
let norm: f32 = vector.iter().map(|x| x * x).sum();
let norm = norm.sqrt();
self.vector_norms.insert(entity_id, norm);
let rotated = self.apply_rotation(vector);
let quantized = self.quantize_vector(&rotated);
self.quantized_vectors.insert(entity_id, quantized);
Ok(())
}
fn apply_rotation(&self, vector: &[f32]) -> Vec<f32> {
let d = self.config.dimension;
let mut rotated = vec![0.0f32; d];
for (i, rot_row) in self.rotation_matrix.iter().enumerate().take(d) {
for (j, &val) in vector.iter().enumerate().take(d) {
rotated[i] += rot_row[j] * val;
}
}
rotated
}
fn quantize_vector(&self, vector: &[f32]) -> Vec<u8> {
vector
.iter()
.map(|&val| {
let mut best_idx = 0;
let mut best_dist = f32::MAX;
for (idx, ¢roid) in self.codebook.iter().enumerate() {
let dist = (val - centroid).abs();
if dist < best_dist {
best_dist = dist;
best_idx = idx;
}
}
best_idx as u8
})
.collect()
}
#[allow(dead_code)]
fn dequantize_vector(&self, quantized: &[u8]) -> Vec<f32> {
quantized
.iter()
.map(|&idx| self.codebook[idx as usize])
.collect()
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(i64, f32)>> {
if query.len() != self.config.dimension {
return Err(Error::InvalidVectorDimension {
expected: self.config.dimension,
actual: query.len(),
});
}
let rotated_query = self.apply_rotation(query);
let quantized_query = self.quantize_vector(&rotated_query);
let query_norm: f32 = query.iter().map(|x| x * x).sum();
let query_norm = query_norm.sqrt();
let mut results: Vec<(i64, f32)> = self
.quantized_vectors
.iter()
.map(|(&entity_id, quantized_vec)| {
let similarity = self.compute_similarity(
&quantized_query,
quantized_vec,
query_norm,
self.vector_norms.get(&entity_id).copied().unwrap_or(1.0),
);
(entity_id, similarity)
})
.collect();
results.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
results.truncate(k);
Ok(results)
}
fn compute_similarity(
&self,
query: &[u8],
target: &[u8],
_query_norm: f32,
_target_norm: f32,
) -> f32 {
if query.len() != target.len() {
return 0.0;
}
let mut dot_product = 0.0f32;
let mut query_sq = 0.0f32;
let mut target_sq = 0.0f32;
for i in 0..query.len() {
let q_val = self.codebook[query[i] as usize];
let t_val = self.codebook[target[i] as usize];
dot_product += q_val * t_val;
query_sq += q_val * q_val;
target_sq += t_val * t_val;
}
let denom = query_sq.sqrt() * target_sq.sqrt();
if denom > 0.0 {
dot_product / denom
} else {
0.0
}
}
pub fn add_vectors_batch(&mut self, vectors: &[(i64, Vec<f32>)]) -> Result<()> {
for (entity_id, vector) in vectors {
self.add_vector(*entity_id, vector)?;
}
Ok(())
}
pub fn stats(&self) -> TurboQuantStats {
TurboQuantStats {
num_vectors: self.quantized_vectors.len(),
dimension: self.config.dimension,
bit_width: self.config.bit_width,
bytes_per_vector: self.config.dimension, compression_ratio: 32.0 / self.config.bit_width as f32, }
}
pub fn remove_vector(&mut self, entity_id: i64) -> Result<()> {
self.quantized_vectors.remove(&entity_id);
self.vector_norms.remove(&entity_id);
Ok(())
}
pub fn clear(&mut self) {
self.quantized_vectors.clear();
self.vector_norms.clear();
}
pub fn len(&self) -> usize {
self.quantized_vectors.len()
}
pub fn is_empty(&self) -> bool {
self.quantized_vectors.is_empty()
}
pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
let serialized = serde_json::to_string(self)?;
std::fs::write(path, serialized)?;
Ok(())
}
pub fn load<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
let contents = std::fs::read_to_string(path)?;
let index: Self = serde_json::from_str(&contents)?;
Ok(index)
}
pub fn to_bytes(&self) -> Result<Vec<u8>> {
Ok(serde_json::to_vec(self)?)
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
Ok(serde_json::from_slice(bytes)?)
}
pub fn config(&self) -> &TurboQuantConfig {
&self.config
}
pub fn search_batch(&self, queries: &[Vec<f32>], k: usize) -> Result<Vec<Vec<(i64, f32)>>> {
queries.iter().map(|query| self.search(query, k)).collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TurboQuantStats {
pub num_vectors: usize,
pub dimension: usize,
pub bit_width: usize,
pub bytes_per_vector: usize,
pub compression_ratio: f32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_index() {
let config = TurboQuantConfig {
dimension: 128,
bit_width: 3,
seed: 42,
};
let index = TurboQuantIndex::new(config).unwrap();
assert_eq!(index.config.dimension, 128);
assert_eq!(index.config.bit_width, 3);
}
#[test]
fn test_add_and_search() {
let config = TurboQuantConfig {
dimension: 128,
bit_width: 3,
seed: 42,
};
let mut index = TurboQuantIndex::new(config).unwrap();
let vec1: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
let vec2: Vec<f32> = (0..128).map(|i| ((i + 64) % 128) as f32 / 128.0).collect();
let vec3: Vec<f32> = (0..128).map(|i| 1.0 - (i as f32) / 128.0).collect();
index.add_vector(1, &vec1).unwrap();
index.add_vector(2, &vec2).unwrap();
index.add_vector(3, &vec3).unwrap();
let results = index.search(&vec1, 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, 1); }
#[test]
fn test_compression_ratio() {
let config = TurboQuantConfig {
dimension: 384,
bit_width: 3,
seed: 42,
};
let index = TurboQuantIndex::new(config).unwrap();
let stats = index.stats();
assert!(stats.compression_ratio > 10.0);
}
#[test]
fn test_stats() {
let config = TurboQuantConfig {
dimension: 384,
bit_width: 3,
seed: 42,
};
let mut index = TurboQuantIndex::new(config).unwrap();
let vec: Vec<f32> = vec![0.1; 384];
index.add_vector(1, &vec).unwrap();
index.add_vector(2, &vec).unwrap();
let stats = index.stats();
assert_eq!(stats.num_vectors, 2);
assert_eq!(stats.dimension, 384);
}
#[test]
fn test_to_bytes_from_bytes_roundtrip() {
let config = TurboQuantConfig {
dimension: 64,
bit_width: 3,
seed: 42,
};
let mut index = TurboQuantIndex::new(config).unwrap();
let vec: Vec<f32> = (0..64).map(|i| i as f32 / 64.0).collect();
index.add_vector(1, &vec).unwrap();
index.add_vector(2, &vec).unwrap();
let bytes = index.to_bytes().unwrap();
assert!(!bytes.is_empty());
let restored = TurboQuantIndex::from_bytes(&bytes).unwrap();
assert_eq!(restored.config.dimension, 64);
assert_eq!(restored.config.bit_width, 3);
assert_eq!(restored.len(), 2);
let query: Vec<f32> = (0..64).map(|i| i as f32 / 64.0).collect();
let original_results = index.search(&query, 2).unwrap();
let restored_results = restored.search(&query, 2).unwrap();
assert_eq!(original_results, restored_results);
}
}