1use std::cmp::{Ordering, Reverse};
2use std::collections::BinaryHeap;
3
4use crate::math::{dot, l2_norm_sqr, subtract};
5use crate::simd;
6use crate::Metric;
7
8const K_TIGHT_START: [f64; 9] = [0.0, 0.15, 0.20, 0.52, 0.59, 0.71, 0.75, 0.77, 0.81];
9const K_EPS: f64 = 1e-5;
10const K_NENUM: f64 = 10.0;
11const K_CONST_EPSILON: f32 = 1.9;
12
13#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
15pub struct RabitqConfig {
16 pub total_bits: usize,
17 pub t_const: Option<f32>,
21}
22
23impl RabitqConfig {
24 pub fn new(total_bits: usize) -> Self {
25 RabitqConfig {
26 total_bits,
27 t_const: None, }
29 }
30
31 pub fn faster(dim: usize, total_bits: usize, seed: u64) -> Self {
34 let ex_bits = total_bits.saturating_sub(1);
35 let t_const = if ex_bits > 0 {
36 Some(compute_const_scaling_factor(dim, ex_bits, seed))
37 } else {
38 None
39 };
40
41 RabitqConfig {
42 total_bits,
43 t_const,
44 }
45 }
46}
47
48impl Default for RabitqConfig {
49 fn default() -> Self {
50 Self::new(7) }
52}
53
54#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
63pub struct QuantizedVector {
64 pub code: Vec<u16>,
66 pub binary_code_packed: Vec<u8>,
68 pub ex_code_packed: Vec<u8>,
70 #[serde(skip)]
72 pub binary_code_unpacked: Vec<u8>,
73 #[serde(skip)]
75 pub ex_code_unpacked: Vec<u16>,
76 pub ex_bits: u8,
78 pub dim: usize,
80 pub delta: f32,
81 pub vl: f32,
82 pub f_add: f32,
83 pub f_rescale: f32,
84 pub f_error: f32,
85 pub residual_norm: f32,
86 pub f_add_ex: f32,
87 pub f_rescale_ex: f32,
88}
89
90impl QuantizedVector {
91 #[inline]
93 pub fn unpack_binary_code(&self) -> Vec<u8> {
94 if !self.binary_code_unpacked.is_empty() {
96 return self.binary_code_unpacked.clone();
97 }
98 let mut binary_code = vec![0u8; self.dim];
100 simd::unpack_binary_code(&self.binary_code_packed, &mut binary_code, self.dim);
101 binary_code
102 }
103
104 #[inline]
106 pub fn unpack_ex_code(&self) -> Vec<u16> {
107 if !self.ex_code_unpacked.is_empty() {
109 return self.ex_code_unpacked.clone();
110 }
111 let mut ex_code = vec![0u16; self.dim];
113 simd::unpack_ex_code(&self.ex_code_packed, &mut ex_code, self.dim, self.ex_bits);
114 ex_code
115 }
116
117 pub fn ensure_unpacked_cache(&mut self) {
119 if self.binary_code_unpacked.is_empty() {
120 self.binary_code_unpacked = vec![0u8; self.dim];
121 simd::unpack_binary_code(
122 &self.binary_code_packed,
123 &mut self.binary_code_unpacked,
124 self.dim,
125 );
126 }
127 if self.ex_code_unpacked.is_empty() {
128 self.ex_code_unpacked = vec![0u16; self.dim];
129 simd::unpack_ex_code(
130 &self.ex_code_packed,
131 &mut self.ex_code_unpacked,
132 self.dim,
133 self.ex_bits,
134 );
135 }
136 }
137}
138
139pub fn quantize_with_centroid(
141 data: &[f32],
142 centroid: &[f32],
143 config: &RabitqConfig,
144 metric: Metric,
145) -> QuantizedVector {
146 assert_eq!(data.len(), centroid.len());
147 assert!((1..=16).contains(&config.total_bits));
148 let dim = data.len();
149 let ex_bits = config.total_bits.saturating_sub(1);
150
151 let residual = subtract(data, centroid);
152 let mut binary_code = vec![0u8; dim];
153 for (idx, &value) in residual.iter().enumerate() {
154 if value >= 0.0 {
155 binary_code[idx] = 1u8;
156 }
157 }
158
159 let (ex_code, ipnorm_inv) = if ex_bits > 0 {
160 ex_bits_code_with_inv(&residual, ex_bits, config.t_const)
161 } else {
162 (vec![0u16; dim], 1.0f32)
163 };
164
165 let mut total_code = vec![0u16; dim];
166 for i in 0..dim {
167 total_code[i] = ex_code[i] + ((binary_code[i] as u16) << ex_bits);
168 }
169
170 let (f_add, f_rescale, f_error, residual_norm) =
171 compute_one_bit_factors(&residual, centroid, &binary_code, metric);
172 let cb = -((1 << ex_bits) as f32 - 0.5);
173 let quantized_shifted: Vec<f32> = total_code.iter().map(|&code| code as f32 + cb).collect();
174 let norm_quan_sqr = l2_norm_sqr(&quantized_shifted);
175 let dot_residual_quant = dot(&residual, &quantized_shifted);
176
177 let norm_residual_sqr = l2_norm_sqr(&residual);
178 let norm_residual = norm_residual_sqr.sqrt();
179 let norm_quant = norm_quan_sqr.sqrt();
180 let denom = (norm_residual * norm_quant).max(f32::EPSILON);
181 let cos_similarity = (dot_residual_quant / denom).clamp(-1.0, 1.0);
182 let delta = if norm_quant <= f32::EPSILON {
183 0.0
184 } else {
185 (norm_residual / norm_quant) * cos_similarity
186 };
187 let vl = delta * cb;
188
189 let mut f_add_ex = 0.0f32;
190 let mut f_rescale_ex = 0.0f32;
191 if ex_bits > 0 {
192 let factors = compute_extended_factors(
193 &residual,
194 centroid,
195 &binary_code,
196 &ex_code,
197 ipnorm_inv,
198 metric,
199 ex_bits,
200 );
201 f_add_ex = factors.0;
202 f_rescale_ex = factors.1;
203 }
204
205 let binary_code_packed_size = (dim + 7) / 8;
207 let mut binary_code_packed = vec![0u8; binary_code_packed_size];
208 simd::pack_binary_code(&binary_code, &mut binary_code_packed, dim);
209
210 let ex_code_packed_size = match ex_bits {
213 0 => dim / 16 * 2, 1 => dim / 16 * 2, 2 => dim / 16 * 4, 6 => dim / 16 * 12, _ => ((dim * ex_bits) + 7) / 8, };
219 let mut ex_code_packed = vec![0u8; ex_code_packed_size];
220
221 match ex_bits {
223 0 => {
224 }
227 1 => {
228 simd::pack_ex_code_1bit_cpp_compat(&ex_code, &mut ex_code_packed, dim);
230 }
231 2 => {
232 simd::pack_ex_code_2bit_cpp_compat(&ex_code, &mut ex_code_packed, dim);
234 }
235 6 => {
236 simd::pack_ex_code_6bit_cpp_compat(&ex_code, &mut ex_code_packed, dim);
238 }
239 _ => {
240 simd::pack_ex_code(&ex_code, &mut ex_code_packed, dim, ex_bits as u8);
242 }
243 }
244
245 QuantizedVector {
246 code: total_code,
247 binary_code_packed,
248 ex_code_packed,
249 binary_code_unpacked: binary_code, ex_code_unpacked: ex_code, ex_bits: ex_bits as u8,
252 dim,
253 delta,
254 vl,
255 f_add,
256 f_rescale,
257 f_error,
258 residual_norm,
259 f_add_ex,
260 f_rescale_ex,
261 }
262}
263
264fn compute_one_bit_factors(
265 residual: &[f32],
266 centroid: &[f32],
267 binary_code: &[u8],
268 metric: Metric,
269) -> (f32, f32, f32, f32) {
270 let dim = residual.len();
271 let xu_cb: Vec<f32> = binary_code.iter().map(|&bit| bit as f32 - 0.5f32).collect();
272 let l2_sqr = l2_norm_sqr(residual);
273 let l2_norm = l2_sqr.sqrt();
274 let xu_cb_norm_sqr = l2_norm_sqr(&xu_cb);
275 let ip_resi_xucb = dot(residual, &xu_cb);
276 let ip_cent_xucb = dot(centroid, &xu_cb);
277 let dot_residual_centroid = dot(residual, centroid);
278
279 let mut denom = ip_resi_xucb;
280 if denom.abs() <= f32::EPSILON {
281 denom = f32::INFINITY;
282 }
283
284 let mut tmp_error = 0.0f32;
285 if dim > 1 {
286 let ratio = ((l2_sqr * xu_cb_norm_sqr) / (denom * denom)) - 1.0;
287 if ratio.is_finite() && ratio > 0.0 {
288 tmp_error = l2_norm * K_CONST_EPSILON * ((ratio / ((dim - 1) as f32)).max(0.0)).sqrt();
289 }
290 }
291
292 let (f_add, f_rescale, f_error) = match metric {
293 Metric::L2 => {
294 let f_add = l2_sqr + 2.0 * l2_sqr * ip_cent_xucb / denom;
295 let f_rescale = -2.0 * l2_sqr / denom;
296 let f_error = 2.0 * tmp_error;
297 (f_add, f_rescale, f_error)
298 }
299 Metric::InnerProduct => {
300 let f_add = 1.0 - dot_residual_centroid + l2_sqr * ip_cent_xucb / denom;
301 let f_rescale = -l2_sqr / denom;
302 let f_error = tmp_error;
303 (f_add, f_rescale, f_error)
304 }
305 };
306
307 (f_add, f_rescale, f_error, l2_norm)
308}
309
310fn ex_bits_code_with_inv(
311 residual: &[f32],
312 ex_bits: usize,
313 t_const: Option<f32>,
314) -> (Vec<u16>, f32) {
315 let dim = residual.len();
316 let mut normalized_abs: Vec<f32> = residual.iter().map(|x| x.abs()).collect();
317 let norm = normalized_abs.iter().map(|x| x * x).sum::<f32>().sqrt();
318
319 if norm <= f32::EPSILON {
320 return (vec![0u16; dim], 1.0);
321 }
322
323 for value in normalized_abs.iter_mut() {
324 *value /= norm;
325 }
326
327 let t = if let Some(t) = t_const {
329 t as f64
330 } else {
331 best_rescale_factor(&normalized_abs, ex_bits)
332 };
333
334 quantize_ex_with_inv(&normalized_abs, residual, ex_bits, t)
335}
336
337fn best_rescale_factor(o_abs: &[f32], ex_bits: usize) -> f64 {
338 let dim = o_abs.len();
339 let max_o = o_abs.iter().cloned().fold(0.0f32, f32::max) as f64;
340 if max_o <= f64::EPSILON {
341 return 1.0;
342 }
343
344 let table_idx = ex_bits.min(K_TIGHT_START.len() - 1);
345 let t_end = (((1 << ex_bits) - 1) as f64 + K_NENUM) / max_o;
346 let t_start = t_end * K_TIGHT_START[table_idx];
347
348 let mut cur_o_bar = vec![0i32; dim];
349 let mut sqr_denominator = dim as f64 * 0.25;
350 let mut numerator = 0.0f64;
351
352 for (idx, &val) in o_abs.iter().enumerate() {
353 let cur = ((t_start * val as f64) + K_EPS) as i32;
354 cur_o_bar[idx] = cur;
355 sqr_denominator += (cur * cur + cur) as f64;
356 numerator += (cur as f64 + 0.5) * val as f64;
357 }
358
359 #[derive(Copy, Clone, Debug)]
360 struct HeapEntry {
361 t: f64,
362 idx: usize,
363 }
364
365 impl PartialEq for HeapEntry {
366 fn eq(&self, other: &Self) -> bool {
367 self.t.to_bits() == other.t.to_bits() && self.idx == other.idx
368 }
369 }
370
371 impl Eq for HeapEntry {}
372
373 impl PartialOrd for HeapEntry {
374 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
375 Some(self.cmp(other))
376 }
377 }
378
379 impl Ord for HeapEntry {
380 fn cmp(&self, other: &Self) -> Ordering {
381 self.t
382 .total_cmp(&other.t)
383 .then_with(|| self.idx.cmp(&other.idx))
384 }
385 }
386
387 let mut heap: BinaryHeap<Reverse<HeapEntry>> = BinaryHeap::new();
388 for (idx, &val) in o_abs.iter().enumerate() {
389 if val > 0.0 {
390 let next_t = (cur_o_bar[idx] + 1) as f64 / val as f64;
391 heap.push(Reverse(HeapEntry { t: next_t, idx }));
392 }
393 }
394
395 let mut max_ip = 0.0f64;
396 let mut best_t = t_start;
397
398 while let Some(Reverse(HeapEntry { t: cur_t, idx })) = heap.pop() {
399 if cur_t >= t_end {
400 continue;
401 }
402
403 cur_o_bar[idx] += 1;
404 let update = cur_o_bar[idx];
405 sqr_denominator += 2.0 * update as f64;
406 numerator += o_abs[idx] as f64;
407
408 let cur_ip = numerator / sqr_denominator.sqrt();
409 if cur_ip > max_ip {
410 max_ip = cur_ip;
411 best_t = cur_t;
412 }
413
414 if update < (1 << ex_bits) - 1 && o_abs[idx] > 0.0 {
415 let t_next = (update + 1) as f64 / o_abs[idx] as f64;
416 if t_next < t_end {
417 heap.push(Reverse(HeapEntry { t: t_next, idx }));
418 }
419 }
420 }
421
422 if best_t <= 0.0 {
423 t_start.max(f64::EPSILON)
424 } else {
425 best_t
426 }
427}
428
429fn quantize_ex_with_inv(
430 o_abs: &[f32],
431 residual: &[f32],
432 ex_bits: usize,
433 t: f64,
434) -> (Vec<u16>, f32) {
435 let dim = o_abs.len();
436 if dim == 0 {
437 return (Vec::new(), 1.0);
438 }
439
440 let mut code = vec![0u16; dim];
441 let max_val = (1 << ex_bits) - 1;
442 let mut ipnorm = 0.0f64;
443
444 for i in 0..dim {
445 let mut cur = (t * o_abs[i] as f64 + K_EPS) as i32;
446 if cur > max_val {
447 cur = max_val;
448 }
449 code[i] = cur as u16;
450 ipnorm += (cur as f64 + 0.5) * o_abs[i] as f64;
451 }
452
453 let mut ipnorm_inv = if ipnorm.is_finite() && ipnorm > 0.0 {
454 (1.0 / ipnorm) as f32
455 } else {
456 1.0
457 };
458
459 let mask = max_val as u16;
460 if max_val > 0 {
461 for (idx, &res) in residual.iter().enumerate() {
462 if res < 0.0 {
463 code[idx] = (!code[idx]) & mask;
464 }
465 }
466 }
467
468 if !ipnorm_inv.is_finite() {
469 ipnorm_inv = 1.0;
470 }
471
472 (code, ipnorm_inv)
473}
474
475fn compute_extended_factors(
476 residual: &[f32],
477 centroid: &[f32],
478 binary_code: &[u8],
479 ex_code: &[u16],
480 ipnorm_inv: f32,
481 metric: Metric,
482 ex_bits: usize,
483) -> (f32, f32) {
484 let dim = residual.len();
485 let cb = -((1 << ex_bits) as f32 - 0.5);
486 let xu_cb: Vec<f32> = (0..dim)
487 .map(|i| {
488 let total = ex_code[i] as u32 + ((binary_code[i] as u32) << ex_bits);
489 total as f32 + cb
490 })
491 .collect();
492
493 let l2_sqr = l2_norm_sqr(residual);
494 let l2_norm = l2_sqr.sqrt();
495 let xu_cb_norm_sqr = l2_norm_sqr(&xu_cb);
496 let ip_resi_xucb = dot(residual, &xu_cb);
497 let ip_cent_xucb = dot(centroid, &xu_cb);
498 let dot_residual_centroid = dot(residual, centroid);
499
500 let mut denom = ip_resi_xucb * ip_resi_xucb;
501 if denom <= f32::EPSILON {
502 denom = f32::INFINITY;
503 }
504
505 let mut tmp_error = 0.0f32;
506 if dim > 1 {
507 let ratio = ((l2_sqr * xu_cb_norm_sqr) / denom) - 1.0;
508 if ratio > 0.0 {
509 tmp_error = l2_norm * K_CONST_EPSILON * ((ratio / ((dim - 1) as f32)).max(0.0)).sqrt();
510 }
511 }
512
513 let safe_denom = if ip_resi_xucb.abs() <= f32::EPSILON {
514 f32::INFINITY
515 } else {
516 ip_resi_xucb
517 };
518
519 let (f_add_ex, f_rescale_ex) = match metric {
520 Metric::L2 => {
521 let f_add = l2_sqr + 2.0 * l2_sqr * ip_cent_xucb / safe_denom;
522 let f_rescale = -2.0 * l2_norm * ipnorm_inv;
523 (f_add, f_rescale)
524 }
525 Metric::InnerProduct => {
526 let f_add = 1.0 - dot_residual_centroid + l2_sqr * ip_cent_xucb / safe_denom;
527 let f_rescale = -l2_norm * ipnorm_inv;
528 (f_add, f_rescale)
529 }
530 };
531
532 let _ = tmp_error; (f_add_ex, f_rescale_ex)
535}
536
537#[cfg_attr(not(test), allow(dead_code))]
539pub fn reconstruct_into(centroid: &[f32], quantized: &QuantizedVector, output: &mut [f32]) {
540 assert_eq!(centroid.len(), quantized.code.len());
541 assert_eq!(output.len(), centroid.len());
542 for i in 0..centroid.len() {
543 output[i] = centroid[i] + quantized.delta * quantized.code[i] as f32 + quantized.vl;
544 }
545}
546
547pub fn compute_const_scaling_factor(dim: usize, ex_bits: usize, seed: u64) -> f32 {
561 use rand::prelude::*;
562 use rand_distr::{Distribution, Normal};
563
564 const NUM_SAMPLES: usize = 100;
565
566 let mut rng = StdRng::seed_from_u64(seed);
567 let normal = Normal::new(0.0, 1.0).expect("failed to create normal distribution");
568
569 let mut sum_t = 0.0f64;
570
571 for _ in 0..NUM_SAMPLES {
572 let vec: Vec<f32> = (0..dim).map(|_| normal.sample(&mut rng) as f32).collect();
574
575 let norm = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
577 if norm <= f32::EPSILON {
578 continue;
579 }
580
581 let normalized_abs: Vec<f32> = vec.iter().map(|x| (x / norm).abs()).collect();
582
583 let t = best_rescale_factor(&normalized_abs, ex_bits);
585 sum_t += t;
586 }
587
588 (sum_t / NUM_SAMPLES as f64) as f32
589}