Skip to main content

nodedb_vector/quantize/
pq.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Product Quantization (PQ): 8-16x compression for large datasets.
4//!
5//! Splits D-dimensional vectors into M subvectors, clusters each subspace
6//! with K=256 centroids via k-means. Each vector is encoded as M bytes
7//! (one centroid index per subvector).
8//!
9//! Distance is computed via precomputed lookup tables: for each query,
10//! build a `[M][K]` table of distances from the query's subvectors to
11//! all centroids. Then the distance to any encoded vector is just M
12//! table lookups + additions — O(M) per candidate vs O(D) for FP32.
13//!
14//! Trade-off: 2-5% recall loss vs SQ8's <1%, but 2-4x more compression
15//! (8-16x total vs 4x for SQ8). Best for cost-sensitive large datasets.
16
17use std::mem::size_of;
18use std::sync::Arc;
19
20use nodedb_mem::{EngineId, MemoryGovernor};
21use serde::{Deserialize, Serialize};
22
23use crate::error::VectorError;
24
25/// Reserve `bytes` from `governor` for `EngineId::Vector`, or succeed silently
26/// when no governor is configured. The returned guard (if any) must be kept
27/// alive for the duration of the allocation it covers.
28#[inline]
29fn try_reserve_or_skip(
30    governor: &Option<Arc<MemoryGovernor>>,
31    bytes: usize,
32) -> Result<Option<nodedb_mem::BudgetGuard>, VectorError> {
33    match governor {
34        Some(g) => Ok(Some(g.reserve(EngineId::Vector, bytes)?)),
35        None => Ok(None),
36    }
37}
38
39/// PQ codec with trained codebooks.
40#[derive(
41    Clone, Debug, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
42)]
43pub struct PqCodec {
44    /// Original vector dimensionality.
45    pub dim: usize,
46    /// Number of subvectors (subspaces).
47    pub m: usize,
48    /// Centroids per subvector (fixed at 256 for u8 encoding).
49    pub k: usize,
50    /// Dimensions per subvector: `dim / m`.
51    pub sub_dim: usize,
52    /// Codebooks: `codebooks[subspace][centroid][sub_dim_component]`.
53    /// Total: M × K × sub_dim floats.
54    codebooks: Vec<Vec<Vec<f32>>>,
55
56    /// Optional memory governor. Skipped during serialization — it is a
57    /// runtime concern only, not part of the on-disk format.
58    #[serde(skip, default)]
59    #[msgpack(ignore)]
60    governor: Option<Arc<MemoryGovernor>>,
61}
62
63impl PqCodec {
64    /// Attach a memory governor to this codec.
65    ///
66    /// Once set, heap-significant operations (`train`, `encode_batch`,
67    /// `build_distance_table`, `decode`, `to_bytes`) will charge the
68    /// `EngineId::Vector` budget before allocating and release the reservation
69    /// when the returned value is dropped (RAII).  When no governor is set
70    /// those operations proceed unconditionally, preserving backward
71    /// compatibility with callers that do not use the memory governor.
72    ///
73    /// The governor is a runtime concern only — it is **not** serialized.
74    pub fn with_governor(mut self, governor: Arc<MemoryGovernor>) -> Self {
75        self.governor = Some(governor);
76        self
77    }
78
79    /// Train PQ codebooks from a set of training vectors via k-means.
80    ///
81    /// `m` = number of subvectors (must divide `dim` evenly).
82    /// `k` = centroids per subvector (typically 256).
83    /// `max_iter` = k-means iterations (20 is usually sufficient).
84    pub fn train(vectors: &[&[f32]], dim: usize, m: usize, k: usize, max_iter: usize) -> Self {
85        assert!(!vectors.is_empty());
86        assert!(dim > 0 && m > 0 && k > 0);
87        assert!(
88            dim.is_multiple_of(m),
89            "dim ({dim}) must be divisible by m ({m})"
90        );
91
92        let sub_dim = dim / m;
93        // no-governor: cold codebook training; m subspaces (small, ≤ dim/sub_dim), one-time build
94        let mut codebooks = Vec::with_capacity(m);
95
96        for sub in 0..m {
97            let offset = sub * sub_dim;
98            // Extract sub-vectors for this subspace.
99            let sub_vectors: Vec<&[f32]> = vectors
100                .iter()
101                .map(|v| &v[offset..offset + sub_dim])
102                .collect();
103
104            let centroids = kmeans(&sub_vectors, sub_dim, k, max_iter);
105            codebooks.push(centroids);
106        }
107
108        Self {
109            dim,
110            m,
111            k,
112            sub_dim,
113            codebooks,
114            governor: None,
115        }
116    }
117
118    /// Encode a vector: for each subvector, find the nearest centroid index.
119    ///
120    /// This is a per-vector hot-path operation.  Governor charging is
121    /// intentionally skipped here to avoid atomic overhead on every candidate
122    /// during search; use [`encode_batch`] for bulk encoding with budget
123    /// enforcement.
124    pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
125        debug_assert_eq!(vector.len(), self.dim);
126        // no-governor: hot-path per-vector encode; doc comment above intentionally skips governor
127        let mut code = Vec::with_capacity(self.m);
128        for sub in 0..self.m {
129            let offset = sub * self.sub_dim;
130            let sub_vec = &vector[offset..offset + self.sub_dim];
131            let nearest = self.nearest_centroid(sub, sub_vec);
132            code.push(nearest as u8);
133        }
134        code
135    }
136
137    /// Batch encode all vectors into a contiguous byte array.
138    ///
139    /// Charges `m * vectors.len()` bytes to the governor budget (if set)
140    /// before allocating the output buffer.  The guard is released at
141    /// the end of this call — the buffer itself remains alive.
142    pub fn encode_batch(&self, vectors: &[&[f32]]) -> Result<Vec<u8>, VectorError> {
143        let capacity = self.m * vectors.len();
144        let _g = try_reserve_or_skip(&self.governor, capacity * size_of::<u8>())?;
145        // no-governor: governed via try_reserve_or_skip on preceding line
146        let mut out = Vec::with_capacity(capacity);
147        for v in vectors {
148            out.extend(self.encode(v));
149        }
150        Ok(out)
151    }
152
153    /// Build an asymmetric distance table for a query vector.
154    ///
155    /// Returns `table[sub][centroid]` = distance from query's sub-vector
156    /// to each centroid. Pre-computing this table makes distance evaluation
157    /// O(M) per candidate instead of O(D).
158    ///
159    /// Charges `m * k * size_of::<f32>()` bytes to the governor (if set)
160    /// before allocating the table.
161    pub fn build_distance_table(&self, query: &[f32]) -> Result<Vec<Vec<f32>>, VectorError> {
162        debug_assert_eq!(query.len(), self.dim);
163        let total_bytes = self.m * self.k * size_of::<f32>();
164        let _g = try_reserve_or_skip(&self.governor, total_bytes)?;
165        // no-governor: governed via try_reserve_or_skip on preceding line
166        let mut table = Vec::with_capacity(self.m);
167        for sub in 0..self.m {
168            let offset = sub * self.sub_dim;
169            let sub_query = &query[offset..offset + self.sub_dim];
170            // no-governor: inner per-subspace vec; covered by outer reservation m*k*size_of::<f32>()
171            let mut dists = Vec::with_capacity(self.k);
172            for centroid in &self.codebooks[sub] {
173                let d = l2_sub(sub_query, centroid);
174                dists.push(d);
175            }
176            table.push(dists);
177        }
178        Ok(table)
179    }
180
181    /// Compute asymmetric distance using a precomputed distance table.
182    ///
183    /// O(M) per candidate — just M table lookups and additions.
184    #[inline]
185    pub fn asymmetric_distance(&self, table: &[Vec<f32>], code: &[u8]) -> f32 {
186        debug_assert_eq!(code.len(), self.m);
187        let mut dist = 0.0f32;
188        for (sub, &c) in code.iter().enumerate() {
189            dist += table[sub][c as usize];
190        }
191        dist
192    }
193
194    /// Decode a PQ code back to an approximate FP32 vector.
195    ///
196    /// Charges `dim * size_of::<f32>()` bytes to the governor (if set)
197    /// before allocating the output buffer.
198    pub fn decode(&self, code: &[u8]) -> Result<Vec<f32>, VectorError> {
199        debug_assert_eq!(code.len(), self.m);
200        let _g = try_reserve_or_skip(&self.governor, self.dim * size_of::<f32>())?;
201        // no-governor: governed via try_reserve_or_skip on preceding line
202        let mut out = Vec::with_capacity(self.dim);
203        for (sub, &c) in code.iter().enumerate() {
204            out.extend_from_slice(&self.codebooks[sub][c as usize]);
205        }
206        Ok(out)
207    }
208
209    /// Serialize the codec to bytes with a versioned magic header.
210    ///
211    /// Format: `[NDPQ\0\0 (6 bytes)][version: u8 = 1][msgpack payload]`
212    ///
213    /// Charges the estimated serialized size to the governor (if set) before
214    /// allocating the output buffer.  The estimate is conservative:
215    /// `m * k * sub_dim * size_of::<f32>() + 64` (header + framing overhead).
216    pub fn to_bytes(&self) -> Result<Vec<u8>, VectorError> {
217        const MAGIC: &[u8; 6] = b"NDPQ\0\0";
218        const VERSION: u8 = 1;
219        let estimated = self.m * self.k * self.sub_dim * size_of::<f32>() + 64;
220        let _g = try_reserve_or_skip(&self.governor, estimated)?;
221        let payload = zerompk::to_msgpack_vec(self).unwrap_or_default();
222        // no-governor: governed via try_reserve_or_skip on preceding line
223        let mut out = Vec::with_capacity(7 + payload.len());
224        out.extend_from_slice(MAGIC);
225        out.push(VERSION);
226        out.extend_from_slice(&payload);
227        Ok(out)
228    }
229
230    /// Deserialize the codec from bytes produced by [`Self::to_bytes`].
231    ///
232    /// Returns `VectorError::InvalidMagic` if the header does not match
233    /// `NDPQ\0\0`, and `VectorError::UnsupportedVersion` for unknown versions.
234    pub fn from_bytes(bytes: &[u8]) -> Result<Self, VectorError> {
235        const MAGIC: &[u8; 6] = b"NDPQ\0\0";
236        const PQ_FORMAT_VERSION: u8 = 1;
237
238        if bytes.len() < 7 || &bytes[0..6] != MAGIC {
239            return Err(VectorError::InvalidMagic);
240        }
241        let version = bytes[6];
242        if version != PQ_FORMAT_VERSION {
243            return Err(VectorError::UnsupportedVersion {
244                found: version,
245                expected: PQ_FORMAT_VERSION,
246            });
247        }
248        zerompk::from_msgpack::<Self>(&bytes[7..])
249            .map_err(|e| VectorError::DeserializationFailed(e.to_string()))
250    }
251
252    fn nearest_centroid(&self, subspace: usize, sub_vec: &[f32]) -> usize {
253        let mut best_idx = 0;
254        let mut best_dist = f32::MAX;
255        for (i, centroid) in self.codebooks[subspace].iter().enumerate() {
256            let d = l2_sub(sub_vec, centroid);
257            if d < best_dist {
258                best_dist = d;
259                best_idx = i;
260            }
261        }
262        best_idx
263    }
264}
265
266/// L2 squared distance for sub-vectors (used in k-means and encoding).
267#[inline]
268fn l2_sub(a: &[f32], b: &[f32]) -> f32 {
269    let mut sum = 0.0f32;
270    for i in 0..a.len() {
271        let d = a[i] - b[i];
272        sum += d * d;
273    }
274    sum
275}
276
277/// Simple k-means clustering for PQ codebook training.
278///
279/// Uses proper k-means++ initialization (weighted d² sampling) with a
280/// deterministic seed so training is reproducible across runs.
281fn kmeans(data: &[&[f32]], dim: usize, k: usize, max_iter: usize) -> Vec<Vec<f32>> {
282    let n = data.len();
283    if n == 0 || k == 0 {
284        return Vec::new();
285    }
286    let k = k.min(n); // Can't have more centroids than data points.
287
288    // K-means++ initialization with deterministic xorshift.
289    let mut rng = crate::hnsw::Xorshift64::new(0xC0FF_EEDE_ADBE_EF42);
290
291    // no-governor: cold k-means++ training; one-time codebook build, governed at call site
292    let mut centroids: Vec<Vec<f32>> = Vec::with_capacity(k);
293    centroids.push(data[0].to_vec());
294
295    let mut min_dists = vec![f32::MAX; n];
296    // Update against the first centroid.
297    for (i, point) in data.iter().enumerate() {
298        let d = l2_sub(point, &centroids[0]);
299        if d < min_dists[i] {
300            min_dists[i] = d;
301        }
302    }
303
304    for _ in 1..k {
305        let total: f64 = min_dists.iter().map(|&d| d as f64).sum();
306        let next_idx = if total < f64::EPSILON {
307            // All points coincide with existing centroids.
308            0
309        } else {
310            let target = rng.next_f64() * total;
311            let mut acc = 0.0f64;
312            let mut chosen = n - 1;
313            for (i, &d) in min_dists.iter().enumerate() {
314                acc += d as f64;
315                if acc >= target {
316                    chosen = i;
317                    break;
318                }
319            }
320            chosen
321        };
322        centroids.push(data[next_idx].to_vec());
323        // Incrementally update min_dists against the new centroid.
324        let last = centroids.last().expect("just pushed");
325        for (i, point) in data.iter().enumerate() {
326            let d = l2_sub(point, last);
327            if d < min_dists[i] {
328                min_dists[i] = d;
329            }
330        }
331    }
332
333    // K-means iterations.
334    let mut assignments = vec![0usize; n];
335    for _ in 0..max_iter {
336        // Assignment step.
337        let mut changed = false;
338        for (i, point) in data.iter().enumerate() {
339            let mut best = 0;
340            let mut best_d = f32::MAX;
341            for (c, centroid) in centroids.iter().enumerate() {
342                let d = l2_sub(point, centroid);
343                if d < best_d {
344                    best_d = d;
345                    best = c;
346                }
347            }
348            if assignments[i] != best {
349                assignments[i] = best;
350                changed = true;
351            }
352        }
353        if !changed {
354            break;
355        }
356
357        // Update step: recompute centroids as means.
358        let mut sums = vec![vec![0.0f32; dim]; k];
359        let mut counts = vec![0usize; k];
360        for (i, point) in data.iter().enumerate() {
361            let c = assignments[i];
362            counts[c] += 1;
363            for d in 0..dim {
364                sums[c][d] += point[d];
365            }
366        }
367        for c in 0..k {
368            if counts[c] > 0 {
369                for d in 0..dim {
370                    centroids[c][d] = sums[c][d] / counts[c] as f32;
371                }
372            }
373        }
374    }
375
376    centroids
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382
383    fn make_clustered_data() -> Vec<Vec<f32>> {
384        // 4 clusters in 4D space, 50 points each.
385        let mut vecs = Vec::new();
386        for cluster in 0..4 {
387            let center = cluster as f32 * 10.0;
388            for i in 0..50 {
389                vecs.push(vec![
390                    center + (i as f32) * 0.1,
391                    center + (i as f32) * 0.05,
392                    center - (i as f32) * 0.1,
393                    center + (i as f32) * 0.02,
394                ]);
395            }
396        }
397        vecs
398    }
399
400    #[test]
401    fn encode_decode_roundtrip() {
402        let vecs = make_clustered_data();
403        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
404        let codec = PqCodec::train(&refs, 4, 2, 16, 10);
405
406        for v in &vecs {
407            let code = codec.encode(v);
408            assert_eq!(code.len(), 2); // M=2 bytes
409            let decoded = codec.decode(&code).unwrap();
410            assert_eq!(decoded.len(), 4);
411        }
412    }
413
414    #[test]
415    fn distance_table_gives_correct_ordering() {
416        let vecs = make_clustered_data();
417        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
418        let codec = PqCodec::train(&refs, 4, 2, 16, 10);
419
420        let codes: Vec<Vec<u8>> = vecs.iter().map(|v| codec.encode(v)).collect();
421        let query = &[5.0, 5.0, 5.0, 5.0];
422        let table = codec.build_distance_table(query).unwrap();
423
424        // Find nearest via PQ distance.
425        let mut pq_dists: Vec<(usize, f32)> = codes
426            .iter()
427            .enumerate()
428            .map(|(i, c)| (i, codec.asymmetric_distance(&table, c)))
429            .collect();
430        pq_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
431
432        // Find nearest via exact L2.
433        let mut exact_dists: Vec<(usize, f32)> = vecs
434            .iter()
435            .enumerate()
436            .map(|(i, v)| (i, l2_sub(query, v)))
437            .collect();
438        exact_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
439
440        // Top-5 from PQ should have significant overlap with exact top-10.
441        let pq_top: std::collections::HashSet<usize> = pq_dists[..5].iter().map(|x| x.0).collect();
442        let exact_top: std::collections::HashSet<usize> =
443            exact_dists[..10].iter().map(|x| x.0).collect();
444        let overlap = pq_top.intersection(&exact_top).count();
445        assert!(overlap >= 3, "PQ recall too low: {overlap}/5 in top-10");
446    }
447
448    #[test]
449    fn batch_encode() {
450        let vecs = make_clustered_data();
451        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
452        let codec = PqCodec::train(&refs, 4, 2, 16, 10);
453
454        let batch = codec.encode_batch(&refs).unwrap();
455        assert_eq!(batch.len(), 2 * 200); // M=2, N=200
456    }
457
458    // golden format test — verifies the on-disk layout is stable.
459    #[test]
460    fn pq_codec_golden_format() {
461        let vecs = make_clustered_data();
462        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
463        let codec = PqCodec::train(&refs, 4, 2, 16, 10);
464
465        let bytes = codec.to_bytes().unwrap();
466
467        // Magic header.
468        assert_eq!(&bytes[0..6], b"NDPQ\0\0", "magic mismatch");
469        // Version byte.
470        assert_eq!(bytes[6], 1u8, "version must be 1");
471        // Payload at offset 7 must decode back to a valid PqCodec.
472        let restored = zerompk::from_msgpack::<PqCodec>(&bytes[7..])
473            .expect("msgpack payload at offset 7 must decode");
474        assert_eq!(restored.dim, codec.dim);
475        assert_eq!(restored.m, codec.m);
476    }
477
478    #[test]
479    fn pq_version_mismatch_returns_error() {
480        // Craft a header with magic correct but version = 0 (unsupported).
481        let mut crafted = b"NDPQ\0\0".to_vec();
482        crafted.push(0u8); // wrong version
483        crafted.extend_from_slice(b"\x80"); // minimal valid msgpack map
484
485        let err = PqCodec::from_bytes(&crafted).unwrap_err();
486        assert!(
487            matches!(
488                err,
489                VectorError::UnsupportedVersion {
490                    found: 0,
491                    expected: 1
492                }
493            ),
494            "expected UnsupportedVersion, got: {err:?}"
495        );
496    }
497
498    #[test]
499    fn pq_invalid_magic_returns_error() {
500        let bad: &[u8] = b"JUNK\0\0\x01some-payload";
501        let err = PqCodec::from_bytes(bad).unwrap_err();
502        assert!(
503            matches!(err, VectorError::InvalidMagic),
504            "expected InvalidMagic, got: {err:?}"
505        );
506    }
507}