use crate::simd;
use crate::RetrieveError;
pub struct TurboQuantizer {
dimension: usize,
bits_per_coordinate: usize,
rotation_matrix: Vec<Vec<f32>>, quantizers: Vec<ScalarQuantizer>, built: bool,
}
struct ScalarQuantizer {
min: f32,
max: f32,
num_levels: usize,
levels: Vec<f32>, }
impl ScalarQuantizer {
fn new(min: f32, max: f32, num_levels: usize) -> Self {
let step = (max - min) / num_levels as f32;
let levels: Vec<f32> = (0..=num_levels).map(|i| min + i as f32 * step).collect();
Self {
min,
max,
num_levels,
levels,
}
}
fn quantize(&self, value: f32) -> u8 {
let clamped = value.clamp(self.min, self.max);
let normalized = (clamped - self.min) / (self.max - self.min);
let level = (normalized * self.num_levels as f32).floor() as usize;
(level.min(self.num_levels - 1)).min(255) as u8
}
fn dequantize(&self, code: u8) -> f32 {
let code = code.min(self.num_levels as u8 - 1) as usize;
self.levels[code.min(self.levels.len() - 1)]
}
}
impl TurboQuantizer {
pub fn new(dimension: usize, bits_per_coordinate: usize) -> Result<Self, RetrieveError> {
if dimension == 0 || bits_per_coordinate == 0 {
return Err(RetrieveError::Other(
"Dimension and bits_per_coordinate must be greater than 0".to_string(),
));
}
let rotation_matrix = Self::generate_rotation_matrix(dimension)?;
Ok(Self {
dimension,
bits_per_coordinate,
rotation_matrix,
quantizers: Vec::new(),
built: false,
})
}
fn generate_rotation_matrix(dimension: usize) -> Result<Vec<Vec<f32>>, RetrieveError> {
use rand::Rng;
let mut rng = rand::rng();
let mut matrix = Vec::new();
for _ in 0..dimension {
let mut row = Vec::new();
let mut norm = 0.0;
for _ in 0..dimension {
let val = rng.random::<f32>() * 2.0 - 1.0;
norm += val * val;
row.push(val);
}
let norm = norm.sqrt();
if norm > 0.0 {
for val in &mut row {
*val /= norm;
}
}
matrix.push(row);
}
Ok(matrix)
}
pub fn fit(&mut self, vectors: &[f32], num_vectors: usize) -> Result<(), RetrieveError> {
if num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
let mut rotated_vectors = Vec::new();
for i in 0..num_vectors.min(1000) {
let vec = get_vector(vectors, self.dimension, i);
let rotated = self.apply_rotation(vec);
rotated_vectors.push(rotated);
}
let mut coordinate_mins = vec![f32::INFINITY; self.dimension];
let mut coordinate_maxs = vec![f32::NEG_INFINITY; self.dimension];
for rotated in &rotated_vectors {
for (coord_idx, &val) in rotated.iter().enumerate() {
coordinate_mins[coord_idx] = coordinate_mins[coord_idx].min(val);
coordinate_maxs[coord_idx] = coordinate_maxs[coord_idx].max(val);
}
}
let num_levels = 2usize.pow(self.bits_per_coordinate.min(8) as u32);
self.quantizers = coordinate_mins
.iter()
.zip(coordinate_maxs.iter())
.map(|(&min, &max)| ScalarQuantizer::new(min, max, num_levels))
.collect();
self.built = true;
Ok(())
}
fn apply_rotation(&self, vector: &[f32]) -> Vec<f32> {
let mut rotated = vec![0.0; self.dimension];
for (i, row) in self.rotation_matrix.iter().enumerate() {
rotated[i] = simd::dot(vector, row);
}
rotated
}
pub fn quantize(&self, vector: &[f32]) -> Result<Vec<u8>, RetrieveError> {
if !self.built {
return Err(RetrieveError::Other(
"Quantizer must be fit before quantization".to_string(),
));
}
if vector.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: self.dimension,
doc_dim: vector.len(),
});
}
let rotated = self.apply_rotation(vector);
let mut codes = Vec::with_capacity(self.dimension);
for (coord_idx, &val) in rotated.iter().enumerate() {
let code = self.quantizers[coord_idx].quantize(val);
codes.push(code);
}
Ok(codes)
}
pub fn dequantize(&self, codes: &[u8]) -> Result<Vec<f32>, RetrieveError> {
if codes.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: self.dimension,
doc_dim: codes.len(),
});
}
let mut rotated = Vec::with_capacity(self.dimension);
for (coord_idx, &code) in codes.iter().enumerate() {
let val = self.quantizers[coord_idx].dequantize(code);
rotated.push(val);
}
let mut vector = vec![0.0; self.dimension];
for (i, row) in self.rotation_matrix.iter().enumerate() {
for (j, &val) in row.iter().enumerate() {
vector[j] += rotated[i] * val;
}
}
Ok(vector)
}
pub fn approximate_distance(&self, query: &[f32], codes: &[u8]) -> Result<f32, RetrieveError> {
let dequantized = self.dequantize(codes)?;
let dist = 1.0 - simd::dot(query, &dequantized);
Ok(dist)
}
}
fn get_vector(vectors: &[f32], dimension: usize, idx: usize) -> &[f32] {
let start = idx * dimension;
let end = start + dimension;
&vectors[start..end]
}