Skip to main content

nodedb_vector/rerank/codecs/
binary.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! `RerankCodec` wrapper for binary (sign-bit) quantization.
4//!
5//! Bridges `BinaryCodec` (which implements `VectorCodec` with associated types)
6//! into the object-safe `RerankCodec` trait used by the rerank sidecar.
7//!
8//! Binary has no learned state: `train()` is satisfied by the default no-op.
9
10use nodedb_codec::vector_quant::layout::UnifiedQuantizedVectorRef;
11
12use crate::{
13    quantize::binary_codec::BinaryCodec,
14    rerank::codec::{CodecName, PreparedQuery, RerankCodec},
15    rerank::types::RerankError,
16};
17
18// ── packed_bits_len helper ────────────────────────────────────────────────────
19
20/// Binary is 1 bpw: `ceil(dim / 8)` bytes.
21#[inline]
22fn binary_packed_bits_len(dim: usize) -> usize {
23    dim.div_ceil(8)
24}
25
26// ── BinaryRerank ──────────────────────────────────────────────────────────────
27
28/// Object-safe `RerankCodec` wrapper around `BinaryCodec`.
29///
30/// Binary has no learned parameters. All instances with the same `dim` are
31/// equivalent. `train()` is the default no-op.
32pub struct BinaryRerank {
33    codec: BinaryCodec,
34    dim: usize,
35}
36
37impl BinaryRerank {
38    /// Create a binary rerank codec for vectors of length `dim`.
39    pub fn new(dim: usize) -> Self {
40        Self {
41            codec: BinaryCodec { dim },
42            dim,
43        }
44    }
45}
46
47impl RerankCodec for BinaryRerank {
48    /// Encode a full-precision vector to binary sign bits.
49    ///
50    /// The serialized form is the raw `UnifiedQuantizedVector` buffer
51    /// (`as_bytes()`): 32-byte `QuantHeader` followed by `ceil(dim/8)` bytes
52    /// of packed sign bits.
53    fn encode(&self, v: &[f32]) -> Result<Vec<u8>, RerankError> {
54        if v.len() != self.dim {
55            return Err(RerankError::BadInput(format!(
56                "binary encode: vector len {} != codec dim {}",
57                v.len(),
58                self.dim
59            )));
60        }
61        use nodedb_codec::vector_quant::codec::VectorCodec as _;
62        let quantized = self.codec.encode(v);
63        Ok(quantized.as_ref().as_bytes().to_vec())
64    }
65
66    /// Prepare the query for repeated distance calls.
67    ///
68    /// Binary encodes both the query and candidates to sign bits and computes
69    /// Hamming distance. The prepared form is `PreparedQuery::Bytes` holding
70    /// the packed query bits.
71    fn prepare_query(&self, q: &[f32]) -> Result<PreparedQuery, RerankError> {
72        if q.len() != self.dim {
73            return Err(RerankError::BadInput(format!(
74                "binary prepare_query: query len {} != codec dim {}",
75                q.len(),
76                self.dim
77            )));
78        }
79        use nodedb_codec::vector_quant::codec::VectorCodec as _;
80        let query_bits = self.codec.prepare_query(q);
81        Ok(PreparedQuery::Bytes(query_bits))
82    }
83
84    /// Compute Hamming distance from a prepared query to a binary-encoded
85    /// candidate.
86    fn distance_prepared(
87        &self,
88        prepared: &PreparedQuery,
89        encoded: &[u8],
90    ) -> Result<f32, RerankError> {
91        let q_bits = match prepared {
92            PreparedQuery::Bytes(b) => b,
93            _ => {
94                return Err(RerankError::BadInput(
95                    "binary distance: expected PreparedQuery::Bytes".to_string(),
96                ));
97            }
98        };
99
100        let packed_len = binary_packed_bits_len(self.dim);
101        let uqv_ref = UnifiedQuantizedVectorRef::from_bytes(encoded, packed_len).map_err(|e| {
102            RerankError::BadInput(format!(
103                "binary distance: failed to parse encoded bytes: {e}"
104            ))
105        })?;
106
107        let packed = uqv_ref.packed_bits();
108        // Compute Hamming distance directly via the public helper.
109        let dist = crate::quantize::binary::hamming_distance(q_bits, packed) as f32;
110        Ok(dist)
111    }
112
113    fn name(&self) -> CodecName {
114        CodecName::Binary
115    }
116
117    /// Serialize binary codec state.
118    ///
119    /// Format: `[NDBIN\0 (6 bytes)][version: u8 = 1][dim: u32 LE (4 bytes)]` — 11 bytes total.
120    /// `dim` is stored so `rerank_codec_from_bytes` can reconstruct a stateless `BinaryRerank`.
121    fn to_bytes(&self) -> Result<Vec<u8>, RerankError> {
122        let mut buf = Vec::with_capacity(11);
123        buf.extend_from_slice(b"NDBIN\0");
124        buf.push(1u8); // version
125        buf.extend_from_slice(&(self.dim as u32).to_le_bytes());
126        Ok(buf)
127    }
128
129    // train() is the default no-op — Binary has no learned state.
130}
131
132// ── Tests ─────────────────────────────────────────────────────────────────────
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    const DIM: usize = 16;
139
140    fn all_pos() -> Vec<f32> {
141        vec![1.0f32; DIM]
142    }
143
144    fn all_neg() -> Vec<f32> {
145        vec![-1.0f32; DIM]
146    }
147
148    #[test]
149    fn round_trip_returns_finite_distance() {
150        let codec = BinaryRerank::new(DIM);
151        let v1 = all_pos();
152        let v2 = all_neg();
153
154        let enc = codec.encode(&v1).expect("encode v1");
155        let prepared = codec.prepare_query(&v2).expect("prepare_query v2");
156        let dist = codec
157            .distance_prepared(&prepared, &enc)
158            .expect("distance_prepared");
159        assert!(dist.is_finite(), "expected finite distance, got {dist}");
160        assert!(dist >= 0.0, "expected non-negative distance, got {dist}");
161    }
162
163    #[test]
164    fn opposite_vectors_have_max_distance() {
165        let codec = BinaryRerank::new(DIM);
166        let pos = all_pos();
167        let neg = all_neg();
168
169        let enc = codec.encode(&pos).expect("encode pos");
170        let prepared = codec.prepare_query(&neg).expect("prepare_query neg");
171        let dist = codec
172            .distance_prepared(&prepared, &enc)
173            .expect("distance_prepared");
174        assert!(
175            (dist - DIM as f32).abs() < f32::EPSILON,
176            "opposite vectors should have Hamming distance == dim ({DIM}), got {dist}"
177        );
178    }
179
180    #[test]
181    fn identical_vectors_zero_distance() {
182        let codec = BinaryRerank::new(DIM);
183        let v = all_pos();
184
185        let enc = codec.encode(&v).expect("encode");
186        let prepared = codec.prepare_query(&v).expect("prepare_query");
187        let dist = codec
188            .distance_prepared(&prepared, &enc)
189            .expect("distance_prepared");
190        assert!(
191            dist < f32::EPSILON,
192            "identical vectors must have zero Hamming distance, got {dist}"
193        );
194    }
195
196    #[test]
197    fn wrong_prepared_query_variant_returns_bad_input() {
198        let codec = BinaryRerank::new(DIM);
199        let v = all_pos();
200        let enc = codec.encode(&v).expect("encode");
201        let bad_prepared = PreparedQuery::Raw(vec![0.0f32; DIM]);
202
203        let result = codec.distance_prepared(&bad_prepared, &enc);
204        assert!(result.is_err(), "expected BadInput error");
205        let msg = format!("{}", result.unwrap_err());
206        assert!(
207            msg.contains("Bytes"),
208            "error message should mention Bytes, got: {msg}"
209        );
210    }
211
212    #[test]
213    fn name_returns_binary() {
214        let codec = BinaryRerank::new(DIM);
215        assert_eq!(codec.name(), CodecName::Binary);
216    }
217
218    #[test]
219    fn wrong_dim_encode_returns_error() {
220        let codec = BinaryRerank::new(DIM);
221        let bad = vec![0.0f32; DIM + 1];
222        assert!(codec.encode(&bad).is_err());
223    }
224}