#![allow(unused_variables)]
use super::{Codebook, Encoder, Decoder, DistanceComputer, QuantizedVector, PqError, PqResult};
use crate::vector::Vector;
use serde::{Serialize, Deserialize};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProductQuantizerConfig {
pub num_subquantizers: usize,
pub num_centroids: usize,
pub dimension: usize,
pub training_iterations: usize,
pub min_training_samples: usize,
}
impl ProductQuantizerConfig {
pub fn default_for_dimension(dimension: usize) -> PqResult<Self> {
let num_subquantizers = if dimension >= 512 {
8
} else if dimension >= 256 {
4
} else {
2
};
if dimension % num_subquantizers != 0 {
return Err(PqError::InvalidConfig(format!(
"Dimension {} must be divisible by num_subquantizers {}",
dimension, num_subquantizers
)));
}
Ok(Self {
num_subquantizers,
num_centroids: 256, dimension,
training_iterations: 25,
min_training_samples: 10000,
})
}
pub fn validate(&self) -> PqResult<()> {
if self.dimension % self.num_subquantizers != 0 {
return Err(PqError::InvalidConfig(format!(
"Dimension {} must be divisible by num_subquantizers {}",
self.dimension, self.num_subquantizers
)));
}
if self.num_centroids == 0 || self.num_centroids > 256 {
return Err(PqError::InvalidConfig(format!(
"num_centroids must be between 1 and 256, got {}",
self.num_centroids
)));
}
if self.num_subquantizers == 0 {
return Err(PqError::InvalidConfig(
"num_subquantizers must be > 0".to_string()
));
}
if self.training_iterations == 0 {
return Err(PqError::InvalidConfig(
"training_iterations must be > 0".to_string()
));
}
Ok(())
}
pub fn subvector_dimension(&self) -> usize {
self.dimension / self.num_subquantizers
}
#[cfg(test)]
pub fn test_config(dimension: usize) -> PqResult<Self> {
let num_subquantizers = if dimension >= 768 {
8
} else if dimension >= 512 {
8
} else if dimension >= 256 {
4
} else if dimension >= 128 {
4
} else if dimension >= 64 {
4
} else if dimension >= 32 {
4
} else {
2
};
if dimension % num_subquantizers != 0 {
let mut found = false;
for nsq in [8, 4, 2, 1] {
if dimension % nsq == 0 {
return Ok(Self {
num_subquantizers: nsq,
num_centroids: 32, dimension,
training_iterations: 5, min_training_samples: 100, });
}
}
return Err(PqError::InvalidConfig(format!(
"Dimension {} cannot be evenly divided into sub-quantizers",
dimension
)));
}
Ok(Self {
num_subquantizers,
num_centroids: 32, dimension,
training_iterations: 5, min_training_samples: 100, })
}
}
impl Default for ProductQuantizerConfig {
fn default() -> Self {
Self {
num_subquantizers: 8,
num_centroids: 256,
dimension: 768, training_iterations: 25,
min_training_samples: 10000,
}
}
}
pub struct ProductQuantizer {
config: ProductQuantizerConfig,
codebook: Arc<Codebook>,
encoder: Encoder,
decoder: Decoder,
distance_computer: DistanceComputer,
}
impl ProductQuantizer {
pub fn new(config: ProductQuantizerConfig, codebook: Codebook) -> PqResult<Self> {
config.validate()?;
let codebook = Arc::new(codebook);
let encoder = Encoder::new(codebook.clone());
let decoder = Decoder::new(codebook.clone());
let distance_computer = DistanceComputer::new(codebook.clone());
Ok(Self {
config,
codebook,
encoder,
decoder,
distance_computer,
})
}
pub fn train(
config: ProductQuantizerConfig,
training_vectors: &[Vector],
) -> PqResult<Self> {
config.validate()?;
let absolute_minimum = config.num_centroids;
if training_vectors.len() < absolute_minimum {
return Err(PqError::InsufficientTrainingData(
training_vectors.len(),
absolute_minimum,
));
}
if training_vectors.len() < config.min_training_samples {
}
for (idx, vec) in training_vectors.iter().enumerate() {
if vec.len() != config.dimension {
return Err(PqError::DimensionMismatch {
expected: config.dimension,
actual: vec.len(),
});
}
}
let codebook = super::training::train_codebook(&config, training_vectors)?;
Self::new(config, codebook)
}
pub fn encode(&self, vector: &Vector) -> PqResult<QuantizedVector> {
self.encoder.encode(vector)
}
pub fn encode_batch(&self, vectors: &[Vector]) -> PqResult<Vec<QuantizedVector>> {
vectors.iter().map(|v| self.encode(v)).collect()
}
pub fn decode(&self, quantized: &QuantizedVector) -> PqResult<Vector> {
self.decoder.decode(quantized)
}
pub fn compute_distance(
&self,
query: &Vector,
quantized: &QuantizedVector,
) -> PqResult<f32> {
self.distance_computer.compute_distance(query, quantized)
}
pub fn precompute_distance_table(&self, query: &Vector) -> PqResult<Vec<Vec<f32>>> {
self.distance_computer.precompute_distance_table(query)
}
pub fn compute_distance_with_table(
&self,
distance_table: &[Vec<f32>],
quantized: &QuantizedVector,
) -> PqResult<f32> {
self.distance_computer
.compute_distance_with_table(distance_table, quantized)
}
pub fn config(&self) -> &ProductQuantizerConfig {
&self.config
}
pub fn codebook(&self) -> Arc<Codebook> {
self.codebook.clone()
}
pub fn compression_ratio(&self) -> f32 {
let original_size = self.config.dimension * std::mem::size_of::<f32>();
let compressed_size = self.config.num_subquantizers * std::mem::size_of::<u8>();
original_size as f32 / compressed_size as f32
}
pub fn memory_per_vector(&self) -> usize {
self.config.num_subquantizers * std::mem::size_of::<u8>()
}
pub fn codebook_size(&self) -> usize {
self.codebook.memory_size()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_config_validation() {
let config = ProductQuantizerConfig {
num_subquantizers: 8,
num_centroids: 256,
dimension: 768,
training_iterations: 25,
min_training_samples: 1000,
};
assert!(config.validate().is_ok());
}
#[test]
fn test_config_invalid_dimension() {
let config = ProductQuantizerConfig {
num_subquantizers: 8,
num_centroids: 256,
dimension: 100, training_iterations: 25,
min_training_samples: 1000,
};
assert!(config.validate().is_err());
}
#[test]
fn test_config_subvector_dimension() {
let config = ProductQuantizerConfig {
num_subquantizers: 8,
num_centroids: 256,
dimension: 768,
training_iterations: 25,
min_training_samples: 1000,
};
assert_eq!(config.subvector_dimension(), 96);
}
#[test]
fn test_config_default_for_dimension() {
let config = ProductQuantizerConfig::default_for_dimension(768).unwrap();
assert_eq!(config.num_subquantizers, 8);
assert_eq!(config.dimension, 768);
assert_eq!(config.subvector_dimension(), 96);
}
#[test]
fn test_compression_ratio() {
let config = ProductQuantizerConfig::default_for_dimension(768).unwrap();
let codebook = Codebook::new(
config.num_subquantizers,
config.num_centroids,
config.subvector_dimension(),
);
let pq = ProductQuantizer::new(config, codebook).unwrap();
assert_eq!(pq.compression_ratio(), 384.0);
assert_eq!(pq.memory_per_vector(), 8);
}
}