1use error_forge::ForgeError;
21use iqdb_types::{DistanceMetric, IqdbError, Result};
22
23use crate::code::BqCode;
24use crate::traits::Quantizer;
25use crate::validate::{dim_eq, finite_non_empty, training_set};
26
27const BITS_PER_WORD: usize = u64::BITS as usize;
28
29#[derive(Debug, Clone, PartialEq)]
31struct BqCalibration {
32 means: Vec<f32>,
34}
35
36#[derive(Debug, Clone, Default, PartialEq)]
61pub struct BinaryQuantizer {
62 calibration: Option<BqCalibration>,
63}
64
65impl BinaryQuantizer {
66 #[must_use]
78 pub fn new() -> Self {
79 Self { calibration: None }
80 }
81
82 #[must_use]
95 pub fn dim(&self) -> Option<usize> {
96 self.calibration.as_ref().map(|c| c.means.len())
97 }
98
99 fn calibration(&self) -> Result<&BqCalibration> {
100 self.calibration.as_ref().ok_or(IqdbError::InvalidConfig {
101 reason: "BinaryQuantizer has not been trained",
102 })
103 }
104}
105
106impl Quantizer for BinaryQuantizer {
107 type Quantized = BqCode;
108
109 #[tracing::instrument(
110 level = "info",
111 skip_all,
112 fields(quantizer = "bq", training_size = vectors.len()),
113 )]
114 fn train(&mut self, vectors: &[&[f32]]) -> Result<()> {
115 let dim = training_set(vectors).inspect_err(|err: &IqdbError| {
116 tracing::error!(
117 error.kind = err.kind(),
118 error.reason = err.caption(),
119 "binary quantizer training failed",
120 );
121 })?;
122 let mut sums = vec![0.0_f64; dim];
123 for v in vectors {
124 for (i, &x) in v.iter().enumerate() {
125 sums[i] += f64::from(x);
126 }
127 }
128 let n = vectors.len() as f64;
129 let means: Vec<f32> = sums.iter().map(|s| (s / n) as f32).collect();
130 if means.iter().any(|m| !m.is_finite()) {
135 let err = IqdbError::InvalidVector;
136 tracing::error!(
137 error.kind = err.kind(),
138 error.reason = err.caption(),
139 "binary quantizer training failed: non-finite mean",
140 );
141 return Err(err);
142 }
143 self.calibration = Some(BqCalibration { means });
144 Ok(())
145 }
146
147 fn quantize(&self, vector: &[f32]) -> Result<Self::Quantized> {
148 let cal = self.calibration()?;
149 finite_non_empty(vector)?;
150 dim_eq(cal.means.len(), vector.len())?;
151 Ok(BqCode {
152 words: pack_bits(vector, &cal.means),
153 dim: vector.len(),
154 })
155 }
156
157 fn dequantize(&self, quantized: &Self::Quantized) -> Result<Vec<f32>> {
158 let cal = self.calibration()?;
159 dim_eq(cal.means.len(), quantized.dim)?;
160 let mut out = Vec::with_capacity(quantized.dim);
161 for i in 0..quantized.dim {
162 let word = quantized.words[i / BITS_PER_WORD];
163 let bit = (word >> (i % BITS_PER_WORD)) & 1;
164 out.push(if bit == 1 { 1.0_f32 } else { -1.0_f32 });
165 }
166 Ok(out)
167 }
168
169 fn distance(
170 &self,
171 query: &[f32],
172 quantized: &Self::Quantized,
173 metric: DistanceMetric,
174 ) -> Result<f32> {
175 let cal = self.calibration()?;
176 finite_non_empty(query)?;
177 dim_eq(cal.means.len(), query.len())?;
178 dim_eq(cal.means.len(), quantized.dim)?;
179 if metric != DistanceMetric::Hamming {
180 return Err(IqdbError::InvalidMetric);
181 }
182 let query_words = pack_bits(query, &cal.means);
185 let mut diff: u32 = 0;
186 for (q, c) in query_words.iter().zip(quantized.words.iter()) {
187 diff = diff.saturating_add((q ^ c).count_ones());
188 }
189 Ok(diff as f32)
190 }
191}
192
193fn pack_bits(vector: &[f32], means: &[f32]) -> Vec<u64> {
197 let dim = vector.len();
198 let words = dim.div_ceil(BITS_PER_WORD);
199 let mut out = vec![0_u64; words];
200 for i in 0..dim {
201 if vector[i] >= means[i] {
202 out[i / BITS_PER_WORD] |= 1_u64 << (i % BITS_PER_WORD);
203 }
204 }
205 out
206}
207
208#[cfg(test)]
209mod tests {
210 #![allow(clippy::unwrap_used)]
211
212 use super::*;
213 use iqdb_types::{DistanceMetric, IqdbError};
214
215 fn trained_unit() -> BinaryQuantizer {
216 let mut bq = BinaryQuantizer::new();
217 bq.train(&[&[0.0_f32, 1.0, 2.0][..], &[2.0_f32, 1.0, 0.0][..]])
218 .unwrap();
219 bq
220 }
221
222 #[test]
223 fn quantize_before_train_returns_invalid_config() {
224 let bq = BinaryQuantizer::new();
225 let err = bq.quantize(&[0.5_f32, 0.5]).unwrap_err();
226 assert!(
227 matches!(err, IqdbError::InvalidConfig { .. }),
228 "expected InvalidConfig, got {err:?}",
229 );
230 }
231
232 #[test]
233 fn distance_before_train_returns_invalid_config() {
234 let bq = BinaryQuantizer::new();
235 let code = BqCode {
236 words: vec![0],
237 dim: 3,
238 };
239 let err = bq
240 .distance(&[0.0_f32, 0.0, 0.0], &code, DistanceMetric::Hamming)
241 .unwrap_err();
242 assert!(
243 matches!(err, IqdbError::InvalidConfig { .. }),
244 "expected InvalidConfig, got {err:?}",
245 );
246 }
247
248 #[test]
249 fn dequantize_before_train_returns_invalid_config() {
250 let bq = BinaryQuantizer::new();
251 let code = BqCode {
252 words: vec![0],
253 dim: 3,
254 };
255 let err = bq.dequantize(&code).unwrap_err();
256 assert!(
257 matches!(err, IqdbError::InvalidConfig { .. }),
258 "expected InvalidConfig, got {err:?}",
259 );
260 }
261
262 #[test]
263 fn train_empty_set_returns_invalid_config() {
264 let mut bq = BinaryQuantizer::new();
265 let empty: [&[f32]; 0] = [];
266 let err = bq.train(&empty).unwrap_err();
267 assert!(
268 matches!(err, IqdbError::InvalidConfig { .. }),
269 "expected InvalidConfig, got {err:?}",
270 );
271 }
272
273 #[test]
274 fn train_inconsistent_dim_returns_dimension_mismatch() {
275 let mut bq = BinaryQuantizer::new();
276 let a = [0.0_f32, 1.0, 2.0];
277 let b = [1.0_f32, 0.0];
278 let err = bq.train(&[&a[..], &b[..]]).unwrap_err();
279 assert_eq!(
280 err,
281 IqdbError::DimensionMismatch {
282 expected: 3,
283 found: 2,
284 },
285 );
286 }
287
288 #[test]
289 fn train_non_finite_returns_invalid_vector() {
290 let mut bq = BinaryQuantizer::new();
291 let v = [1.0_f32, f32::NAN];
292 assert_eq!(bq.train(&[&v[..]]).unwrap_err(), IqdbError::InvalidVector,);
293 }
294
295 #[test]
296 fn quantize_dim_mismatch_returns_dimension_mismatch() {
297 let bq = trained_unit();
298 let err = bq.quantize(&[0.5_f32, 0.5]).unwrap_err();
299 assert_eq!(
300 err,
301 IqdbError::DimensionMismatch {
302 expected: 3,
303 found: 2,
304 },
305 );
306 }
307
308 #[test]
309 fn quantize_non_finite_returns_invalid_vector() {
310 let bq = trained_unit();
311 let err = bq.quantize(&[0.5_f32, f32::NEG_INFINITY, 0.5]).unwrap_err();
312 assert_eq!(err, IqdbError::InvalidVector);
313 }
314
315 #[test]
316 fn distance_rejects_non_hamming_metrics() {
317 let bq = trained_unit();
318 let code = bq.quantize(&[0.5_f32, 0.5, 0.5]).unwrap();
319 let q = [0.5_f32, 0.5, 0.5];
320 for metric in [
321 DistanceMetric::Cosine,
322 DistanceMetric::DotProduct,
323 DistanceMetric::Euclidean,
324 DistanceMetric::Manhattan,
325 ] {
326 assert_eq!(
327 bq.distance(&q, &code, metric).unwrap_err(),
328 IqdbError::InvalidMetric,
329 "metric {metric:?} must be rejected",
330 );
331 }
332 }
333
334 #[test]
335 fn distance_self_consistency_is_zero() {
336 let bq = trained_unit();
340 let v = [0.4_f32, 1.1, 1.9];
341 let code = bq.quantize(&v).unwrap();
342 let d = bq.distance(&v, &code, DistanceMetric::Hamming).unwrap();
343 assert_eq!(d, 0.0);
344 }
345
346 fn naive_hamming(a: &[u64], b: &[u64]) -> u32 {
347 a.iter()
348 .zip(b.iter())
349 .map(|(x, y)| (x ^ y).count_ones())
350 .sum()
351 }
352
353 #[test]
354 fn hamming_matches_naive_popcount_reference() {
355 let mut bq = BinaryQuantizer::new();
356 let dim = 70;
359 let a: Vec<f32> = (0..dim).map(|i| (i as f32).sin()).collect();
360 let b: Vec<f32> = (0..dim).map(|i| (i as f32).cos()).collect();
361 bq.train(&[&a[..], &b[..]]).unwrap();
362
363 let query: Vec<f32> = (0..dim).map(|i| ((i as f32) * 0.5).sin()).collect();
364 let code = bq.quantize(&b).unwrap();
365 let d = bq.distance(&query, &code, DistanceMetric::Hamming).unwrap();
366
367 let cal = bq.calibration.as_ref().unwrap();
368 let query_words = pack_bits(&query, &cal.means);
369 let expected = naive_hamming(&query_words, &code.words);
370 assert_eq!(d as u32, expected);
371 }
372
373 #[test]
374 fn quantize_zeros_padding_bits_for_dim_not_multiple_of_64() {
375 let dims = [63_usize, 64, 65, 127, 128, 129];
376 for &dim in &dims {
377 let zeros = vec![0.0_f32; dim];
382 let ones = vec![1.0_f32; dim];
383 let mut bq = BinaryQuantizer::new();
384 bq.train(&[&zeros[..], &ones[..]]).unwrap();
385
386 let code = bq.quantize(&ones).unwrap();
387 assert_eq!(code.dim, dim);
388 assert_eq!(code.words.len(), dim.div_ceil(BITS_PER_WORD));
389
390 let used_in_last = dim % BITS_PER_WORD;
391 if used_in_last != 0 {
392 let last = *code.words.last().unwrap();
393 let padding_mask = !0_u64 << used_in_last;
394 assert_eq!(
395 last & padding_mask,
396 0,
397 "dim={dim}: padding bits must be zero",
398 );
399 }
400 }
401 }
402}