use crate::distance::cosine_distance_normalized;
use crate::RetrieveError;
pub struct SAQQuantizer {
dimension: usize,
num_segments: usize,
bits_per_segment: Vec<usize>, codebooks: Vec<Vec<Vec<f32>>>, segment_bounds: Vec<(usize, usize)>, #[allow(dead_code)]
pca_matrix: Option<Vec<Vec<f32>>>, }
impl SAQQuantizer {
pub fn new(
dimension: usize,
num_segments: usize,
total_bits: usize,
) -> Result<Self, RetrieveError> {
if dimension == 0 || num_segments == 0 || total_bits == 0 {
return Err(RetrieveError::InvalidParameter(
"All parameters must be greater than 0".to_string(),
));
}
if dimension % num_segments != 0 {
return Err(RetrieveError::InvalidParameter(
"Dimension must be divisible by num_segments".to_string(),
));
}
let bits_per_segment = vec![total_bits / num_segments; num_segments];
let segment_dim = dimension / num_segments;
let mut segment_bounds = Vec::new();
for i in 0..num_segments {
segment_bounds.push((i * segment_dim, (i + 1) * segment_dim));
}
Ok(Self {
dimension,
num_segments,
bits_per_segment,
codebooks: Vec::new(),
segment_bounds,
pca_matrix: None,
})
}
pub fn fit(&mut self, vectors: &[f32], num_vectors: usize) -> Result<(), RetrieveError> {
self.optimize_segmentation(vectors, num_vectors)?;
self.train_codebooks(vectors, num_vectors)?;
Ok(())
}
fn optimize_segmentation(
&mut self,
vectors: &[f32],
num_vectors: usize,
) -> Result<(), RetrieveError> {
let _segment_dim = self.dimension / self.num_segments; let total_bits: usize = self.bits_per_segment.iter().sum();
let mut segment_variances = Vec::new();
for (start, end) in &self.segment_bounds {
let mut variance = 0.0;
for i in 0..num_vectors {
let vec = get_vector(vectors, self.dimension, i);
let segment = &vec[*start..*end];
let mean: f32 = segment.iter().sum::<f32>() / segment.len() as f32;
let var: f32 =
segment.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / segment.len() as f32;
variance += var;
}
variance /= num_vectors as f32;
segment_variances.push(variance);
}
let total_variance: f32 = segment_variances.iter().sum();
if total_variance > 0.0 {
self.bits_per_segment = segment_variances
.iter()
.map(|&var| {
let ratio = var / total_variance;
(ratio * total_bits as f32).ceil() as usize
})
.collect();
let allocated: usize = self.bits_per_segment.iter().sum();
if allocated > total_bits {
let diff = allocated - total_bits;
let mut sorted_indices: Vec<usize> = (0..self.num_segments).collect();
sorted_indices.sort_unstable_by(|&a, &b| {
segment_variances[a]
.partial_cmp(&segment_variances[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
for &idx in sorted_indices.iter().take(diff) {
if self.bits_per_segment[idx] > 0 {
self.bits_per_segment[idx] -= 1;
}
}
}
}
Ok(())
}
fn train_codebooks(
&mut self,
vectors: &[f32],
num_vectors: usize,
) -> Result<(), RetrieveError> {
self.codebooks = Vec::new();
for (segment_idx, (start, end)) in self.segment_bounds.iter().enumerate() {
let segment_dim = end - start;
let codebook_size = 2usize.pow(self.bits_per_segment[segment_idx].min(8) as u32);
let mut subvectors = Vec::new();
for i in 0..num_vectors {
let vec = get_vector(vectors, self.dimension, i);
subvectors.push(vec[*start..*end].to_vec());
}
use rand::Rng;
let mut rng = rand::rng();
let mut codebook = Vec::new();
for _ in 0..codebook_size {
let centroid: Vec<f32> = (0..segment_dim)
.map(|_| rng.random::<f32>() * 2.0 - 1.0)
.collect();
codebook.push(crate::distance::normalize(¢roid));
}
self.codebooks.push(codebook);
}
Ok(())
}
pub fn quantize(&self, vector: &[f32]) -> Vec<Vec<u8>> {
let mut codes = Vec::new();
for (segment_idx, (start, end)) in self.segment_bounds.iter().enumerate() {
let segment = &vector[*start..*end];
let codebook = &self.codebooks[segment_idx];
let mut best_code = 0u8;
let mut best_dist = f32::INFINITY;
for (code, codeword) in codebook.iter().enumerate() {
let dist = cosine_distance_normalized(segment, codeword);
if dist < best_dist {
best_dist = dist;
best_code = code.min(255) as u8;
}
}
let refined_code = self.refine_code(segment, codebook, best_code);
codes.push(vec![refined_code]);
}
codes
}
fn refine_code(&self, segment: &[f32], codebook: &[Vec<f32>], initial_code: u8) -> u8 {
let mut best_code = initial_code;
let mut best_dist = f32::INFINITY;
if (initial_code as usize) < codebook.len() {
best_dist = cosine_distance_normalized(segment, &codebook[initial_code as usize]);
}
let check_range = 3u8;
let start = initial_code.saturating_sub(check_range);
let end = initial_code
.saturating_add(check_range)
.min(codebook.len() as u8);
for code in start..=end {
if (code as usize) < codebook.len() {
let dist = cosine_distance_normalized(segment, &codebook[code as usize]);
if dist < best_dist {
best_dist = dist;
best_code = code;
}
}
}
best_code
}
pub fn approximate_distance(&self, query: &[f32], codes: &[Vec<u8>]) -> f32 {
let mut total_dist = 0.0;
for (segment_idx, (start, end)) in self.segment_bounds.iter().enumerate() {
if let Some(code_vec) = codes.get(segment_idx) {
if let Some(&code) = code_vec.first() {
let query_segment = &query[*start..*end];
if (code as usize) < self.codebooks[segment_idx].len() {
let codeword = &self.codebooks[segment_idx][code as usize];
total_dist += cosine_distance_normalized(query_segment, codeword);
}
}
}
}
total_dist
}
}
fn get_vector(vectors: &[f32], dimension: usize, idx: usize) -> &[f32] {
let start = idx * dimension;
let end = start + dimension;
&vectors[start..end]
}
#[cfg(test)]
mod tests {
use super::*;
fn normalize(v: &mut [f32]) -> f32 {
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in v.iter_mut() {
*x /= norm;
}
}
norm
}
fn make_training_data(num_vectors: usize, dim: usize) -> Vec<f32> {
use rand::Rng;
let mut rng = rand::rng();
let mut data = Vec::with_capacity(num_vectors * dim);
for _ in 0..num_vectors {
let mut v: Vec<f32> = (0..dim).map(|_| rng.random::<f32>() * 2.0 - 1.0).collect();
normalize(&mut v);
data.extend_from_slice(&v);
}
data
}
#[test]
fn saq_new_valid_params() {
let q = SAQQuantizer::new(16, 4, 8);
assert!(q.is_ok());
let q = q.unwrap();
assert_eq!(q.dimension, 16);
assert_eq!(q.num_segments, 4);
assert_eq!(q.segment_bounds.len(), 4);
}
#[test]
fn saq_new_rejects_zero_params() {
assert!(SAQQuantizer::new(0, 4, 8).is_err());
assert!(SAQQuantizer::new(16, 0, 8).is_err());
assert!(SAQQuantizer::new(16, 4, 0).is_err());
}
#[test]
fn saq_new_rejects_indivisible_dimension() {
assert!(SAQQuantizer::new(15, 4, 8).is_err());
}
#[test]
fn saq_encode_decode_roundtrip_finite() {
let dim = 16;
let num_segments = 4;
let total_bits = 16;
let num_train = 50;
let data = make_training_data(num_train, dim);
let mut quantizer = SAQQuantizer::new(dim, num_segments, total_bits).unwrap();
quantizer.fit(&data, num_train).unwrap();
for i in 0..num_train {
let vec = get_vector(&data, dim, i);
let codes = quantizer.quantize(vec);
assert_eq!(
codes.len(),
num_segments,
"code count must equal num_segments"
);
let self_dist = quantizer.approximate_distance(vec, &codes);
assert!(
self_dist.is_finite() && self_dist >= 0.0,
"Self-distance must be finite and non-negative, got {} for vector {}",
self_dist,
i
);
}
}
#[test]
fn saq_approximate_distance_closer_for_similar_vectors() {
let dim = 8;
let num_segments = 2;
let total_bits = 8;
let num_train = 30;
let data = make_training_data(num_train, dim);
let mut quantizer = SAQQuantizer::new(dim, num_segments, total_bits).unwrap();
quantizer.fit(&data, num_train).unwrap();
let query = get_vector(&data, dim, 0);
let self_codes = quantizer.quantize(query);
let self_dist = quantizer.approximate_distance(query, &self_codes);
let mut other_dists = Vec::new();
for i in 1..num_train {
let v = get_vector(&data, dim, i);
let codes = quantizer.quantize(v);
other_dists.push(quantizer.approximate_distance(query, &codes));
}
let avg_other: f32 = other_dists.iter().sum::<f32>() / other_dists.len() as f32;
assert!(
self_dist <= avg_other + 0.5,
"Self-distance {} should be <= avg other distance {} (with margin)",
self_dist,
avg_other
);
}
#[test]
fn saq_single_vector() {
let dim = 8;
let num_segments = 2;
let total_bits = 4;
let mut v = vec![0.5, -0.3, 0.1, 0.7, -0.2, 0.4, -0.6, 0.8];
normalize(&mut v);
let data = v.clone();
let mut quantizer = SAQQuantizer::new(dim, num_segments, total_bits).unwrap();
quantizer.fit(&data, 1).unwrap();
let codes = quantizer.quantize(&v);
assert_eq!(codes.len(), num_segments);
for code_vec in &codes {
assert!(!code_vec.is_empty());
}
}
#[test]
fn saq_segment_bounds_correct() {
let q = SAQQuantizer::new(12, 3, 6).unwrap();
assert_eq!(q.segment_bounds, vec![(0, 4), (4, 8), (8, 12)]);
}
#[test]
fn cosine_distance_fn_identical_vectors() {
let v = vec![0.6, 0.8]; let d = cosine_distance_normalized(&v, &v);
assert!(
d.abs() < 1e-5,
"cosine distance to self should be ~0, got {}",
d
);
}
#[test]
fn cosine_distance_fn_opposite_vectors() {
let a = vec![1.0, 0.0];
let b = vec![-1.0, 0.0];
let d = cosine_distance_normalized(&a, &b);
assert!(
(d - 2.0).abs() < 1e-5,
"cosine distance for opposite vectors should be ~2.0, got {}",
d
);
}
#[test]
fn get_vector_fn() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
assert_eq!(get_vector(&data, 3, 0), &[1.0, 2.0, 3.0]);
assert_eq!(get_vector(&data, 3, 1), &[4.0, 5.0, 6.0]);
assert_eq!(get_vector(&data, 2, 2), &[5.0, 6.0]);
}
}