Skip to main content

rabitq_rs/
quantizer.rs

1use std::cmp::{Ordering, Reverse};
2use std::collections::BinaryHeap;
3
4use crate::math::{dot, l2_norm_sqr, subtract};
5use crate::simd;
6use crate::Metric;
7
8const K_TIGHT_START: [f64; 9] = [0.0, 0.15, 0.20, 0.52, 0.59, 0.71, 0.75, 0.77, 0.81];
9const K_EPS: f64 = 1e-5;
10const K_NENUM: f64 = 10.0;
11const K_CONST_EPSILON: f32 = 1.9;
12
13/// Configuration for RaBitQ quantisation.
14#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
15pub struct RabitqConfig {
16    pub total_bits: usize,
17    /// Precomputed constant scaling factor for faster quantization.
18    /// If None, compute optimal t for each vector (slower but more accurate).
19    /// If Some(t_const), use this constant for all vectors (100-500x faster).
20    pub t_const: Option<f32>,
21}
22
23impl RabitqConfig {
24    pub fn new(total_bits: usize) -> Self {
25        RabitqConfig {
26            total_bits,
27            t_const: None, // Default to precise mode
28        }
29    }
30
31    /// Create a faster config with precomputed scaling factor.
32    /// This trades <1% accuracy for 100-500x faster quantization.
33    pub fn faster(dim: usize, total_bits: usize, seed: u64) -> Self {
34        let ex_bits = total_bits.saturating_sub(1);
35        let t_const = if ex_bits > 0 {
36            Some(compute_const_scaling_factor(dim, ex_bits, seed))
37        } else {
38            None
39        };
40
41        RabitqConfig {
42            total_bits,
43            t_const,
44        }
45    }
46}
47
48impl Default for RabitqConfig {
49    fn default() -> Self {
50        Self::new(7) // Default to 7 bits as per MSTG spec
51    }
52}
53
54/// Quantised representation of a vector using packed format (aligned with C++ implementation).
55///
56/// Binary codes and extended codes are stored in bit-packed format for memory efficiency.
57/// - `binary_code_packed`: 1 bit per dimension
58/// - `ex_code_packed`: ex_bits per dimension
59///
60/// For performance-critical search operations, unpacked codes are cached to avoid
61/// repeated unpacking overhead.
62#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
63pub struct QuantizedVector {
64    /// Packed binary code (1 bit per element)
65    pub binary_code_packed: Vec<u8>,
66    /// Packed extended code (ex_bits per element)
67    pub ex_code_packed: Vec<u8>,
68    /// Number of extended bits per element
69    pub ex_bits: u8,
70    /// Original dimension (before padding)
71    pub dim: usize,
72    pub delta: f32,
73    pub vl: f32,
74    pub f_add: f32,
75    pub f_rescale: f32,
76    pub f_error: f32,
77    pub residual_norm: f32,
78    pub f_add_ex: f32,
79    pub f_rescale_ex: f32,
80}
81
82impl QuantizedVector {
83    /// Unpack binary code for computation
84    #[inline]
85    pub fn unpack_binary_code(&self) -> Vec<u8> {
86        let mut binary_code = vec![0u8; self.dim];
87        simd::unpack_binary_code(&self.binary_code_packed, &mut binary_code, self.dim);
88        binary_code
89    }
90
91    /// Unpack extended code for computation
92    #[inline]
93    pub fn unpack_ex_code(&self) -> Vec<u16> {
94        let mut ex_code = vec![0u16; self.dim];
95        simd::unpack_ex_code(&self.ex_code_packed, &mut ex_code, self.dim, self.ex_bits);
96        ex_code
97    }
98
99    /// Ensure unpacked caches are populated (No-op in memory-optimized version)
100    pub fn ensure_unpacked_cache(&mut self) {}
101
102    /// Calculate heap memory usage in bytes
103    pub fn heap_size(&self) -> usize {
104        self.binary_code_packed.capacity() * std::mem::size_of::<u8>()
105            + self.ex_code_packed.capacity() * std::mem::size_of::<u8>()
106    }
107}
108
109/// Quantise a vector relative to a centroid with custom configuration.
110pub fn quantize_with_centroid(
111    data: &[f32],
112    centroid: &[f32],
113    config: &RabitqConfig,
114    metric: Metric,
115) -> QuantizedVector {
116    assert_eq!(data.len(), centroid.len());
117    assert!((1..=16).contains(&config.total_bits));
118    let dim = data.len();
119    let ex_bits = config.total_bits.saturating_sub(1);
120
121    let residual = subtract(data, centroid);
122    let mut binary_code = vec![0u8; dim];
123    for (idx, &value) in residual.iter().enumerate() {
124        if value >= 0.0 {
125            binary_code[idx] = 1u8;
126        }
127    }
128
129    let (ex_code, ipnorm_inv) = if ex_bits > 0 {
130        ex_bits_code_with_inv(&residual, ex_bits, config.t_const)
131    } else {
132        (vec![0u16; dim], 1.0f32)
133    };
134
135    let mut total_code = vec![0u16; dim];
136    for i in 0..dim {
137        total_code[i] = ex_code[i] + ((binary_code[i] as u16) << ex_bits);
138    }
139
140    let (f_add, f_rescale, f_error, residual_norm) =
141        compute_one_bit_factors(&residual, centroid, &binary_code, metric);
142    let cb = -((1 << ex_bits) as f32 - 0.5);
143    let quantized_shifted: Vec<f32> = total_code.iter().map(|&code| code as f32 + cb).collect();
144    let norm_quan_sqr = l2_norm_sqr(&quantized_shifted);
145    let dot_residual_quant = dot(&residual, &quantized_shifted);
146
147    let norm_residual_sqr = l2_norm_sqr(&residual);
148    let norm_residual = norm_residual_sqr.sqrt();
149    let norm_quant = norm_quan_sqr.sqrt();
150    let denom = (norm_residual * norm_quant).max(f32::EPSILON);
151    let cos_similarity = (dot_residual_quant / denom).clamp(-1.0, 1.0);
152    let delta = if norm_quant <= f32::EPSILON {
153        0.0
154    } else {
155        (norm_residual / norm_quant) * cos_similarity
156    };
157    let vl = delta * cb;
158
159    let mut f_add_ex = 0.0f32;
160    let mut f_rescale_ex = 0.0f32;
161    if ex_bits > 0 {
162        let factors = compute_extended_factors(
163            &residual,
164            centroid,
165            &binary_code,
166            &ex_code,
167            ipnorm_inv,
168            metric,
169            ex_bits,
170        );
171        f_add_ex = factors.0;
172        f_rescale_ex = factors.1;
173    }
174
175    // Pack binary code and ex code into bit-packed format
176    let binary_code_packed_size = dim.div_ceil(8);
177    let mut binary_code_packed = vec![0u8; binary_code_packed_size];
178    simd::pack_binary_code(&binary_code, &mut binary_code_packed, dim);
179
180    // Use C++-compatible packing format for ex_code (Phase 4 optimization)
181    // This enables direct SIMD operations without unpacking overhead
182    let ex_code_packed_size = match ex_bits {
183        0 => dim / 16 * 2,  // 1-bit total (binary only), but allocate for consistency
184        1 => dim / 16 * 2,  // 2-bit total (1-bit ex-code) - not commonly used
185        2 => dim / 16 * 4,  // 3-bit total (2-bit ex-code)
186        6 => dim / 16 * 12, // 7-bit total (6-bit ex-code)
187        _ => (dim * ex_bits).div_ceil(8), // Fallback for other bit configs
188    };
189    let mut ex_code_packed = vec![0u8; ex_code_packed_size];
190
191    // Pack using C++-compatible format based on ex_bits
192    match ex_bits {
193        0 => {
194            // Binary-only: no ex-code to pack (all zeros)
195            // Keep packed array as zeros
196        }
197        1 => {
198            // 1-bit ex-code (2-bit total RaBitQ)
199            simd::pack_ex_code_1bit_cpp_compat(&ex_code, &mut ex_code_packed, dim);
200        }
201        2 => {
202            // 2-bit ex-code (3-bit total RaBitQ)
203            simd::pack_ex_code_2bit_cpp_compat(&ex_code, &mut ex_code_packed, dim);
204        }
205        6 => {
206            // 6-bit ex-code (7-bit total RaBitQ)
207            simd::pack_ex_code_6bit_cpp_compat(&ex_code, &mut ex_code_packed, dim);
208        }
209        _ => {
210            // Fallback to generic packing for unsupported bit configs
211            simd::pack_ex_code(&ex_code, &mut ex_code_packed, dim, ex_bits as u8);
212        }
213    }
214
215    QuantizedVector {
216        binary_code_packed,
217        ex_code_packed,
218        ex_bits: ex_bits as u8,
219        dim,
220        delta,
221        vl,
222        f_add,
223        f_rescale,
224        f_error,
225        residual_norm,
226        f_add_ex,
227        f_rescale_ex,
228    }
229}
230
231fn compute_one_bit_factors(
232    residual: &[f32],
233    centroid: &[f32],
234    binary_code: &[u8],
235    metric: Metric,
236) -> (f32, f32, f32, f32) {
237    let dim = residual.len();
238    let xu_cb: Vec<f32> = binary_code.iter().map(|&bit| bit as f32 - 0.5f32).collect();
239    let l2_sqr = l2_norm_sqr(residual);
240    let l2_norm = l2_sqr.sqrt();
241    let xu_cb_norm_sqr = l2_norm_sqr(&xu_cb);
242    let ip_resi_xucb = dot(residual, &xu_cb);
243    let ip_cent_xucb = dot(centroid, &xu_cb);
244    let dot_residual_centroid = dot(residual, centroid);
245
246    let mut denom = ip_resi_xucb;
247    if denom.abs() <= f32::EPSILON {
248        denom = f32::INFINITY;
249    }
250
251    let mut tmp_error = 0.0f32;
252    if dim > 1 {
253        let ratio = ((l2_sqr * xu_cb_norm_sqr) / (denom * denom)) - 1.0;
254        if ratio.is_finite() && ratio > 0.0 {
255            tmp_error = l2_norm * K_CONST_EPSILON * ((ratio / ((dim - 1) as f32)).max(0.0)).sqrt();
256        }
257    }
258
259    let (f_add, f_rescale, f_error) = match metric {
260        Metric::L2 => {
261            let f_add = l2_sqr + 2.0 * l2_sqr * ip_cent_xucb / denom;
262            let f_rescale = -2.0 * l2_sqr / denom;
263            let f_error = 2.0 * tmp_error;
264            (f_add, f_rescale, f_error)
265        }
266        Metric::InnerProduct => {
267            let f_add = 1.0 - dot_residual_centroid + l2_sqr * ip_cent_xucb / denom;
268            let f_rescale = -l2_sqr / denom;
269            let f_error = tmp_error;
270            (f_add, f_rescale, f_error)
271        }
272    };
273
274    (f_add, f_rescale, f_error, l2_norm)
275}
276
277fn ex_bits_code_with_inv(
278    residual: &[f32],
279    ex_bits: usize,
280    t_const: Option<f32>,
281) -> (Vec<u16>, f32) {
282    let dim = residual.len();
283    let mut normalized_abs: Vec<f32> = residual.iter().map(|x| x.abs()).collect();
284    let norm = normalized_abs.iter().map(|x| x * x).sum::<f32>().sqrt();
285
286    if norm <= f32::EPSILON {
287        return (vec![0u16; dim], 1.0);
288    }
289
290    for value in normalized_abs.iter_mut() {
291        *value /= norm;
292    }
293
294    // Use precomputed t_const if available, otherwise compute optimal t
295    let t = if let Some(t) = t_const {
296        t as f64
297    } else {
298        best_rescale_factor(&normalized_abs, ex_bits)
299    };
300
301    quantize_ex_with_inv(&normalized_abs, residual, ex_bits, t)
302}
303
304fn best_rescale_factor(o_abs: &[f32], ex_bits: usize) -> f64 {
305    let dim = o_abs.len();
306    let max_o = o_abs.iter().cloned().fold(0.0f32, f32::max) as f64;
307    if max_o <= f64::EPSILON {
308        return 1.0;
309    }
310
311    let table_idx = ex_bits.min(K_TIGHT_START.len() - 1);
312    let t_end = (((1 << ex_bits) - 1) as f64 + K_NENUM) / max_o;
313    let t_start = t_end * K_TIGHT_START[table_idx];
314
315    let mut cur_o_bar = vec![0i32; dim];
316    let mut sqr_denominator = dim as f64 * 0.25;
317    let mut numerator = 0.0f64;
318
319    for (idx, &val) in o_abs.iter().enumerate() {
320        let cur = ((t_start * val as f64) + K_EPS) as i32;
321        cur_o_bar[idx] = cur;
322        sqr_denominator += (cur * cur + cur) as f64;
323        numerator += (cur as f64 + 0.5) * val as f64;
324    }
325
326    #[derive(Copy, Clone, Debug)]
327    struct HeapEntry {
328        t: f64,
329        idx: usize,
330    }
331
332    impl PartialEq for HeapEntry {
333        fn eq(&self, other: &Self) -> bool {
334            self.t.to_bits() == other.t.to_bits() && self.idx == other.idx
335        }
336    }
337
338    impl Eq for HeapEntry {}
339
340    impl PartialOrd for HeapEntry {
341        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
342            Some(self.cmp(other))
343        }
344    }
345
346    impl Ord for HeapEntry {
347        fn cmp(&self, other: &Self) -> Ordering {
348            self.t
349                .total_cmp(&other.t)
350                .then_with(|| self.idx.cmp(&other.idx))
351        }
352    }
353
354    let mut heap: BinaryHeap<Reverse<HeapEntry>> = BinaryHeap::new();
355    for (idx, &val) in o_abs.iter().enumerate() {
356        if val > 0.0 {
357            let next_t = (cur_o_bar[idx] + 1) as f64 / val as f64;
358            heap.push(Reverse(HeapEntry { t: next_t, idx }));
359        }
360    }
361
362    let mut max_ip = 0.0f64;
363    let mut best_t = t_start;
364
365    while let Some(Reverse(HeapEntry { t: cur_t, idx })) = heap.pop() {
366        if cur_t >= t_end {
367            continue;
368        }
369
370        cur_o_bar[idx] += 1;
371        let update = cur_o_bar[idx];
372        sqr_denominator += 2.0 * update as f64;
373        numerator += o_abs[idx] as f64;
374
375        let cur_ip = numerator / sqr_denominator.sqrt();
376        if cur_ip > max_ip {
377            max_ip = cur_ip;
378            best_t = cur_t;
379        }
380
381        if update < (1 << ex_bits) - 1 && o_abs[idx] > 0.0 {
382            let t_next = (update + 1) as f64 / o_abs[idx] as f64;
383            if t_next < t_end {
384                heap.push(Reverse(HeapEntry { t: t_next, idx }));
385            }
386        }
387    }
388
389    if best_t <= 0.0 {
390        t_start.max(f64::EPSILON)
391    } else {
392        best_t
393    }
394}
395
396fn quantize_ex_with_inv(
397    o_abs: &[f32],
398    residual: &[f32],
399    ex_bits: usize,
400    t: f64,
401) -> (Vec<u16>, f32) {
402    let dim = o_abs.len();
403    if dim == 0 {
404        return (Vec::new(), 1.0);
405    }
406
407    let mut code = vec![0u16; dim];
408    let max_val = (1 << ex_bits) - 1;
409    let mut ipnorm = 0.0f64;
410
411    for i in 0..dim {
412        let mut cur = (t * o_abs[i] as f64 + K_EPS) as i32;
413        if cur > max_val {
414            cur = max_val;
415        }
416        code[i] = cur as u16;
417        ipnorm += (cur as f64 + 0.5) * o_abs[i] as f64;
418    }
419
420    let mut ipnorm_inv = if ipnorm.is_finite() && ipnorm > 0.0 {
421        (1.0 / ipnorm) as f32
422    } else {
423        1.0
424    };
425
426    let mask = max_val as u16;
427    if max_val > 0 {
428        for (idx, &res) in residual.iter().enumerate() {
429            if res < 0.0 {
430                code[idx] = (!code[idx]) & mask;
431            }
432        }
433    }
434
435    if !ipnorm_inv.is_finite() {
436        ipnorm_inv = 1.0;
437    }
438
439    (code, ipnorm_inv)
440}
441
442fn compute_extended_factors(
443    residual: &[f32],
444    centroid: &[f32],
445    binary_code: &[u8],
446    ex_code: &[u16],
447    ipnorm_inv: f32,
448    metric: Metric,
449    ex_bits: usize,
450) -> (f32, f32) {
451    let dim = residual.len();
452    let cb = -((1 << ex_bits) as f32 - 0.5);
453    let xu_cb: Vec<f32> = (0..dim)
454        .map(|i| {
455            let total = ex_code[i] as u32 + ((binary_code[i] as u32) << ex_bits);
456            total as f32 + cb
457        })
458        .collect();
459
460    let l2_sqr = l2_norm_sqr(residual);
461    let l2_norm = l2_sqr.sqrt();
462    let xu_cb_norm_sqr = l2_norm_sqr(&xu_cb);
463    let ip_resi_xucb = dot(residual, &xu_cb);
464    let ip_cent_xucb = dot(centroid, &xu_cb);
465    let dot_residual_centroid = dot(residual, centroid);
466
467    let mut denom = ip_resi_xucb * ip_resi_xucb;
468    if denom <= f32::EPSILON {
469        denom = f32::INFINITY;
470    }
471
472    let mut tmp_error = 0.0f32;
473    if dim > 1 {
474        let ratio = ((l2_sqr * xu_cb_norm_sqr) / denom) - 1.0;
475        if ratio > 0.0 {
476            tmp_error = l2_norm * K_CONST_EPSILON * ((ratio / ((dim - 1) as f32)).max(0.0)).sqrt();
477        }
478    }
479
480    let safe_denom = if ip_resi_xucb.abs() <= f32::EPSILON {
481        f32::INFINITY
482    } else {
483        ip_resi_xucb
484    };
485
486    let (f_add_ex, f_rescale_ex) = match metric {
487        Metric::L2 => {
488            let f_add = l2_sqr + 2.0 * l2_sqr * ip_cent_xucb / safe_denom;
489            let f_rescale = -2.0 * l2_norm * ipnorm_inv;
490            (f_add, f_rescale)
491        }
492        Metric::InnerProduct => {
493            let f_add = 1.0 - dot_residual_centroid + l2_sqr * ip_cent_xucb / safe_denom;
494            let f_rescale = -l2_norm * ipnorm_inv;
495            (f_add, f_rescale)
496        }
497    };
498
499    let _ = tmp_error; // retain structure parity; tmp_error may be used in future
500
501    (f_add_ex, f_rescale_ex)
502}
503
504/// Reconstruct a vector from its quantised representation and centroid (in rotated space).
505///
506/// This reconstructs the vector in the **rotated space**. To get the original vector,
507/// you need to apply inverse rotation to the result.
508#[allow(dead_code)]
509pub(crate) fn reconstruct_into(centroid: &[f32], quantized: &QuantizedVector, output: &mut [f32]) {
510    assert_eq!(centroid.len(), quantized.dim);
511    assert_eq!(output.len(), centroid.len());
512
513    let binary_code = quantized.unpack_binary_code();
514    let ex_code = quantized.unpack_ex_code();
515
516    for i in 0..centroid.len() {
517        let total_code =
518            (ex_code[i] as u32 + ((binary_code[i] as u32) << quantized.ex_bits)) as f32;
519        output[i] = centroid[i] + quantized.delta * total_code + quantized.vl;
520    }
521}
522
523/// Compute a constant scaling factor for faster quantization.
524///
525/// This function samples random normalized vectors and computes the average optimal
526/// scaling factor. Using this constant factor for all vectors is 100-500x faster
527/// than computing the optimal factor per-vector, with <1% accuracy loss.
528///
529/// # Arguments
530/// * `dim` - Vector dimensionality
531/// * `ex_bits` - Number of extended bits for quantization
532/// * `seed` - Random seed for reproducibility
533///
534/// # Returns
535/// Average optimal scaling factor across 100 random samples
536pub fn compute_const_scaling_factor(dim: usize, ex_bits: usize, seed: u64) -> f32 {
537    use rand::prelude::*;
538    use rand_distr::{Distribution, Normal};
539
540    const NUM_SAMPLES: usize = 100;
541
542    let mut rng = StdRng::seed_from_u64(seed);
543    let normal = Normal::new(0.0, 1.0).expect("failed to create normal distribution");
544
545    let mut sum_t = 0.0f64;
546
547    for _ in 0..NUM_SAMPLES {
548        // Generate random Gaussian vector
549        let vec: Vec<f32> = (0..dim).map(|_| normal.sample(&mut rng) as f32).collect();
550
551        // Normalize and take absolute value
552        let norm = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
553        if norm <= f32::EPSILON {
554            continue;
555        }
556
557        let normalized_abs: Vec<f32> = vec.iter().map(|x| (x / norm).abs()).collect();
558
559        // Compute optimal scaling factor for this random vector
560        let t = best_rescale_factor(&normalized_abs, ex_bits);
561        sum_t += t;
562    }
563
564    (sum_t / NUM_SAMPLES as f64) as f32
565}