nodedb_vector/rerank/codecs/
binary.rs1use nodedb_codec::vector_quant::layout::UnifiedQuantizedVectorRef;
11
12use crate::{
13 quantize::binary_codec::BinaryCodec,
14 rerank::codec::{CodecName, PreparedQuery, RerankCodec},
15 rerank::types::RerankError,
16};
17
18#[inline]
22fn binary_packed_bits_len(dim: usize) -> usize {
23 dim.div_ceil(8)
24}
25
26pub struct BinaryRerank {
33 codec: BinaryCodec,
34 dim: usize,
35}
36
37impl BinaryRerank {
38 pub fn new(dim: usize) -> Self {
40 Self {
41 codec: BinaryCodec { dim },
42 dim,
43 }
44 }
45}
46
47impl RerankCodec for BinaryRerank {
48 fn encode(&self, v: &[f32]) -> Result<Vec<u8>, RerankError> {
54 if v.len() != self.dim {
55 return Err(RerankError::BadInput(format!(
56 "binary encode: vector len {} != codec dim {}",
57 v.len(),
58 self.dim
59 )));
60 }
61 use nodedb_codec::vector_quant::codec::VectorCodec as _;
62 let quantized = self.codec.encode(v);
63 Ok(quantized.as_ref().as_bytes().to_vec())
64 }
65
66 fn prepare_query(&self, q: &[f32]) -> Result<PreparedQuery, RerankError> {
72 if q.len() != self.dim {
73 return Err(RerankError::BadInput(format!(
74 "binary prepare_query: query len {} != codec dim {}",
75 q.len(),
76 self.dim
77 )));
78 }
79 use nodedb_codec::vector_quant::codec::VectorCodec as _;
80 let query_bits = self.codec.prepare_query(q);
81 Ok(PreparedQuery::Bytes(query_bits))
82 }
83
84 fn distance_prepared(
87 &self,
88 prepared: &PreparedQuery,
89 encoded: &[u8],
90 ) -> Result<f32, RerankError> {
91 let q_bits = match prepared {
92 PreparedQuery::Bytes(b) => b,
93 _ => {
94 return Err(RerankError::BadInput(
95 "binary distance: expected PreparedQuery::Bytes".to_string(),
96 ));
97 }
98 };
99
100 let packed_len = binary_packed_bits_len(self.dim);
101 let uqv_ref = UnifiedQuantizedVectorRef::from_bytes(encoded, packed_len).map_err(|e| {
102 RerankError::BadInput(format!(
103 "binary distance: failed to parse encoded bytes: {e}"
104 ))
105 })?;
106
107 let packed = uqv_ref.packed_bits();
108 let dist = crate::quantize::binary::hamming_distance(q_bits, packed) as f32;
110 Ok(dist)
111 }
112
113 fn name(&self) -> CodecName {
114 CodecName::Binary
115 }
116
117 fn to_bytes(&self) -> Result<Vec<u8>, RerankError> {
122 let mut buf = Vec::with_capacity(11);
123 buf.extend_from_slice(b"NDBIN\0");
124 buf.push(1u8); buf.extend_from_slice(&(self.dim as u32).to_le_bytes());
126 Ok(buf)
127 }
128
129 }
131
132#[cfg(test)]
135mod tests {
136 use super::*;
137
138 const DIM: usize = 16;
139
140 fn all_pos() -> Vec<f32> {
141 vec![1.0f32; DIM]
142 }
143
144 fn all_neg() -> Vec<f32> {
145 vec![-1.0f32; DIM]
146 }
147
148 #[test]
149 fn round_trip_returns_finite_distance() {
150 let codec = BinaryRerank::new(DIM);
151 let v1 = all_pos();
152 let v2 = all_neg();
153
154 let enc = codec.encode(&v1).expect("encode v1");
155 let prepared = codec.prepare_query(&v2).expect("prepare_query v2");
156 let dist = codec
157 .distance_prepared(&prepared, &enc)
158 .expect("distance_prepared");
159 assert!(dist.is_finite(), "expected finite distance, got {dist}");
160 assert!(dist >= 0.0, "expected non-negative distance, got {dist}");
161 }
162
163 #[test]
164 fn opposite_vectors_have_max_distance() {
165 let codec = BinaryRerank::new(DIM);
166 let pos = all_pos();
167 let neg = all_neg();
168
169 let enc = codec.encode(&pos).expect("encode pos");
170 let prepared = codec.prepare_query(&neg).expect("prepare_query neg");
171 let dist = codec
172 .distance_prepared(&prepared, &enc)
173 .expect("distance_prepared");
174 assert!(
175 (dist - DIM as f32).abs() < f32::EPSILON,
176 "opposite vectors should have Hamming distance == dim ({DIM}), got {dist}"
177 );
178 }
179
180 #[test]
181 fn identical_vectors_zero_distance() {
182 let codec = BinaryRerank::new(DIM);
183 let v = all_pos();
184
185 let enc = codec.encode(&v).expect("encode");
186 let prepared = codec.prepare_query(&v).expect("prepare_query");
187 let dist = codec
188 .distance_prepared(&prepared, &enc)
189 .expect("distance_prepared");
190 assert!(
191 dist < f32::EPSILON,
192 "identical vectors must have zero Hamming distance, got {dist}"
193 );
194 }
195
196 #[test]
197 fn wrong_prepared_query_variant_returns_bad_input() {
198 let codec = BinaryRerank::new(DIM);
199 let v = all_pos();
200 let enc = codec.encode(&v).expect("encode");
201 let bad_prepared = PreparedQuery::Raw(vec![0.0f32; DIM]);
202
203 let result = codec.distance_prepared(&bad_prepared, &enc);
204 assert!(result.is_err(), "expected BadInput error");
205 let msg = format!("{}", result.unwrap_err());
206 assert!(
207 msg.contains("Bytes"),
208 "error message should mention Bytes, got: {msg}"
209 );
210 }
211
212 #[test]
213 fn name_returns_binary() {
214 let codec = BinaryRerank::new(DIM);
215 assert_eq!(codec.name(), CodecName::Binary);
216 }
217
218 #[test]
219 fn wrong_dim_encode_returns_error() {
220 let codec = BinaryRerank::new(DIM);
221 let bad = vec![0.0f32; DIM + 1];
222 assert!(codec.encode(&bad).is_err());
223 }
224}