Skip to main content

oxirs_vec/
pq_encoder.rs

1//! Product quantization encoding and decoding (v1.1.0 round 16).
2//!
3//! Product Quantization (PQ) decomposes a high-dimensional space into `M`
4//! independent sub-spaces of dimension `d/M` and quantizes each sub-space
5//! separately into `K` centroids.
6//!
7//! Reference: Jégou et al., "Product Quantization for Nearest Neighbor Search",
8//! IEEE TPAMI 2011. <https://doi.org/10.1109/TPAMI.2010.57>
9
10// ──────────────────────────────────────────────────────────────────────────────
11// PqConfig
12// ──────────────────────────────────────────────────────────────────────────────
13
14/// Configuration parameters for a product quantizer.
15#[derive(Debug, Clone)]
16pub struct PqConfig {
17    /// Number of sub-spaces `M`.
18    pub num_subspaces: usize,
19    /// Number of centroids per sub-space `K` (typically 256).
20    pub num_centroids: usize,
21    /// Full dimensionality of the input vectors.
22    pub dimension: usize,
23}
24
25impl PqConfig {
26    /// Create a new `PqConfig`.
27    ///
28    /// Returns an error if `dimension` is not divisible by `num_subspaces`,
29    /// or if either `num_subspaces` or `num_centroids` is zero.
30    pub fn new(
31        dimension: usize,
32        num_subspaces: usize,
33        num_centroids: usize,
34    ) -> Result<Self, String> {
35        if num_subspaces == 0 {
36            return Err("num_subspaces must be > 0".to_string());
37        }
38        if num_centroids == 0 {
39            return Err("num_centroids must be > 0".to_string());
40        }
41        if dimension == 0 {
42            return Err("dimension must be > 0".to_string());
43        }
44        if dimension % num_subspaces != 0 {
45            return Err(format!(
46                "dimension ({}) must be divisible by num_subspaces ({})",
47                dimension, num_subspaces
48            ));
49        }
50        Ok(Self {
51            num_subspaces,
52            num_centroids,
53            dimension,
54        })
55    }
56
57    /// Return the dimensionality of a single sub-space: `dimension / num_subspaces`.
58    pub fn subspace_dim(&self) -> usize {
59        self.dimension / self.num_subspaces
60    }
61}
62
63// ──────────────────────────────────────────────────────────────────────────────
64// PqEncoder
65// ──────────────────────────────────────────────────────────────────────────────
66
67/// A product quantizer with pre-trained codebooks.
68///
69/// The codebook `codebooks[m][k]` is the `k`-th centroid vector for sub-space
70/// `m`, with length `config.subspace_dim()`.
71pub struct PqEncoder {
72    /// Configuration used to build this quantizer.
73    config: PqConfig,
74    /// Codebooks: `M × K × subspace_dim`.
75    codebooks: Vec<Vec<Vec<f32>>>,
76}
77
78impl PqEncoder {
79    /// Create a `PqEncoder` with randomly initialised codebooks using a
80    /// deterministic LCG so tests are reproducible.
81    pub fn new_random(config: PqConfig) -> Self {
82        let sub_dim = config.subspace_dim();
83        let mut seed: u64 = 0xdeadbeef_cafebabe;
84        let mut codebooks: Vec<Vec<Vec<f32>>> = Vec::with_capacity(config.num_subspaces);
85
86        for _ in 0..config.num_subspaces {
87            let mut centroids: Vec<Vec<f32>> = Vec::with_capacity(config.num_centroids);
88            for _ in 0..config.num_centroids {
89                let centroid: Vec<f32> = (0..sub_dim)
90                    .map(|_| {
91                        // LCG: a=6364136223846793005, c=1442695040888963407 (Knuth)
92                        seed = seed
93                            .wrapping_mul(6_364_136_223_846_793_005)
94                            .wrapping_add(1_442_695_040_888_963_407);
95                        // Map to [-1, 1]
96                        let bits = (seed >> 11) as f32;
97                        bits / (1u64 << 53) as f32 * 2.0 - 1.0
98                    })
99                    .collect();
100                centroids.push(centroid);
101            }
102            codebooks.push(centroids);
103        }
104
105        Self { config, codebooks }
106    }
107
108    /// Encode a vector into `M` centroid indices (one per sub-space).
109    ///
110    /// Returns an error if `vector.len() != config.dimension`.
111    pub fn encode(&self, vector: &[f32]) -> Result<Vec<usize>, String> {
112        if vector.len() != self.config.dimension {
113            return Err(format!(
114                "Vector length {} does not match configured dimension {}",
115                vector.len(),
116                self.config.dimension
117            ));
118        }
119        let sub_dim = self.config.subspace_dim();
120        let mut codes = Vec::with_capacity(self.config.num_subspaces);
121
122        for m in 0..self.config.num_subspaces {
123            let sub_vec = &vector[m * sub_dim..(m + 1) * sub_dim];
124            let best = self.nearest_centroid(m, sub_vec);
125            codes.push(best);
126        }
127        Ok(codes)
128    }
129
130    /// Decode `M` centroid indices back to an approximate reconstructed vector.
131    ///
132    /// Returns an error if `codes.len() != config.num_subspaces` or any code
133    /// index is out of bounds.
134    pub fn decode(&self, codes: &[usize]) -> Result<Vec<f32>, String> {
135        if codes.len() != self.config.num_subspaces {
136            return Err(format!(
137                "Expected {} codes, got {}",
138                self.config.num_subspaces,
139                codes.len()
140            ));
141        }
142        let sub_dim = self.config.subspace_dim();
143        let mut result = vec![0.0f32; self.config.dimension];
144
145        for (m, &code) in codes.iter().enumerate() {
146            if code >= self.config.num_centroids {
147                return Err(format!(
148                    "Code {} in sub-space {} exceeds num_centroids {}",
149                    code, m, self.config.num_centroids
150                ));
151            }
152            let centroid = &self.codebooks[m][code];
153            let offset = m * sub_dim;
154            result[offset..offset + sub_dim].copy_from_slice(centroid);
155        }
156        Ok(result)
157    }
158
159    /// Compute the asymmetric distance between a query vector and encoded codes.
160    ///
161    /// The asymmetric distance is the sum of squared Euclidean distances
162    /// between each query sub-vector and its assigned centroid.
163    ///
164    /// Returns an error if `query.len() != config.dimension` or codes are invalid.
165    pub fn asymmetric_distance(&self, query: &[f32], codes: &[usize]) -> Result<f32, String> {
166        if query.len() != self.config.dimension {
167            return Err(format!(
168                "Query length {} does not match configured dimension {}",
169                query.len(),
170                self.config.dimension
171            ));
172        }
173        if codes.len() != self.config.num_subspaces {
174            return Err(format!(
175                "Expected {} codes, got {}",
176                self.config.num_subspaces,
177                codes.len()
178            ));
179        }
180        let sub_dim = self.config.subspace_dim();
181        let mut total_dist = 0.0f32;
182
183        for (m, &code) in codes.iter().enumerate() {
184            if code >= self.config.num_centroids {
185                return Err(format!(
186                    "Code {} in sub-space {} exceeds num_centroids {}",
187                    code, m, self.config.num_centroids
188                ));
189            }
190            let centroid = &self.codebooks[m][code];
191            let sub_query = &query[m * sub_dim..(m + 1) * sub_dim];
192            let sq_dist: f32 = sub_query
193                .iter()
194                .zip(centroid.iter())
195                .map(|(q, c)| (q - c) * (q - c))
196                .sum();
197            total_dist += sq_dist;
198        }
199        Ok(total_dist)
200    }
201
202    /// Return a reference to the encoder's configuration.
203    pub fn config(&self) -> &PqConfig {
204        &self.config
205    }
206
207    // ── Internal helpers ──────────────────────────────────────────────────────
208
209    /// Return the index of the nearest centroid in sub-space `m` to `sub_vec`.
210    fn nearest_centroid(&self, m: usize, sub_vec: &[f32]) -> usize {
211        let centroids = &self.codebooks[m];
212        let mut best_idx = 0usize;
213        let mut best_dist = f32::MAX;
214
215        for (k, centroid) in centroids.iter().enumerate() {
216            let dist: f32 = sub_vec
217                .iter()
218                .zip(centroid.iter())
219                .map(|(a, b)| (a - b) * (a - b))
220                .sum();
221            if dist < best_dist {
222                best_dist = dist;
223                best_idx = k;
224            }
225        }
226        best_idx
227    }
228}
229
230// ──────────────────────────────────────────────────────────────────────────────
231// Tests
232// ──────────────────────────────────────────────────────────────────────────────
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    fn make_encoder(dim: usize, m: usize, k: usize) -> PqEncoder {
239        let cfg = PqConfig::new(dim, m, k).expect("valid config");
240        PqEncoder::new_random(cfg)
241    }
242
243    // ── PqConfig ──────────────────────────────────────────────────────────────
244
245    #[test]
246    fn test_config_valid() {
247        let cfg = PqConfig::new(64, 4, 256).expect("ok");
248        assert_eq!(cfg.dimension, 64);
249        assert_eq!(cfg.num_subspaces, 4);
250        assert_eq!(cfg.num_centroids, 256);
251    }
252
253    #[test]
254    fn test_config_subspace_dim() {
255        let cfg = PqConfig::new(64, 4, 256).expect("ok");
256        assert_eq!(cfg.subspace_dim(), 16);
257    }
258
259    #[test]
260    fn test_config_subspace_dim_small() {
261        let cfg = PqConfig::new(8, 2, 4).expect("ok");
262        assert_eq!(cfg.subspace_dim(), 4);
263    }
264
265    #[test]
266    fn test_config_invalid_not_divisible() {
267        let result = PqConfig::new(7, 4, 256);
268        assert!(result.is_err());
269    }
270
271    #[test]
272    fn test_config_invalid_zero_subspaces() {
273        let result = PqConfig::new(64, 0, 256);
274        assert!(result.is_err());
275    }
276
277    #[test]
278    fn test_config_invalid_zero_centroids() {
279        let result = PqConfig::new(64, 4, 0);
280        assert!(result.is_err());
281    }
282
283    #[test]
284    fn test_config_invalid_zero_dimension() {
285        let result = PqConfig::new(0, 4, 256);
286        assert!(result.is_err());
287    }
288
289    #[test]
290    fn test_config_single_subspace() {
291        let cfg = PqConfig::new(16, 1, 8).expect("ok");
292        assert_eq!(cfg.subspace_dim(), 16);
293    }
294
295    // ── encode ────────────────────────────────────────────────────────────────
296
297    #[test]
298    fn test_encode_returns_m_codes() {
299        let enc = make_encoder(16, 4, 8);
300        let vec: Vec<f32> = (0..16).map(|i| i as f32).collect();
301        let codes = enc.encode(&vec).expect("encode ok");
302        assert_eq!(codes.len(), 4);
303    }
304
305    #[test]
306    fn test_encode_codes_in_range() {
307        let enc = make_encoder(16, 4, 8);
308        let vec: Vec<f32> = (0..16).map(|i| i as f32 * 0.5).collect();
309        let codes = enc.encode(&vec).expect("encode ok");
310        for code in codes {
311            assert!(code < 8, "code {} should be < 8", code);
312        }
313    }
314
315    #[test]
316    fn test_encode_wrong_dimension_error() {
317        let enc = make_encoder(16, 4, 8);
318        let result = enc.encode(&[1.0, 2.0, 3.0]);
319        assert!(result.is_err());
320    }
321
322    #[test]
323    fn test_encode_zero_vector() {
324        let enc = make_encoder(8, 2, 4);
325        let vec = vec![0.0f32; 8];
326        let codes = enc.encode(&vec).expect("encode ok");
327        assert_eq!(codes.len(), 2);
328    }
329
330    #[test]
331    fn test_encode_deterministic() {
332        let enc = make_encoder(16, 4, 8);
333        let vec: Vec<f32> = (0..16).map(|i| i as f32).collect();
334        let codes1 = enc.encode(&vec).expect("ok");
335        let codes2 = enc.encode(&vec).expect("ok");
336        assert_eq!(codes1, codes2);
337    }
338
339    // ── decode ────────────────────────────────────────────────────────────────
340
341    #[test]
342    fn test_decode_returns_full_dimension() {
343        let enc = make_encoder(16, 4, 8);
344        let codes = vec![0usize; 4];
345        let decoded = enc.decode(&codes).expect("decode ok");
346        assert_eq!(decoded.len(), 16);
347    }
348
349    #[test]
350    fn test_decode_wrong_code_count_error() {
351        let enc = make_encoder(16, 4, 8);
352        let codes = vec![0usize; 3]; // should be 4
353        assert!(enc.decode(&codes).is_err());
354    }
355
356    #[test]
357    fn test_decode_out_of_range_code_error() {
358        let enc = make_encoder(16, 4, 8);
359        let codes = vec![0, 0, 0, 100]; // 100 >= num_centroids=8
360        assert!(enc.decode(&codes).is_err());
361    }
362
363    #[test]
364    fn test_encode_decode_roundtrip_shape() {
365        let enc = make_encoder(32, 4, 16);
366        let vec: Vec<f32> = (0..32).map(|i| i as f32).collect();
367        let codes = enc.encode(&vec).expect("encode ok");
368        let decoded = enc.decode(&codes).expect("decode ok");
369        assert_eq!(decoded.len(), 32);
370        assert_eq!(codes.len(), 4);
371    }
372
373    // ── asymmetric_distance ───────────────────────────────────────────────────
374
375    #[test]
376    fn test_asymmetric_distance_non_negative() {
377        let enc = make_encoder(16, 4, 8);
378        let vec: Vec<f32> = (0..16).map(|i| i as f32).collect();
379        let codes = enc.encode(&vec).expect("encode ok");
380        let dist = enc.asymmetric_distance(&vec, &codes).expect("dist ok");
381        assert!(dist >= 0.0);
382    }
383
384    #[test]
385    fn test_asymmetric_distance_zero_for_centroid_query() {
386        let enc = make_encoder(8, 2, 4);
387        // A vector of zeros — its nearest centroids are found and the
388        // distance to those centroids should be >= 0.
389        let vec = vec![0.0f32; 8];
390        let codes = enc.encode(&vec).expect("encode ok");
391        let dist = enc.asymmetric_distance(&vec, &codes).expect("dist ok");
392        assert!(dist >= 0.0);
393    }
394
395    #[test]
396    fn test_asymmetric_distance_wrong_query_dim() {
397        let enc = make_encoder(16, 4, 8);
398        let codes = vec![0usize; 4];
399        let result = enc.asymmetric_distance(&[1.0, 2.0], &codes);
400        assert!(result.is_err());
401    }
402
403    #[test]
404    fn test_asymmetric_distance_wrong_code_count() {
405        let enc = make_encoder(16, 4, 8);
406        let vec = vec![0.0f32; 16];
407        let result = enc.asymmetric_distance(&vec, &[0, 0]);
408        assert!(result.is_err());
409    }
410
411    // ── config accessor ───────────────────────────────────────────────────────
412
413    #[test]
414    fn test_config_accessor() {
415        let enc = make_encoder(32, 8, 16);
416        let cfg = enc.config();
417        assert_eq!(cfg.dimension, 32);
418        assert_eq!(cfg.num_subspaces, 8);
419        assert_eq!(cfg.num_centroids, 16);
420        assert_eq!(cfg.subspace_dim(), 4);
421    }
422
423    // ── new_random reproducibility ────────────────────────────────────────────
424
425    #[test]
426    fn test_new_random_reproducible() {
427        let cfg1 = PqConfig::new(16, 4, 8).expect("ok");
428        let cfg2 = PqConfig::new(16, 4, 8).expect("ok");
429        let enc1 = PqEncoder::new_random(cfg1);
430        let enc2 = PqEncoder::new_random(cfg2);
431        let vec: Vec<f32> = (0..16).map(|i| i as f32).collect();
432        assert_eq!(
433            enc1.encode(&vec).expect("ok"),
434            enc2.encode(&vec).expect("ok")
435        );
436    }
437}