Skip to main content

iqdb_quantize/
product.rs

1//! [`ProductQuantizer`] — product quantization (PQ).
2//!
3//! PQ splits each input vector into `M = n_subvectors` equal-length
4//! chunks and learns a small codebook of `K = n_centroids` centroids
5//! (with `K <= 256`) for each chunk via k-means. A vector compresses
6//! to `M` bytes — one centroid index per chunk — for a compression
7//! ratio of `(dim * 4) / M` (e.g. 768 dims at `M = 16` → 16 bytes,
8//! 192×). Reconstruction error trades off cleanly against `M` and `K`.
9//!
10//! Asymmetric distance computation (ADC) keeps the query in `f32`,
11//! precomputes a per-subvector distance table from the query to each
12//! of the `K` centroids, and scores a stored code with `M` table
13//! lookups plus a single summation pass. The math is decomposable —
14//! and so **PQ ADC returns the same value as
15//! [`Quantizer::distance`](crate::Quantizer::distance) would after
16//! [`Quantizer::dequantize`](crate::Quantizer::dequantize) +
17//! [`iqdb_distance::compute`]** — for every metric where it's
18//! supported.
19//!
20//! ## Supported metrics
21//!
22//! | Metric                                  | Supported | Why                                                              |
23//! |-----------------------------------------|-----------|------------------------------------------------------------------|
24//! | [`DistanceMetric::Euclidean`]           | yes       | `L2² = Σ_m L2²(q_m, c_m)`; take `sqrt` once at the end.          |
25//! | [`DistanceMetric::DotProduct`]          | yes       | `dot = Σ_m dot(q_m, c_m)`; raw inner product (matches SQ8).      |
26//! | [`DistanceMetric::Manhattan`]           | yes       | `L1 = Σ_m L1(q_m, c_m)`.                                         |
27//! | [`DistanceMetric::Cosine`]              | **no**    | Requires a global `‖c‖` PQ cannot recover per subvector. Returns `InvalidMetric`.|
28//! | [`DistanceMetric::Hamming`]             | **no**    | Meaningless on `f32` codes. Returns `InvalidMetric`.             |
29//!
30//! Production practice: L2-normalize vectors before training and use
31//! [`DistanceMetric::DotProduct`] when you want cosine semantics.
32//!
33//! [`DistanceMetric::Euclidean`]: iqdb_types::DistanceMetric::Euclidean
34//! [`DistanceMetric::DotProduct`]: iqdb_types::DistanceMetric::DotProduct
35//! [`DistanceMetric::Manhattan`]: iqdb_types::DistanceMetric::Manhattan
36//! [`DistanceMetric::Cosine`]: iqdb_types::DistanceMetric::Cosine
37//! [`DistanceMetric::Hamming`]: iqdb_types::DistanceMetric::Hamming
38
39use error_forge::ForgeError;
40use iqdb_distance::compute_batch;
41use iqdb_types::{DistanceMetric, IqdbError, Result};
42
43use crate::code::PqCode;
44use crate::train::{assign_to_cluster, squared_l2, train_codebook};
45use crate::traits::Quantizer;
46use crate::validate::{dim_eq, finite_non_empty, training_set};
47
48/// Default number of subvectors used by [`ProductQuantizer::new`].
49const DEFAULT_N_SUBVECTORS: usize = 8;
50/// Default number of centroids per subvector used by [`ProductQuantizer::new`].
51const DEFAULT_N_CENTROIDS: usize = 256;
52/// Upper bound on `n_centroids`: codes are stored as `u8`.
53const MAX_N_CENTROIDS: usize = 256;
54/// Default seed used by [`ProductQuantizer::new`].
55const DEFAULT_SEED: u64 = 0;
56
57/// Calibration learned during [`ProductQuantizer::train`].
58#[derive(Debug, Clone, PartialEq)]
59struct PqCalibration {
60    /// The trained input dimension; equals `n_subvectors * sub_dim`.
61    dim: usize,
62    /// `M`, the number of subvectors.
63    n_subvectors: usize,
64    /// `dim / n_subvectors`.
65    sub_dim: usize,
66    /// `K`, the number of centroids per subvector codebook.
67    n_centroids: usize,
68    /// `codebooks[m][k]` is the `k`-th centroid of subvector `m`,
69    /// stored as a `Vec<f32>` of length `sub_dim`.
70    codebooks: Vec<Vec<Vec<f32>>>,
71}
72
73/// Product quantizer: `M` subvectors × `K` centroids per subvector.
74///
75/// Build one with [`ProductQuantizer::new`] for the standard
76/// `M = 8, K = 256` shape, or [`ProductQuantizer::with_config`] to
77/// pick `M`, `K`, and the training `seed` explicitly. Train it once
78/// with a representative sample, then quantize and compare. The
79/// trained quantizer is callable from multiple threads — it owns its
80/// calibration by value and exposes no interior mutability.
81///
82/// # Examples
83///
84/// ```
85/// use iqdb_quantize::{ProductQuantizer, Quantizer};
86/// use iqdb_types::DistanceMetric;
87///
88/// let mut pq = ProductQuantizer::with_config(2, 4, 7);
89/// let training: Vec<Vec<f32>> = (0..16)
90///     .map(|i| {
91///         let f = i as f32;
92///         vec![f, f + 1.0, f + 2.0, f + 3.0]
93///     })
94///     .collect();
95/// let refs: Vec<&[f32]> = training.iter().map(Vec::as_slice).collect();
96/// pq.train(&refs).expect("training succeeds");
97///
98/// let code = pq.quantize(&[1.0_f32, 2.0, 3.0, 4.0]).expect("quantize");
99/// let d = pq
100///     .distance(&[1.0_f32, 2.0, 3.0, 4.0], &code, DistanceMetric::Euclidean)
101///     .expect("supported metric");
102/// assert!(d.is_finite());
103/// ```
104#[derive(Debug, Clone, PartialEq)]
105pub struct ProductQuantizer {
106    n_subvectors: usize,
107    n_centroids: usize,
108    seed: u64,
109    calibration: Option<PqCalibration>,
110}
111
112impl Default for ProductQuantizer {
113    fn default() -> Self {
114        Self::new()
115    }
116}
117
118impl ProductQuantizer {
119    /// Build an untrained PQ with the standard shape (`M = 8`,
120    /// `K = 256`, `seed = 0`).
121    ///
122    /// Every hot method returns [`IqdbError::InvalidConfig`] until
123    /// [`Quantizer::train`] succeeds. The trained dimension must be a
124    /// multiple of `M`, so `new()`'s `M = 8` works for the common
125    /// embedding dimensions (128, 256, 384, 512, 768, 1024, …) but
126    /// not for, say, dim 50; use [`ProductQuantizer::with_config`]
127    /// when that matters.
128    ///
129    /// # Examples
130    ///
131    /// ```
132    /// use iqdb_quantize::ProductQuantizer;
133    /// let pq = ProductQuantizer::new();
134    /// assert_eq!(pq.n_subvectors(), 8);
135    /// assert_eq!(pq.n_centroids(), 256);
136    /// ```
137    #[must_use]
138    pub fn new() -> Self {
139        Self::with_config(DEFAULT_N_SUBVECTORS, DEFAULT_N_CENTROIDS, DEFAULT_SEED)
140    }
141
142    /// Build an untrained PQ with the given shape and training seed.
143    ///
144    /// All three parameters take effect at [`Quantizer::train`] time;
145    /// invalid combinations (e.g. `n_centroids == 0`, `n_centroids >
146    /// 256`, training dim not divisible by `n_subvectors`) surface as
147    /// [`IqdbError::InvalidConfig`] from `train`. The constructor
148    /// itself is infallible — it just stores the configuration.
149    ///
150    /// # Examples
151    ///
152    /// ```
153    /// use iqdb_quantize::ProductQuantizer;
154    /// let pq = ProductQuantizer::with_config(16, 256, 42);
155    /// assert_eq!(pq.n_subvectors(), 16);
156    /// assert_eq!(pq.n_centroids(), 256);
157    /// assert_eq!(pq.seed(), 42);
158    /// ```
159    #[must_use]
160    pub fn with_config(n_subvectors: usize, n_centroids: usize, seed: u64) -> Self {
161        Self {
162            n_subvectors,
163            n_centroids,
164            seed,
165            calibration: None,
166        }
167    }
168
169    /// The trained dimension, if any.
170    ///
171    /// # Examples
172    ///
173    /// ```
174    /// use iqdb_quantize::{ProductQuantizer, Quantizer};
175    /// let mut pq = ProductQuantizer::with_config(2, 4, 7);
176    /// assert_eq!(pq.dim(), None);
177    /// let data: Vec<Vec<f32>> = (0..8).map(|i| vec![i as f32; 4]).collect();
178    /// let refs: Vec<&[f32]> = data.iter().map(Vec::as_slice).collect();
179    /// pq.train(&refs).expect("ok");
180    /// assert_eq!(pq.dim(), Some(4));
181    /// ```
182    #[must_use]
183    pub fn dim(&self) -> Option<usize> {
184        self.calibration.as_ref().map(|c| c.dim)
185    }
186
187    /// The configured number of subvectors `M`.
188    ///
189    /// # Examples
190    ///
191    /// ```
192    /// use iqdb_quantize::ProductQuantizer;
193    /// assert_eq!(ProductQuantizer::with_config(4, 16, 1).n_subvectors(), 4);
194    /// ```
195    #[must_use]
196    pub fn n_subvectors(&self) -> usize {
197        self.n_subvectors
198    }
199
200    /// The configured number of centroids per subvector codebook `K`.
201    ///
202    /// # Examples
203    ///
204    /// ```
205    /// use iqdb_quantize::ProductQuantizer;
206    /// assert_eq!(ProductQuantizer::with_config(4, 16, 1).n_centroids(), 16);
207    /// ```
208    #[must_use]
209    pub fn n_centroids(&self) -> usize {
210        self.n_centroids
211    }
212
213    /// The configured training seed.
214    ///
215    /// Same seed + same training data ⇒ byte-identical codebooks.
216    ///
217    /// # Examples
218    ///
219    /// ```
220    /// use iqdb_quantize::ProductQuantizer;
221    /// assert_eq!(ProductQuantizer::with_config(4, 16, 99).seed(), 99);
222    /// ```
223    #[must_use]
224    pub fn seed(&self) -> u64 {
225        self.seed
226    }
227
228    fn calibration(&self) -> Result<&PqCalibration> {
229        self.calibration.as_ref().ok_or(IqdbError::InvalidConfig {
230            reason: "ProductQuantizer has not been trained",
231        })
232    }
233
234    /// Validate the configured shape against the training-set dimension.
235    /// Returns `sub_dim = dim / n_subvectors` on success.
236    fn validate_shape(&self, dim: usize, training_count: usize) -> Result<usize> {
237        if self.n_subvectors == 0 {
238            return Err(IqdbError::InvalidConfig {
239                reason: "ProductQuantizer requires n_subvectors >= 1",
240            });
241        }
242        if self.n_centroids == 0 {
243            return Err(IqdbError::InvalidConfig {
244                reason: "ProductQuantizer requires n_centroids >= 1",
245            });
246        }
247        if self.n_centroids > MAX_N_CENTROIDS {
248            return Err(IqdbError::InvalidConfig {
249                reason: "ProductQuantizer requires n_centroids <= 256 (one byte per code)",
250            });
251        }
252        if dim == 0 || !dim.is_multiple_of(self.n_subvectors) {
253            return Err(IqdbError::InvalidConfig {
254                reason: "ProductQuantizer requires training dim to be a positive multiple of n_subvectors",
255            });
256        }
257        if training_count < self.n_centroids {
258            return Err(IqdbError::InvalidConfig {
259                reason: "ProductQuantizer requires training_set.len() >= n_centroids",
260            });
261        }
262        Ok(dim / self.n_subvectors)
263    }
264}
265
266impl Quantizer for ProductQuantizer {
267    type Quantized = PqCode;
268
269    #[tracing::instrument(
270        level = "info",
271        skip_all,
272        fields(
273            quantizer = "pq",
274            training_size = vectors.len(),
275            n_subvectors = self.n_subvectors,
276            n_centroids = self.n_centroids,
277        ),
278    )]
279    fn train(&mut self, vectors: &[&[f32]]) -> Result<()> {
280        let dim = training_set(vectors).inspect_err(|err: &IqdbError| {
281            tracing::error!(
282                error.kind = err.kind(),
283                error.reason = err.caption(),
284                "product quantizer training failed",
285            );
286        })?;
287        let sub_dim = self
288            .validate_shape(dim, vectors.len())
289            .inspect_err(|err: &IqdbError| {
290                tracing::error!(
291                    error.kind = err.kind(),
292                    error.reason = err.caption(),
293                    "product quantizer training failed",
294                );
295            })?;
296
297        // Build the per-subvector training slices and train one
298        // codebook per subvector position. The seed is per-subvector
299        // (`base_seed.wrapping_add(m as u64)`) so the M k-means runs
300        // don't all draw from the same PRNG state.
301        let mut codebooks: Vec<Vec<Vec<f32>>> = Vec::with_capacity(self.n_subvectors);
302        for m in 0..self.n_subvectors {
303            let start = m * sub_dim;
304            let end = start + sub_dim;
305            let slices: Vec<&[f32]> = vectors.iter().map(|v| &v[start..end]).collect();
306            let centroids = train_codebook(
307                sub_dim,
308                self.n_centroids,
309                self.seed.wrapping_add(m as u64),
310                &slices,
311            )
312            .inspect_err(|err: &IqdbError| {
313                tracing::error!(
314                    error.kind = err.kind(),
315                    error.reason = err.caption(),
316                    subvector = m,
317                    "product quantizer codebook training failed",
318                );
319            })?;
320            codebooks.push(centroids);
321        }
322
323        self.calibration = Some(PqCalibration {
324            dim,
325            n_subvectors: self.n_subvectors,
326            sub_dim,
327            n_centroids: self.n_centroids,
328            codebooks,
329        });
330        Ok(())
331    }
332
333    fn quantize(&self, vector: &[f32]) -> Result<Self::Quantized> {
334        let cal = self.calibration()?;
335        finite_non_empty(vector)?;
336        dim_eq(cal.dim, vector.len())?;
337        let mut codes: Vec<u8> = Vec::with_capacity(cal.n_subvectors);
338        for m in 0..cal.n_subvectors {
339            let start = m * cal.sub_dim;
340            let end = start + cal.sub_dim;
341            let idx = assign_to_cluster(&cal.codebooks[m], &vector[start..end]);
342            // `assign_to_cluster` returns an index in `0..n_centroids`,
343            // and `n_centroids <= 256` (enforced in `validate_shape`),
344            // so this cast cannot lose information.
345            codes.push(idx as u8);
346        }
347        Ok(PqCode {
348            codes,
349            dim: cal.dim,
350            n_subvectors: cal.n_subvectors,
351        })
352    }
353
354    fn dequantize(&self, quantized: &Self::Quantized) -> Result<Vec<f32>> {
355        let cal = self.calibration()?;
356        dim_eq(cal.dim, quantized.dim)?;
357        if quantized.n_subvectors != cal.n_subvectors {
358            return Err(IqdbError::DimensionMismatch {
359                expected: cal.n_subvectors,
360                found: quantized.n_subvectors,
361            });
362        }
363        let mut out: Vec<f32> = Vec::with_capacity(cal.dim);
364        for (m, &code) in quantized.codes.iter().enumerate() {
365            let centroid = &cal.codebooks[m][code as usize];
366            out.extend_from_slice(centroid);
367        }
368        Ok(out)
369    }
370
371    fn distance(
372        &self,
373        query: &[f32],
374        quantized: &Self::Quantized,
375        metric: DistanceMetric,
376    ) -> Result<f32> {
377        let tables = self.build_query_tables(query, metric)?;
378        tables.distance(quantized)
379    }
380}
381
382impl ProductQuantizer {
383    /// Build the ADC lookup tables for `(query, metric)` once so the
384    /// caller can score many [`PqCode`]s against the same query
385    /// without rebuilding the `M × K` table per call.
386    ///
387    /// This is the primitive that
388    /// [`Quantizer::distance`](crate::Quantizer::distance) is built
389    /// on; callers scoring a single code can keep using `distance`
390    /// directly. Use this method when scoring a batch — e.g.
391    /// IVF-PQ's intra-cluster scan, which builds the table once per
392    /// query and then scores every code in every probed cluster.
393    ///
394    /// # Errors
395    ///
396    /// Returns [`IqdbError::InvalidConfig`] if the quantizer is
397    /// untrained, [`IqdbError::InvalidVector`] if `query` is empty or
398    /// non-finite, [`IqdbError::DimensionMismatch`] if `query.len()`
399    /// doesn't match the trained dim, or [`IqdbError::InvalidMetric`]
400    /// for [`DistanceMetric::Cosine`] / [`DistanceMetric::Hamming`].
401    ///
402    /// # Examples
403    ///
404    /// ```
405    /// use iqdb_quantize::{ProductQuantizer, Quantizer};
406    /// use iqdb_types::DistanceMetric;
407    ///
408    /// let mut pq = ProductQuantizer::with_config(2, 4, 7);
409    /// let training: Vec<Vec<f32>> = (0..16)
410    ///     .map(|i| {
411    ///         let f = i as f32;
412    ///         vec![f, f + 1.0, f + 2.0, f + 3.0]
413    ///     })
414    ///     .collect();
415    /// let refs: Vec<&[f32]> = training.iter().map(Vec::as_slice).collect();
416    /// pq.train(&refs).expect("training succeeds");
417    ///
418    /// let code_a = pq.quantize(&[1.0_f32, 2.0, 3.0, 4.0]).expect("quantize");
419    /// let code_b = pq.quantize(&[5.0_f32, 6.0, 7.0, 8.0]).expect("quantize");
420    ///
421    /// // Build the table ONCE for this (query, metric), then score many codes.
422    /// let query = [1.0_f32, 2.0, 3.0, 4.0];
423    /// let tables = pq
424    ///     .build_query_tables(&query, DistanceMetric::Euclidean)
425    ///     .expect("supported metric");
426    /// let d_a = tables.distance(&code_a).expect("matching code shape");
427    /// let d_b = tables.distance(&code_b).expect("matching code shape");
428    /// assert!(d_a.is_finite() && d_b.is_finite());
429    /// ```
430    pub fn build_query_tables(&self, query: &[f32], metric: DistanceMetric) -> Result<PqAdcTables> {
431        let cal = self.calibration()?;
432        finite_non_empty(query)?;
433        dim_eq(cal.dim, query.len())?;
434        match metric {
435            DistanceMetric::Euclidean | DistanceMetric::DotProduct | DistanceMetric::Manhattan => {}
436            DistanceMetric::Cosine | DistanceMetric::Hamming => {
437                return Err(IqdbError::InvalidMetric);
438            }
439            // `DistanceMetric` is `#[non_exhaustive]` in published iqdb-types
440            // v1.0.0; any future variant defaults to InvalidMetric until PQ
441            // explicitly opts in. Behavior on the five existing variants is
442            // unchanged.
443            _ => return Err(IqdbError::InvalidMetric),
444        }
445        let table = build_adc_table_rows(query, metric, cal)?;
446        Ok(PqAdcTables {
447            table,
448            metric,
449            n_subvectors: cal.n_subvectors,
450            n_centroids: cal.n_centroids,
451            dim: cal.dim,
452        })
453    }
454}
455
456/// Per-`(query, metric)` precomputed ADC lookup tables built from a
457/// [`ProductQuantizer`].
458///
459/// Build once with [`ProductQuantizer::build_query_tables`], then
460/// score many [`PqCode`]s against it via [`PqAdcTables::distance`]
461/// without rebuilding the `M × K` table per call.
462///
463/// Row `m` of the internal table holds the distances from query
464/// subvector `q_m` to each of the `K` centroids of codebook `m`,
465/// packed row-major. For [`DistanceMetric::Euclidean`] the row holds
466/// **squared L2** values (so they sum decomposably across
467/// subvectors); [`PqAdcTables::distance`] takes `sqrt` of the total
468/// exactly once for Euclidean.
469#[derive(Debug, Clone)]
470pub struct PqAdcTables {
471    /// `n_subvectors * n_centroids` entries, row-major.
472    table: Vec<f32>,
473    metric: DistanceMetric,
474    n_subvectors: usize,
475    n_centroids: usize,
476    dim: usize,
477}
478
479impl PqAdcTables {
480    /// Score a single [`PqCode`] against the prepared tables.
481    ///
482    /// The returned value matches
483    /// [`Quantizer::distance`](crate::Quantizer::distance) for the
484    /// same `(query, code, metric)` — for [`DistanceMetric::Euclidean`]
485    /// the table holds squared L2 per subvector and this method
486    /// `sqrt`s the sum exactly once; the other supported metrics
487    /// (`DotProduct`, `Manhattan`) sum directly.
488    ///
489    /// # Errors
490    ///
491    /// Returns [`IqdbError::DimensionMismatch`] if `code` was produced
492    /// by a [`ProductQuantizer`] with a different `M` or trained `dim`
493    /// — typically the same quantizer that built the tables.
494    pub fn distance(&self, code: &PqCode) -> Result<f32> {
495        if code.n_subvectors != self.n_subvectors {
496            return Err(IqdbError::DimensionMismatch {
497                expected: self.n_subvectors,
498                found: code.n_subvectors,
499            });
500        }
501        if code.dim != self.dim {
502            return Err(IqdbError::DimensionMismatch {
503                expected: self.dim,
504                found: code.dim,
505            });
506        }
507        let total = score_code_rows(&self.table, code, self.n_centroids);
508        Ok(if self.metric == DistanceMetric::Euclidean {
509            total.sqrt()
510        } else {
511            total
512        })
513    }
514
515    /// The metric these tables were built for.
516    #[must_use]
517    pub fn metric(&self) -> DistanceMetric {
518        self.metric
519    }
520
521    /// The number of subvectors `M`.
522    #[must_use]
523    pub fn n_subvectors(&self) -> usize {
524        self.n_subvectors
525    }
526
527    /// The number of centroids per subvector codebook `K`.
528    #[must_use]
529    pub fn n_centroids(&self) -> usize {
530        self.n_centroids
531    }
532
533    /// The trained dimension these tables were built against.
534    #[must_use]
535    pub fn dim(&self) -> usize {
536        self.dim
537    }
538}
539
540fn build_adc_table_rows(
541    query: &[f32],
542    metric: DistanceMetric,
543    cal: &PqCalibration,
544) -> Result<Vec<f32>> {
545    let total_entries = cal.n_subvectors * cal.n_centroids;
546    let mut table: Vec<f32> = vec![0.0; total_entries];
547    let mut centroid_refs: Vec<&[f32]> = Vec::with_capacity(cal.n_centroids);
548    for m in 0..cal.n_subvectors {
549        let start = m * cal.sub_dim;
550        let end = start + cal.sub_dim;
551        let q_sub = &query[start..end];
552        let row_start = m * cal.n_centroids;
553        let row_end = row_start + cal.n_centroids;
554        let row = &mut table[row_start..row_end];
555
556        match metric {
557            DistanceMetric::Euclidean => {
558                // Squared L2 per centroid, summed decomposably across
559                // subvectors. The caller takes `sqrt` of the total in
560                // `PqAdcTables::distance`.
561                for (k, centroid) in cal.codebooks[m].iter().enumerate() {
562                    row[k] = squared_l2(q_sub, centroid);
563                }
564            }
565            DistanceMetric::DotProduct | DistanceMetric::Manhattan => {
566                centroid_refs.clear();
567                for centroid in &cal.codebooks[m] {
568                    centroid_refs.push(centroid.as_slice());
569                }
570                compute_batch(metric, q_sub, &centroid_refs, row)?;
571            }
572            DistanceMetric::Cosine | DistanceMetric::Hamming => {
573                // Rejected earlier in `build_query_tables` — expressing
574                // it as an error here keeps the match total without a
575                // panic if the upstream guard is ever relaxed.
576                return Err(IqdbError::InvalidMetric);
577            }
578            // `DistanceMetric` is `#[non_exhaustive]` in published iqdb-types
579            // v1.0.0; same defensive treatment as `build_query_tables`.
580            _ => return Err(IqdbError::InvalidMetric),
581        }
582    }
583    Ok(table)
584}
585
586fn score_code_rows(table: &[f32], code: &PqCode, n_centroids: usize) -> f32 {
587    let mut sum: f32 = 0.0;
588    for (m, &c) in code.codes.iter().enumerate() {
589        let row_start = m * n_centroids;
590        sum += table[row_start + c as usize];
591    }
592    sum
593}