Skip to main content

nodedb_vector/rerank/codecs/
sq8.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! `RerankCodec` wrapper for SQ8 scalar quantization.
4//!
5//! Bridges `Sq8Codec` (which implements `VectorCodec` with associated types)
6//! into the object-safe `RerankCodec` trait used by the rerank sidecar.
7
8use nodedb_codec::vector_quant::layout::UnifiedQuantizedVectorRef;
9
10use crate::{
11    quantize::sq8::Sq8Codec,
12    rerank::codec::{CodecName, PreparedQuery, RerankCodec},
13    rerank::types::RerankError,
14};
15
16// ── packed_bits_len helper ────────────────────────────────────────────────────
17
18/// SQ8 is 8 bpw, so packed_bits_len == dim bytes.
19#[inline]
20fn sq8_packed_bits_len(dim: usize) -> usize {
21    dim
22}
23
24// ── Sq8Rerank ─────────────────────────────────────────────────────────────────
25
26/// Object-safe `RerankCodec` wrapper around `Sq8Codec`.
27///
28/// `train()` calls `Sq8Codec::calibrate` to fit per-dimension min/max from a
29/// sample of vectors. Subsequent `encode` / `distance_prepared` calls use the
30/// calibrated codec.
31pub struct Sq8Rerank {
32    codec: Sq8Codec,
33    dim: usize,
34}
35
36impl Sq8Rerank {
37    /// Create an untrained wrapper with a default-calibrated codec.
38    ///
39    /// The default codec treats every dimension's min as 0.0 and max as 1.0,
40    /// which is suitable for normalized embeddings. For best accuracy call
41    /// `train()` with representative samples before encoding.
42    pub fn new(dim: usize) -> Self {
43        // Build a minimal calibration over the unit range so encoding is
44        // functional before train() is called.
45        let lo = vec![0.0f32; dim];
46        let hi = vec![1.0f32; dim];
47        let samples: Vec<&[f32]> = vec![lo.as_slice(), hi.as_slice()];
48        let codec = Sq8Codec::calibrate(&samples, dim);
49        Self { codec, dim }
50    }
51
52    /// Wrap an already-trained `Sq8Codec`.
53    pub fn from_codec(codec: Sq8Codec) -> Self {
54        let dim = codec.dim;
55        Self { codec, dim }
56    }
57}
58
59impl RerankCodec for Sq8Rerank {
60    /// Encode a full-precision vector to SQ8 bytes.
61    ///
62    /// The serialized form is the raw `UnifiedQuantizedVector` buffer
63    /// (`as_bytes()`), which embeds a 32-byte `QuantHeader` followed by
64    /// `dim` packed INT8 codes.
65    fn encode(&self, v: &[f32]) -> Result<Vec<u8>, RerankError> {
66        if v.len() != self.dim {
67            return Err(RerankError::BadInput(format!(
68                "sq8 encode: vector len {} != codec dim {}",
69                v.len(),
70                self.dim
71            )));
72        }
73        use nodedb_codec::vector_quant::codec::VectorCodec as _;
74        let quantized = self.codec.encode(v);
75        Ok(quantized.as_ref().as_bytes().to_vec())
76    }
77
78    /// Prepare the query for repeated distance calls.
79    ///
80    /// SQ8 is asymmetric: the query is kept in full FP32 precision while
81    /// candidates are INT8. The prepared form is therefore `PreparedQuery::Raw`.
82    fn prepare_query(&self, q: &[f32]) -> Result<PreparedQuery, RerankError> {
83        if q.len() != self.dim {
84            return Err(RerankError::BadInput(format!(
85                "sq8 prepare_query: query len {} != codec dim {}",
86                q.len(),
87                self.dim
88            )));
89        }
90        Ok(PreparedQuery::Raw(q.to_vec()))
91    }
92
93    /// Compute asymmetric L2 distance from a prepared FP32 query to an
94    /// SQ8-encoded candidate.
95    fn distance_prepared(
96        &self,
97        prepared: &PreparedQuery,
98        encoded: &[u8],
99    ) -> Result<f32, RerankError> {
100        let q = match prepared {
101            PreparedQuery::Raw(q) => q,
102            _ => {
103                return Err(RerankError::BadInput(
104                    "sq8 distance: expected PreparedQuery::Raw".to_string(),
105                ));
106            }
107        };
108
109        let packed_len = sq8_packed_bits_len(self.dim);
110        let uqv_ref = UnifiedQuantizedVectorRef::from_bytes(encoded, packed_len).map_err(|e| {
111            RerankError::BadInput(format!("sq8 distance: failed to parse encoded bytes: {e}"))
112        })?;
113
114        let packed = uqv_ref.packed_bits();
115        let dist = self.codec.asymmetric_l2(q, packed);
116        Ok(dist)
117    }
118
119    fn name(&self) -> CodecName {
120        CodecName::Sq8
121    }
122
123    fn to_bytes(&self) -> Result<Vec<u8>, RerankError> {
124        Ok(self.codec.to_bytes())
125    }
126
127    /// Calibrate from a sample of vectors.
128    ///
129    /// Replaces the current codec state. Requires at least one sample.
130    fn train(&mut self, samples: &[&[f32]]) -> Result<(), RerankError> {
131        if samples.is_empty() {
132            return Err(RerankError::BadInput(
133                "sq8 train: empty sample set".to_string(),
134            ));
135        }
136        self.codec = Sq8Codec::calibrate(samples, self.dim);
137        Ok(())
138    }
139}
140
141// ── Tests ─────────────────────────────────────────────────────────────────────
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146
147    const DIM: usize = 16;
148    const EPS: f32 = 1e-2;
149
150    fn make_vec(base: f32) -> Vec<f32> {
151        (0..DIM).map(|i| base + i as f32 * 0.01).collect()
152    }
153
154    fn trained_codec() -> Sq8Rerank {
155        let samples: Vec<Vec<f32>> = (0..50).map(|i| make_vec(i as f32 * 0.1)).collect();
156        let refs: Vec<&[f32]> = samples.iter().map(|v| v.as_slice()).collect();
157        let mut codec = Sq8Rerank::new(DIM);
158        codec.train(&refs).expect("train must succeed");
159        codec
160    }
161
162    #[test]
163    fn round_trip_returns_finite_distance() {
164        let codec = trained_codec();
165        let v1 = make_vec(0.5);
166        let v2 = make_vec(1.0);
167
168        let enc = codec.encode(&v1).expect("encode v1");
169        let prepared = codec.prepare_query(&v2).expect("prepare_query v2");
170        let dist = codec
171            .distance_prepared(&prepared, &enc)
172            .expect("distance_prepared");
173        assert!(dist.is_finite(), "expected finite distance, got {dist}");
174        assert!(dist >= 0.0, "expected non-negative distance, got {dist}");
175    }
176
177    #[test]
178    fn identical_vectors_small_distance() {
179        let codec = trained_codec();
180        let v = make_vec(0.5);
181
182        let enc = codec.encode(&v).expect("encode");
183        let prepared = codec.prepare_query(&v).expect("prepare_query");
184        let dist = codec
185            .distance_prepared(&prepared, &enc)
186            .expect("distance_prepared");
187        assert!(dist.is_finite());
188        assert!(
189            dist < EPS,
190            "identical vectors should have near-zero distance, got {dist}"
191        );
192    }
193
194    #[test]
195    fn wrong_prepared_query_variant_returns_bad_input() {
196        let codec = trained_codec();
197        let v = make_vec(0.5);
198        let enc = codec.encode(&v).expect("encode");
199        let bad_prepared = PreparedQuery::Bytes(vec![0u8; 8]);
200
201        let result = codec.distance_prepared(&bad_prepared, &enc);
202        assert!(result.is_err(), "expected BadInput error");
203        let msg = format!("{}", result.unwrap_err());
204        assert!(
205            msg.contains("Raw"),
206            "error message should mention Raw, got: {msg}"
207        );
208    }
209
210    #[test]
211    fn name_returns_sq8() {
212        let codec = Sq8Rerank::new(DIM);
213        assert_eq!(codec.name(), CodecName::Sq8);
214    }
215
216    #[test]
217    fn train_calibrates_without_error() {
218        let mut codec = Sq8Rerank::new(DIM);
219        let samples: Vec<Vec<f32>> = (0..20).map(|i| make_vec(i as f32 * 0.05)).collect();
220        let refs: Vec<&[f32]> = samples.iter().map(|v| v.as_slice()).collect();
221        codec.train(&refs).expect("train must succeed");
222
223        // After training, encode + distance must still work.
224        let v = make_vec(0.5);
225        let enc = codec.encode(&v).expect("encode after train");
226        let prep = codec.prepare_query(&v).expect("prepare after train");
227        let dist = codec
228            .distance_prepared(&prep, &enc)
229            .expect("distance after train");
230        assert!(dist.is_finite());
231    }
232
233    #[test]
234    fn wrong_dim_encode_returns_error() {
235        let codec = Sq8Rerank::new(DIM);
236        let bad = vec![0.0f32; DIM + 1];
237        assert!(codec.encode(&bad).is_err());
238    }
239}