Skip to main content

nodedb_vector/rerank/codecs/
pq.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! `RerankCodec` wrapper for Product Quantization (PQ).
4//!
5//! PQ is a training-based codec: `train()` runs k-means over a sample of
6//! vectors to learn per-subspace codebooks. Until training is complete,
7//! `encode` and `prepare_query` return `RerankError::NotTrained`.
8//!
9//! Distance uses the ADC (Asymmetric Distance Computation) model: the query
10//! is kept in FP32 and a per-subspace lookup table is precomputed once via
11//! `prepare_query`; each candidate lookup is then O(M) table additions.
12//! The prepared form maps directly to `PreparedQuery::Lut` — the existing
13//! variant already holds `Vec<Vec<f32>>` which is exactly the PQ distance
14//! table (`lut[sub][centroid]`).
15
16use nodedb_codec::vector_quant::layout::UnifiedQuantizedVectorRef;
17
18use crate::{
19    quantize::pq::PqCodec,
20    rerank::codec::{CodecName, PreparedQuery, RerankCodec},
21    rerank::types::RerankError,
22};
23
24// ── packed_bits_len helper ────────────────────────────────────────────────────
25
26/// PQ stores one centroid-index byte per subspace: packed_bits_len == m.
27#[inline]
28fn pq_packed_bits_len(m: usize) -> usize {
29    m
30}
31
32// ── PqRerank ──────────────────────────────────────────────────────────────────
33
34/// Object-safe `RerankCodec` wrapper around `PqCodec`.
35///
36/// The codec starts untrained. `encode` and `prepare_query` return
37/// `RerankError::NotTrained` until `train()` has been called with a
38/// representative sample of vectors.
39///
40/// `from_codec` accepts a pre-trained `PqCodec` (used when restoring from
41/// a snapshot).
42pub struct PqRerank {
43    codec: Option<PqCodec>,
44    dim: usize,
45    m: usize,
46    k: usize,
47    max_iter: usize,
48}
49
50impl PqRerank {
51    /// Construct an untrained PQ codec configuration.
52    ///
53    /// `m` is the number of subspaces; `k` is centroids per subspace.
54    /// Defaults used by higher-level callers: `m = 8`, `k = 256`.
55    /// `encode` / `distance_prepared` return `RerankError::NotTrained` until
56    /// `train()` has been called.
57    pub fn new(dim: usize, m: usize, k: usize) -> Self {
58        Self {
59            codec: None,
60            dim,
61            m,
62            k,
63            max_iter: 25,
64        }
65    }
66
67    /// Construct from a pre-trained codec (used when restoring from snapshot).
68    pub fn from_codec(codec: PqCodec) -> Self {
69        let dim = codec.dim;
70        let m = codec.m;
71        let k = codec.k;
72        Self {
73            codec: Some(codec),
74            dim,
75            m,
76            k,
77            max_iter: 25,
78        }
79    }
80}
81
82impl RerankCodec for PqRerank {
83    /// Encode a full-precision vector to PQ bytes (one centroid index per subspace).
84    ///
85    /// The serialized form is the raw `UnifiedQuantizedVector` buffer
86    /// (`as_bytes()`): 32-byte `QuantHeader` followed by `m` code bytes.
87    fn encode(&self, v: &[f32]) -> Result<Vec<u8>, RerankError> {
88        if v.len() != self.dim {
89            return Err(RerankError::BadInput(format!(
90                "pq encode: vector len {} != codec dim {}",
91                v.len(),
92                self.dim
93            )));
94        }
95        let codec = self.codec.as_ref().ok_or_else(|| {
96            RerankError::NotTrained(
97                "pq: codec must be trained before encoding (call train() with a sample of vectors)"
98                    .to_string(),
99            )
100        })?;
101        use nodedb_codec::vector_quant::codec::VectorCodec;
102        let quantized = <PqCodec as VectorCodec>::encode(codec, v);
103        Ok(quantized.as_ref().as_bytes().to_vec())
104    }
105
106    /// Prepare the query by precomputing the M×K asymmetric distance table.
107    ///
108    /// The prepared form is `PreparedQuery::Lut` where `lut[sub][centroid]`
109    /// holds the squared L2 distance from the query's sub-vector to each
110    /// centroid of subspace `sub`. This is the standard ADC lookup table.
111    fn prepare_query(&self, q: &[f32]) -> Result<PreparedQuery, RerankError> {
112        if q.len() != self.dim {
113            return Err(RerankError::BadInput(format!(
114                "pq prepare_query: query len {} != codec dim {}",
115                q.len(),
116                self.dim
117            )));
118        }
119        let codec = self.codec.as_ref().ok_or_else(|| {
120            RerankError::NotTrained(
121                "pq: codec must be trained before prepare_query (call train() with a sample of vectors)"
122                    .to_string(),
123            )
124        })?;
125        use nodedb_codec::vector_quant::codec::VectorCodec;
126        let pq_query = <PqCodec as VectorCodec>::prepare_query(codec, q);
127        Ok(PreparedQuery::Lut(pq_query.distance_table))
128    }
129
130    /// Compute asymmetric ADC distance from a prepared query to a PQ-encoded
131    /// candidate.
132    ///
133    /// Expects `PreparedQuery::Lut` produced by `prepare_query`.
134    fn distance_prepared(
135        &self,
136        prepared: &PreparedQuery,
137        encoded: &[u8],
138    ) -> Result<f32, RerankError> {
139        let lut = match prepared {
140            PreparedQuery::Lut(t) => t,
141            _ => {
142                return Err(RerankError::BadInput(
143                    "pq distance: expected PreparedQuery::Lut".to_string(),
144                ));
145            }
146        };
147
148        let packed_len = pq_packed_bits_len(self.m);
149        let uqv_ref = UnifiedQuantizedVectorRef::from_bytes(encoded, packed_len).map_err(|e| {
150            RerankError::BadInput(format!("pq distance: failed to parse encoded bytes: {e}"))
151        })?;
152
153        let packed = uqv_ref.packed_bits();
154        // ADC: sum lut[sub][code[sub]] for each subspace.
155        let dist = packed
156            .iter()
157            .enumerate()
158            .map(|(sub, &code)| {
159                lut.get(sub)
160                    .and_then(|row| row.get(code as usize).copied())
161                    .unwrap_or(0.0)
162            })
163            .sum();
164        Ok(dist)
165    }
166
167    fn name(&self) -> CodecName {
168        CodecName::Pq
169    }
170
171    fn to_bytes(&self) -> Result<Vec<u8>, RerankError> {
172        let codec = self.codec.as_ref().ok_or_else(|| {
173            RerankError::NotTrained("pq sidecar serialize: codec not trained".to_string())
174        })?;
175        codec
176            .to_bytes()
177            .map_err(|e| RerankError::BadInput(format!("pq to_bytes: {e}")))
178    }
179
180    /// Train PQ codebooks via k-means on a sample of vectors.
181    ///
182    /// Validates that:
183    /// - `samples` is non-empty.
184    /// - Every sample has length `self.dim`.
185    /// - `self.dim % self.m == 0` (PQ requires divisible dimensionality).
186    /// - At least `self.k` samples are provided (k-means needs ≥ k points).
187    ///
188    /// On success, stores the trained codec and subsequent `encode` /
189    /// `distance_prepared` calls will succeed.
190    fn train(&mut self, samples: &[&[f32]]) -> Result<(), RerankError> {
191        if samples.is_empty() {
192            return Err(RerankError::BadInput(
193                "pq train: empty sample set".to_string(),
194            ));
195        }
196        for s in samples {
197            if s.len() != self.dim {
198                return Err(RerankError::BadInput(format!(
199                    "pq train: sample has len {} but codec dim is {}",
200                    s.len(),
201                    self.dim
202                )));
203            }
204        }
205        if !self.dim.is_multiple_of(self.m) {
206            return Err(RerankError::BadInput(format!(
207                "pq train: dim ({}) must be divisible by m ({})",
208                self.dim, self.m
209            )));
210        }
211        if samples.len() < self.k {
212            return Err(RerankError::BadInput(format!(
213                "pq train: need >= k samples for k-means, got {}",
214                samples.len()
215            )));
216        }
217        let codec = PqCodec::train(samples, self.dim, self.m, self.k, self.max_iter);
218        self.codec = Some(codec);
219        Ok(())
220    }
221}
222
223// ── Tests ─────────────────────────────────────────────────────────────────────
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    const DIM: usize = 32;
230    const M: usize = 4;
231    const K: usize = 8;
232    const N: usize = 64;
233
234    fn det_vec(i: usize, dim: usize) -> Vec<f32> {
235        (0..dim)
236            .map(|j| ((i * 31 + j) % 100) as f32 / 100.0)
237            .collect()
238    }
239
240    fn trained() -> PqRerank {
241        let vecs: Vec<Vec<f32>> = (0..N).map(|i| det_vec(i, DIM)).collect();
242        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
243        let mut codec = PqRerank::new(DIM, M, K);
244        codec.train(&refs).expect("train must succeed");
245        codec
246    }
247
248    #[test]
249    fn train_then_encode_roundtrip() {
250        let codec = trained();
251        let v = det_vec(0, DIM);
252        let enc = codec.encode(&v).expect("encode");
253        let prep = codec.prepare_query(&v).expect("prepare_query");
254        let dist = codec.distance_prepared(&prep, &enc).expect("distance");
255        assert!(dist.is_finite(), "distance must be finite, got {dist}");
256        assert!(dist >= 0.0, "distance must be non-negative, got {dist}");
257        // Self-distance should be small for ADC on identical vector.
258        assert!(dist < 1.0, "self-distance too large: {dist}");
259    }
260
261    #[test]
262    fn encode_before_train_returns_not_trained() {
263        let codec = PqRerank::new(DIM, M, K);
264        let v = det_vec(0, DIM);
265        let err = codec.encode(&v).unwrap_err();
266        let msg = format!("{err}");
267        assert!(
268            msg.contains("not trained") || msg.contains("trained"),
269            "expected 'trained' in error, got: {msg}"
270        );
271    }
272
273    #[test]
274    fn train_with_wrong_dim_sample_fails() {
275        let vecs: Vec<Vec<f32>> = (0..N).map(|i| det_vec(i, DIM)).collect();
276        let mut refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
277        let bad = det_vec(0, DIM + 4);
278        refs.push(bad.as_slice());
279        let mut codec = PqRerank::new(DIM, M, K);
280        let err = codec.train(&refs).unwrap_err();
281        let msg = format!("{err}");
282        assert!(
283            msg.contains("bad input"),
284            "expected bad input error, got: {msg}"
285        );
286    }
287
288    #[test]
289    fn train_with_indivisible_dim_fails() {
290        // dim=33, m=4: 33 % 4 != 0
291        let vecs: Vec<Vec<f32>> = (0..16).map(|i| det_vec(i, 33)).collect();
292        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
293        let mut codec = PqRerank::new(33, 4, 8);
294        let err = codec.train(&refs).unwrap_err();
295        let msg = format!("{err}");
296        assert!(
297            msg.contains("divisible"),
298            "expected divisibility error, got: {msg}"
299        );
300    }
301
302    #[test]
303    fn train_with_too_few_samples_fails() {
304        // k=8 but only 4 samples
305        let vecs: Vec<Vec<f32>> = (0..4).map(|i| det_vec(i, DIM)).collect();
306        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
307        let mut codec = PqRerank::new(DIM, M, 8);
308        let err = codec.train(&refs).unwrap_err();
309        let msg = format!("{err}");
310        assert!(
311            msg.contains("k samples") || msg.contains("bad input"),
312            "expected sample count error, got: {msg}"
313        );
314    }
315
316    #[test]
317    fn name_is_pq() {
318        let codec = PqRerank::new(DIM, M, K);
319        assert_eq!(codec.name(), CodecName::Pq);
320    }
321}