Skip to main content

nodedb_vector/rerank/
sidecar.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use super::codec::{CodecName, PreparedQuery, RerankCodec};
7use super::types::RerankError;
8
9const SIDECAR_MAGIC: [u8; 4] = *b"NDCC";
10const SIDECAR_VERSION: u8 = 1;
11
12/// Per-collection encoded-vector storage keyed by surrogate id, paired with
13/// the trained codec. Encoded vectors live alongside (not inside) the HNSW
14/// index — HNSW keeps full-precision vectors for graph traversal; the sidecar
15/// is consulted only during base-layer rerank.
16pub struct CodecSidecar {
17    codec: Arc<dyn RerankCodec>,
18    encoded: HashMap<u32, Vec<u8>>,
19}
20
21impl CodecSidecar {
22    pub fn new(codec: Arc<dyn RerankCodec>) -> Self {
23        Self {
24            codec,
25            encoded: HashMap::new(),
26        }
27    }
28
29    pub fn codec_name(&self) -> CodecName {
30        self.codec.name()
31    }
32
33    /// Encode a vector and insert it under `id`. Overwrites any existing entry.
34    pub fn encode_and_insert(&mut self, id: u32, vector: &[f32]) -> Result<(), RerankError> {
35        let bytes = self.codec.encode(vector)?;
36        self.encoded.insert(id, bytes);
37        Ok(())
38    }
39
40    pub fn remove(&mut self, id: u32) {
41        self.encoded.remove(&id);
42    }
43
44    pub fn get(&self, id: u32) -> Option<&[u8]> {
45        self.encoded.get(&id).map(|v| v.as_slice())
46    }
47
48    pub fn len(&self) -> usize {
49        self.encoded.len()
50    }
51
52    pub fn is_empty(&self) -> bool {
53        self.encoded.is_empty()
54    }
55
56    /// Serialize the sidecar (codec state + all encoded vectors) to bytes.
57    ///
58    /// Format: `[NDCC (4 bytes)][version: u8 = 1][msgpack payload]`
59    pub fn to_bytes(&self) -> Result<Vec<u8>, RerankError> {
60        let codec_bytes = self.codec.to_bytes()?;
61        let codec_name_byte = codec_name_to_u8(self.codec.name());
62
63        #[derive(
64            serde::Serialize, serde::Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
65        )]
66        struct Payload {
67            codec_name: u8,
68            codec_bytes: Vec<u8>,
69            encoded: Vec<(u32, Vec<u8>)>,
70        }
71
72        let payload = Payload {
73            codec_name: codec_name_byte,
74            codec_bytes,
75            encoded: self.encoded.iter().map(|(k, v)| (*k, v.clone())).collect(),
76        };
77
78        let body = zerompk::to_msgpack_vec(&payload)
79            .map_err(|e| RerankError::BadInput(format!("sidecar serialize: {e}")))?;
80        let mut buf = Vec::with_capacity(5 + body.len());
81        buf.extend_from_slice(&SIDECAR_MAGIC);
82        buf.push(SIDECAR_VERSION);
83        buf.extend_from_slice(&body);
84        Ok(buf)
85    }
86
87    /// Deserialize a sidecar from bytes produced by [`Self::to_bytes`].
88    pub fn from_bytes(bytes: &[u8]) -> Result<Self, RerankError> {
89        if bytes.len() < 5 {
90            return Err(RerankError::BadInput(
91                "sidecar from_bytes: too short".into(),
92            ));
93        }
94        if bytes[..4] != SIDECAR_MAGIC {
95            return Err(RerankError::BadInput(
96                "sidecar from_bytes: bad magic".into(),
97            ));
98        }
99        let version = bytes[4];
100        if version != SIDECAR_VERSION {
101            return Err(RerankError::BadInput(format!(
102                "sidecar from_bytes: unknown version {version}"
103            )));
104        }
105
106        #[derive(
107            serde::Serialize, serde::Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
108        )]
109        struct Payload {
110            codec_name: u8,
111            codec_bytes: Vec<u8>,
112            encoded: Vec<(u32, Vec<u8>)>,
113        }
114
115        let payload: Payload = zerompk::from_msgpack(&bytes[5..])
116            .map_err(|e| RerankError::BadInput(format!("sidecar deserialize: {e}")))?;
117
118        let codec_name = codec_name_from_u8(payload.codec_name).ok_or_else(|| {
119            RerankError::BadInput(format!(
120                "sidecar from_bytes: unknown codec_name byte {}",
121                payload.codec_name
122            ))
123        })?;
124        let codec = super::codec::rerank_codec_from_bytes(codec_name, &payload.codec_bytes)?;
125        let encoded = payload.encoded.into_iter().collect();
126        Ok(CodecSidecar { codec, encoded })
127    }
128
129    pub fn prepare_query(&self, query: &[f32]) -> Result<PreparedQuery, RerankError> {
130        self.codec.prepare_query(query)
131    }
132
133    /// Compute distance from a prepared query to the encoded vector at `id`.
134    /// Returns `Ok(None)` when the id isn't in the sidecar (lost / not yet
135    /// encoded); returns the distance otherwise.
136    pub fn distance_prepared(
137        &self,
138        prepared: &PreparedQuery,
139        id: u32,
140    ) -> Result<Option<f32>, RerankError> {
141        match self.encoded.get(&id) {
142            None => Ok(None),
143            Some(bytes) => self.codec.distance_prepared(prepared, bytes).map(Some),
144        }
145    }
146}
147
148fn codec_name_to_u8(name: CodecName) -> u8 {
149    match name {
150        CodecName::Sq8 => 0,
151        CodecName::Pq => 1,
152        CodecName::Binary => 2,
153        CodecName::RaBitQ => 3,
154        CodecName::Bbq => 4,
155    }
156}
157
158fn codec_name_from_u8(b: u8) -> Option<CodecName> {
159    match b {
160        0 => Some(CodecName::Sq8),
161        1 => Some(CodecName::Pq),
162        2 => Some(CodecName::Binary),
163        3 => Some(CodecName::RaBitQ),
164        4 => Some(CodecName::Bbq),
165        _ => None,
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use crate::rerank::codec::CodecName;
173
174    struct StubCodec;
175
176    impl RerankCodec for StubCodec {
177        fn encode(&self, v: &[f32]) -> Result<Vec<u8>, RerankError> {
178            Ok(v.iter().flat_map(|x| x.to_le_bytes()).collect())
179        }
180
181        fn prepare_query(&self, q: &[f32]) -> Result<PreparedQuery, RerankError> {
182            Ok(PreparedQuery::Raw(q.to_vec()))
183        }
184
185        fn distance_prepared(
186            &self,
187            prepared: &PreparedQuery,
188            encoded: &[u8],
189        ) -> Result<f32, RerankError> {
190            let query = match prepared {
191                PreparedQuery::Raw(v) => v,
192                _ => {
193                    return Err(RerankError::BadInput(
194                        "StubCodec expects Raw prepared query".into(),
195                    ));
196                }
197            };
198            let encoded_floats: Vec<f32> = encoded
199                .chunks_exact(4)
200                .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
201                .collect();
202            if query.len() != encoded_floats.len() {
203                return Err(RerankError::BadInput("dimension mismatch".into()));
204            }
205            let dist = query
206                .iter()
207                .zip(encoded_floats.iter())
208                .map(|(a, b)| (a - b) * (a - b))
209                .sum::<f32>()
210                .sqrt();
211            Ok(dist)
212        }
213
214        fn name(&self) -> CodecName {
215            CodecName::Binary
216        }
217
218        fn to_bytes(&self) -> Result<Vec<u8>, RerankError> {
219            Err(RerankError::BadInput(
220                "StubCodec does not support serialization".into(),
221            ))
222        }
223    }
224
225    fn make_sidecar() -> CodecSidecar {
226        CodecSidecar::new(Arc::new(StubCodec))
227    }
228
229    #[test]
230    fn insert_and_get() {
231        let mut s = make_sidecar();
232        assert!(s.is_empty());
233        s.encode_and_insert(1, &[1.0, 2.0]).unwrap();
234        s.encode_and_insert(2, &[3.0, 4.0]).unwrap();
235        s.encode_and_insert(3, &[5.0, 6.0]).unwrap();
236        assert_eq!(s.len(), 3);
237
238        let expected_1: Vec<u8> = [1.0f32, 2.0f32]
239            .iter()
240            .flat_map(|x| x.to_le_bytes())
241            .collect();
242        assert_eq!(s.get(1), Some(expected_1.as_slice()));
243    }
244
245    #[test]
246    fn remove_returns_none() {
247        let mut s = make_sidecar();
248        s.encode_and_insert(10, &[1.0]).unwrap();
249        s.remove(10);
250        assert_eq!(s.get(10), None);
251        let prepared = s.prepare_query(&[1.0]).unwrap();
252        assert_eq!(s.distance_prepared(&prepared, 10).unwrap(), None);
253    }
254
255    #[test]
256    fn distance_prepared_correct() {
257        let mut s = make_sidecar();
258        s.encode_and_insert(5, &[0.0, 0.0]).unwrap();
259        let prepared = s.prepare_query(&[3.0, 4.0]).unwrap();
260        let dist = s.distance_prepared(&prepared, 5).unwrap().unwrap();
261        assert!((dist - 5.0).abs() < 1e-5, "expected L2=5.0, got {dist}");
262    }
263
264    #[test]
265    fn codec_name_passthrough() {
266        let s = make_sidecar();
267        assert_eq!(s.codec_name(), CodecName::Binary);
268        assert_eq!(s.codec_name().as_str(), "binary");
269    }
270
271    #[test]
272    fn len_and_is_empty() {
273        let mut s = make_sidecar();
274        assert!(s.is_empty());
275        s.encode_and_insert(1, &[1.0]).unwrap();
276        assert!(!s.is_empty());
277        assert_eq!(s.len(), 1);
278        s.remove(1);
279        assert!(s.is_empty());
280    }
281
282    // ── Sidecar serialization tests ────────────────────────────────────────────
283
284    fn det_vec(i: usize, dim: usize) -> Vec<f32> {
285        (0..dim)
286            .map(|j| ((i * 31 + j) % 100) as f32 / 100.0)
287            .collect()
288    }
289
290    #[test]
291    fn sidecar_roundtrip_sq8() {
292        use crate::rerank::codecs::Sq8Rerank;
293        let dim = 16;
294        let mut codec = Sq8Rerank::new(dim);
295        let samples: Vec<Vec<f32>> = (0..20).map(|i| det_vec(i, dim)).collect();
296        let refs: Vec<&[f32]> = samples.iter().map(|v| v.as_slice()).collect();
297        codec.train(&refs).unwrap();
298
299        let mut s = CodecSidecar::new(Arc::new(codec));
300        for i in 0..5u32 {
301            s.encode_and_insert(i, &det_vec(i as usize, dim)).unwrap();
302        }
303        let bytes = s.to_bytes().expect("to_bytes");
304        let s2 = CodecSidecar::from_bytes(&bytes).expect("from_bytes");
305        assert_eq!(s2.codec_name(), CodecName::Sq8);
306        for i in 0..5u32 {
307            assert_eq!(s.get(i), s2.get(i), "encoded bytes differ for id {i}");
308        }
309    }
310
311    #[test]
312    fn sidecar_roundtrip_binary() {
313        use crate::rerank::codecs::BinaryRerank;
314        let dim = 16;
315        let mut s = CodecSidecar::new(Arc::new(BinaryRerank::new(dim)));
316        for i in 0..5u32 {
317            s.encode_and_insert(i, &det_vec(i as usize, dim)).unwrap();
318        }
319        let bytes = s.to_bytes().expect("to_bytes");
320        let s2 = CodecSidecar::from_bytes(&bytes).expect("from_bytes");
321        assert_eq!(s2.codec_name(), CodecName::Binary);
322        for i in 0..5u32 {
323            assert_eq!(s.get(i), s2.get(i), "encoded bytes differ for id {i}");
324        }
325    }
326
327    #[test]
328    fn sidecar_roundtrip_pq() {
329        use crate::rerank::codecs::PqRerank;
330        let dim = 16;
331        let m = 4;
332        let k = 8;
333        let mut codec = PqRerank::new(dim, m, k);
334        let samples: Vec<Vec<f32>> = (0..32).map(|i| det_vec(i, dim)).collect();
335        let refs: Vec<&[f32]> = samples.iter().map(|v| v.as_slice()).collect();
336        codec.train(&refs).unwrap();
337
338        let mut s = CodecSidecar::new(Arc::new(codec));
339        for i in 0..5u32 {
340            s.encode_and_insert(i, &det_vec(i as usize, dim)).unwrap();
341        }
342        let bytes = s.to_bytes().expect("to_bytes");
343        let s2 = CodecSidecar::from_bytes(&bytes).expect("from_bytes");
344        assert_eq!(s2.codec_name(), CodecName::Pq);
345        for i in 0..5u32 {
346            assert_eq!(s.get(i), s2.get(i), "encoded bytes differ for id {i}");
347        }
348    }
349
350    #[test]
351    fn sidecar_roundtrip_rabitq() {
352        use crate::rerank::codecs::RaBitQRerank;
353        let dim = 16;
354        let mut codec = RaBitQRerank::new(dim, 0xDEADBEEF_C0FFEE42);
355        let samples: Vec<Vec<f32>> = (0..20).map(|i| det_vec(i, dim)).collect();
356        let refs: Vec<&[f32]> = samples.iter().map(|v| v.as_slice()).collect();
357        codec.train(&refs).unwrap();
358
359        let mut s = CodecSidecar::new(Arc::new(codec));
360        for i in 0..5u32 {
361            s.encode_and_insert(i, &det_vec(i as usize, dim)).unwrap();
362        }
363        let bytes = s.to_bytes().expect("to_bytes");
364        let s2 = CodecSidecar::from_bytes(&bytes).expect("from_bytes");
365        assert_eq!(s2.codec_name(), CodecName::RaBitQ);
366        for i in 0..5u32 {
367            assert_eq!(s.get(i), s2.get(i), "encoded bytes differ for id {i}");
368        }
369    }
370
371    #[test]
372    fn sidecar_roundtrip_bbq() {
373        use crate::rerank::codecs::BbqRerank;
374        let dim = 16;
375        let mut codec = BbqRerank::new(dim, 4);
376        let samples: Vec<Vec<f32>> = (0..20).map(|i| det_vec(i, dim)).collect();
377        let refs: Vec<&[f32]> = samples.iter().map(|v| v.as_slice()).collect();
378        codec.train(&refs).unwrap();
379
380        let mut s = CodecSidecar::new(Arc::new(codec));
381        for i in 0..5u32 {
382            s.encode_and_insert(i, &det_vec(i as usize, dim)).unwrap();
383        }
384        let bytes = s.to_bytes().expect("to_bytes");
385        let s2 = CodecSidecar::from_bytes(&bytes).expect("from_bytes");
386        assert_eq!(s2.codec_name(), CodecName::Bbq);
387        for i in 0..5u32 {
388            assert_eq!(s.get(i), s2.get(i), "encoded bytes differ for id {i}");
389        }
390    }
391
392    #[test]
393    fn sidecar_bad_magic_returns_error() {
394        use crate::rerank::codecs::BinaryRerank;
395        let s = CodecSidecar::new(Arc::new(BinaryRerank::new(4)));
396        let mut bytes = s.to_bytes().unwrap();
397        bytes[0] = b'X';
398        assert!(CodecSidecar::from_bytes(&bytes).is_err());
399    }
400
401    #[test]
402    fn sidecar_bad_version_returns_error() {
403        use crate::rerank::codecs::BinaryRerank;
404        let s = CodecSidecar::new(Arc::new(BinaryRerank::new(4)));
405        let mut bytes = s.to_bytes().unwrap();
406        bytes[4] = 99;
407        assert!(CodecSidecar::from_bytes(&bytes).is_err());
408    }
409
410    #[test]
411    fn sidecar_distance_matches_after_roundtrip() {
412        use crate::rerank::codecs::Sq8Rerank;
413        let dim = 16;
414        let mut codec = Sq8Rerank::new(dim);
415        let samples: Vec<Vec<f32>> = (0..20).map(|i| det_vec(i, dim)).collect();
416        let refs: Vec<&[f32]> = samples.iter().map(|v| v.as_slice()).collect();
417        codec.train(&refs).unwrap();
418
419        let mut s = CodecSidecar::new(Arc::new(codec));
420        s.encode_and_insert(1, &det_vec(3, dim)).unwrap();
421        let query_vec = det_vec(7, dim);
422        let prepared_orig = s.prepare_query(&query_vec).unwrap();
423        let d_orig = s.distance_prepared(&prepared_orig, 1).unwrap().unwrap();
424
425        let bytes = s.to_bytes().unwrap();
426        let s2 = CodecSidecar::from_bytes(&bytes).unwrap();
427        let prepared_rest = s2.prepare_query(&query_vec).unwrap();
428        let d_rest = s2.distance_prepared(&prepared_rest, 1).unwrap().unwrap();
429
430        assert!(
431            (d_orig - d_rest).abs() < 1e-5,
432            "distance diverged: {d_orig} vs {d_rest}"
433        );
434    }
435}