iqdb_quantize/product.rs
1//! [`ProductQuantizer`] — product quantization (PQ).
2//!
3//! PQ splits each input vector into `M = n_subvectors` equal-length
4//! chunks and learns a small codebook of `K = n_centroids` centroids
5//! (with `K <= 256`) for each chunk via k-means. A vector compresses
6//! to `M` bytes — one centroid index per chunk — for a compression
7//! ratio of `(dim * 4) / M` (e.g. 768 dims at `M = 16` → 16 bytes,
8//! 192×). Reconstruction error trades off cleanly against `M` and `K`.
9//!
10//! Asymmetric distance computation (ADC) keeps the query in `f32`,
11//! precomputes a per-subvector distance table from the query to each
12//! of the `K` centroids, and scores a stored code with `M` table
13//! lookups plus a single summation pass. The math is decomposable —
14//! and so **PQ ADC returns the same value as
15//! [`Quantizer::distance`](crate::Quantizer::distance) would after
16//! [`Quantizer::dequantize`](crate::Quantizer::dequantize) +
17//! [`iqdb_distance::compute`]** — for every metric where it's
18//! supported.
19//!
20//! ## Supported metrics
21//!
22//! | Metric | Supported | Why |
23//! |-----------------------------------------|-----------|------------------------------------------------------------------|
24//! | [`DistanceMetric::Euclidean`] | yes | `L2² = Σ_m L2²(q_m, c_m)`; take `sqrt` once at the end. |
25//! | [`DistanceMetric::DotProduct`] | yes | `dot = Σ_m dot(q_m, c_m)`; raw inner product (matches SQ8). |
26//! | [`DistanceMetric::Manhattan`] | yes | `L1 = Σ_m L1(q_m, c_m)`. |
27//! | [`DistanceMetric::Cosine`] | **no** | Requires a global `‖c‖` PQ cannot recover per subvector. Returns `InvalidMetric`.|
28//! | [`DistanceMetric::Hamming`] | **no** | Meaningless on `f32` codes. Returns `InvalidMetric`. |
29//!
30//! Production practice: L2-normalize vectors before training and use
31//! [`DistanceMetric::DotProduct`] when you want cosine semantics.
32//!
33//! [`DistanceMetric::Euclidean`]: iqdb_types::DistanceMetric::Euclidean
34//! [`DistanceMetric::DotProduct`]: iqdb_types::DistanceMetric::DotProduct
35//! [`DistanceMetric::Manhattan`]: iqdb_types::DistanceMetric::Manhattan
36//! [`DistanceMetric::Cosine`]: iqdb_types::DistanceMetric::Cosine
37//! [`DistanceMetric::Hamming`]: iqdb_types::DistanceMetric::Hamming
38
39use error_forge::ForgeError;
40use iqdb_distance::compute_batch;
41use iqdb_types::{DistanceMetric, IqdbError, Result};
42
43use crate::code::PqCode;
44use crate::train::{assign_to_cluster, squared_l2, train_codebook};
45use crate::traits::Quantizer;
46use crate::validate::{dim_eq, finite_non_empty, training_set};
47
48/// Default number of subvectors used by [`ProductQuantizer::new`].
49const DEFAULT_N_SUBVECTORS: usize = 8;
50/// Default number of centroids per subvector used by [`ProductQuantizer::new`].
51const DEFAULT_N_CENTROIDS: usize = 256;
52/// Upper bound on `n_centroids`: codes are stored as `u8`.
53const MAX_N_CENTROIDS: usize = 256;
54/// Default seed used by [`ProductQuantizer::new`].
55const DEFAULT_SEED: u64 = 0;
56
57/// Calibration learned during [`ProductQuantizer::train`].
58#[derive(Debug, Clone, PartialEq)]
59struct PqCalibration {
60 /// The trained input dimension; equals `n_subvectors * sub_dim`.
61 dim: usize,
62 /// `M`, the number of subvectors.
63 n_subvectors: usize,
64 /// `dim / n_subvectors`.
65 sub_dim: usize,
66 /// `K`, the number of centroids per subvector codebook.
67 n_centroids: usize,
68 /// `codebooks[m][k]` is the `k`-th centroid of subvector `m`,
69 /// stored as a `Vec<f32>` of length `sub_dim`.
70 codebooks: Vec<Vec<Vec<f32>>>,
71}
72
73/// Product quantizer: `M` subvectors × `K` centroids per subvector.
74///
75/// Build one with [`ProductQuantizer::new`] for the standard
76/// `M = 8, K = 256` shape, or [`ProductQuantizer::with_config`] to
77/// pick `M`, `K`, and the training `seed` explicitly. Train it once
78/// with a representative sample, then quantize and compare. The
79/// trained quantizer is callable from multiple threads — it owns its
80/// calibration by value and exposes no interior mutability.
81///
82/// # Examples
83///
84/// ```
85/// use iqdb_quantize::{ProductQuantizer, Quantizer};
86/// use iqdb_types::DistanceMetric;
87///
88/// let mut pq = ProductQuantizer::with_config(2, 4, 7);
89/// let training: Vec<Vec<f32>> = (0..16)
90/// .map(|i| {
91/// let f = i as f32;
92/// vec![f, f + 1.0, f + 2.0, f + 3.0]
93/// })
94/// .collect();
95/// let refs: Vec<&[f32]> = training.iter().map(Vec::as_slice).collect();
96/// pq.train(&refs).expect("training succeeds");
97///
98/// let code = pq.quantize(&[1.0_f32, 2.0, 3.0, 4.0]).expect("quantize");
99/// let d = pq
100/// .distance(&[1.0_f32, 2.0, 3.0, 4.0], &code, DistanceMetric::Euclidean)
101/// .expect("supported metric");
102/// assert!(d.is_finite());
103/// ```
104#[derive(Debug, Clone, PartialEq)]
105pub struct ProductQuantizer {
106 n_subvectors: usize,
107 n_centroids: usize,
108 seed: u64,
109 calibration: Option<PqCalibration>,
110}
111
112impl Default for ProductQuantizer {
113 fn default() -> Self {
114 Self::new()
115 }
116}
117
118impl ProductQuantizer {
119 /// Build an untrained PQ with the standard shape (`M = 8`,
120 /// `K = 256`, `seed = 0`).
121 ///
122 /// Every hot method returns [`IqdbError::InvalidConfig`] until
123 /// [`Quantizer::train`] succeeds. The trained dimension must be a
124 /// multiple of `M`, so `new()`'s `M = 8` works for the common
125 /// embedding dimensions (128, 256, 384, 512, 768, 1024, …) but
126 /// not for, say, dim 50; use [`ProductQuantizer::with_config`]
127 /// when that matters.
128 ///
129 /// # Examples
130 ///
131 /// ```
132 /// use iqdb_quantize::ProductQuantizer;
133 /// let pq = ProductQuantizer::new();
134 /// assert_eq!(pq.n_subvectors(), 8);
135 /// assert_eq!(pq.n_centroids(), 256);
136 /// ```
137 #[must_use]
138 pub fn new() -> Self {
139 Self::with_config(DEFAULT_N_SUBVECTORS, DEFAULT_N_CENTROIDS, DEFAULT_SEED)
140 }
141
142 /// Build an untrained PQ with the given shape and training seed.
143 ///
144 /// All three parameters take effect at [`Quantizer::train`] time;
145 /// invalid combinations (e.g. `n_centroids == 0`, `n_centroids >
146 /// 256`, training dim not divisible by `n_subvectors`) surface as
147 /// [`IqdbError::InvalidConfig`] from `train`. The constructor
148 /// itself is infallible — it just stores the configuration.
149 ///
150 /// # Examples
151 ///
152 /// ```
153 /// use iqdb_quantize::ProductQuantizer;
154 /// let pq = ProductQuantizer::with_config(16, 256, 42);
155 /// assert_eq!(pq.n_subvectors(), 16);
156 /// assert_eq!(pq.n_centroids(), 256);
157 /// assert_eq!(pq.seed(), 42);
158 /// ```
159 #[must_use]
160 pub fn with_config(n_subvectors: usize, n_centroids: usize, seed: u64) -> Self {
161 Self {
162 n_subvectors,
163 n_centroids,
164 seed,
165 calibration: None,
166 }
167 }
168
169 /// The trained dimension, if any.
170 ///
171 /// # Examples
172 ///
173 /// ```
174 /// use iqdb_quantize::{ProductQuantizer, Quantizer};
175 /// let mut pq = ProductQuantizer::with_config(2, 4, 7);
176 /// assert_eq!(pq.dim(), None);
177 /// let data: Vec<Vec<f32>> = (0..8).map(|i| vec![i as f32; 4]).collect();
178 /// let refs: Vec<&[f32]> = data.iter().map(Vec::as_slice).collect();
179 /// pq.train(&refs).expect("ok");
180 /// assert_eq!(pq.dim(), Some(4));
181 /// ```
182 #[must_use]
183 pub fn dim(&self) -> Option<usize> {
184 self.calibration.as_ref().map(|c| c.dim)
185 }
186
187 /// The configured number of subvectors `M`.
188 ///
189 /// # Examples
190 ///
191 /// ```
192 /// use iqdb_quantize::ProductQuantizer;
193 /// assert_eq!(ProductQuantizer::with_config(4, 16, 1).n_subvectors(), 4);
194 /// ```
195 #[must_use]
196 pub fn n_subvectors(&self) -> usize {
197 self.n_subvectors
198 }
199
200 /// The configured number of centroids per subvector codebook `K`.
201 ///
202 /// # Examples
203 ///
204 /// ```
205 /// use iqdb_quantize::ProductQuantizer;
206 /// assert_eq!(ProductQuantizer::with_config(4, 16, 1).n_centroids(), 16);
207 /// ```
208 #[must_use]
209 pub fn n_centroids(&self) -> usize {
210 self.n_centroids
211 }
212
213 /// The configured training seed.
214 ///
215 /// Same seed + same training data ⇒ byte-identical codebooks.
216 ///
217 /// # Examples
218 ///
219 /// ```
220 /// use iqdb_quantize::ProductQuantizer;
221 /// assert_eq!(ProductQuantizer::with_config(4, 16, 99).seed(), 99);
222 /// ```
223 #[must_use]
224 pub fn seed(&self) -> u64 {
225 self.seed
226 }
227
228 fn calibration(&self) -> Result<&PqCalibration> {
229 self.calibration.as_ref().ok_or(IqdbError::InvalidConfig {
230 reason: "ProductQuantizer has not been trained",
231 })
232 }
233
234 /// Validate the configured shape against the training-set dimension.
235 /// Returns `sub_dim = dim / n_subvectors` on success.
236 fn validate_shape(&self, dim: usize, training_count: usize) -> Result<usize> {
237 if self.n_subvectors == 0 {
238 return Err(IqdbError::InvalidConfig {
239 reason: "ProductQuantizer requires n_subvectors >= 1",
240 });
241 }
242 if self.n_centroids == 0 {
243 return Err(IqdbError::InvalidConfig {
244 reason: "ProductQuantizer requires n_centroids >= 1",
245 });
246 }
247 if self.n_centroids > MAX_N_CENTROIDS {
248 return Err(IqdbError::InvalidConfig {
249 reason: "ProductQuantizer requires n_centroids <= 256 (one byte per code)",
250 });
251 }
252 if dim == 0 || !dim.is_multiple_of(self.n_subvectors) {
253 return Err(IqdbError::InvalidConfig {
254 reason: "ProductQuantizer requires training dim to be a positive multiple of n_subvectors",
255 });
256 }
257 if training_count < self.n_centroids {
258 return Err(IqdbError::InvalidConfig {
259 reason: "ProductQuantizer requires training_set.len() >= n_centroids",
260 });
261 }
262 Ok(dim / self.n_subvectors)
263 }
264}
265
266impl Quantizer for ProductQuantizer {
267 type Quantized = PqCode;
268
269 #[tracing::instrument(
270 level = "info",
271 skip_all,
272 fields(
273 quantizer = "pq",
274 training_size = vectors.len(),
275 n_subvectors = self.n_subvectors,
276 n_centroids = self.n_centroids,
277 ),
278 )]
279 fn train(&mut self, vectors: &[&[f32]]) -> Result<()> {
280 let dim = training_set(vectors).inspect_err(|err: &IqdbError| {
281 tracing::error!(
282 error.kind = err.kind(),
283 error.reason = err.caption(),
284 "product quantizer training failed",
285 );
286 })?;
287 let sub_dim = self
288 .validate_shape(dim, vectors.len())
289 .inspect_err(|err: &IqdbError| {
290 tracing::error!(
291 error.kind = err.kind(),
292 error.reason = err.caption(),
293 "product quantizer training failed",
294 );
295 })?;
296
297 // Build the per-subvector training slices and train one
298 // codebook per subvector position. The seed is per-subvector
299 // (`base_seed.wrapping_add(m as u64)`) so the M k-means runs
300 // don't all draw from the same PRNG state.
301 let mut codebooks: Vec<Vec<Vec<f32>>> = Vec::with_capacity(self.n_subvectors);
302 for m in 0..self.n_subvectors {
303 let start = m * sub_dim;
304 let end = start + sub_dim;
305 let slices: Vec<&[f32]> = vectors.iter().map(|v| &v[start..end]).collect();
306 let centroids = train_codebook(
307 sub_dim,
308 self.n_centroids,
309 self.seed.wrapping_add(m as u64),
310 &slices,
311 )
312 .inspect_err(|err: &IqdbError| {
313 tracing::error!(
314 error.kind = err.kind(),
315 error.reason = err.caption(),
316 subvector = m,
317 "product quantizer codebook training failed",
318 );
319 })?;
320 codebooks.push(centroids);
321 }
322
323 self.calibration = Some(PqCalibration {
324 dim,
325 n_subvectors: self.n_subvectors,
326 sub_dim,
327 n_centroids: self.n_centroids,
328 codebooks,
329 });
330 Ok(())
331 }
332
333 fn quantize(&self, vector: &[f32]) -> Result<Self::Quantized> {
334 let cal = self.calibration()?;
335 finite_non_empty(vector)?;
336 dim_eq(cal.dim, vector.len())?;
337 let mut codes: Vec<u8> = Vec::with_capacity(cal.n_subvectors);
338 for m in 0..cal.n_subvectors {
339 let start = m * cal.sub_dim;
340 let end = start + cal.sub_dim;
341 let idx = assign_to_cluster(&cal.codebooks[m], &vector[start..end]);
342 // `assign_to_cluster` returns an index in `0..n_centroids`,
343 // and `n_centroids <= 256` (enforced in `validate_shape`),
344 // so this cast cannot lose information.
345 codes.push(idx as u8);
346 }
347 Ok(PqCode {
348 codes,
349 dim: cal.dim,
350 n_subvectors: cal.n_subvectors,
351 })
352 }
353
354 fn dequantize(&self, quantized: &Self::Quantized) -> Result<Vec<f32>> {
355 let cal = self.calibration()?;
356 dim_eq(cal.dim, quantized.dim)?;
357 if quantized.n_subvectors != cal.n_subvectors {
358 return Err(IqdbError::DimensionMismatch {
359 expected: cal.n_subvectors,
360 found: quantized.n_subvectors,
361 });
362 }
363 let mut out: Vec<f32> = Vec::with_capacity(cal.dim);
364 for (m, &code) in quantized.codes.iter().enumerate() {
365 let centroid = &cal.codebooks[m][code as usize];
366 out.extend_from_slice(centroid);
367 }
368 Ok(out)
369 }
370
371 fn distance(
372 &self,
373 query: &[f32],
374 quantized: &Self::Quantized,
375 metric: DistanceMetric,
376 ) -> Result<f32> {
377 let tables = self.build_query_tables(query, metric)?;
378 tables.distance(quantized)
379 }
380}
381
382impl ProductQuantizer {
383 /// Build the ADC lookup tables for `(query, metric)` once so the
384 /// caller can score many [`PqCode`]s against the same query
385 /// without rebuilding the `M × K` table per call.
386 ///
387 /// This is the primitive that
388 /// [`Quantizer::distance`](crate::Quantizer::distance) is built
389 /// on; callers scoring a single code can keep using `distance`
390 /// directly. Use this method when scoring a batch — e.g.
391 /// IVF-PQ's intra-cluster scan, which builds the table once per
392 /// query and then scores every code in every probed cluster.
393 ///
394 /// # Errors
395 ///
396 /// Returns [`IqdbError::InvalidConfig`] if the quantizer is
397 /// untrained, [`IqdbError::InvalidVector`] if `query` is empty or
398 /// non-finite, [`IqdbError::DimensionMismatch`] if `query.len()`
399 /// doesn't match the trained dim, or [`IqdbError::InvalidMetric`]
400 /// for [`DistanceMetric::Cosine`] / [`DistanceMetric::Hamming`].
401 ///
402 /// # Examples
403 ///
404 /// ```
405 /// use iqdb_quantize::{ProductQuantizer, Quantizer};
406 /// use iqdb_types::DistanceMetric;
407 ///
408 /// let mut pq = ProductQuantizer::with_config(2, 4, 7);
409 /// let training: Vec<Vec<f32>> = (0..16)
410 /// .map(|i| {
411 /// let f = i as f32;
412 /// vec![f, f + 1.0, f + 2.0, f + 3.0]
413 /// })
414 /// .collect();
415 /// let refs: Vec<&[f32]> = training.iter().map(Vec::as_slice).collect();
416 /// pq.train(&refs).expect("training succeeds");
417 ///
418 /// let code_a = pq.quantize(&[1.0_f32, 2.0, 3.0, 4.0]).expect("quantize");
419 /// let code_b = pq.quantize(&[5.0_f32, 6.0, 7.0, 8.0]).expect("quantize");
420 ///
421 /// // Build the table ONCE for this (query, metric), then score many codes.
422 /// let query = [1.0_f32, 2.0, 3.0, 4.0];
423 /// let tables = pq
424 /// .build_query_tables(&query, DistanceMetric::Euclidean)
425 /// .expect("supported metric");
426 /// let d_a = tables.distance(&code_a).expect("matching code shape");
427 /// let d_b = tables.distance(&code_b).expect("matching code shape");
428 /// assert!(d_a.is_finite() && d_b.is_finite());
429 /// ```
430 pub fn build_query_tables(&self, query: &[f32], metric: DistanceMetric) -> Result<PqAdcTables> {
431 let cal = self.calibration()?;
432 finite_non_empty(query)?;
433 dim_eq(cal.dim, query.len())?;
434 match metric {
435 DistanceMetric::Euclidean | DistanceMetric::DotProduct | DistanceMetric::Manhattan => {}
436 DistanceMetric::Cosine | DistanceMetric::Hamming => {
437 return Err(IqdbError::InvalidMetric);
438 }
439 // `DistanceMetric` is `#[non_exhaustive]` in published iqdb-types
440 // v1.0.0; any future variant defaults to InvalidMetric until PQ
441 // explicitly opts in. Behavior on the five existing variants is
442 // unchanged.
443 _ => return Err(IqdbError::InvalidMetric),
444 }
445 let table = build_adc_table_rows(query, metric, cal)?;
446 Ok(PqAdcTables {
447 table,
448 metric,
449 n_subvectors: cal.n_subvectors,
450 n_centroids: cal.n_centroids,
451 dim: cal.dim,
452 })
453 }
454}
455
456/// Per-`(query, metric)` precomputed ADC lookup tables built from a
457/// [`ProductQuantizer`].
458///
459/// Build once with [`ProductQuantizer::build_query_tables`], then
460/// score many [`PqCode`]s against it via [`PqAdcTables::distance`]
461/// without rebuilding the `M × K` table per call.
462///
463/// Row `m` of the internal table holds the distances from query
464/// subvector `q_m` to each of the `K` centroids of codebook `m`,
465/// packed row-major. For [`DistanceMetric::Euclidean`] the row holds
466/// **squared L2** values (so they sum decomposably across
467/// subvectors); [`PqAdcTables::distance`] takes `sqrt` of the total
468/// exactly once for Euclidean.
469#[derive(Debug, Clone)]
470pub struct PqAdcTables {
471 /// `n_subvectors * n_centroids` entries, row-major.
472 table: Vec<f32>,
473 metric: DistanceMetric,
474 n_subvectors: usize,
475 n_centroids: usize,
476 dim: usize,
477}
478
479impl PqAdcTables {
480 /// Score a single [`PqCode`] against the prepared tables.
481 ///
482 /// The returned value matches
483 /// [`Quantizer::distance`](crate::Quantizer::distance) for the
484 /// same `(query, code, metric)` — for [`DistanceMetric::Euclidean`]
485 /// the table holds squared L2 per subvector and this method
486 /// `sqrt`s the sum exactly once; the other supported metrics
487 /// (`DotProduct`, `Manhattan`) sum directly.
488 ///
489 /// # Errors
490 ///
491 /// Returns [`IqdbError::DimensionMismatch`] if `code` was produced
492 /// by a [`ProductQuantizer`] with a different `M` or trained `dim`
493 /// — typically the same quantizer that built the tables.
494 pub fn distance(&self, code: &PqCode) -> Result<f32> {
495 if code.n_subvectors != self.n_subvectors {
496 return Err(IqdbError::DimensionMismatch {
497 expected: self.n_subvectors,
498 found: code.n_subvectors,
499 });
500 }
501 if code.dim != self.dim {
502 return Err(IqdbError::DimensionMismatch {
503 expected: self.dim,
504 found: code.dim,
505 });
506 }
507 let total = score_code_rows(&self.table, code, self.n_centroids);
508 Ok(if self.metric == DistanceMetric::Euclidean {
509 total.sqrt()
510 } else {
511 total
512 })
513 }
514
515 /// The metric these tables were built for.
516 #[must_use]
517 pub fn metric(&self) -> DistanceMetric {
518 self.metric
519 }
520
521 /// The number of subvectors `M`.
522 #[must_use]
523 pub fn n_subvectors(&self) -> usize {
524 self.n_subvectors
525 }
526
527 /// The number of centroids per subvector codebook `K`.
528 #[must_use]
529 pub fn n_centroids(&self) -> usize {
530 self.n_centroids
531 }
532
533 /// The trained dimension these tables were built against.
534 #[must_use]
535 pub fn dim(&self) -> usize {
536 self.dim
537 }
538}
539
540fn build_adc_table_rows(
541 query: &[f32],
542 metric: DistanceMetric,
543 cal: &PqCalibration,
544) -> Result<Vec<f32>> {
545 let total_entries = cal.n_subvectors * cal.n_centroids;
546 let mut table: Vec<f32> = vec![0.0; total_entries];
547 let mut centroid_refs: Vec<&[f32]> = Vec::with_capacity(cal.n_centroids);
548 for m in 0..cal.n_subvectors {
549 let start = m * cal.sub_dim;
550 let end = start + cal.sub_dim;
551 let q_sub = &query[start..end];
552 let row_start = m * cal.n_centroids;
553 let row_end = row_start + cal.n_centroids;
554 let row = &mut table[row_start..row_end];
555
556 match metric {
557 DistanceMetric::Euclidean => {
558 // Squared L2 per centroid, summed decomposably across
559 // subvectors. The caller takes `sqrt` of the total in
560 // `PqAdcTables::distance`.
561 for (k, centroid) in cal.codebooks[m].iter().enumerate() {
562 row[k] = squared_l2(q_sub, centroid);
563 }
564 }
565 DistanceMetric::DotProduct | DistanceMetric::Manhattan => {
566 centroid_refs.clear();
567 for centroid in &cal.codebooks[m] {
568 centroid_refs.push(centroid.as_slice());
569 }
570 compute_batch(metric, q_sub, ¢roid_refs, row)?;
571 }
572 DistanceMetric::Cosine | DistanceMetric::Hamming => {
573 // Rejected earlier in `build_query_tables` — expressing
574 // it as an error here keeps the match total without a
575 // panic if the upstream guard is ever relaxed.
576 return Err(IqdbError::InvalidMetric);
577 }
578 // `DistanceMetric` is `#[non_exhaustive]` in published iqdb-types
579 // v1.0.0; same defensive treatment as `build_query_tables`.
580 _ => return Err(IqdbError::InvalidMetric),
581 }
582 }
583 Ok(table)
584}
585
586fn score_code_rows(table: &[f32], code: &PqCode, n_centroids: usize) -> f32 {
587 let mut sum: f32 = 0.0;
588 for (m, &c) in code.codes.iter().enumerate() {
589 let row_start = m * n_centroids;
590 sum += table[row_start + c as usize];
591 }
592 sum
593}