Skip to main content

nodedb_codec/vector_quant/
opq.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Optimized Product Quantization (OPQ) — Non-Para OPQ via iterative
4//! SVD-Procrustes rotation that minimizes PQ reconstruction error, yielding
5//! 10–20% recall improvement over vanilla PQ at equal memory.
6//!
7//! # Algorithm
8//!
9//! OPQ wraps standard PQ with a learned rotation matrix `R` (dim × dim,
10//! row-major) applied before codebook training and at query time:
11//!
12//! ```text
13//! encode(v) = PQ_encode(R · v)
14//! distance(q, v) = ADC(R · q, PQ_code(v))
15//! ```
16//!
17//! ## Non-Para OPQ (Ge et al., CVPR 2013)
18//!
19//! The rotation is learned by alternating between two steps until convergence:
20//!
21//! 1. **Codebook step** — hold `R` fixed, train PQ codebooks on the rotated
22//!    training set `R · X` via Lloyd's k-means.
23//! 2. **Procrustes step** — hold codebooks fixed, update `R` to minimize the
24//!    Frobenius reconstruction error ‖R·X − reconstruct(quantize(R·X))‖_F
25//!    via closed-form SVD:
26//!    - Let `Y` = dequantized reconstruction of `R·X` (dim × N matrix).
27//!    - Compute cross-correlation `M = X · Yᵀ`  (dim × dim).
28//!    - SVD: `M = U · Σ · Vᵀ`.
29//!    - New rotation: `R = V · Uᵀ`.
30//!
31//! This alternation is repeated for `opq_iters` iterations (default 5).
32//!
33//! ## Storage format
34//!
35//! `QuantMode::Pq` is reused in `UnifiedQuantizedVector` headers — OPQ is
36//! structurally PQ post-rotation and requires no new on-disk discriminant.
37//! The rotation matrix is stored in `OpqCodec` and applied transparently.
38
39use nalgebra::{DMatrix, SVD};
40
41use crate::vector_quant::codec::{AdcLut, VectorCodec};
42use crate::vector_quant::layout::{QuantHeader, QuantMode, UnifiedQuantizedVector};
43use crate::vector_quant::opq_kmeans::l2_sq;
44use crate::vector_quant::opq_kmeans::lloyd;
45
46// ── OpqCodec ──────────────────────────────────────────────────────────────────
47
48/// Optimized Product Quantization codec.
49///
50/// Stores a learned rotation matrix `R` (dim × dim, row-major) and PQ
51/// codebooks trained on the rotated training set via Non-Para OPQ iterations.
52pub struct OpqCodec {
53    pub dim: usize,
54    /// Number of PQ subspaces.
55    pub m: usize,
56    /// Centroids per subspace (256 for u8 codes).
57    pub k: usize,
58    pub sub_dim: usize,
59    /// Learned rotation matrix R (dim × dim, row-major).
60    rotation: Vec<f32>,
61    /// PQ codebooks trained on R·v: \[M\]\[K\]\[sub_dim\].
62    codebooks: Vec<Vec<Vec<f32>>>,
63}
64
65impl OpqCodec {
66    /// Train an OPQ codec using the Non-Para OPQ algorithm.
67    ///
68    /// Alternates between a codebook step (Lloyd's k-means on the rotated
69    /// training set) and a Procrustes step (SVD-based rotation update to
70    /// minimize reconstruction error) for `opq_iters` iterations.
71    ///
72    /// - `opq_iters`: number of alternating Procrustes+codebook iterations.
73    /// - `kmeans_iters`: Lloyd's k-means iterations per subspace per OPQ iter.
74    pub fn train(
75        vectors: &[&[f32]],
76        dim: usize,
77        m: usize,
78        k: usize,
79        opq_iters: usize,
80        kmeans_iters: usize,
81    ) -> Self {
82        assert!(!vectors.is_empty(), "training set must be non-empty");
83        assert!(dim > 0 && m > 0 && k > 0, "dim/m/k must be positive");
84        assert!(
85            dim.is_multiple_of(m),
86            "dim ({dim}) must be divisible by m ({m})"
87        );
88        let sub_dim = dim / m;
89        let seed = dim as u64 ^ ((m as u64) << 16) ^ ((k as u64) << 32);
90
91        let mut rotation = identity(dim);
92        let mut codebooks: Vec<Vec<Vec<f32>>> = Vec::new();
93
94        let iters = opq_iters.max(1);
95
96        for iter in 0..iters {
97            // Codebook step: train PQ on the current rotated training set.
98            let rotated: Vec<Vec<f32>> =
99                vectors.iter().map(|v| matvec(&rotation, v, dim)).collect();
100            codebooks = train_codebooks(&rotated, m, k, sub_dim, kmeans_iters, seed ^ iter as u64);
101
102            // Procrustes step: find R minimising ‖R·X - Y‖_F where Y is
103            // the dequantized reconstruction of R·X.
104            //
105            // Closed-form solution (Ge et al. CVPR 2013, §3.2):
106            //   M = X · Yᵀ   (dim × dim)
107            //   SVD(M) = U Σ Vᵀ
108            //   R_new = V · Uᵀ
109            //
110            // Skip rotation update on the last iteration — codebooks were
111            // already retrained with the current R.
112            if iter + 1 < iters {
113                let n = vectors.len();
114                // Build dim×N matrices X (original) and Y (reconstructed).
115                // DMatrix is column-major; we store column j = vector j.
116                let x_mat = DMatrix::from_fn(dim, n, |row, col| vectors[col][row]);
117                let y_mat = {
118                    let recon: Vec<Vec<f32>> = rotated
119                        .iter()
120                        .map(|rv| {
121                            let codes = pq_encode(rv, &codebooks, m, sub_dim);
122                            dequantize_codes(&codes, &codebooks)
123                        })
124                        .collect();
125                    DMatrix::from_fn(dim, n, |row, col| recon[col][row])
126                };
127
128                // M = X · Yᵀ  (dim × dim)
129                let m_mat = &x_mat * y_mat.transpose();
130
131                // Guard: skip rotation update if M contains NaN (degenerate
132                // training data or all-zero reconstructions on early iters).
133                let has_nan = m_mat.iter().any(|x| x.is_nan());
134                if !has_nan {
135                    let svd = SVD::new(m_mat, true, true);
136                    if let (Some(u), Some(v_t)) = (svd.u, svd.v_t) {
137                        // R = V · Uᵀ  →  in nalgebra: V = v_tᵀ, so R = v_tᵀ · uᵀ
138                        let r_new = v_t.transpose() * u.transpose();
139                        // Convert column-major DMatrix to row-major Vec<f32>.
140                        let mut buf = Vec::with_capacity(dim * dim);
141                        for i in 0..dim {
142                            for j in 0..dim {
143                                buf.push(r_new[(i, j)]);
144                            }
145                        }
146                        rotation = buf;
147                    }
148                }
149            }
150        }
151
152        Self {
153            dim,
154            m,
155            k,
156            sub_dim,
157            rotation,
158            codebooks,
159        }
160    }
161
162    /// Apply the rotation matrix to `v`, returning `R · v`.
163    pub fn apply_rotation(&self, v: &[f32]) -> Vec<f32> {
164        matvec(&self.rotation, v, self.dim)
165    }
166
167    fn encode_inner(&self, v: &[f32]) -> (Vec<u8>, UnifiedQuantizedVector) {
168        let rotated = self.apply_rotation(v);
169        let codes = pq_encode(&rotated, &self.codebooks, self.m, self.sub_dim);
170        let uqv = make_uqv(&codes, self.dim as u16);
171        (codes, uqv)
172    }
173
174    fn dequantize(&self, codes: &[u8]) -> Vec<f32> {
175        dequantize_codes(codes, &self.codebooks)
176    }
177}
178
179// ── Internal helpers ──────────────────────────────────────────────────────────
180
181/// Return a dim×dim row-major identity matrix.
182fn identity(dim: usize) -> Vec<f32> {
183    let mut mat = vec![0.0f32; dim * dim];
184    for i in 0..dim {
185        mat[i * dim + i] = 1.0;
186    }
187    mat
188}
189
190/// Dequantize PQ codes into a reconstructed vector in rotated space.
191fn dequantize_codes(codes: &[u8], codebooks: &[Vec<Vec<f32>>]) -> Vec<f32> {
192    let mut out = Vec::with_capacity(codebooks.len() * codebooks[0][0].len());
193    for (s, &c) in codes.iter().enumerate() {
194        out.extend_from_slice(&codebooks[s][c as usize]);
195    }
196    out
197}
198
199/// Row-major matrix-vector multiply: returns R · v.
200#[inline]
201fn matvec(r: &[f32], v: &[f32], dim: usize) -> Vec<f32> {
202    let mut out = vec![0.0f32; dim];
203    for i in 0..dim {
204        let row = &r[i * dim..(i + 1) * dim];
205        out[i] = row.iter().zip(v.iter()).map(|(a, b)| a * b).sum();
206    }
207    out
208}
209
210fn pq_encode(v: &[f32], codebooks: &[Vec<Vec<f32>>], m: usize, sub_dim: usize) -> Vec<u8> {
211    let mut codes = Vec::with_capacity(m);
212    #[allow(clippy::needless_range_loop)]
213    for s in 0..m {
214        let offset = s * sub_dim;
215        let sub = &v[offset..offset + sub_dim];
216        let best = codebooks[s]
217            .iter()
218            .enumerate()
219            .min_by(|(_, a), (_, b)| {
220                l2_sq(sub, a)
221                    .partial_cmp(&l2_sq(sub, b))
222                    .unwrap_or(std::cmp::Ordering::Equal)
223            })
224            .map(|(i, _)| i)
225            .unwrap_or(0);
226        codes.push(best as u8);
227    }
228    codes
229}
230
231fn train_codebooks(
232    rotated: &[Vec<f32>],
233    m: usize,
234    k: usize,
235    sub_dim: usize,
236    kmeans_iters: usize,
237    seed: u64,
238) -> Vec<Vec<Vec<f32>>> {
239    let mut codebooks = Vec::with_capacity(m);
240    for s in 0..m {
241        let offset = s * sub_dim;
242        let sub_vecs: Vec<Vec<f32>> = rotated
243            .iter()
244            .map(|v| v[offset..offset + sub_dim].to_vec())
245            .collect();
246        let centroids = lloyd(
247            &sub_vecs,
248            sub_dim,
249            k,
250            kmeans_iters,
251            seed ^ (s as u64 * 0x1234567),
252        );
253        codebooks.push(centroids);
254    }
255    codebooks
256}
257
258fn make_uqv(codes: &[u8], dim: u16) -> UnifiedQuantizedVector {
259    let header = QuantHeader {
260        quant_mode: QuantMode::Pq as u16,
261        dim,
262        global_scale: 1.0,
263        residual_norm: 0.0,
264        dot_quantized: 0.0,
265        outlier_bitmask: 0,
266        reserved: [0; 8],
267    };
268    UnifiedQuantizedVector::new(header, codes, &[])
269        .expect("make_uqv: layout construction must not fail for valid inputs")
270}
271
272// ── VectorCodec wrapper types ─────────────────────────────────────────────────
273
274/// Quantized form returned by [`OpqCodec::encode`].
275pub struct OpqQuantized {
276    codes: Vec<u8>,
277    uqv: UnifiedQuantizedVector,
278}
279
280impl AsRef<UnifiedQuantizedVector> for OpqQuantized {
281    fn as_ref(&self) -> &UnifiedQuantizedVector {
282        &self.uqv
283    }
284}
285
286/// Prepared query: rotated vector + flat ADC distance table (M×K, row-major).
287pub struct OpqQuery {
288    pub distance_table: Vec<f32>,
289    #[allow(dead_code)]
290    rotated: Vec<f32>,
291}
292
293// ── VectorCodec impl ──────────────────────────────────────────────────────────
294
295impl VectorCodec for OpqCodec {
296    type Quantized = OpqQuantized;
297    type Query = OpqQuery;
298
299    fn encode(&self, v: &[f32]) -> Self::Quantized {
300        let (codes, uqv) = self.encode_inner(v);
301        OpqQuantized { codes, uqv }
302    }
303
304    /// Rotate the query, then build flat ADC distance table `[M × K]`.
305    fn prepare_query(&self, q: &[f32]) -> Self::Query {
306        let rotated = self.apply_rotation(q);
307        let mut table = vec![0.0f32; self.m * self.k];
308        for s in 0..self.m {
309            let offset = s * self.sub_dim;
310            let sub_q = &rotated[offset..offset + self.sub_dim];
311            for c in 0..self.k {
312                table[s * self.k + c] = l2_sq(sub_q, &self.codebooks[s][c]);
313            }
314        }
315        OpqQuery {
316            distance_table: table,
317            rotated,
318        }
319    }
320
321    fn adc_lut(&self, q: &Self::Query) -> Option<AdcLut> {
322        let mut lut = AdcLut::new(self.m as u16, self.k as u16);
323        lut.table.copy_from_slice(&q.distance_table);
324        Some(lut)
325    }
326
327    /// Symmetric: dequantize both sides in rotated space, compute L2.
328    fn fast_symmetric_distance(&self, q: &Self::Quantized, v: &Self::Quantized) -> f32 {
329        let qv = self.dequantize(&q.codes);
330        let vv = self.dequantize(&v.codes);
331        l2_sq(&qv, &vv)
332    }
333
334    /// Asymmetric: O(M) ADC table lookups — one per subspace.
335    fn exact_asymmetric_distance(&self, q: &Self::Query, v: &Self::Quantized) -> f32 {
336        v.codes
337            .iter()
338            .enumerate()
339            .map(|(s, &code)| q.distance_table[s * self.k + code as usize])
340            .sum()
341    }
342}
343
344// ── Tests ─────────────────────────────────────────────────────────────────────
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    fn tiny_dataset() -> Vec<Vec<f32>> {
351        (0..10)
352            .map(|i| {
353                let base = i as f32 * 2.0;
354                vec![
355                    base,
356                    base + 0.1,
357                    base - 0.1,
358                    base + 0.2,
359                    base * 0.5,
360                    base * 0.5 + 0.1,
361                    base * 0.5 - 0.1,
362                    base * 0.5 + 0.05,
363                ]
364            })
365            .collect()
366    }
367
368    fn train_tiny() -> OpqCodec {
369        let vecs = tiny_dataset();
370        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
371        OpqCodec::train(&refs, 8, 2, 4, 10, 30)
372    }
373
374    #[test]
375    fn encode_produces_m_bytes() {
376        let codec = train_tiny();
377        let vecs = tiny_dataset();
378        for v in &vecs {
379            let q = codec.encode(v);
380            assert_eq!(q.codes.len(), codec.m);
381        }
382    }
383
384    #[test]
385    fn distance_is_non_negative() {
386        let codec = train_tiny();
387        let vecs = tiny_dataset();
388        for v in &vecs {
389            let qv = codec.encode(v);
390            let qq = codec.prepare_query(v);
391            let asym = codec.exact_asymmetric_distance(&qq, &qv);
392            let sym = codec.fast_symmetric_distance(&qv, &qv);
393            assert!(
394                asym >= 0.0,
395                "asymmetric distance must be non-negative, got {asym}"
396            );
397            assert!(
398                sym >= 0.0,
399                "symmetric distance must be non-negative, got {sym}"
400            );
401        }
402    }
403
404    #[test]
405    fn top1_recall_on_training_set() {
406        let vecs = tiny_dataset();
407        let codec = train_tiny();
408        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
409        let encoded: Vec<_> = refs.iter().map(|v| codec.encode(v)).collect();
410
411        let mut correct = 0usize;
412        for (i, v) in refs.iter().enumerate() {
413            let query = codec.prepare_query(v);
414            let best = encoded
415                .iter()
416                .enumerate()
417                .min_by(|(_, a), (_, b)| {
418                    codec
419                        .exact_asymmetric_distance(&query, a)
420                        .partial_cmp(&codec.exact_asymmetric_distance(&query, b))
421                        .unwrap_or(std::cmp::Ordering::Equal)
422                })
423                .map(|(idx, _)| idx)
424                .unwrap_or(usize::MAX);
425            if best == i {
426                correct += 1;
427            }
428        }
429        let recall = correct as f64 / vecs.len() as f64;
430        // SVD-Procrustes converges to ~70% on this minimum-size synthetic set
431        // (n=10, dim=8, m=2, k=4: 4 bits per vector, codespace collisions
432        // inevitable). Empirical measurements on SIFT1M with realistic
433        // (m=8, k=256, dim=128) routinely hit ≥0.95 — see bench harness.
434        assert!(
435            recall >= 0.70,
436            "top-1 recall on training set too low: {correct}/{} = {recall:.2}",
437            vecs.len()
438        );
439    }
440
441    #[test]
442    fn more_iterations_reduce_reconstruction_error() {
443        let vecs = tiny_dataset();
444        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
445
446        let codec_1 = OpqCodec::train(&refs, 8, 2, 4, 1, 10);
447        let codec_5 = OpqCodec::train(&refs, 8, 2, 4, 5, 10);
448
449        let mean_recon_error = |codec: &OpqCodec| -> f32 {
450            refs.iter()
451                .map(|v| {
452                    let rotated = codec.apply_rotation(v);
453                    let codes = pq_encode(&rotated, &codec.codebooks, codec.m, codec.sub_dim);
454                    let recon = dequantize_codes(&codes, &codec.codebooks);
455                    l2_sq(&rotated, &recon)
456                })
457                .sum::<f32>()
458                / refs.len() as f32
459        };
460
461        let err_1 = mean_recon_error(&codec_1);
462        let err_5 = mean_recon_error(&codec_5);
463
464        assert!(
465            err_5 <= err_1 * 1.05,
466            "5-iter OPQ (err={err_5:.4}) should have ≤ reconstruction error than 1-iter (err={err_1:.4})"
467        );
468    }
469}