Skip to main content

nodedb_vector/quantize/
pq.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Product Quantization (PQ): 8-16x compression for large datasets.
4//!
5//! Splits D-dimensional vectors into M subvectors, clusters each subspace
6//! with K=256 centroids via k-means. Each vector is encoded as M bytes
7//! (one centroid index per subvector).
8//!
9//! Distance is computed via precomputed lookup tables: for each query,
10//! build a `[M][K]` table of distances from the query's subvectors to
11//! all centroids. Then the distance to any encoded vector is just M
12//! table lookups + additions — O(M) per candidate vs O(D) for FP32.
13//!
14//! Trade-off: 2-5% recall loss vs SQ8's <1%, but 2-4x more compression
15//! (8-16x total vs 4x for SQ8). Best for cost-sensitive large datasets.
16
17use std::mem::size_of;
18use std::sync::Arc;
19
20use nodedb_mem::{EngineId, MemoryGovernor};
21use serde::{Deserialize, Serialize};
22
23use crate::error::VectorError;
24
25/// Reserve `bytes` from `governor` for `EngineId::Vector`, or succeed silently
26/// when no governor is configured. The returned guard (if any) must be kept
27/// alive for the duration of the allocation it covers.
28#[inline]
29fn try_reserve_or_skip(
30    governor: &Option<Arc<MemoryGovernor>>,
31    bytes: usize,
32) -> Result<Option<nodedb_mem::BudgetGuard>, VectorError> {
33    match governor {
34        Some(g) => Ok(Some(g.reserve(EngineId::Vector, bytes)?)),
35        None => Ok(None),
36    }
37}
38
39/// PQ codec with trained codebooks.
40#[derive(
41    Clone, Debug, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
42)]
43pub struct PqCodec {
44    /// Original vector dimensionality.
45    pub dim: usize,
46    /// Number of subvectors (subspaces).
47    pub m: usize,
48    /// Centroids per subvector (fixed at 256 for u8 encoding).
49    pub k: usize,
50    /// Dimensions per subvector: `dim / m`.
51    pub sub_dim: usize,
52    /// Codebooks: `codebooks[subspace][centroid][sub_dim_component]`.
53    /// Total: M × K × sub_dim floats.
54    codebooks: Vec<Vec<Vec<f32>>>,
55
56    /// Optional memory governor. Skipped during serialization — it is a
57    /// runtime concern only, not part of the on-disk format.
58    #[serde(skip, default)]
59    #[msgpack(ignore)]
60    governor: Option<Arc<MemoryGovernor>>,
61}
62
63impl PqCodec {
64    /// Attach a memory governor to this codec.
65    ///
66    /// Once set, heap-significant operations (`train`, `encode_batch`,
67    /// `build_distance_table`, `decode`, `to_bytes`) will charge the
68    /// `EngineId::Vector` budget before allocating and release the reservation
69    /// when the returned value is dropped (RAII).  When no governor is set
70    /// those operations proceed unconditionally, preserving backward
71    /// compatibility with callers that do not use the memory governor.
72    ///
73    /// The governor is a runtime concern only — it is **not** serialized.
74    pub fn with_governor(mut self, governor: Arc<MemoryGovernor>) -> Self {
75        self.governor = Some(governor);
76        self
77    }
78
79    /// Train PQ codebooks from a set of training vectors via k-means.
80    ///
81    /// `m` = number of subvectors (must divide `dim` evenly).
82    /// `k` = centroids per subvector (typically 256).
83    /// `max_iter` = k-means iterations (20 is usually sufficient).
84    pub fn train(vectors: &[&[f32]], dim: usize, m: usize, k: usize, max_iter: usize) -> Self {
85        assert!(!vectors.is_empty());
86        assert!(dim > 0 && m > 0 && k > 0);
87        assert!(
88            dim.is_multiple_of(m),
89            "dim ({dim}) must be divisible by m ({m})"
90        );
91
92        let sub_dim = dim / m;
93        let mut codebooks = Vec::with_capacity(m);
94
95        for sub in 0..m {
96            let offset = sub * sub_dim;
97            // Extract sub-vectors for this subspace.
98            let sub_vectors: Vec<&[f32]> = vectors
99                .iter()
100                .map(|v| &v[offset..offset + sub_dim])
101                .collect();
102
103            let centroids = kmeans(&sub_vectors, sub_dim, k, max_iter);
104            codebooks.push(centroids);
105        }
106
107        Self {
108            dim,
109            m,
110            k,
111            sub_dim,
112            codebooks,
113            governor: None,
114        }
115    }
116
117    /// Encode a vector: for each subvector, find the nearest centroid index.
118    ///
119    /// This is a per-vector hot-path operation.  Governor charging is
120    /// intentionally skipped here to avoid atomic overhead on every candidate
121    /// during search; use [`encode_batch`] for bulk encoding with budget
122    /// enforcement.
123    pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
124        debug_assert_eq!(vector.len(), self.dim);
125        let mut code = Vec::with_capacity(self.m);
126        for sub in 0..self.m {
127            let offset = sub * self.sub_dim;
128            let sub_vec = &vector[offset..offset + self.sub_dim];
129            let nearest = self.nearest_centroid(sub, sub_vec);
130            code.push(nearest as u8);
131        }
132        code
133    }
134
135    /// Batch encode all vectors into a contiguous byte array.
136    ///
137    /// Charges `m * vectors.len()` bytes to the governor budget (if set)
138    /// before allocating the output buffer.  The guard is released at
139    /// the end of this call — the buffer itself remains alive.
140    pub fn encode_batch(&self, vectors: &[&[f32]]) -> Result<Vec<u8>, VectorError> {
141        let capacity = self.m * vectors.len();
142        let _g = try_reserve_or_skip(&self.governor, capacity * size_of::<u8>())?;
143        let mut out = Vec::with_capacity(capacity);
144        for v in vectors {
145            out.extend(self.encode(v));
146        }
147        Ok(out)
148    }
149
150    /// Build an asymmetric distance table for a query vector.
151    ///
152    /// Returns `table[sub][centroid]` = distance from query's sub-vector
153    /// to each centroid. Pre-computing this table makes distance evaluation
154    /// O(M) per candidate instead of O(D).
155    ///
156    /// Charges `m * k * size_of::<f32>()` bytes to the governor (if set)
157    /// before allocating the table.
158    pub fn build_distance_table(&self, query: &[f32]) -> Result<Vec<Vec<f32>>, VectorError> {
159        debug_assert_eq!(query.len(), self.dim);
160        let total_bytes = self.m * self.k * size_of::<f32>();
161        let _g = try_reserve_or_skip(&self.governor, total_bytes)?;
162        let mut table = Vec::with_capacity(self.m);
163        for sub in 0..self.m {
164            let offset = sub * self.sub_dim;
165            let sub_query = &query[offset..offset + self.sub_dim];
166            let mut dists = Vec::with_capacity(self.k);
167            for centroid in &self.codebooks[sub] {
168                let d = l2_sub(sub_query, centroid);
169                dists.push(d);
170            }
171            table.push(dists);
172        }
173        Ok(table)
174    }
175
176    /// Compute asymmetric distance using a precomputed distance table.
177    ///
178    /// O(M) per candidate — just M table lookups and additions.
179    #[inline]
180    pub fn asymmetric_distance(&self, table: &[Vec<f32>], code: &[u8]) -> f32 {
181        debug_assert_eq!(code.len(), self.m);
182        let mut dist = 0.0f32;
183        for (sub, &c) in code.iter().enumerate() {
184            dist += table[sub][c as usize];
185        }
186        dist
187    }
188
189    /// Decode a PQ code back to an approximate FP32 vector.
190    ///
191    /// Charges `dim * size_of::<f32>()` bytes to the governor (if set)
192    /// before allocating the output buffer.
193    pub fn decode(&self, code: &[u8]) -> Result<Vec<f32>, VectorError> {
194        debug_assert_eq!(code.len(), self.m);
195        let _g = try_reserve_or_skip(&self.governor, self.dim * size_of::<f32>())?;
196        let mut out = Vec::with_capacity(self.dim);
197        for (sub, &c) in code.iter().enumerate() {
198            out.extend_from_slice(&self.codebooks[sub][c as usize]);
199        }
200        Ok(out)
201    }
202
203    /// Serialize the codec to bytes with a versioned magic header.
204    ///
205    /// Format: `[NDPQ\0\0 (6 bytes)][version: u8 = 1][msgpack payload]`
206    ///
207    /// Charges the estimated serialized size to the governor (if set) before
208    /// allocating the output buffer.  The estimate is conservative:
209    /// `m * k * sub_dim * size_of::<f32>() + 64` (header + framing overhead).
210    pub fn to_bytes(&self) -> Result<Vec<u8>, VectorError> {
211        const MAGIC: &[u8; 6] = b"NDPQ\0\0";
212        const VERSION: u8 = 1;
213        let estimated = self.m * self.k * self.sub_dim * size_of::<f32>() + 64;
214        let _g = try_reserve_or_skip(&self.governor, estimated)?;
215        let payload = zerompk::to_msgpack_vec(self).unwrap_or_default();
216        let mut out = Vec::with_capacity(7 + payload.len());
217        out.extend_from_slice(MAGIC);
218        out.push(VERSION);
219        out.extend_from_slice(&payload);
220        Ok(out)
221    }
222
223    /// Deserialize the codec from bytes produced by [`Self::to_bytes`].
224    ///
225    /// Returns `VectorError::InvalidMagic` if the header does not match
226    /// `NDPQ\0\0`, and `VectorError::UnsupportedVersion` for unknown versions.
227    pub fn from_bytes(bytes: &[u8]) -> Result<Self, VectorError> {
228        const MAGIC: &[u8; 6] = b"NDPQ\0\0";
229        const PQ_FORMAT_VERSION: u8 = 1;
230
231        if bytes.len() < 7 || &bytes[0..6] != MAGIC {
232            return Err(VectorError::InvalidMagic);
233        }
234        let version = bytes[6];
235        if version != PQ_FORMAT_VERSION {
236            return Err(VectorError::UnsupportedVersion {
237                found: version,
238                expected: PQ_FORMAT_VERSION,
239            });
240        }
241        zerompk::from_msgpack::<Self>(&bytes[7..])
242            .map_err(|e| VectorError::DeserializationFailed(e.to_string()))
243    }
244
245    fn nearest_centroid(&self, subspace: usize, sub_vec: &[f32]) -> usize {
246        let mut best_idx = 0;
247        let mut best_dist = f32::MAX;
248        for (i, centroid) in self.codebooks[subspace].iter().enumerate() {
249            let d = l2_sub(sub_vec, centroid);
250            if d < best_dist {
251                best_dist = d;
252                best_idx = i;
253            }
254        }
255        best_idx
256    }
257}
258
259/// L2 squared distance for sub-vectors (used in k-means and encoding).
260#[inline]
261fn l2_sub(a: &[f32], b: &[f32]) -> f32 {
262    let mut sum = 0.0f32;
263    for i in 0..a.len() {
264        let d = a[i] - b[i];
265        sum += d * d;
266    }
267    sum
268}
269
270/// Simple k-means clustering for PQ codebook training.
271///
272/// Uses proper k-means++ initialization (weighted d² sampling) with a
273/// deterministic seed so training is reproducible across runs.
274fn kmeans(data: &[&[f32]], dim: usize, k: usize, max_iter: usize) -> Vec<Vec<f32>> {
275    let n = data.len();
276    if n == 0 || k == 0 {
277        return Vec::new();
278    }
279    let k = k.min(n); // Can't have more centroids than data points.
280
281    // K-means++ initialization with deterministic xorshift.
282    let mut rng = crate::hnsw::Xorshift64::new(0xC0FF_EEDE_ADBE_EF42);
283
284    let mut centroids: Vec<Vec<f32>> = Vec::with_capacity(k);
285    centroids.push(data[0].to_vec());
286
287    let mut min_dists = vec![f32::MAX; n];
288    // Update against the first centroid.
289    for (i, point) in data.iter().enumerate() {
290        let d = l2_sub(point, &centroids[0]);
291        if d < min_dists[i] {
292            min_dists[i] = d;
293        }
294    }
295
296    for _ in 1..k {
297        let total: f64 = min_dists.iter().map(|&d| d as f64).sum();
298        let next_idx = if total < f64::EPSILON {
299            // All points coincide with existing centroids.
300            0
301        } else {
302            let target = rng.next_f64() * total;
303            let mut acc = 0.0f64;
304            let mut chosen = n - 1;
305            for (i, &d) in min_dists.iter().enumerate() {
306                acc += d as f64;
307                if acc >= target {
308                    chosen = i;
309                    break;
310                }
311            }
312            chosen
313        };
314        centroids.push(data[next_idx].to_vec());
315        // Incrementally update min_dists against the new centroid.
316        let last = centroids.last().expect("just pushed");
317        for (i, point) in data.iter().enumerate() {
318            let d = l2_sub(point, last);
319            if d < min_dists[i] {
320                min_dists[i] = d;
321            }
322        }
323    }
324
325    // K-means iterations.
326    let mut assignments = vec![0usize; n];
327    for _ in 0..max_iter {
328        // Assignment step.
329        let mut changed = false;
330        for (i, point) in data.iter().enumerate() {
331            let mut best = 0;
332            let mut best_d = f32::MAX;
333            for (c, centroid) in centroids.iter().enumerate() {
334                let d = l2_sub(point, centroid);
335                if d < best_d {
336                    best_d = d;
337                    best = c;
338                }
339            }
340            if assignments[i] != best {
341                assignments[i] = best;
342                changed = true;
343            }
344        }
345        if !changed {
346            break;
347        }
348
349        // Update step: recompute centroids as means.
350        let mut sums = vec![vec![0.0f32; dim]; k];
351        let mut counts = vec![0usize; k];
352        for (i, point) in data.iter().enumerate() {
353            let c = assignments[i];
354            counts[c] += 1;
355            for d in 0..dim {
356                sums[c][d] += point[d];
357            }
358        }
359        for c in 0..k {
360            if counts[c] > 0 {
361                for d in 0..dim {
362                    centroids[c][d] = sums[c][d] / counts[c] as f32;
363                }
364            }
365        }
366    }
367
368    centroids
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    fn make_clustered_data() -> Vec<Vec<f32>> {
376        // 4 clusters in 4D space, 50 points each.
377        let mut vecs = Vec::new();
378        for cluster in 0..4 {
379            let center = cluster as f32 * 10.0;
380            for i in 0..50 {
381                vecs.push(vec![
382                    center + (i as f32) * 0.1,
383                    center + (i as f32) * 0.05,
384                    center - (i as f32) * 0.1,
385                    center + (i as f32) * 0.02,
386                ]);
387            }
388        }
389        vecs
390    }
391
392    #[test]
393    fn encode_decode_roundtrip() {
394        let vecs = make_clustered_data();
395        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
396        let codec = PqCodec::train(&refs, 4, 2, 16, 10);
397
398        for v in &vecs {
399            let code = codec.encode(v);
400            assert_eq!(code.len(), 2); // M=2 bytes
401            let decoded = codec.decode(&code).unwrap();
402            assert_eq!(decoded.len(), 4);
403        }
404    }
405
406    #[test]
407    fn distance_table_gives_correct_ordering() {
408        let vecs = make_clustered_data();
409        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
410        let codec = PqCodec::train(&refs, 4, 2, 16, 10);
411
412        let codes: Vec<Vec<u8>> = vecs.iter().map(|v| codec.encode(v)).collect();
413        let query = &[5.0, 5.0, 5.0, 5.0];
414        let table = codec.build_distance_table(query).unwrap();
415
416        // Find nearest via PQ distance.
417        let mut pq_dists: Vec<(usize, f32)> = codes
418            .iter()
419            .enumerate()
420            .map(|(i, c)| (i, codec.asymmetric_distance(&table, c)))
421            .collect();
422        pq_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
423
424        // Find nearest via exact L2.
425        let mut exact_dists: Vec<(usize, f32)> = vecs
426            .iter()
427            .enumerate()
428            .map(|(i, v)| (i, l2_sub(query, v)))
429            .collect();
430        exact_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
431
432        // Top-5 from PQ should have significant overlap with exact top-10.
433        let pq_top: std::collections::HashSet<usize> = pq_dists[..5].iter().map(|x| x.0).collect();
434        let exact_top: std::collections::HashSet<usize> =
435            exact_dists[..10].iter().map(|x| x.0).collect();
436        let overlap = pq_top.intersection(&exact_top).count();
437        assert!(overlap >= 3, "PQ recall too low: {overlap}/5 in top-10");
438    }
439
440    #[test]
441    fn batch_encode() {
442        let vecs = make_clustered_data();
443        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
444        let codec = PqCodec::train(&refs, 4, 2, 16, 10);
445
446        let batch = codec.encode_batch(&refs).unwrap();
447        assert_eq!(batch.len(), 2 * 200); // M=2, N=200
448    }
449
450    // golden format test — verifies the on-disk layout is stable.
451    #[test]
452    fn pq_codec_golden_format() {
453        let vecs = make_clustered_data();
454        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
455        let codec = PqCodec::train(&refs, 4, 2, 16, 10);
456
457        let bytes = codec.to_bytes().unwrap();
458
459        // Magic header.
460        assert_eq!(&bytes[0..6], b"NDPQ\0\0", "magic mismatch");
461        // Version byte.
462        assert_eq!(bytes[6], 1u8, "version must be 1");
463        // Payload at offset 7 must decode back to a valid PqCodec.
464        let restored = zerompk::from_msgpack::<PqCodec>(&bytes[7..])
465            .expect("msgpack payload at offset 7 must decode");
466        assert_eq!(restored.dim, codec.dim);
467        assert_eq!(restored.m, codec.m);
468    }
469
470    #[test]
471    fn pq_version_mismatch_returns_error() {
472        // Craft a header with magic correct but version = 0 (unsupported).
473        let mut crafted = b"NDPQ\0\0".to_vec();
474        crafted.push(0u8); // wrong version
475        crafted.extend_from_slice(b"\x80"); // minimal valid msgpack map
476
477        let err = PqCodec::from_bytes(&crafted).unwrap_err();
478        assert!(
479            matches!(
480                err,
481                VectorError::UnsupportedVersion {
482                    found: 0,
483                    expected: 1
484                }
485            ),
486            "expected UnsupportedVersion, got: {err:?}"
487        );
488    }
489
490    #[test]
491    fn pq_invalid_magic_returns_error() {
492        let bad: &[u8] = b"JUNK\0\0\x01some-payload";
493        let err = PqCodec::from_bytes(bad).unwrap_err();
494        assert!(
495            matches!(err, VectorError::InvalidMagic),
496            "expected InvalidMagic, got: {err:?}"
497        );
498    }
499}