use super::{Codebook, QuantizedVector, PqError, PqResult};
use crate::vector::Vector;
use std::sync::Arc;
pub struct DistanceComputer {
codebook: Arc<Codebook>,
}
#[allow(clippy::indexing_slicing)]
impl DistanceComputer {
pub fn new(codebook: Arc<Codebook>) -> Self {
Self { codebook }
}
pub fn compute_distance(
&self,
query: &Vector,
quantized: &QuantizedVector,
) -> PqResult<f32> {
let dimension = self.codebook.dimension();
if query.len() != dimension {
return Err(PqError::DimensionMismatch {
expected: dimension,
actual: query.len(),
});
}
let num_subquantizers = self.codebook.num_subquantizers();
if quantized.codes.len() != num_subquantizers {
return Err(PqError::EncodingError(format!(
"Expected {} codes, got {}",
num_subquantizers,
quantized.codes.len()
)));
}
let subvector_dim = self.codebook.subvector_dimension();
let mut distance = 0.0_f32;
for sq_idx in 0..num_subquantizers {
let start = sq_idx * subvector_dim;
let end = start + subvector_dim;
let query_subvector = &query[start..end];
let code = quantized.codes[sq_idx];
let centroid = self.codebook.get_centroid(sq_idx, code as usize)?;
distance += l2_distance_squared(query_subvector, centroid);
}
Ok(distance.sqrt())
}
pub fn precompute_distance_table(&self, query: &Vector) -> PqResult<Vec<Vec<f32>>> {
let dimension = self.codebook.dimension();
if query.len() != dimension {
return Err(PqError::DimensionMismatch {
expected: dimension,
actual: query.len(),
});
}
let num_subquantizers = self.codebook.num_subquantizers();
let num_centroids = self.codebook.num_centroids();
let subvector_dim = self.codebook.subvector_dimension();
let mut table = vec![vec![0.0_f32; num_centroids]; num_subquantizers];
#[allow(clippy::needless_range_loop)]
for sq_idx in 0..num_subquantizers {
let start = sq_idx * subvector_dim;
let end = start + subvector_dim;
let query_subvector = &query[start..end];
for c_idx in 0..num_centroids {
let centroid = self.codebook.get_centroid(sq_idx, c_idx)?;
table[sq_idx][c_idx] = l2_distance_squared(query_subvector, centroid);
}
}
Ok(table)
}
pub fn compute_distance_with_table(
&self,
distance_table: &[Vec<f32>],
quantized: &QuantizedVector,
) -> PqResult<f32> {
let num_subquantizers = self.codebook.num_subquantizers();
if distance_table.len() != num_subquantizers {
return Err(PqError::EncodingError(format!(
"Distance table has wrong number of rows: expected {}, got {}",
num_subquantizers,
distance_table.len()
)));
}
if quantized.codes.len() != num_subquantizers {
return Err(PqError::EncodingError(format!(
"Expected {} codes, got {}",
num_subquantizers,
quantized.codes.len()
)));
}
let mut distance_squared = 0.0_f32;
#[allow(clippy::needless_range_loop)]
for sq_idx in 0..num_subquantizers {
let code = quantized.codes[sq_idx] as usize;
if code >= distance_table[sq_idx].len() {
return Err(PqError::InvalidCentroidIndex(code));
}
distance_squared += distance_table[sq_idx][code];
}
Ok(distance_squared.sqrt())
}
pub fn compute_distances_batch(
&self,
distance_table: &[Vec<f32>],
quantized_vectors: &[QuantizedVector],
) -> PqResult<Vec<f32>> {
quantized_vectors
.iter()
.map(|qv| self.compute_distance_with_table(distance_table, qv))
.collect()
}
pub fn codebook(&self) -> Arc<Codebook> {
self.codebook.clone()
}
}
#[inline]
fn l2_distance_squared(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum()
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use super::super::Encoder;
fn create_test_codebook() -> Codebook {
let mut codebook = Codebook::new(2, 2, 2);
codebook.set_centroid(0, 0, vec![0.0, 0.0]).unwrap();
codebook.set_centroid(0, 1, vec![1.0, 1.0]).unwrap();
codebook.set_centroid(1, 0, vec![2.0, 2.0]).unwrap();
codebook.set_centroid(1, 1, vec![3.0, 3.0]).unwrap();
codebook
}
#[test]
fn test_distance_computation_basic() {
let codebook = Arc::new(create_test_codebook());
let distance_computer = DistanceComputer::new(codebook.clone());
let encoder = Encoder::new(codebook);
let query = vec![0.5, 0.5, 2.5, 2.5];
let db_vector = vec![1.0, 1.0, 3.0, 3.0];
let quantized = encoder.encode(&db_vector).unwrap();
let pq_distance = distance_computer
.compute_distance(&query, &quantized)
.unwrap();
let actual_distance = ((0.5_f32 - 1.0_f32).powi(2)
+ (0.5_f32 - 1.0_f32).powi(2)
+ (2.5_f32 - 3.0_f32).powi(2)
+ (2.5_f32 - 3.0_f32).powi(2))
.sqrt();
assert!((pq_distance - actual_distance).abs() < 0.01);
}
#[test]
fn test_precompute_distance_table() {
let codebook = Arc::new(create_test_codebook());
let distance_computer = DistanceComputer::new(codebook);
let query = vec![0.5, 0.5, 2.5, 2.5];
let table = distance_computer
.precompute_distance_table(&query)
.unwrap();
assert_eq!(table.len(), 2);
assert_eq!(table[0].len(), 2);
assert_eq!(table[1].len(), 2);
let expected: f32 = 0.5_f32.powi(2) + 0.5_f32.powi(2);
assert!((table[0][0] - expected).abs() < 0.001);
let expected: f32 = 0.5_f32.powi(2) + 0.5_f32.powi(2);
assert!((table[0][1] - expected).abs() < 0.001);
}
#[test]
fn test_distance_with_table() {
let codebook = Arc::new(create_test_codebook());
let distance_computer = DistanceComputer::new(codebook.clone());
let encoder = Encoder::new(codebook);
let query = vec![0.5, 0.5, 2.5, 2.5];
let db_vector = vec![1.0, 1.0, 3.0, 3.0];
let table = distance_computer
.precompute_distance_table(&query)
.unwrap();
let quantized = encoder.encode(&db_vector).unwrap();
let distance_with_table = distance_computer
.compute_distance_with_table(&table, &quantized)
.unwrap();
let distance_without_table = distance_computer
.compute_distance(&query, &quantized)
.unwrap();
assert!((distance_with_table - distance_without_table).abs() < 0.0001);
}
#[test]
fn test_distance_batch() {
let codebook = Arc::new(create_test_codebook());
let distance_computer = DistanceComputer::new(codebook.clone());
let encoder = Encoder::new(codebook);
let query = vec![0.5, 0.5, 2.5, 2.5];
let db_vectors = vec![
vec![0.0, 0.0, 2.0, 2.0],
vec![1.0, 1.0, 3.0, 3.0],
];
let table = distance_computer
.precompute_distance_table(&query)
.unwrap();
let quantized: Vec<_> = db_vectors
.iter()
.map(|v| encoder.encode(v).unwrap())
.collect();
let distances = distance_computer
.compute_distances_batch(&table, &quantized)
.unwrap();
assert_eq!(distances.len(), 2);
for distance in distances {
assert!(distance >= 0.0);
assert!(distance.is_finite());
}
}
#[test]
fn test_l2_distance_squared() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let dist_sq = l2_distance_squared(&a, &b);
assert_eq!(dist_sq, 27.0);
}
}