use crate::error::Error;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use super::pq_kmeans::{kmeans_train, l2_squared, nearest_centroid};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PQCodebook {
pub centroids: Vec<Vec<Vec<f32>>>,
pub dimension: usize,
pub num_subspaces: usize,
pub num_centroids: usize,
pub subspace_dim: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PQVector {
pub codes: Vec<u16>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProductQuantizer {
pub codebook: PQCodebook,
pub rotation: Option<Vec<f32>>,
}
pub(super) fn validate_train_params(
vectors: &[Vec<f32>],
num_subspaces: usize,
num_centroids: usize,
) -> Result<(usize, usize), Error> {
validate_basic_params(vectors, num_subspaces, num_centroids)?;
let dimension = vectors[0].len();
validate_dimension(vectors, dimension, num_subspaces, num_centroids)?;
let subspace_dim = dimension / num_subspaces;
Ok((dimension, subspace_dim))
}
fn validate_basic_params(
vectors: &[Vec<f32>],
num_subspaces: usize,
num_centroids: usize,
) -> Result<(), Error> {
if vectors.is_empty() {
return Err(Error::InvalidQuantizerConfig(
"cannot train PQ with empty dataset".into(),
));
}
if num_subspaces == 0 {
return Err(Error::InvalidQuantizerConfig(
"num_subspaces must be > 0".into(),
));
}
if num_centroids == 0 {
return Err(Error::InvalidQuantizerConfig(
"num_centroids must be > 0".into(),
));
}
if u16::try_from(num_centroids).is_err() {
return Err(Error::InvalidQuantizerConfig(
"num_centroids must fit in u16 (max 65535)".into(),
));
}
Ok(())
}
fn validate_dimension(
vectors: &[Vec<f32>],
dimension: usize,
num_subspaces: usize,
num_centroids: usize,
) -> Result<(), Error> {
if dimension == 0 {
return Err(Error::InvalidQuantizerConfig(
"vectors must have non-zero dimension".into(),
));
}
if !vectors.iter().all(|v| v.len() == dimension) {
return Err(Error::InvalidQuantizerConfig(
"all vectors must share the same dimension".into(),
));
}
if !dimension.is_multiple_of(num_subspaces) {
return Err(Error::InvalidQuantizerConfig(
"dimension must be divisible by num_subspaces".into(),
));
}
if num_centroids > vectors.len() {
return Err(Error::InvalidQuantizerConfig(format!(
"num_centroids ({num_centroids}) exceeds number of training vectors ({})",
vectors.len()
)));
}
Ok(())
}
impl ProductQuantizer {
pub fn train(
vectors: &[Vec<f32>],
num_subspaces: usize,
num_centroids: usize,
) -> Result<Self, Error> {
let (dimension, subspace_dim) =
validate_train_params(vectors, num_subspaces, num_centroids)?;
let centroids =
train_subspace_centroids(vectors, num_subspaces, subspace_dim, num_centroids);
#[cfg(debug_assertions)]
check_degenerate_centroids(¢roids);
let lut_size = num_subspaces * num_centroids * 4;
if lut_size > 8192 {
tracing::warn!("PQ LUT size {lut_size} bytes exceeds L1-friendly 8KB threshold");
}
Ok(Self {
codebook: PQCodebook {
centroids,
dimension,
num_subspaces,
num_centroids,
subspace_dim,
},
rotation: None,
})
}
pub fn quantize(&self, vector: &[f32]) -> Result<PQVector, Error> {
if vector.len() != self.codebook.dimension {
return Err(Error::InvalidQuantizerConfig(format!(
"vector dimension mismatch: expected {}, got {}",
self.codebook.dimension,
vector.len()
)));
}
let rotated = self.apply_rotation(vector);
let effective: &[f32] = &rotated;
let mut codes = Vec::with_capacity(self.codebook.num_subspaces);
for subspace in 0..self.codebook.num_subspaces {
let start = subspace * self.codebook.subspace_dim;
let end = start + self.codebook.subspace_dim;
let code = nearest_centroid(&effective[start..end], &self.codebook.centroids[subspace]);
#[allow(clippy::cast_possible_truncation)]
codes.push(code as u16);
}
Ok(PQVector { codes })
}
pub fn reconstruct(&self, pq_vector: &PQVector) -> Result<Vec<f32>, Error> {
if pq_vector.codes.len() != self.codebook.num_subspaces {
return Err(Error::InvalidQuantizerConfig(format!(
"code count mismatch: expected {}, got {}",
self.codebook.num_subspaces,
pq_vector.codes.len()
)));
}
let mut reconstructed = Vec::with_capacity(self.codebook.dimension);
for (subspace, &code) in pq_vector.codes.iter().enumerate() {
let code_idx = usize::from(code);
if code_idx >= self.codebook.centroids[subspace].len() {
return Err(Error::InvalidQuantizerConfig(format!(
"code index {code_idx} out of range for subspace {subspace} \
(max {})",
self.codebook.centroids[subspace].len() - 1
)));
}
let centroid = &self.codebook.centroids[subspace][code_idx];
reconstructed.extend_from_slice(centroid);
}
Ok(reconstructed)
}
}
fn train_single_subspace(
vectors: &[Vec<f32>],
subspace: usize,
subspace_dim: usize,
num_centroids: usize,
#[cfg(feature = "gpu")] gpu_ctx: Option<&crate::gpu::PqGpuContext>,
) -> Vec<Vec<f32>> {
let start = subspace * subspace_dim;
let end = start + subspace_dim;
let sub_vectors: Vec<Vec<f32>> = vectors.iter().map(|v| v[start..end].to_vec()).collect();
#[allow(clippy::cast_possible_truncation)]
let seed = 42u64.wrapping_add(subspace as u64);
kmeans_train(
&sub_vectors,
num_centroids,
50,
seed,
#[cfg(feature = "gpu")]
gpu_ctx,
)
}
fn train_subspace_centroids(
vectors: &[Vec<f32>],
num_subspaces: usize,
subspace_dim: usize,
num_centroids: usize,
) -> Vec<Vec<Vec<f32>>> {
#[cfg(feature = "gpu")]
let gpu_ctx = crate::gpu::PqGpuContext::new();
#[cfg(feature = "persistence")]
{
use rayon::prelude::*;
(0..num_subspaces)
.into_par_iter()
.map(|s| {
train_single_subspace(
vectors,
s,
subspace_dim,
num_centroids,
#[cfg(feature = "gpu")]
gpu_ctx.as_ref(),
)
})
.collect()
}
#[cfg(not(feature = "persistence"))]
{
(0..num_subspaces)
.map(|s| {
train_single_subspace(
vectors,
s,
subspace_dim,
num_centroids,
#[cfg(feature = "gpu")]
gpu_ctx.as_ref(),
)
})
.collect()
}
}
#[cfg(debug_assertions)]
fn check_degenerate_centroids(centroids: &[Vec<Vec<f32>>]) {
for (subspace, sub_centroids) in centroids.iter().enumerate() {
for i in 0..sub_centroids.len() {
for j in (i + 1)..sub_centroids.len() {
let dist = l2_squared(&sub_centroids[i], &sub_centroids[j]);
if dist < 1e-6 {
tracing::warn!(
"degenerate centroids detected in subspace {subspace}: \
centroids {i} and {j} distance {dist}"
);
}
}
}
}
}
impl ProductQuantizer {
#[must_use]
pub fn precompute_lut(&self, query: &[f32]) -> Vec<f32> {
let query = self.apply_rotation(query);
let m = self.codebook.num_subspaces;
let k = self.codebook.num_centroids;
let sd = self.codebook.subspace_dim;
let mut lut = Vec::with_capacity(m * k);
for subspace in 0..m {
let q_sub = &query[subspace * sd..(subspace + 1) * sd];
for centroid in &self.codebook.centroids[subspace] {
lut.push(l2_squared(q_sub, centroid));
}
}
lut
}
pub(crate) fn apply_rotation<'a>(&self, vector: &'a [f32]) -> Cow<'a, [f32]> {
match &self.rotation {
None => Cow::Borrowed(vector),
Some(matrix) => {
let d = vector.len();
let mut rotated = vec![0.0_f32; d];
for i in 0..d {
for j in 0..d {
rotated[i] += matrix[i * d + j] * vector[j];
}
}
Cow::Owned(rotated)
}
}
}
}
#[must_use]
#[cfg_attr(not(feature = "persistence"), allow(dead_code))]
pub(crate) fn distance_pq_l2(
query_vector: &[f32],
pq_vector: &PQVector,
quantizer: &ProductQuantizer,
) -> f32 {
debug_assert_eq!(query_vector.len(), quantizer.codebook.dimension);
debug_assert_eq!(pq_vector.codes.len(), quantizer.codebook.num_subspaces);
let lut = quantizer.precompute_lut(query_vector);
distance_pq_l2_with_lut(pq_vector, &lut, quantizer.codebook.num_centroids)
}
#[must_use]
#[cfg_attr(not(feature = "persistence"), allow(dead_code))]
pub(crate) fn distance_pq_l2_with_lut(
pq_vector: &PQVector,
lut: &[f32],
num_centroids: usize,
) -> f32 {
pq_vector
.codes
.iter()
.enumerate()
.map(|(subspace, &code)| lut[subspace * num_centroids + usize::from(code)])
.sum::<f32>()
.sqrt()
}
#[cfg_attr(not(feature = "persistence"), allow(dead_code))]
const ADC_SIMD_BATCH_THRESHOLD: usize = 8;
#[cfg_attr(not(feature = "persistence"), allow(dead_code))]
pub(crate) fn pq_adc_batch_rescore(
quantizer: &ProductQuantizer,
query: &[f32],
pq_vectors: &[&PQVector],
) -> crate::error::Result<Vec<f32>> {
if pq_vectors.is_empty() {
return Ok(Vec::new());
}
let m = quantizer.codebook.num_subspaces;
if pq_vectors.len() < ADC_SIMD_BATCH_THRESHOLD {
let lut = quantizer.precompute_lut(query);
let k = quantizer.codebook.num_centroids;
return Ok(pq_vectors
.iter()
.map(|pq_vec| distance_pq_l2_with_lut(pq_vec, &lut, k))
.collect());
}
let lut = quantizer.precompute_lut(query);
let code_slices: Vec<&[u16]> = pq_vectors
.iter()
.map(|pq_vec| pq_vec.codes.as_slice())
.collect();
let squared_dists = crate::simd_native::adc::adc_distances_batch(&lut, &code_slices, m)?;
Ok(squared_dists.iter().map(|&d| d.sqrt()).collect())
}
#[cfg(test)]
#[path = "pq_tests.rs"]
mod tests;