1use std::cmp::{Ordering, Reverse};
2use std::collections::BinaryHeap;
3
4use crate::math::{dot, l2_norm_sqr, subtract};
5use crate::simd;
6use crate::Metric;
7
8const K_TIGHT_START: [f64; 9] = [0.0, 0.15, 0.20, 0.52, 0.59, 0.71, 0.75, 0.77, 0.81];
9const K_EPS: f64 = 1e-5;
10const K_NENUM: f64 = 10.0;
11const K_CONST_EPSILON: f32 = 1.9;
12
13#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
15pub struct RabitqConfig {
16 pub total_bits: usize,
17 pub t_const: Option<f32>,
21}
22
23impl RabitqConfig {
24 pub fn new(total_bits: usize) -> Self {
25 RabitqConfig {
26 total_bits,
27 t_const: None, }
29 }
30
31 pub fn faster(dim: usize, total_bits: usize, seed: u64) -> Self {
34 let ex_bits = total_bits.saturating_sub(1);
35 let t_const = if ex_bits > 0 {
36 Some(compute_const_scaling_factor(dim, ex_bits, seed))
37 } else {
38 None
39 };
40
41 RabitqConfig {
42 total_bits,
43 t_const,
44 }
45 }
46}
47
48impl Default for RabitqConfig {
49 fn default() -> Self {
50 Self::new(7) }
52}
53
54#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
63pub struct QuantizedVector {
64 pub binary_code_packed: Vec<u8>,
66 pub ex_code_packed: Vec<u8>,
68 pub ex_bits: u8,
70 pub dim: usize,
72 pub delta: f32,
73 pub vl: f32,
74 pub f_add: f32,
75 pub f_rescale: f32,
76 pub f_error: f32,
77 pub residual_norm: f32,
78 pub f_add_ex: f32,
79 pub f_rescale_ex: f32,
80}
81
82impl QuantizedVector {
83 #[inline]
85 pub fn unpack_binary_code(&self) -> Vec<u8> {
86 let mut binary_code = vec![0u8; self.dim];
87 simd::unpack_binary_code(&self.binary_code_packed, &mut binary_code, self.dim);
88 binary_code
89 }
90
91 #[inline]
93 pub fn unpack_ex_code(&self) -> Vec<u16> {
94 let mut ex_code = vec![0u16; self.dim];
95 simd::unpack_ex_code(&self.ex_code_packed, &mut ex_code, self.dim, self.ex_bits);
96 ex_code
97 }
98
99 pub fn ensure_unpacked_cache(&mut self) {}
101
102 pub fn heap_size(&self) -> usize {
104 self.binary_code_packed.capacity() * std::mem::size_of::<u8>()
105 + self.ex_code_packed.capacity() * std::mem::size_of::<u8>()
106 }
107}
108
109pub fn quantize_with_centroid(
111 data: &[f32],
112 centroid: &[f32],
113 config: &RabitqConfig,
114 metric: Metric,
115) -> QuantizedVector {
116 assert_eq!(data.len(), centroid.len());
117 assert!((1..=16).contains(&config.total_bits));
118 let dim = data.len();
119 let ex_bits = config.total_bits.saturating_sub(1);
120
121 let residual = subtract(data, centroid);
122 let mut binary_code = vec![0u8; dim];
123 for (idx, &value) in residual.iter().enumerate() {
124 if value >= 0.0 {
125 binary_code[idx] = 1u8;
126 }
127 }
128
129 let (ex_code, ipnorm_inv) = if ex_bits > 0 {
130 ex_bits_code_with_inv(&residual, ex_bits, config.t_const)
131 } else {
132 (vec![0u16; dim], 1.0f32)
133 };
134
135 let mut total_code = vec![0u16; dim];
136 for i in 0..dim {
137 total_code[i] = ex_code[i] + ((binary_code[i] as u16) << ex_bits);
138 }
139
140 let (f_add, f_rescale, f_error, residual_norm) =
141 compute_one_bit_factors(&residual, centroid, &binary_code, metric);
142 let cb = -((1 << ex_bits) as f32 - 0.5);
143 let quantized_shifted: Vec<f32> = total_code.iter().map(|&code| code as f32 + cb).collect();
144 let norm_quan_sqr = l2_norm_sqr(&quantized_shifted);
145 let dot_residual_quant = dot(&residual, &quantized_shifted);
146
147 let norm_residual_sqr = l2_norm_sqr(&residual);
148 let norm_residual = norm_residual_sqr.sqrt();
149 let norm_quant = norm_quan_sqr.sqrt();
150 let denom = (norm_residual * norm_quant).max(f32::EPSILON);
151 let cos_similarity = (dot_residual_quant / denom).clamp(-1.0, 1.0);
152 let delta = if norm_quant <= f32::EPSILON {
153 0.0
154 } else {
155 (norm_residual / norm_quant) * cos_similarity
156 };
157 let vl = delta * cb;
158
159 let mut f_add_ex = 0.0f32;
160 let mut f_rescale_ex = 0.0f32;
161 if ex_bits > 0 {
162 let factors = compute_extended_factors(
163 &residual,
164 centroid,
165 &binary_code,
166 &ex_code,
167 ipnorm_inv,
168 metric,
169 ex_bits,
170 );
171 f_add_ex = factors.0;
172 f_rescale_ex = factors.1;
173 }
174
175 let binary_code_packed_size = dim.div_ceil(8);
177 let mut binary_code_packed = vec![0u8; binary_code_packed_size];
178 simd::pack_binary_code(&binary_code, &mut binary_code_packed, dim);
179
180 let ex_code_packed_size = match ex_bits {
183 0 => dim / 16 * 2, 1 => dim / 16 * 2, 2 => dim / 16 * 4, 6 => dim / 16 * 12, _ => (dim * ex_bits).div_ceil(8), };
189 let mut ex_code_packed = vec![0u8; ex_code_packed_size];
190
191 match ex_bits {
193 0 => {
194 }
197 1 => {
198 simd::pack_ex_code_1bit_cpp_compat(&ex_code, &mut ex_code_packed, dim);
200 }
201 2 => {
202 simd::pack_ex_code_2bit_cpp_compat(&ex_code, &mut ex_code_packed, dim);
204 }
205 6 => {
206 simd::pack_ex_code_6bit_cpp_compat(&ex_code, &mut ex_code_packed, dim);
208 }
209 _ => {
210 simd::pack_ex_code(&ex_code, &mut ex_code_packed, dim, ex_bits as u8);
212 }
213 }
214
215 QuantizedVector {
216 binary_code_packed,
217 ex_code_packed,
218 ex_bits: ex_bits as u8,
219 dim,
220 delta,
221 vl,
222 f_add,
223 f_rescale,
224 f_error,
225 residual_norm,
226 f_add_ex,
227 f_rescale_ex,
228 }
229}
230
231fn compute_one_bit_factors(
232 residual: &[f32],
233 centroid: &[f32],
234 binary_code: &[u8],
235 metric: Metric,
236) -> (f32, f32, f32, f32) {
237 let dim = residual.len();
238 let xu_cb: Vec<f32> = binary_code.iter().map(|&bit| bit as f32 - 0.5f32).collect();
239 let l2_sqr = l2_norm_sqr(residual);
240 let l2_norm = l2_sqr.sqrt();
241 let xu_cb_norm_sqr = l2_norm_sqr(&xu_cb);
242 let ip_resi_xucb = dot(residual, &xu_cb);
243 let ip_cent_xucb = dot(centroid, &xu_cb);
244 let dot_residual_centroid = dot(residual, centroid);
245
246 let mut denom = ip_resi_xucb;
247 if denom.abs() <= f32::EPSILON {
248 denom = f32::INFINITY;
249 }
250
251 let mut tmp_error = 0.0f32;
252 if dim > 1 {
253 let ratio = ((l2_sqr * xu_cb_norm_sqr) / (denom * denom)) - 1.0;
254 if ratio.is_finite() && ratio > 0.0 {
255 tmp_error = l2_norm * K_CONST_EPSILON * ((ratio / ((dim - 1) as f32)).max(0.0)).sqrt();
256 }
257 }
258
259 let (f_add, f_rescale, f_error) = match metric {
260 Metric::L2 => {
261 let f_add = l2_sqr + 2.0 * l2_sqr * ip_cent_xucb / denom;
262 let f_rescale = -2.0 * l2_sqr / denom;
263 let f_error = 2.0 * tmp_error;
264 (f_add, f_rescale, f_error)
265 }
266 Metric::InnerProduct => {
267 let f_add = 1.0 - dot_residual_centroid + l2_sqr * ip_cent_xucb / denom;
268 let f_rescale = -l2_sqr / denom;
269 let f_error = tmp_error;
270 (f_add, f_rescale, f_error)
271 }
272 };
273
274 (f_add, f_rescale, f_error, l2_norm)
275}
276
277fn ex_bits_code_with_inv(
278 residual: &[f32],
279 ex_bits: usize,
280 t_const: Option<f32>,
281) -> (Vec<u16>, f32) {
282 let dim = residual.len();
283 let mut normalized_abs: Vec<f32> = residual.iter().map(|x| x.abs()).collect();
284 let norm = normalized_abs.iter().map(|x| x * x).sum::<f32>().sqrt();
285
286 if norm <= f32::EPSILON {
287 return (vec![0u16; dim], 1.0);
288 }
289
290 for value in normalized_abs.iter_mut() {
291 *value /= norm;
292 }
293
294 let t = if let Some(t) = t_const {
296 t as f64
297 } else {
298 best_rescale_factor(&normalized_abs, ex_bits)
299 };
300
301 quantize_ex_with_inv(&normalized_abs, residual, ex_bits, t)
302}
303
304fn best_rescale_factor(o_abs: &[f32], ex_bits: usize) -> f64 {
305 let dim = o_abs.len();
306 let max_o = o_abs.iter().cloned().fold(0.0f32, f32::max) as f64;
307 if max_o <= f64::EPSILON {
308 return 1.0;
309 }
310
311 let table_idx = ex_bits.min(K_TIGHT_START.len() - 1);
312 let t_end = (((1 << ex_bits) - 1) as f64 + K_NENUM) / max_o;
313 let t_start = t_end * K_TIGHT_START[table_idx];
314
315 let mut cur_o_bar = vec![0i32; dim];
316 let mut sqr_denominator = dim as f64 * 0.25;
317 let mut numerator = 0.0f64;
318
319 for (idx, &val) in o_abs.iter().enumerate() {
320 let cur = ((t_start * val as f64) + K_EPS) as i32;
321 cur_o_bar[idx] = cur;
322 sqr_denominator += (cur * cur + cur) as f64;
323 numerator += (cur as f64 + 0.5) * val as f64;
324 }
325
326 #[derive(Copy, Clone, Debug)]
327 struct HeapEntry {
328 t: f64,
329 idx: usize,
330 }
331
332 impl PartialEq for HeapEntry {
333 fn eq(&self, other: &Self) -> bool {
334 self.t.to_bits() == other.t.to_bits() && self.idx == other.idx
335 }
336 }
337
338 impl Eq for HeapEntry {}
339
340 impl PartialOrd for HeapEntry {
341 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
342 Some(self.cmp(other))
343 }
344 }
345
346 impl Ord for HeapEntry {
347 fn cmp(&self, other: &Self) -> Ordering {
348 self.t
349 .total_cmp(&other.t)
350 .then_with(|| self.idx.cmp(&other.idx))
351 }
352 }
353
354 let mut heap: BinaryHeap<Reverse<HeapEntry>> = BinaryHeap::new();
355 for (idx, &val) in o_abs.iter().enumerate() {
356 if val > 0.0 {
357 let next_t = (cur_o_bar[idx] + 1) as f64 / val as f64;
358 heap.push(Reverse(HeapEntry { t: next_t, idx }));
359 }
360 }
361
362 let mut max_ip = 0.0f64;
363 let mut best_t = t_start;
364
365 while let Some(Reverse(HeapEntry { t: cur_t, idx })) = heap.pop() {
366 if cur_t >= t_end {
367 continue;
368 }
369
370 cur_o_bar[idx] += 1;
371 let update = cur_o_bar[idx];
372 sqr_denominator += 2.0 * update as f64;
373 numerator += o_abs[idx] as f64;
374
375 let cur_ip = numerator / sqr_denominator.sqrt();
376 if cur_ip > max_ip {
377 max_ip = cur_ip;
378 best_t = cur_t;
379 }
380
381 if update < (1 << ex_bits) - 1 && o_abs[idx] > 0.0 {
382 let t_next = (update + 1) as f64 / o_abs[idx] as f64;
383 if t_next < t_end {
384 heap.push(Reverse(HeapEntry { t: t_next, idx }));
385 }
386 }
387 }
388
389 if best_t <= 0.0 {
390 t_start.max(f64::EPSILON)
391 } else {
392 best_t
393 }
394}
395
396fn quantize_ex_with_inv(
397 o_abs: &[f32],
398 residual: &[f32],
399 ex_bits: usize,
400 t: f64,
401) -> (Vec<u16>, f32) {
402 let dim = o_abs.len();
403 if dim == 0 {
404 return (Vec::new(), 1.0);
405 }
406
407 let mut code = vec![0u16; dim];
408 let max_val = (1 << ex_bits) - 1;
409 let mut ipnorm = 0.0f64;
410
411 for i in 0..dim {
412 let mut cur = (t * o_abs[i] as f64 + K_EPS) as i32;
413 if cur > max_val {
414 cur = max_val;
415 }
416 code[i] = cur as u16;
417 ipnorm += (cur as f64 + 0.5) * o_abs[i] as f64;
418 }
419
420 let mut ipnorm_inv = if ipnorm.is_finite() && ipnorm > 0.0 {
421 (1.0 / ipnorm) as f32
422 } else {
423 1.0
424 };
425
426 let mask = max_val as u16;
427 if max_val > 0 {
428 for (idx, &res) in residual.iter().enumerate() {
429 if res < 0.0 {
430 code[idx] = (!code[idx]) & mask;
431 }
432 }
433 }
434
435 if !ipnorm_inv.is_finite() {
436 ipnorm_inv = 1.0;
437 }
438
439 (code, ipnorm_inv)
440}
441
442fn compute_extended_factors(
443 residual: &[f32],
444 centroid: &[f32],
445 binary_code: &[u8],
446 ex_code: &[u16],
447 ipnorm_inv: f32,
448 metric: Metric,
449 ex_bits: usize,
450) -> (f32, f32) {
451 let dim = residual.len();
452 let cb = -((1 << ex_bits) as f32 - 0.5);
453 let xu_cb: Vec<f32> = (0..dim)
454 .map(|i| {
455 let total = ex_code[i] as u32 + ((binary_code[i] as u32) << ex_bits);
456 total as f32 + cb
457 })
458 .collect();
459
460 let l2_sqr = l2_norm_sqr(residual);
461 let l2_norm = l2_sqr.sqrt();
462 let xu_cb_norm_sqr = l2_norm_sqr(&xu_cb);
463 let ip_resi_xucb = dot(residual, &xu_cb);
464 let ip_cent_xucb = dot(centroid, &xu_cb);
465 let dot_residual_centroid = dot(residual, centroid);
466
467 let mut denom = ip_resi_xucb * ip_resi_xucb;
468 if denom <= f32::EPSILON {
469 denom = f32::INFINITY;
470 }
471
472 let mut tmp_error = 0.0f32;
473 if dim > 1 {
474 let ratio = ((l2_sqr * xu_cb_norm_sqr) / denom) - 1.0;
475 if ratio > 0.0 {
476 tmp_error = l2_norm * K_CONST_EPSILON * ((ratio / ((dim - 1) as f32)).max(0.0)).sqrt();
477 }
478 }
479
480 let safe_denom = if ip_resi_xucb.abs() <= f32::EPSILON {
481 f32::INFINITY
482 } else {
483 ip_resi_xucb
484 };
485
486 let (f_add_ex, f_rescale_ex) = match metric {
487 Metric::L2 => {
488 let f_add = l2_sqr + 2.0 * l2_sqr * ip_cent_xucb / safe_denom;
489 let f_rescale = -2.0 * l2_norm * ipnorm_inv;
490 (f_add, f_rescale)
491 }
492 Metric::InnerProduct => {
493 let f_add = 1.0 - dot_residual_centroid + l2_sqr * ip_cent_xucb / safe_denom;
494 let f_rescale = -l2_norm * ipnorm_inv;
495 (f_add, f_rescale)
496 }
497 };
498
499 let _ = tmp_error; (f_add_ex, f_rescale_ex)
502}
503
504#[allow(dead_code)]
509pub(crate) fn reconstruct_into(centroid: &[f32], quantized: &QuantizedVector, output: &mut [f32]) {
510 assert_eq!(centroid.len(), quantized.dim);
511 assert_eq!(output.len(), centroid.len());
512
513 let binary_code = quantized.unpack_binary_code();
514 let ex_code = quantized.unpack_ex_code();
515
516 for i in 0..centroid.len() {
517 let total_code =
518 (ex_code[i] as u32 + ((binary_code[i] as u32) << quantized.ex_bits)) as f32;
519 output[i] = centroid[i] + quantized.delta * total_code + quantized.vl;
520 }
521}
522
523pub fn compute_const_scaling_factor(dim: usize, ex_bits: usize, seed: u64) -> f32 {
537 use rand::prelude::*;
538 use rand_distr::{Distribution, Normal};
539
540 const NUM_SAMPLES: usize = 100;
541
542 let mut rng = StdRng::seed_from_u64(seed);
543 let normal = Normal::new(0.0, 1.0).expect("failed to create normal distribution");
544
545 let mut sum_t = 0.0f64;
546
547 for _ in 0..NUM_SAMPLES {
548 let vec: Vec<f32> = (0..dim).map(|_| normal.sample(&mut rng) as f32).collect();
550
551 let norm = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
553 if norm <= f32::EPSILON {
554 continue;
555 }
556
557 let normalized_abs: Vec<f32> = vec.iter().map(|x| (x / norm).abs()).collect();
558
559 let t = best_rescale_factor(&normalized_abs, ex_bits);
561 sum_t += t;
562 }
563
564 (sum_t / NUM_SAMPLES as f64) as f32
565}