1use std::collections::HashMap;
4use std::sync::Arc;
5
6use super::codec::{CodecName, PreparedQuery, RerankCodec};
7use super::types::RerankError;
8
9const SIDECAR_MAGIC: [u8; 4] = *b"NDCC";
10const SIDECAR_VERSION: u8 = 1;
11
12pub struct CodecSidecar {
17 codec: Arc<dyn RerankCodec>,
18 encoded: HashMap<u32, Vec<u8>>,
19}
20
21impl CodecSidecar {
22 pub fn new(codec: Arc<dyn RerankCodec>) -> Self {
23 Self {
24 codec,
25 encoded: HashMap::new(),
26 }
27 }
28
29 pub fn codec_name(&self) -> CodecName {
30 self.codec.name()
31 }
32
33 pub fn encode_and_insert(&mut self, id: u32, vector: &[f32]) -> Result<(), RerankError> {
35 let bytes = self.codec.encode(vector)?;
36 self.encoded.insert(id, bytes);
37 Ok(())
38 }
39
40 pub fn remove(&mut self, id: u32) {
41 self.encoded.remove(&id);
42 }
43
44 pub fn get(&self, id: u32) -> Option<&[u8]> {
45 self.encoded.get(&id).map(|v| v.as_slice())
46 }
47
48 pub fn len(&self) -> usize {
49 self.encoded.len()
50 }
51
52 pub fn is_empty(&self) -> bool {
53 self.encoded.is_empty()
54 }
55
56 pub fn to_bytes(&self) -> Result<Vec<u8>, RerankError> {
60 let codec_bytes = self.codec.to_bytes()?;
61 let codec_name_byte = codec_name_to_u8(self.codec.name());
62
63 #[derive(
64 serde::Serialize, serde::Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
65 )]
66 struct Payload {
67 codec_name: u8,
68 codec_bytes: Vec<u8>,
69 encoded: Vec<(u32, Vec<u8>)>,
70 }
71
72 let payload = Payload {
73 codec_name: codec_name_byte,
74 codec_bytes,
75 encoded: self.encoded.iter().map(|(k, v)| (*k, v.clone())).collect(),
76 };
77
78 let body = zerompk::to_msgpack_vec(&payload)
79 .map_err(|e| RerankError::BadInput(format!("sidecar serialize: {e}")))?;
80 let mut buf = Vec::with_capacity(5 + body.len());
81 buf.extend_from_slice(&SIDECAR_MAGIC);
82 buf.push(SIDECAR_VERSION);
83 buf.extend_from_slice(&body);
84 Ok(buf)
85 }
86
87 pub fn from_bytes(bytes: &[u8]) -> Result<Self, RerankError> {
89 if bytes.len() < 5 {
90 return Err(RerankError::BadInput(
91 "sidecar from_bytes: too short".into(),
92 ));
93 }
94 if bytes[..4] != SIDECAR_MAGIC {
95 return Err(RerankError::BadInput(
96 "sidecar from_bytes: bad magic".into(),
97 ));
98 }
99 let version = bytes[4];
100 if version != SIDECAR_VERSION {
101 return Err(RerankError::BadInput(format!(
102 "sidecar from_bytes: unknown version {version}"
103 )));
104 }
105
106 #[derive(
107 serde::Serialize, serde::Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
108 )]
109 struct Payload {
110 codec_name: u8,
111 codec_bytes: Vec<u8>,
112 encoded: Vec<(u32, Vec<u8>)>,
113 }
114
115 let payload: Payload = zerompk::from_msgpack(&bytes[5..])
116 .map_err(|e| RerankError::BadInput(format!("sidecar deserialize: {e}")))?;
117
118 let codec_name = codec_name_from_u8(payload.codec_name).ok_or_else(|| {
119 RerankError::BadInput(format!(
120 "sidecar from_bytes: unknown codec_name byte {}",
121 payload.codec_name
122 ))
123 })?;
124 let codec = super::codec::rerank_codec_from_bytes(codec_name, &payload.codec_bytes)?;
125 let encoded = payload.encoded.into_iter().collect();
126 Ok(CodecSidecar { codec, encoded })
127 }
128
129 pub fn prepare_query(&self, query: &[f32]) -> Result<PreparedQuery, RerankError> {
130 self.codec.prepare_query(query)
131 }
132
133 pub fn distance_prepared(
137 &self,
138 prepared: &PreparedQuery,
139 id: u32,
140 ) -> Result<Option<f32>, RerankError> {
141 match self.encoded.get(&id) {
142 None => Ok(None),
143 Some(bytes) => self.codec.distance_prepared(prepared, bytes).map(Some),
144 }
145 }
146}
147
148fn codec_name_to_u8(name: CodecName) -> u8 {
149 match name {
150 CodecName::Sq8 => 0,
151 CodecName::Pq => 1,
152 CodecName::Binary => 2,
153 CodecName::RaBitQ => 3,
154 CodecName::Bbq => 4,
155 }
156}
157
158fn codec_name_from_u8(b: u8) -> Option<CodecName> {
159 match b {
160 0 => Some(CodecName::Sq8),
161 1 => Some(CodecName::Pq),
162 2 => Some(CodecName::Binary),
163 3 => Some(CodecName::RaBitQ),
164 4 => Some(CodecName::Bbq),
165 _ => None,
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use crate::rerank::codec::CodecName;
173
174 struct StubCodec;
175
176 impl RerankCodec for StubCodec {
177 fn encode(&self, v: &[f32]) -> Result<Vec<u8>, RerankError> {
178 Ok(v.iter().flat_map(|x| x.to_le_bytes()).collect())
179 }
180
181 fn prepare_query(&self, q: &[f32]) -> Result<PreparedQuery, RerankError> {
182 Ok(PreparedQuery::Raw(q.to_vec()))
183 }
184
185 fn distance_prepared(
186 &self,
187 prepared: &PreparedQuery,
188 encoded: &[u8],
189 ) -> Result<f32, RerankError> {
190 let query = match prepared {
191 PreparedQuery::Raw(v) => v,
192 _ => {
193 return Err(RerankError::BadInput(
194 "StubCodec expects Raw prepared query".into(),
195 ));
196 }
197 };
198 let encoded_floats: Vec<f32> = encoded
199 .chunks_exact(4)
200 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
201 .collect();
202 if query.len() != encoded_floats.len() {
203 return Err(RerankError::BadInput("dimension mismatch".into()));
204 }
205 let dist = query
206 .iter()
207 .zip(encoded_floats.iter())
208 .map(|(a, b)| (a - b) * (a - b))
209 .sum::<f32>()
210 .sqrt();
211 Ok(dist)
212 }
213
214 fn name(&self) -> CodecName {
215 CodecName::Binary
216 }
217
218 fn to_bytes(&self) -> Result<Vec<u8>, RerankError> {
219 Err(RerankError::BadInput(
220 "StubCodec does not support serialization".into(),
221 ))
222 }
223 }
224
225 fn make_sidecar() -> CodecSidecar {
226 CodecSidecar::new(Arc::new(StubCodec))
227 }
228
229 #[test]
230 fn insert_and_get() {
231 let mut s = make_sidecar();
232 assert!(s.is_empty());
233 s.encode_and_insert(1, &[1.0, 2.0]).unwrap();
234 s.encode_and_insert(2, &[3.0, 4.0]).unwrap();
235 s.encode_and_insert(3, &[5.0, 6.0]).unwrap();
236 assert_eq!(s.len(), 3);
237
238 let expected_1: Vec<u8> = [1.0f32, 2.0f32]
239 .iter()
240 .flat_map(|x| x.to_le_bytes())
241 .collect();
242 assert_eq!(s.get(1), Some(expected_1.as_slice()));
243 }
244
245 #[test]
246 fn remove_returns_none() {
247 let mut s = make_sidecar();
248 s.encode_and_insert(10, &[1.0]).unwrap();
249 s.remove(10);
250 assert_eq!(s.get(10), None);
251 let prepared = s.prepare_query(&[1.0]).unwrap();
252 assert_eq!(s.distance_prepared(&prepared, 10).unwrap(), None);
253 }
254
255 #[test]
256 fn distance_prepared_correct() {
257 let mut s = make_sidecar();
258 s.encode_and_insert(5, &[0.0, 0.0]).unwrap();
259 let prepared = s.prepare_query(&[3.0, 4.0]).unwrap();
260 let dist = s.distance_prepared(&prepared, 5).unwrap().unwrap();
261 assert!((dist - 5.0).abs() < 1e-5, "expected L2=5.0, got {dist}");
262 }
263
264 #[test]
265 fn codec_name_passthrough() {
266 let s = make_sidecar();
267 assert_eq!(s.codec_name(), CodecName::Binary);
268 assert_eq!(s.codec_name().as_str(), "binary");
269 }
270
271 #[test]
272 fn len_and_is_empty() {
273 let mut s = make_sidecar();
274 assert!(s.is_empty());
275 s.encode_and_insert(1, &[1.0]).unwrap();
276 assert!(!s.is_empty());
277 assert_eq!(s.len(), 1);
278 s.remove(1);
279 assert!(s.is_empty());
280 }
281
282 fn det_vec(i: usize, dim: usize) -> Vec<f32> {
285 (0..dim)
286 .map(|j| ((i * 31 + j) % 100) as f32 / 100.0)
287 .collect()
288 }
289
290 #[test]
291 fn sidecar_roundtrip_sq8() {
292 use crate::rerank::codecs::Sq8Rerank;
293 let dim = 16;
294 let mut codec = Sq8Rerank::new(dim);
295 let samples: Vec<Vec<f32>> = (0..20).map(|i| det_vec(i, dim)).collect();
296 let refs: Vec<&[f32]> = samples.iter().map(|v| v.as_slice()).collect();
297 codec.train(&refs).unwrap();
298
299 let mut s = CodecSidecar::new(Arc::new(codec));
300 for i in 0..5u32 {
301 s.encode_and_insert(i, &det_vec(i as usize, dim)).unwrap();
302 }
303 let bytes = s.to_bytes().expect("to_bytes");
304 let s2 = CodecSidecar::from_bytes(&bytes).expect("from_bytes");
305 assert_eq!(s2.codec_name(), CodecName::Sq8);
306 for i in 0..5u32 {
307 assert_eq!(s.get(i), s2.get(i), "encoded bytes differ for id {i}");
308 }
309 }
310
311 #[test]
312 fn sidecar_roundtrip_binary() {
313 use crate::rerank::codecs::BinaryRerank;
314 let dim = 16;
315 let mut s = CodecSidecar::new(Arc::new(BinaryRerank::new(dim)));
316 for i in 0..5u32 {
317 s.encode_and_insert(i, &det_vec(i as usize, dim)).unwrap();
318 }
319 let bytes = s.to_bytes().expect("to_bytes");
320 let s2 = CodecSidecar::from_bytes(&bytes).expect("from_bytes");
321 assert_eq!(s2.codec_name(), CodecName::Binary);
322 for i in 0..5u32 {
323 assert_eq!(s.get(i), s2.get(i), "encoded bytes differ for id {i}");
324 }
325 }
326
327 #[test]
328 fn sidecar_roundtrip_pq() {
329 use crate::rerank::codecs::PqRerank;
330 let dim = 16;
331 let m = 4;
332 let k = 8;
333 let mut codec = PqRerank::new(dim, m, k);
334 let samples: Vec<Vec<f32>> = (0..32).map(|i| det_vec(i, dim)).collect();
335 let refs: Vec<&[f32]> = samples.iter().map(|v| v.as_slice()).collect();
336 codec.train(&refs).unwrap();
337
338 let mut s = CodecSidecar::new(Arc::new(codec));
339 for i in 0..5u32 {
340 s.encode_and_insert(i, &det_vec(i as usize, dim)).unwrap();
341 }
342 let bytes = s.to_bytes().expect("to_bytes");
343 let s2 = CodecSidecar::from_bytes(&bytes).expect("from_bytes");
344 assert_eq!(s2.codec_name(), CodecName::Pq);
345 for i in 0..5u32 {
346 assert_eq!(s.get(i), s2.get(i), "encoded bytes differ for id {i}");
347 }
348 }
349
350 #[test]
351 fn sidecar_roundtrip_rabitq() {
352 use crate::rerank::codecs::RaBitQRerank;
353 let dim = 16;
354 let mut codec = RaBitQRerank::new(dim, 0xDEADBEEF_C0FFEE42);
355 let samples: Vec<Vec<f32>> = (0..20).map(|i| det_vec(i, dim)).collect();
356 let refs: Vec<&[f32]> = samples.iter().map(|v| v.as_slice()).collect();
357 codec.train(&refs).unwrap();
358
359 let mut s = CodecSidecar::new(Arc::new(codec));
360 for i in 0..5u32 {
361 s.encode_and_insert(i, &det_vec(i as usize, dim)).unwrap();
362 }
363 let bytes = s.to_bytes().expect("to_bytes");
364 let s2 = CodecSidecar::from_bytes(&bytes).expect("from_bytes");
365 assert_eq!(s2.codec_name(), CodecName::RaBitQ);
366 for i in 0..5u32 {
367 assert_eq!(s.get(i), s2.get(i), "encoded bytes differ for id {i}");
368 }
369 }
370
371 #[test]
372 fn sidecar_roundtrip_bbq() {
373 use crate::rerank::codecs::BbqRerank;
374 let dim = 16;
375 let mut codec = BbqRerank::new(dim, 4);
376 let samples: Vec<Vec<f32>> = (0..20).map(|i| det_vec(i, dim)).collect();
377 let refs: Vec<&[f32]> = samples.iter().map(|v| v.as_slice()).collect();
378 codec.train(&refs).unwrap();
379
380 let mut s = CodecSidecar::new(Arc::new(codec));
381 for i in 0..5u32 {
382 s.encode_and_insert(i, &det_vec(i as usize, dim)).unwrap();
383 }
384 let bytes = s.to_bytes().expect("to_bytes");
385 let s2 = CodecSidecar::from_bytes(&bytes).expect("from_bytes");
386 assert_eq!(s2.codec_name(), CodecName::Bbq);
387 for i in 0..5u32 {
388 assert_eq!(s.get(i), s2.get(i), "encoded bytes differ for id {i}");
389 }
390 }
391
392 #[test]
393 fn sidecar_bad_magic_returns_error() {
394 use crate::rerank::codecs::BinaryRerank;
395 let s = CodecSidecar::new(Arc::new(BinaryRerank::new(4)));
396 let mut bytes = s.to_bytes().unwrap();
397 bytes[0] = b'X';
398 assert!(CodecSidecar::from_bytes(&bytes).is_err());
399 }
400
401 #[test]
402 fn sidecar_bad_version_returns_error() {
403 use crate::rerank::codecs::BinaryRerank;
404 let s = CodecSidecar::new(Arc::new(BinaryRerank::new(4)));
405 let mut bytes = s.to_bytes().unwrap();
406 bytes[4] = 99;
407 assert!(CodecSidecar::from_bytes(&bytes).is_err());
408 }
409
410 #[test]
411 fn sidecar_distance_matches_after_roundtrip() {
412 use crate::rerank::codecs::Sq8Rerank;
413 let dim = 16;
414 let mut codec = Sq8Rerank::new(dim);
415 let samples: Vec<Vec<f32>> = (0..20).map(|i| det_vec(i, dim)).collect();
416 let refs: Vec<&[f32]> = samples.iter().map(|v| v.as_slice()).collect();
417 codec.train(&refs).unwrap();
418
419 let mut s = CodecSidecar::new(Arc::new(codec));
420 s.encode_and_insert(1, &det_vec(3, dim)).unwrap();
421 let query_vec = det_vec(7, dim);
422 let prepared_orig = s.prepare_query(&query_vec).unwrap();
423 let d_orig = s.distance_prepared(&prepared_orig, 1).unwrap().unwrap();
424
425 let bytes = s.to_bytes().unwrap();
426 let s2 = CodecSidecar::from_bytes(&bytes).unwrap();
427 let prepared_rest = s2.prepare_query(&query_vec).unwrap();
428 let d_rest = s2.distance_prepared(&prepared_rest, 1).unwrap().unwrap();
429
430 assert!(
431 (d_orig - d_rest).abs() < 1e-5,
432 "distance diverged: {d_orig} vs {d_rest}"
433 );
434 }
435}