Skip to main content

nodedb_vector/rerank/
codec.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use std::sync::Arc;
4
5use super::types::RerankError;
6
7/// Identity tag for a rerank codec — used to detect mismatch when a search
8/// requests a different codec than the sidecar was built with.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum CodecName {
11    Sq8,
12    Pq,
13    Binary,
14    RaBitQ,
15    Bbq,
16}
17
18impl CodecName {
19    pub fn as_str(self) -> &'static str {
20        match self {
21            CodecName::Sq8 => "sq8",
22            CodecName::Pq => "pq",
23            CodecName::Binary => "binary",
24            CodecName::RaBitQ => "rabitq",
25            CodecName::Bbq => "bbq",
26        }
27    }
28}
29
30/// Prepared query form — opaque payload held by the caller between
31/// `prepare_query` and `distance_prepared` calls. New variants will be added
32/// as specific codec impls land in later sub-tasks.
33pub enum PreparedQuery {
34    /// Raw full-precision query, used by codecs whose prepared form is just
35    /// the input vector (e.g. RaBitQ rotation applied later, Binary).
36    Raw(Vec<f32>),
37    /// Per-subspace lookup table, used by ADC-style codecs (PQ, OPQ).
38    Lut(Vec<Vec<f32>>),
39    /// Codec-specific opaque bytes — for codecs that don't fit the above two
40    /// shapes (e.g. BBQ carries a centroid + alpha).
41    Bytes(Vec<u8>),
42}
43
44/// Object-safe trait for asymmetric rerank codecs. Each impl wraps an existing
45/// `nodedb-codec::VectorCodec` and exposes a uniform shape so the sidecar can
46/// hold `Arc<dyn RerankCodec>` regardless of the underlying associated-type
47/// machinery.
48pub trait RerankCodec: Send + Sync {
49    /// Encode a full-precision vector. Returns fixed-width bytes for this codec.
50    fn encode(&self, v: &[f32]) -> Result<Vec<u8>, RerankError>;
51
52    /// Prepare a query once before repeated distance calls.
53    fn prepare_query(&self, q: &[f32]) -> Result<PreparedQuery, RerankError>;
54
55    /// Compute asymmetric distance from a prepared query to an encoded vector.
56    fn distance_prepared(
57        &self,
58        prepared: &PreparedQuery,
59        encoded: &[u8],
60    ) -> Result<f32, RerankError>;
61
62    /// Identity tag for mismatch detection.
63    fn name(&self) -> CodecName;
64
65    /// Train from a sample of vectors. Default no-op for codecs that don't need
66    /// training (e.g. Binary). Specific codec impls override this when needed.
67    fn train(&mut self, _samples: &[&[f32]]) -> Result<(), RerankError> {
68        Ok(())
69    }
70
71    /// Serialize trained state to bytes. Each codec uses its own magic header
72    /// (NDSQ / NDBIN / NDPQ / NDRBQ / NDBBQ). The bytes are codec-specific;
73    /// `rerank_codec_from_bytes` is used for restore, paired with `name()`.
74    fn to_bytes(&self) -> Result<Vec<u8>, RerankError>;
75}
76
77/// Reconstruct a `RerankCodec` from its byte form. The `name` tag tells us
78/// which wrapper to dispatch into; the bytes are the codec's own format.
79pub fn rerank_codec_from_bytes(
80    name: CodecName,
81    bytes: &[u8],
82) -> Result<Arc<dyn RerankCodec>, RerankError> {
83    use crate::quantize::pq::PqCodec;
84    use crate::quantize::sq8::Sq8Codec;
85    use crate::rerank::codecs::{BbqRerank, BinaryRerank, PqRerank, RaBitQRerank, Sq8Rerank};
86    use nodedb_codec::vector_quant::bbq::BbqCodec;
87    use nodedb_codec::vector_quant::rabitq::RaBitQCodec;
88
89    match name {
90        CodecName::Sq8 => {
91            let inner = Sq8Codec::from_bytes(bytes)
92                .map_err(|e| RerankError::BadInput(format!("sq8 from_bytes: {e}")))?;
93            Ok(Arc::new(Sq8Rerank::from_codec(inner)))
94        }
95        CodecName::Binary => {
96            // Format: [NDBIN\0 (6 bytes)][version u8 = 1][dim u32 LE (4 bytes)] — 11 bytes total.
97            if bytes.len() < 11 {
98                return Err(RerankError::BadInput("binary from_bytes: too short".into()));
99            }
100            if &bytes[..6] != b"NDBIN\0" {
101                return Err(RerankError::BadInput("binary from_bytes: bad magic".into()));
102            }
103            // bytes[6] is version, bytes[7..11] is dim as u32 LE.
104            let dim = u32::from_le_bytes([bytes[7], bytes[8], bytes[9], bytes[10]]) as usize;
105            Ok(Arc::new(BinaryRerank::new(dim)))
106        }
107        CodecName::Pq => {
108            let inner = PqCodec::from_bytes(bytes)
109                .map_err(|e| RerankError::BadInput(format!("pq from_bytes: {e}")))?;
110            Ok(Arc::new(PqRerank::from_codec(inner)))
111        }
112        CodecName::RaBitQ => {
113            let inner = RaBitQCodec::from_bytes(bytes)
114                .map_err(|e| RerankError::BadInput(format!("rabitq from_bytes: {e}")))?;
115            Ok(Arc::new(RaBitQRerank::from_codec(inner)))
116        }
117        CodecName::Bbq => {
118            let inner = BbqCodec::from_bytes(bytes)
119                .map_err(|e| RerankError::BadInput(format!("bbq from_bytes: {e}")))?;
120            Ok(Arc::new(BbqRerank::from_codec(inner)))
121        }
122    }
123}