use error_forge::ForgeError;
use iqdb_distance::compute_batch;
use iqdb_types::{DistanceMetric, IqdbError, Result};
use crate::code::PqCode;
use crate::train::{assign_to_cluster, squared_l2, train_codebook};
use crate::traits::Quantizer;
use crate::validate::{dim_eq, finite_non_empty, training_set};
const DEFAULT_N_SUBVECTORS: usize = 8;
const DEFAULT_N_CENTROIDS: usize = 256;
const MAX_N_CENTROIDS: usize = 256;
const DEFAULT_SEED: u64 = 0;
#[derive(Debug, Clone, PartialEq)]
struct PqCalibration {
dim: usize,
n_subvectors: usize,
sub_dim: usize,
n_centroids: usize,
codebooks: Vec<Vec<Vec<f32>>>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ProductQuantizer {
n_subvectors: usize,
n_centroids: usize,
seed: u64,
calibration: Option<PqCalibration>,
}
impl Default for ProductQuantizer {
fn default() -> Self {
Self::new()
}
}
impl ProductQuantizer {
#[must_use]
pub fn new() -> Self {
Self::with_config(DEFAULT_N_SUBVECTORS, DEFAULT_N_CENTROIDS, DEFAULT_SEED)
}
#[must_use]
pub fn with_config(n_subvectors: usize, n_centroids: usize, seed: u64) -> Self {
Self {
n_subvectors,
n_centroids,
seed,
calibration: None,
}
}
#[must_use]
pub fn dim(&self) -> Option<usize> {
self.calibration.as_ref().map(|c| c.dim)
}
#[must_use]
pub fn n_subvectors(&self) -> usize {
self.n_subvectors
}
#[must_use]
pub fn n_centroids(&self) -> usize {
self.n_centroids
}
#[must_use]
pub fn seed(&self) -> u64 {
self.seed
}
fn calibration(&self) -> Result<&PqCalibration> {
self.calibration.as_ref().ok_or(IqdbError::InvalidConfig {
reason: "ProductQuantizer has not been trained",
})
}
fn validate_shape(&self, dim: usize, training_count: usize) -> Result<usize> {
if self.n_subvectors == 0 {
return Err(IqdbError::InvalidConfig {
reason: "ProductQuantizer requires n_subvectors >= 1",
});
}
if self.n_centroids == 0 {
return Err(IqdbError::InvalidConfig {
reason: "ProductQuantizer requires n_centroids >= 1",
});
}
if self.n_centroids > MAX_N_CENTROIDS {
return Err(IqdbError::InvalidConfig {
reason: "ProductQuantizer requires n_centroids <= 256 (one byte per code)",
});
}
if dim == 0 || !dim.is_multiple_of(self.n_subvectors) {
return Err(IqdbError::InvalidConfig {
reason: "ProductQuantizer requires training dim to be a positive multiple of n_subvectors",
});
}
if training_count < self.n_centroids {
return Err(IqdbError::InvalidConfig {
reason: "ProductQuantizer requires training_set.len() >= n_centroids",
});
}
Ok(dim / self.n_subvectors)
}
}
impl Quantizer for ProductQuantizer {
type Quantized = PqCode;
#[tracing::instrument(
level = "info",
skip_all,
fields(
quantizer = "pq",
training_size = vectors.len(),
n_subvectors = self.n_subvectors,
n_centroids = self.n_centroids,
),
)]
fn train(&mut self, vectors: &[&[f32]]) -> Result<()> {
let dim = training_set(vectors).inspect_err(|err: &IqdbError| {
tracing::error!(
error.kind = err.kind(),
error.reason = err.caption(),
"product quantizer training failed",
);
})?;
let sub_dim = self
.validate_shape(dim, vectors.len())
.inspect_err(|err: &IqdbError| {
tracing::error!(
error.kind = err.kind(),
error.reason = err.caption(),
"product quantizer training failed",
);
})?;
let mut codebooks: Vec<Vec<Vec<f32>>> = Vec::with_capacity(self.n_subvectors);
for m in 0..self.n_subvectors {
let start = m * sub_dim;
let end = start + sub_dim;
let slices: Vec<&[f32]> = vectors.iter().map(|v| &v[start..end]).collect();
let centroids = train_codebook(
sub_dim,
self.n_centroids,
self.seed.wrapping_add(m as u64),
&slices,
)
.inspect_err(|err: &IqdbError| {
tracing::error!(
error.kind = err.kind(),
error.reason = err.caption(),
subvector = m,
"product quantizer codebook training failed",
);
})?;
codebooks.push(centroids);
}
self.calibration = Some(PqCalibration {
dim,
n_subvectors: self.n_subvectors,
sub_dim,
n_centroids: self.n_centroids,
codebooks,
});
Ok(())
}
fn quantize(&self, vector: &[f32]) -> Result<Self::Quantized> {
let cal = self.calibration()?;
finite_non_empty(vector)?;
dim_eq(cal.dim, vector.len())?;
let mut codes: Vec<u8> = Vec::with_capacity(cal.n_subvectors);
for m in 0..cal.n_subvectors {
let start = m * cal.sub_dim;
let end = start + cal.sub_dim;
let idx = assign_to_cluster(&cal.codebooks[m], &vector[start..end]);
codes.push(idx as u8);
}
Ok(PqCode {
codes,
dim: cal.dim,
n_subvectors: cal.n_subvectors,
})
}
fn dequantize(&self, quantized: &Self::Quantized) -> Result<Vec<f32>> {
let cal = self.calibration()?;
dim_eq(cal.dim, quantized.dim)?;
if quantized.n_subvectors != cal.n_subvectors {
return Err(IqdbError::DimensionMismatch {
expected: cal.n_subvectors,
found: quantized.n_subvectors,
});
}
let mut out: Vec<f32> = Vec::with_capacity(cal.dim);
for (m, &code) in quantized.codes.iter().enumerate() {
let centroid = &cal.codebooks[m][code as usize];
out.extend_from_slice(centroid);
}
Ok(out)
}
fn distance(
&self,
query: &[f32],
quantized: &Self::Quantized,
metric: DistanceMetric,
) -> Result<f32> {
let tables = self.build_query_tables(query, metric)?;
tables.distance(quantized)
}
}
impl ProductQuantizer {
pub fn build_query_tables(&self, query: &[f32], metric: DistanceMetric) -> Result<PqAdcTables> {
let cal = self.calibration()?;
finite_non_empty(query)?;
dim_eq(cal.dim, query.len())?;
match metric {
DistanceMetric::Euclidean | DistanceMetric::DotProduct | DistanceMetric::Manhattan => {}
DistanceMetric::Cosine | DistanceMetric::Hamming => {
return Err(IqdbError::InvalidMetric);
}
_ => return Err(IqdbError::InvalidMetric),
}
let table = build_adc_table_rows(query, metric, cal)?;
Ok(PqAdcTables {
table,
metric,
n_subvectors: cal.n_subvectors,
n_centroids: cal.n_centroids,
dim: cal.dim,
})
}
}
#[derive(Debug, Clone)]
pub struct PqAdcTables {
table: Vec<f32>,
metric: DistanceMetric,
n_subvectors: usize,
n_centroids: usize,
dim: usize,
}
impl PqAdcTables {
pub fn distance(&self, code: &PqCode) -> Result<f32> {
if code.n_subvectors != self.n_subvectors {
return Err(IqdbError::DimensionMismatch {
expected: self.n_subvectors,
found: code.n_subvectors,
});
}
if code.dim != self.dim {
return Err(IqdbError::DimensionMismatch {
expected: self.dim,
found: code.dim,
});
}
let total = score_code_rows(&self.table, code, self.n_centroids);
Ok(if self.metric == DistanceMetric::Euclidean {
total.sqrt()
} else {
total
})
}
#[must_use]
pub fn metric(&self) -> DistanceMetric {
self.metric
}
#[must_use]
pub fn n_subvectors(&self) -> usize {
self.n_subvectors
}
#[must_use]
pub fn n_centroids(&self) -> usize {
self.n_centroids
}
#[must_use]
pub fn dim(&self) -> usize {
self.dim
}
}
fn build_adc_table_rows(
query: &[f32],
metric: DistanceMetric,
cal: &PqCalibration,
) -> Result<Vec<f32>> {
let total_entries = cal.n_subvectors * cal.n_centroids;
let mut table: Vec<f32> = vec![0.0; total_entries];
let mut centroid_refs: Vec<&[f32]> = Vec::with_capacity(cal.n_centroids);
for m in 0..cal.n_subvectors {
let start = m * cal.sub_dim;
let end = start + cal.sub_dim;
let q_sub = &query[start..end];
let row_start = m * cal.n_centroids;
let row_end = row_start + cal.n_centroids;
let row = &mut table[row_start..row_end];
match metric {
DistanceMetric::Euclidean => {
for (k, centroid) in cal.codebooks[m].iter().enumerate() {
row[k] = squared_l2(q_sub, centroid);
}
}
DistanceMetric::DotProduct | DistanceMetric::Manhattan => {
centroid_refs.clear();
for centroid in &cal.codebooks[m] {
centroid_refs.push(centroid.as_slice());
}
compute_batch(metric, q_sub, ¢roid_refs, row)?;
}
DistanceMetric::Cosine | DistanceMetric::Hamming => {
return Err(IqdbError::InvalidMetric);
}
_ => return Err(IqdbError::InvalidMetric),
}
}
Ok(table)
}
fn score_code_rows(table: &[f32], code: &PqCode, n_centroids: usize) -> f32 {
let mut sum: f32 = 0.0;
for (m, &c) in code.codes.iter().enumerate() {
let row_start = m * n_centroids;
sum += table[row_start + c as usize];
}
sum
}