hermes_core/structures/vector/quantization/
rabitq.rs1use rand::prelude::*;
14use serde::{Deserialize, Serialize};
15
16use super::super::ivf::cluster::QuantizedCode;
17use super::Quantizer;
18
19#[cfg(target_arch = "aarch64")]
20#[allow(unused_imports)]
21use std::arch::aarch64::*;
22
23#[cfg(target_arch = "x86_64")]
24#[allow(unused_imports)]
25use std::arch::x86_64::*;
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct RaBitQConfig {
30 pub dim: usize,
32 pub query_bits: u8,
34 pub seed: u64,
36}
37
38impl RaBitQConfig {
39 pub fn new(dim: usize) -> Self {
40 Self {
41 dim,
42 query_bits: 4,
43 seed: 42,
44 }
45 }
46
47 pub fn with_seed(mut self, seed: u64) -> Self {
48 self.seed = seed;
49 self
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct QuantizedVector {
56 pub bits: Vec<u8>,
58 pub dist_to_centroid: f32,
60 pub self_dot: f32,
62 pub popcount: u32,
64}
65
66impl QuantizedCode for QuantizedVector {
67 fn size_bytes(&self) -> usize {
68 self.bits.len() + 4 + 4 + 4 }
70}
71
72#[derive(Debug, Clone)]
74pub struct QuantizedQuery {
75 pub quantized: Vec<u8>,
77 pub dist_to_centroid: f32,
79 pub lower: f32,
81 pub width: f32,
83 pub sum: u32,
85 pub luts: Vec<[u16; 16]>,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct RaBitQCodebook {
94 pub config: RaBitQConfig,
96 pub random_signs: Vec<i8>,
98 pub random_perm: Vec<u32>,
100 pub version: u64,
102}
103
104impl RaBitQCodebook {
105 pub fn new(config: RaBitQConfig) -> Self {
107 let dim = config.dim;
108 let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
109
110 let random_signs: Vec<i8> = (0..dim)
112 .map(|_| if rng.random::<bool>() { 1 } else { -1 })
113 .collect();
114
115 let mut random_perm: Vec<u32> = (0..dim as u32).collect();
117 for i in (1..dim).rev() {
118 let j = rng.random_range(0..=i);
119 random_perm.swap(i, j);
120 }
121
122 let version = std::time::SystemTime::now()
123 .duration_since(std::time::UNIX_EPOCH)
124 .unwrap_or_default()
125 .as_millis() as u64;
126
127 Self {
128 config,
129 random_signs,
130 random_perm,
131 version,
132 }
133 }
134
135 pub fn encode(&self, vector: &[f32], centroid: Option<&[f32]>) -> QuantizedVector {
139 let dim = self.config.dim;
140
141 let centered: Vec<f32> = if let Some(c) = centroid {
143 vector.iter().zip(c).map(|(&v, &c)| v - c).collect()
144 } else {
145 vector.to_vec()
146 };
147
148 let norm: f32 = centered.iter().map(|x| x * x).sum::<f32>().sqrt();
149 let dist_to_centroid = norm;
150
151 let normalized: Vec<f32> = if norm > 1e-10 {
153 centered.iter().map(|x| x / norm).collect()
154 } else {
155 centered
156 };
157
158 let transformed: Vec<f32> = (0..dim)
160 .map(|i| {
161 let src_idx = self.random_perm[i] as usize;
162 normalized[src_idx] * self.random_signs[src_idx] as f32
163 })
164 .collect();
165
166 let num_bytes = dim.div_ceil(8);
168 let mut bits = vec![0u8; num_bytes];
169 let mut popcount = 0u32;
170
171 for i in 0..dim {
172 if transformed[i] >= 0.0 {
173 bits[i / 8] |= 1 << (i % 8);
174 popcount += 1;
175 }
176 }
177
178 let scale = 1.0 / (dim as f32).sqrt();
180 let mut self_dot = 0.0f32;
181 for i in 0..dim {
182 let o_bar_i = if (bits[i / 8] >> (i % 8)) & 1 == 1 {
183 scale
184 } else {
185 -scale
186 };
187 self_dot += transformed[i] * o_bar_i;
188 }
189
190 QuantizedVector {
191 bits,
192 dist_to_centroid,
193 self_dot,
194 popcount,
195 }
196 }
197
198 pub fn prepare_query(&self, query: &[f32], centroid: Option<&[f32]>) -> QuantizedQuery {
200 let dim = self.config.dim;
201
202 let centered: Vec<f32> = if let Some(c) = centroid {
204 query.iter().zip(c).map(|(&v, &c)| v - c).collect()
205 } else {
206 query.to_vec()
207 };
208
209 let norm: f32 = centered.iter().map(|x| x * x).sum::<f32>().sqrt();
210 let dist_to_centroid = norm;
211
212 let normalized: Vec<f32> = if norm > 1e-10 {
214 centered.iter().map(|x| x / norm).collect()
215 } else {
216 centered
217 };
218
219 let transformed: Vec<f32> = (0..dim)
221 .map(|i| {
222 let src_idx = self.random_perm[i] as usize;
223 normalized[src_idx] * self.random_signs[src_idx] as f32
224 })
225 .collect();
226
227 let min_val = transformed.iter().cloned().fold(f32::INFINITY, f32::min);
229 let max_val = transformed
230 .iter()
231 .cloned()
232 .fold(f32::NEG_INFINITY, f32::max);
233 let lower = min_val;
234 let width = if max_val > min_val {
235 max_val - min_val
236 } else {
237 1.0
238 };
239
240 let quantized_vals: Vec<u8> = transformed
242 .iter()
243 .map(|&x| {
244 let normalized = (x - lower) / width;
245 (normalized * 15.0).round().clamp(0.0, 15.0) as u8
246 })
247 .collect();
248
249 let num_bytes = dim.div_ceil(2);
251 let mut quantized = vec![0u8; num_bytes];
252 for i in 0..dim {
253 if i % 2 == 0 {
254 quantized[i / 2] |= quantized_vals[i];
255 } else {
256 quantized[i / 2] |= quantized_vals[i] << 4;
257 }
258 }
259
260 let sum: u32 = quantized_vals.iter().map(|&x| x as u32).sum();
262
263 let num_luts = dim.div_ceil(4);
265 let mut luts = vec![[0u16; 16]; num_luts];
266
267 for (lut_idx, lut) in luts.iter_mut().enumerate() {
268 let base_dim = lut_idx * 4;
269 for pattern in 0u8..16 {
270 let mut dot = 0u16;
271 for bit in 0..4 {
272 let dim_idx = base_dim + bit;
273 if dim_idx < dim && (pattern >> bit) & 1 == 1 {
274 dot += quantized_vals[dim_idx] as u16;
275 }
276 }
277 lut[pattern as usize] = dot;
278 }
279 }
280
281 QuantizedQuery {
282 quantized,
283 dist_to_centroid,
284 lower,
285 width,
286 sum,
287 luts,
288 }
289 }
290
291 pub fn estimate_distance(&self, query: &QuantizedQuery, code: &QuantizedVector) -> f32 {
293 let dim = self.config.dim;
294
295 let dot_sum = lut_dot_product_simd(&code.bits, &query.luts);
297
298 let scale = 1.0 / (dim as f32).sqrt();
299
300 let sum_positive = code.popcount as f32 * query.lower + dot_sum as f32 * query.width / 15.0;
302 let sum_all = dim as f32 * query.lower + query.sum as f32 * query.width / 15.0;
303
304 let q_obar_dot = scale * (2.0 * sum_positive - sum_all);
306
307 let q_o_estimate = if code.self_dot.abs() > 1e-6 {
309 q_obar_dot / code.self_dot
310 } else {
311 q_obar_dot
312 };
313
314 let q_o_clamped = q_o_estimate.clamp(-1.0, 1.0);
316
317 let dist_sq = code.dist_to_centroid * code.dist_to_centroid
319 + query.dist_to_centroid * query.dist_to_centroid
320 - 2.0 * code.dist_to_centroid * query.dist_to_centroid * q_o_clamped;
321
322 dist_sq.max(0.0)
323 }
324
325 pub fn size_bytes(&self) -> usize {
327 self.random_signs.len() + self.random_perm.len() * 4 + 64
328 }
329
330 pub fn estimated_memory_bytes(&self) -> usize {
332 self.size_bytes()
333 }
334}
335
336impl Quantizer for RaBitQCodebook {
337 type Code = QuantizedVector;
338 type Config = RaBitQConfig;
339 type QueryData = QuantizedQuery;
340
341 fn encode(&self, vector: &[f32], centroid: Option<&[f32]>) -> Self::Code {
342 self.encode(vector, centroid)
343 }
344
345 fn prepare_query(&self, query: &[f32], centroid: Option<&[f32]>) -> Self::QueryData {
346 self.prepare_query(query, centroid)
347 }
348
349 fn compute_distance(&self, query_data: &Self::QueryData, code: &Self::Code) -> f32 {
350 self.estimate_distance(query_data, code)
351 }
352
353 fn size_bytes(&self) -> usize {
354 self.size_bytes()
355 }
356}
357
358#[inline]
364fn lut_dot_product_simd(bits: &[u8], luts: &[[u16; 16]]) -> u32 {
365 #[cfg(target_arch = "aarch64")]
366 {
367 if let Some(result) = lut_dot_product_neon(bits, luts) {
368 return result;
369 }
370 }
371
372 #[cfg(target_arch = "x86_64")]
373 {
374 if is_x86_feature_detected!("ssse3") {
375 unsafe {
376 if let Some(result) = lut_dot_product_ssse3(bits, luts) {
377 return result;
378 }
379 }
380 }
381 }
382
383 lut_dot_product_scalar(bits, luts)
384}
385
386#[inline]
388fn lut_dot_product_scalar(bits: &[u8], luts: &[[u16; 16]]) -> u32 {
389 let mut dot_sum = 0u32;
390
391 for (lut_idx, lut) in luts.iter().enumerate() {
392 let base_bit = lut_idx * 4;
393 let byte_idx = base_bit / 8;
394 let bit_offset = base_bit % 8;
395
396 let byte = bits.get(byte_idx).copied().unwrap_or(0);
397 let next_byte = bits.get(byte_idx + 1).copied().unwrap_or(0);
398
399 let pattern = if bit_offset <= 4 {
400 (byte >> bit_offset) & 0x0F
401 } else {
402 ((byte >> bit_offset) | (next_byte << (8 - bit_offset))) & 0x0F
403 };
404
405 dot_sum += lut[pattern as usize] as u32;
406 }
407
408 dot_sum
409}
410
411#[cfg(target_arch = "aarch64")]
413#[inline]
414fn lut_dot_product_neon(bits: &[u8], luts: &[[u16; 16]]) -> Option<u32> {
415 if luts.len() < 8 {
416 return None;
417 }
418
419 let mut total = 0u32;
420 let num_luts = luts.len();
421 let mut lut_idx = 0;
422
423 while lut_idx + 2 <= num_luts {
424 let base_bit0 = lut_idx * 4;
425 let base_bit1 = (lut_idx + 1) * 4;
426
427 let byte_idx0 = base_bit0 / 8;
428 let bit_offset0 = base_bit0 % 8;
429 let byte_idx1 = base_bit1 / 8;
430 let bit_offset1 = base_bit1 % 8;
431
432 let byte0 = bits.get(byte_idx0).copied().unwrap_or(0);
433 let next0 = bits.get(byte_idx0 + 1).copied().unwrap_or(0);
434 let byte1 = bits.get(byte_idx1).copied().unwrap_or(0);
435 let next1 = bits.get(byte_idx1 + 1).copied().unwrap_or(0);
436
437 let pattern0 = if bit_offset0 <= 4 {
438 (byte0 >> bit_offset0) & 0x0F
439 } else {
440 ((byte0 >> bit_offset0) | (next0 << (8 - bit_offset0))) & 0x0F
441 };
442
443 let pattern1 = if bit_offset1 <= 4 {
444 (byte1 >> bit_offset1) & 0x0F
445 } else {
446 ((byte1 >> bit_offset1) | (next1 << (8 - bit_offset1))) & 0x0F
447 };
448
449 total += luts[lut_idx][pattern0 as usize] as u32;
450 total += luts[lut_idx + 1][pattern1 as usize] as u32;
451
452 lut_idx += 2;
453 }
454
455 while lut_idx < num_luts {
456 let base_bit = lut_idx * 4;
457 let byte_idx = base_bit / 8;
458 let bit_offset = base_bit % 8;
459
460 let byte = bits.get(byte_idx).copied().unwrap_or(0);
461 let next_byte = bits.get(byte_idx + 1).copied().unwrap_or(0);
462
463 let pattern = if bit_offset <= 4 {
464 (byte >> bit_offset) & 0x0F
465 } else {
466 ((byte >> bit_offset) | (next_byte << (8 - bit_offset))) & 0x0F
467 };
468
469 total += luts[lut_idx][pattern as usize] as u32;
470 lut_idx += 1;
471 }
472
473 Some(total)
474}
475
476#[cfg(target_arch = "x86_64")]
478#[target_feature(enable = "ssse3")]
479#[inline]
480unsafe fn lut_dot_product_ssse3(bits: &[u8], luts: &[[u16; 16]]) -> Option<u32> {
481 if luts.len() < 8 {
482 return None;
483 }
484 Some(lut_dot_product_scalar(bits, luts))
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490
491 #[test]
492 fn test_rabitq_codebook_basic() {
493 let config = RaBitQConfig::new(128);
494 let codebook = RaBitQCodebook::new(config);
495
496 assert_eq!(codebook.random_signs.len(), 128);
497 assert_eq!(codebook.random_perm.len(), 128);
498 }
499
500 #[test]
501 fn test_encode_decode() {
502 let config = RaBitQConfig::new(64);
503 let codebook = RaBitQCodebook::new(config);
504
505 let vector: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) / 32.0).collect();
506 let code = codebook.encode(&vector, None);
507
508 assert_eq!(code.bits.len(), 8); assert!(code.dist_to_centroid > 0.0);
510 }
511
512 #[test]
513 fn test_distance_estimation() {
514 let config = RaBitQConfig::new(64);
515 let codebook = RaBitQCodebook::new(config);
516
517 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
518 let v1: Vec<f32> = (0..64).map(|_| rng.random::<f32>() - 0.5).collect();
519 let v2: Vec<f32> = (0..64).map(|_| rng.random::<f32>() - 0.5).collect();
520
521 let code = codebook.encode(&v1, None);
522 let query = codebook.prepare_query(&v2, None);
523
524 let estimated = codebook.estimate_distance(&query, &code);
525 assert!(estimated >= 0.0);
526 }
527
528 #[test]
529 fn test_quantizer_trait() {
530 let config = RaBitQConfig::new(32);
531 let codebook = RaBitQCodebook::new(config);
532
533 let vector: Vec<f32> = (0..32).map(|i| i as f32 / 32.0).collect();
534 let query: Vec<f32> = (0..32).map(|i| (31 - i) as f32 / 32.0).collect();
535
536 let code = Quantizer::encode(&codebook, &vector, None);
538 let query_data = Quantizer::prepare_query(&codebook, &query, None);
539 let dist = Quantizer::compute_distance(&codebook, &query_data, &code);
540
541 assert!(dist >= 0.0);
542 }
543}