use crate::RetrieveError;
pub struct SAQQuantizer {
dimension: usize,
num_segments: usize,
bits_per_segment: Vec<usize>, codebooks: Vec<Vec<Vec<f32>>>, segment_bounds: Vec<(usize, usize)>, pca_matrix: Option<Vec<Vec<f32>>>, }
impl SAQQuantizer {
pub fn new(
dimension: usize,
num_segments: usize,
total_bits: usize,
) -> Result<Self, RetrieveError> {
if dimension == 0 || num_segments == 0 || total_bits == 0 {
return Err(RetrieveError::Other(
"All parameters must be greater than 0".to_string(),
));
}
if dimension % num_segments != 0 {
return Err(RetrieveError::Other(
"Dimension must be divisible by num_segments".to_string(),
));
}
let bits_per_segment = vec![total_bits / num_segments; num_segments];
let segment_dim = dimension / num_segments;
let mut segment_bounds = Vec::new();
for i in 0..num_segments {
segment_bounds.push((i * segment_dim, (i + 1) * segment_dim));
}
Ok(Self {
dimension,
num_segments,
bits_per_segment,
codebooks: Vec::new(),
segment_bounds,
pca_matrix: None,
})
}
pub fn fit(&mut self, vectors: &[f32], num_vectors: usize) -> Result<(), RetrieveError> {
self.optimize_segmentation(vectors, num_vectors)?;
self.train_codebooks(vectors, num_vectors)?;
Ok(())
}
fn optimize_segmentation(
&mut self,
vectors: &[f32],
num_vectors: usize,
) -> Result<(), RetrieveError> {
let _segment_dim = self.dimension / self.num_segments; let total_bits: usize = self.bits_per_segment.iter().sum();
let mut segment_variances = Vec::new();
for (start, end) in &self.segment_bounds {
let mut variance = 0.0;
for i in 0..num_vectors {
let vec = get_vector(vectors, self.dimension, i);
let segment = &vec[*start..*end];
let mean: f32 = segment.iter().sum::<f32>() / segment.len() as f32;
let var: f32 =
segment.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / segment.len() as f32;
variance += var;
}
variance /= num_vectors as f32;
segment_variances.push(variance);
}
let total_variance: f32 = segment_variances.iter().sum();
if total_variance > 0.0 {
self.bits_per_segment = segment_variances
.iter()
.map(|&var| {
let ratio = var / total_variance;
(ratio * total_bits as f32).ceil() as usize
})
.collect();
let allocated: usize = self.bits_per_segment.iter().sum();
if allocated > total_bits {
let diff = allocated - total_bits;
let mut sorted_indices: Vec<usize> = (0..self.num_segments).collect();
sorted_indices.sort_by(|&a, &b| {
segment_variances[a]
.partial_cmp(&segment_variances[b])
.unwrap()
});
for &idx in sorted_indices.iter().take(diff) {
if self.bits_per_segment[idx] > 0 {
self.bits_per_segment[idx] -= 1;
}
}
}
}
Ok(())
}
fn train_codebooks(
&mut self,
vectors: &[f32],
num_vectors: usize,
) -> Result<(), RetrieveError> {
self.codebooks = Vec::new();
for (segment_idx, (start, end)) in self.segment_bounds.iter().enumerate() {
let segment_dim = end - start;
let codebook_size = 2usize.pow(self.bits_per_segment[segment_idx].min(8) as u32);
let mut subvectors = Vec::new();
for i in 0..num_vectors {
let vec = get_vector(vectors, self.dimension, i);
subvectors.push(vec[*start..*end].to_vec());
}
use rand::Rng;
let mut rng = rand::rng();
let mut codebook = Vec::new();
for _ in 0..codebook_size {
let mut centroid = Vec::with_capacity(segment_dim);
let mut norm = 0.0;
for _ in 0..segment_dim {
let val = rng.random::<f32>() * 2.0 - 1.0;
norm += val * val;
centroid.push(val);
}
let norm = norm.sqrt();
if norm > 0.0 {
for val in &mut centroid {
*val /= norm;
}
}
codebook.push(centroid);
}
self.codebooks.push(codebook);
}
Ok(())
}
pub fn quantize(&self, vector: &[f32]) -> Vec<Vec<u8>> {
let mut codes = Vec::new();
for (segment_idx, (start, end)) in self.segment_bounds.iter().enumerate() {
let segment = &vector[*start..*end];
let codebook = &self.codebooks[segment_idx];
let mut best_code = 0u8;
let mut best_dist = f32::INFINITY;
for (code, codeword) in codebook.iter().enumerate() {
let dist = cosine_distance(segment, codeword);
if dist < best_dist {
best_dist = dist;
best_code = code.min(255) as u8;
}
}
let refined_code = self.refine_code(segment, codebook, best_code);
codes.push(vec![refined_code]);
}
codes
}
fn refine_code(&self, segment: &[f32], codebook: &[Vec<f32>], initial_code: u8) -> u8 {
let mut best_code = initial_code;
let mut best_dist = f32::INFINITY;
if (initial_code as usize) < codebook.len() {
best_dist = cosine_distance(segment, &codebook[initial_code as usize]);
}
let check_range = 3u8;
let start = initial_code.saturating_sub(check_range);
let end = initial_code
.saturating_add(check_range)
.min(codebook.len() as u8);
for code in start..=end {
if (code as usize) < codebook.len() {
let dist = cosine_distance(segment, &codebook[code as usize]);
if dist < best_dist {
best_dist = dist;
best_code = code;
}
}
}
best_code
}
pub fn approximate_distance(&self, query: &[f32], codes: &[Vec<u8>]) -> f32 {
let mut total_dist = 0.0;
for (segment_idx, (start, end)) in self.segment_bounds.iter().enumerate() {
if let Some(code_vec) = codes.get(segment_idx) {
if let Some(&code) = code_vec.first() {
let query_segment = &query[*start..*end];
if (code as usize) < self.codebooks[segment_idx].len() {
let codeword = &self.codebooks[segment_idx][code as usize];
total_dist += cosine_distance(query_segment, codeword);
}
}
}
}
total_dist
}
}
fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
crate::distance::cosine_distance_normalized(a, b)
}
fn get_vector(vectors: &[f32], dimension: usize, idx: usize) -> &[f32] {
let start = idx * dimension;
let end = start + dimension;
&vectors[start..end]
}