Skip to main content

nodedb_vector/quantize/
pq_codec.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! `VectorCodec` implementation for `PqCodec`.
4//!
5//! Wraps the existing Product Quantization codec as a dual-phase codec.
6//! The `Quantized` newtype holds a `UnifiedQuantizedVector` with
7//! `QuantMode::Pq` and M code bytes in the packed-bits region.
8//! The `Query` type carries both the precomputed per-subspace distance table
9//! (`Vec<Vec<f32>>`) and the original FP32 query (for symmetric dequantization).
10
11use nodedb_codec::vector_quant::{
12    codec::{AdcLut, VectorCodec},
13    layout::{QuantHeader, QuantMode, UnifiedQuantizedVector},
14};
15
16use crate::quantize::pq::PqCodec;
17
18// ── Newtype ───────────────────────────────────────────────────────────────────
19
20/// Thin newtype wrapping `UnifiedQuantizedVector` for PQ-encoded vectors.
21pub struct PqQuantized(pub UnifiedQuantizedVector);
22
23impl AsRef<UnifiedQuantizedVector> for PqQuantized {
24    #[inline]
25    fn as_ref(&self) -> &UnifiedQuantizedVector {
26        &self.0
27    }
28}
29
30// ── Query type ────────────────────────────────────────────────────────────────
31
32/// Prepared query for PQ asymmetric distance computation.
33///
34/// `distance_table[sub][centroid]` holds the precomputed L2 squared distance
35/// from the query's sub-vector in `sub` to each of the K centroids.
36/// `raw` is the original FP32 query, retained for symmetric dequantization.
37pub struct PqQuery {
38    /// Per-subspace distance table: `distance_table[sub][centroid] -> f32`.
39    pub distance_table: Vec<Vec<f32>>,
40    /// Original FP32 query (used for symmetric distance dequantization).
41    pub raw: Vec<f32>,
42}
43
44// ── Helper ────────────────────────────────────────────────────────────────────
45
46#[inline]
47fn packed_bits_of(q: &PqQuantized) -> &[u8] {
48    q.0.packed_bits()
49}
50
51// ── VectorCodec impl ──────────────────────────────────────────────────────────
52
53impl VectorCodec for PqCodec {
54    type Quantized = PqQuantized;
55    /// Prepared query: precomputed distance table + original FP32 query.
56    type Query = PqQuery;
57
58    /// Encode an FP32 vector: one centroid index byte per subspace.
59    ///
60    /// # Panics
61    ///
62    /// `UnifiedQuantizedVector::new` fails only when the outlier bitmask does
63    /// not match the provided outlier slice. With `outlier_bitmask = 0` and an
64    /// empty slice this can never happen. The `expect` is therefore unreachable
65    /// in practice.
66    fn encode(&self, v: &[f32]) -> Self::Quantized {
67        let codes = self.encode(v);
68        let header = QuantHeader {
69            quant_mode: QuantMode::Pq as u16,
70            dim: self.dim as u16,
71            global_scale: 0.0,
72            residual_norm: 0.0,
73            dot_quantized: 0.0,
74            outlier_bitmask: 0,
75            reserved: [0; 8],
76        };
77        let uqv = UnifiedQuantizedVector::new(header, &codes, &[])
78            .expect("PqCodec::encode: layout construction is infallible (no outliers)");
79        PqQuantized(uqv)
80    }
81
82    /// Prepare the query by precomputing the M×K asymmetric distance table.
83    ///
84    /// The `VectorCodec` trait does not propagate errors.  A `PqCodec` used
85    /// via this trait path is created by `PqCodec::train` which sets no
86    /// governor; `build_distance_table` therefore always returns `Ok` here.
87    /// If a governor is attached and its budget is exhausted the caller that
88    /// constructed the codec is responsible for handling the error — this impl
89    /// panics with a descriptive message so the budget violation is never
90    /// silently ignored.
91    fn prepare_query(&self, q: &[f32]) -> Self::Query {
92        let distance_table = self.build_distance_table(q).expect(
93            "PqCodec::prepare_query: build_distance_table failed; \
94                     if a governor is attached ensure it has sufficient budget",
95        );
96        PqQuery {
97            distance_table,
98            raw: q.to_vec(),
99        }
100    }
101
102    /// Build the `AdcLut` from the precomputed distance table for use by
103    /// SIMD rerank kernels (`pshufb` / `vpermb`).
104    fn adc_lut(&self, q: &Self::Query) -> Option<AdcLut> {
105        let m = self.m as u16;
106        let k = self.k as u16;
107        let mut lut = AdcLut::new(m, k);
108        for (sub, sub_table) in q.distance_table.iter().enumerate() {
109            for (centroid, &dist) in sub_table.iter().enumerate() {
110                let idx = sub * self.k + centroid;
111                lut.table[idx] = dist;
112            }
113        }
114        Some(lut)
115    }
116
117    /// Symmetric distance between two PQ-encoded vectors.
118    ///
119    /// Both codes are decoded to approximate FP32 vectors via the codebook,
120    /// then the squared L2 difference is accumulated. This is the correct
121    /// definition of symmetric PQ distance: each vector is approximated by
122    /// its nearest centroids, and the distance is computed in FP32.
123    #[inline]
124    fn fast_symmetric_distance(&self, q: &Self::Quantized, v: &Self::Quantized) -> f32 {
125        let dq_a = self
126            .decode(packed_bits_of(q))
127            .expect("PqCodec::fast_symmetric_distance: decode failed");
128        let dq_b = self
129            .decode(packed_bits_of(v))
130            .expect("PqCodec::fast_symmetric_distance: decode failed");
131        dq_a.iter()
132            .zip(dq_b.iter())
133            .map(|(&a, &b)| {
134                let d = a - b;
135                d * d
136            })
137            .sum()
138    }
139
140    /// Asymmetric ADC distance: precomputed distance table vs stored code.
141    ///
142    /// O(M) per candidate — delegates to `PqCodec::asymmetric_distance`.
143    #[inline]
144    fn exact_asymmetric_distance(&self, q: &Self::Query, v: &Self::Quantized) -> f32 {
145        self.asymmetric_distance(&q.distance_table, packed_bits_of(v))
146    }
147}
148
149// ── Tests ─────────────────────────────────────────────────────────────────────
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    fn make_codec() -> PqCodec {
156        let vecs: Vec<Vec<f32>> = (0..80)
157            .map(|i| {
158                let c = (i / 20) as f32 * 5.0;
159                vec![
160                    c + (i as f32) * 0.1,
161                    c - (i as f32) * 0.05,
162                    c + 1.0,
163                    c - 1.0,
164                ]
165            })
166            .collect();
167        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
168        PqCodec::train(&refs, 4, 2, 8, 10)
169    }
170
171    /// `encode` round-trip: packed_bits in the UQV must match the raw
172    /// `PqCodec::encode` output.
173    #[test]
174    fn encode_packed_bits_matches_raw_encode() {
175        let codec = make_codec();
176        let v = vec![2.0f32, 1.0, 3.0, -1.0];
177        let raw = codec.encode(&v);
178        let quantized = <PqCodec as VectorCodec>::encode(&codec, &v);
179        assert_eq!(quantized.as_ref().packed_bits(), raw.as_slice());
180    }
181
182    /// `fast_symmetric_distance` returns a non-negative finite value.
183    #[test]
184    fn fast_symmetric_distance_is_non_negative_finite() {
185        let codec = make_codec();
186        let a = <PqCodec as VectorCodec>::encode(&codec, &[0.5, 0.1, 1.0, -0.5]);
187        let b = <PqCodec as VectorCodec>::encode(&codec, &[5.0, 4.0, 6.0, 4.5]);
188        let d = codec.fast_symmetric_distance(&a, &b);
189        assert!(d.is_finite(), "expected finite distance, got {d}");
190        assert!(d >= 0.0, "expected non-negative distance, got {d}");
191    }
192
193    /// `exact_asymmetric_distance` returns a non-negative finite value.
194    #[test]
195    fn exact_asymmetric_distance_is_non_negative_finite() {
196        let codec = make_codec();
197        let q = codec.prepare_query(&[0.5, 0.1, 1.0, -0.5]);
198        let v = <PqCodec as VectorCodec>::encode(&codec, &[5.0, 4.0, 6.0, 4.5]);
199        let d = codec.exact_asymmetric_distance(&q, &v);
200        assert!(d.is_finite(), "expected finite distance, got {d}");
201        assert!(d >= 0.0, "expected non-negative distance, got {d}");
202    }
203
204    /// `adc_lut` produces a table with the correct shape.
205    #[test]
206    fn adc_lut_has_correct_shape() {
207        let codec = make_codec();
208        let q = codec.prepare_query(&[0.5, 0.1, 1.0, -0.5]);
209        let lut =
210            <PqCodec as VectorCodec>::adc_lut(&codec, &q).expect("PqCodec must produce an AdcLut");
211        assert_eq!(lut.subspace_count, codec.m as u16);
212        assert_eq!(lut.centroids_per_subspace, codec.k as u16);
213        assert_eq!(lut.table.len(), codec.m * codec.k);
214        assert!(lut.table.iter().all(|v| v.is_finite()));
215    }
216
217    /// Verify the trait impl compiles via a generic function.
218    fn use_vector_codec<C: VectorCodec>(c: &C, q: &[f32], v: &[f32]) -> f32 {
219        let qv = c.encode(v);
220        let qq = c.prepare_query(q);
221        c.fast_symmetric_distance(&qv, &qv) + c.exact_asymmetric_distance(&qq, &qv)
222    }
223
224    #[test]
225    fn trait_bounds_compile() {
226        let codec = make_codec();
227        let result = use_vector_codec(&codec, &[0.5, 0.1, 1.0, -0.5], &[5.0, 4.0, 6.0, 4.5]);
228        assert!(result.is_finite());
229    }
230}