hermes_core/structures/vector/quantization/
rabitq.rs1use rand::prelude::*;
14use serde::{Deserialize, Serialize};
15
16use super::super::ivf::cluster::QuantizedCode;
17use super::Quantizer;
18
19#[cfg(target_arch = "aarch64")]
20#[allow(unused_imports)]
21use std::arch::aarch64::*;
22
23#[cfg(target_arch = "x86_64")]
24#[allow(unused_imports)]
25use std::arch::x86_64::*;
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct RaBitQConfig {
30 pub dim: usize,
32 pub query_bits: u8,
34 pub seed: u64,
36}
37
38impl RaBitQConfig {
39 pub fn new(dim: usize) -> Self {
40 Self {
41 dim,
42 query_bits: 4,
43 seed: 42,
44 }
45 }
46
47 pub fn with_seed(mut self, seed: u64) -> Self {
48 self.seed = seed;
49 self
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct QuantizedVector {
56 pub bits: Vec<u8>,
58 pub dist_to_centroid: f32,
60 pub self_dot: f32,
62 pub popcount: u32,
64}
65
66impl QuantizedCode for QuantizedVector {
67 fn size_bytes(&self) -> usize {
68 self.bits.len() + 4 + 4 + 4 }
70}
71
72#[derive(Debug, Clone)]
74pub struct QuantizedQuery {
75 pub quantized: Vec<u8>,
77 pub dist_to_centroid: f32,
79 pub lower: f32,
81 pub width: f32,
83 pub sum: u32,
85 pub luts: Vec<[u16; 16]>,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct RaBitQCodebook {
94 pub config: RaBitQConfig,
96 pub random_signs: Vec<i8>,
98 pub random_perm: Vec<u32>,
100 pub version: u64,
102}
103
104impl RaBitQCodebook {
105 pub fn new(config: RaBitQConfig) -> Self {
107 let dim = config.dim;
108 let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
109
110 let random_signs: Vec<i8> = (0..dim)
112 .map(|_| if rng.random::<bool>() { 1 } else { -1 })
113 .collect();
114
115 let mut random_perm: Vec<u32> = (0..dim as u32).collect();
117 for i in (1..dim).rev() {
118 let j = rng.random_range(0..=i);
119 random_perm.swap(i, j);
120 }
121
122 let version = std::time::SystemTime::now()
123 .duration_since(std::time::UNIX_EPOCH)
124 .unwrap_or_default()
125 .as_millis() as u64;
126
127 Self {
128 config,
129 random_signs,
130 random_perm,
131 version,
132 }
133 }
134
135 pub fn encode(&self, vector: &[f32], centroid: Option<&[f32]>) -> QuantizedVector {
139 let dim = self.config.dim;
140
141 let centered: Vec<f32> = if let Some(c) = centroid {
143 vector.iter().zip(c).map(|(&v, &c)| v - c).collect()
144 } else {
145 vector.to_vec()
146 };
147
148 let norm: f32 = centered.iter().map(|x| x * x).sum::<f32>().sqrt();
149 let dist_to_centroid = norm;
150
151 let normalized: Vec<f32> = if norm > 1e-10 {
153 centered.iter().map(|x| x / norm).collect()
154 } else {
155 centered
156 };
157
158 let transformed: Vec<f32> = (0..dim)
160 .map(|i| {
161 let src_idx = self.random_perm[i] as usize;
162 normalized[src_idx] * self.random_signs[src_idx] as f32
163 })
164 .collect();
165
166 let num_bytes = dim.div_ceil(8);
168 let mut bits = vec![0u8; num_bytes];
169 let mut popcount = 0u32;
170
171 for i in 0..dim {
172 if transformed[i] >= 0.0 {
173 bits[i / 8] |= 1 << (i % 8);
174 popcount += 1;
175 }
176 }
177
178 let scale = 1.0 / (dim as f32).sqrt();
180 let mut self_dot = 0.0f32;
181 for i in 0..dim {
182 let o_bar_i = if (bits[i / 8] >> (i % 8)) & 1 == 1 {
183 scale
184 } else {
185 -scale
186 };
187 self_dot += transformed[i] * o_bar_i;
188 }
189
190 QuantizedVector {
191 bits,
192 dist_to_centroid,
193 self_dot,
194 popcount,
195 }
196 }
197
198 pub fn prepare_query(&self, query: &[f32], centroid: Option<&[f32]>) -> QuantizedQuery {
200 let dim = self.config.dim;
201
202 let centered: Vec<f32> = if let Some(c) = centroid {
204 query.iter().zip(c).map(|(&v, &c)| v - c).collect()
205 } else {
206 query.to_vec()
207 };
208
209 let norm: f32 = centered.iter().map(|x| x * x).sum::<f32>().sqrt();
210 let dist_to_centroid = norm;
211
212 let normalized: Vec<f32> = if norm > 1e-10 {
214 centered.iter().map(|x| x / norm).collect()
215 } else {
216 centered
217 };
218
219 let transformed: Vec<f32> = (0..dim)
221 .map(|i| {
222 let src_idx = self.random_perm[i] as usize;
223 normalized[src_idx] * self.random_signs[src_idx] as f32
224 })
225 .collect();
226
227 let min_val = transformed.iter().cloned().fold(f32::INFINITY, f32::min);
229 let max_val = transformed
230 .iter()
231 .cloned()
232 .fold(f32::NEG_INFINITY, f32::max);
233 let lower = min_val;
234 let width = if max_val > min_val {
235 max_val - min_val
236 } else {
237 1.0
238 };
239
240 let quantized_vals: Vec<u8> = transformed
242 .iter()
243 .map(|&x| {
244 let normalized = (x - lower) / width;
245 (normalized * 15.0).round().clamp(0.0, 15.0) as u8
246 })
247 .collect();
248
249 let num_bytes = dim.div_ceil(2);
251 let mut quantized = vec![0u8; num_bytes];
252 for i in 0..dim {
253 if i % 2 == 0 {
254 quantized[i / 2] |= quantized_vals[i];
255 } else {
256 quantized[i / 2] |= quantized_vals[i] << 4;
257 }
258 }
259
260 let sum: u32 = quantized_vals.iter().map(|&x| x as u32).sum();
262
263 let num_luts = dim.div_ceil(4);
265 let mut luts = vec![[0u16; 16]; num_luts];
266
267 for (lut_idx, lut) in luts.iter_mut().enumerate() {
268 let base_dim = lut_idx * 4;
269 for pattern in 0u8..16 {
270 let mut dot = 0u16;
271 for bit in 0..4 {
272 let dim_idx = base_dim + bit;
273 if dim_idx < dim && (pattern >> bit) & 1 == 1 {
274 dot += quantized_vals[dim_idx] as u16;
275 }
276 }
277 lut[pattern as usize] = dot;
278 }
279 }
280
281 QuantizedQuery {
282 quantized,
283 dist_to_centroid,
284 lower,
285 width,
286 sum,
287 luts,
288 }
289 }
290
291 pub fn estimate_distance(&self, query: &QuantizedQuery, code: &QuantizedVector) -> f32 {
293 let dim = self.config.dim;
294
295 let dot_sum = lut_dot_product_simd(&code.bits, &query.luts);
297
298 let scale = 1.0 / (dim as f32).sqrt();
299
300 let sum_positive = code.popcount as f32 * query.lower + dot_sum as f32 * query.width / 15.0;
302 let sum_all = dim as f32 * query.lower + query.sum as f32 * query.width / 15.0;
303
304 let q_obar_dot = scale * (2.0 * sum_positive - sum_all);
306
307 let q_o_estimate = if code.self_dot.abs() > 1e-6 {
309 q_obar_dot / code.self_dot
310 } else {
311 q_obar_dot
312 };
313
314 let q_o_clamped = q_o_estimate.clamp(-1.0, 1.0);
316
317 let dist_sq = code.dist_to_centroid * code.dist_to_centroid
319 + query.dist_to_centroid * query.dist_to_centroid
320 - 2.0 * code.dist_to_centroid * query.dist_to_centroid * q_o_clamped;
321
322 dist_sq.max(0.0)
323 }
324
325 pub fn size_bytes(&self) -> usize {
327 self.random_signs.len() + self.random_perm.len() * 4 + 64
328 }
329}
330
331impl Quantizer for RaBitQCodebook {
332 type Code = QuantizedVector;
333 type Config = RaBitQConfig;
334 type QueryData = QuantizedQuery;
335
336 fn encode(&self, vector: &[f32], centroid: Option<&[f32]>) -> Self::Code {
337 self.encode(vector, centroid)
338 }
339
340 fn prepare_query(&self, query: &[f32], centroid: Option<&[f32]>) -> Self::QueryData {
341 self.prepare_query(query, centroid)
342 }
343
344 fn compute_distance(&self, query_data: &Self::QueryData, code: &Self::Code) -> f32 {
345 self.estimate_distance(query_data, code)
346 }
347
348 fn size_bytes(&self) -> usize {
349 self.size_bytes()
350 }
351}
352
353#[inline]
359fn lut_dot_product_simd(bits: &[u8], luts: &[[u16; 16]]) -> u32 {
360 #[cfg(target_arch = "aarch64")]
361 {
362 if let Some(result) = lut_dot_product_neon(bits, luts) {
363 return result;
364 }
365 }
366
367 #[cfg(target_arch = "x86_64")]
368 {
369 if is_x86_feature_detected!("ssse3") {
370 unsafe {
371 if let Some(result) = lut_dot_product_ssse3(bits, luts) {
372 return result;
373 }
374 }
375 }
376 }
377
378 lut_dot_product_scalar(bits, luts)
379}
380
381#[inline]
383fn lut_dot_product_scalar(bits: &[u8], luts: &[[u16; 16]]) -> u32 {
384 let mut dot_sum = 0u32;
385
386 for (lut_idx, lut) in luts.iter().enumerate() {
387 let base_bit = lut_idx * 4;
388 let byte_idx = base_bit / 8;
389 let bit_offset = base_bit % 8;
390
391 let byte = bits.get(byte_idx).copied().unwrap_or(0);
392 let next_byte = bits.get(byte_idx + 1).copied().unwrap_or(0);
393
394 let pattern = if bit_offset <= 4 {
395 (byte >> bit_offset) & 0x0F
396 } else {
397 ((byte >> bit_offset) | (next_byte << (8 - bit_offset))) & 0x0F
398 };
399
400 dot_sum += lut[pattern as usize] as u32;
401 }
402
403 dot_sum
404}
405
406#[cfg(target_arch = "aarch64")]
408#[inline]
409fn lut_dot_product_neon(bits: &[u8], luts: &[[u16; 16]]) -> Option<u32> {
410 if luts.len() < 8 {
411 return None;
412 }
413
414 let mut total = 0u32;
415 let num_luts = luts.len();
416 let mut lut_idx = 0;
417
418 while lut_idx + 2 <= num_luts {
419 let base_bit0 = lut_idx * 4;
420 let base_bit1 = (lut_idx + 1) * 4;
421
422 let byte_idx0 = base_bit0 / 8;
423 let bit_offset0 = base_bit0 % 8;
424 let byte_idx1 = base_bit1 / 8;
425 let bit_offset1 = base_bit1 % 8;
426
427 let byte0 = bits.get(byte_idx0).copied().unwrap_or(0);
428 let next0 = bits.get(byte_idx0 + 1).copied().unwrap_or(0);
429 let byte1 = bits.get(byte_idx1).copied().unwrap_or(0);
430 let next1 = bits.get(byte_idx1 + 1).copied().unwrap_or(0);
431
432 let pattern0 = if bit_offset0 <= 4 {
433 (byte0 >> bit_offset0) & 0x0F
434 } else {
435 ((byte0 >> bit_offset0) | (next0 << (8 - bit_offset0))) & 0x0F
436 };
437
438 let pattern1 = if bit_offset1 <= 4 {
439 (byte1 >> bit_offset1) & 0x0F
440 } else {
441 ((byte1 >> bit_offset1) | (next1 << (8 - bit_offset1))) & 0x0F
442 };
443
444 total += luts[lut_idx][pattern0 as usize] as u32;
445 total += luts[lut_idx + 1][pattern1 as usize] as u32;
446
447 lut_idx += 2;
448 }
449
450 while lut_idx < num_luts {
451 let base_bit = lut_idx * 4;
452 let byte_idx = base_bit / 8;
453 let bit_offset = base_bit % 8;
454
455 let byte = bits.get(byte_idx).copied().unwrap_or(0);
456 let next_byte = bits.get(byte_idx + 1).copied().unwrap_or(0);
457
458 let pattern = if bit_offset <= 4 {
459 (byte >> bit_offset) & 0x0F
460 } else {
461 ((byte >> bit_offset) | (next_byte << (8 - bit_offset))) & 0x0F
462 };
463
464 total += luts[lut_idx][pattern as usize] as u32;
465 lut_idx += 1;
466 }
467
468 Some(total)
469}
470
471#[cfg(target_arch = "x86_64")]
473#[target_feature(enable = "ssse3")]
474#[inline]
475unsafe fn lut_dot_product_ssse3(bits: &[u8], luts: &[[u16; 16]]) -> Option<u32> {
476 if luts.len() < 8 {
477 return None;
478 }
479 Some(lut_dot_product_scalar(bits, luts))
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485
486 #[test]
487 fn test_rabitq_codebook_basic() {
488 let config = RaBitQConfig::new(128);
489 let codebook = RaBitQCodebook::new(config);
490
491 assert_eq!(codebook.random_signs.len(), 128);
492 assert_eq!(codebook.random_perm.len(), 128);
493 }
494
495 #[test]
496 fn test_encode_decode() {
497 let config = RaBitQConfig::new(64);
498 let codebook = RaBitQCodebook::new(config);
499
500 let vector: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) / 32.0).collect();
501 let code = codebook.encode(&vector, None);
502
503 assert_eq!(code.bits.len(), 8); assert!(code.dist_to_centroid > 0.0);
505 }
506
507 #[test]
508 fn test_distance_estimation() {
509 let config = RaBitQConfig::new(64);
510 let codebook = RaBitQCodebook::new(config);
511
512 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
513 let v1: Vec<f32> = (0..64).map(|_| rng.random::<f32>() - 0.5).collect();
514 let v2: Vec<f32> = (0..64).map(|_| rng.random::<f32>() - 0.5).collect();
515
516 let code = codebook.encode(&v1, None);
517 let query = codebook.prepare_query(&v2, None);
518
519 let estimated = codebook.estimate_distance(&query, &code);
520 assert!(estimated >= 0.0);
521 }
522
523 #[test]
524 fn test_quantizer_trait() {
525 let config = RaBitQConfig::new(32);
526 let codebook = RaBitQCodebook::new(config);
527
528 let vector: Vec<f32> = (0..32).map(|i| i as f32 / 32.0).collect();
529 let query: Vec<f32> = (0..32).map(|i| (31 - i) as f32 / 32.0).collect();
530
531 let code = Quantizer::encode(&codebook, &vector, None);
533 let query_data = Quantizer::prepare_query(&codebook, &query, None);
534 let dist = Quantizer::compute_distance(&codebook, &query_data, &code);
535
536 assert!(dist >= 0.0);
537 }
538}