rabitq_rs/
quantizer.rs

1use std::cmp::{Ordering, Reverse};
2use std::collections::BinaryHeap;
3
4use crate::math::{dot, l2_norm_sqr, subtract};
5use crate::simd;
6use crate::Metric;
7
8const K_TIGHT_START: [f64; 9] = [0.0, 0.15, 0.20, 0.52, 0.59, 0.71, 0.75, 0.77, 0.81];
9const K_EPS: f64 = 1e-5;
10const K_NENUM: f64 = 10.0;
11const K_CONST_EPSILON: f32 = 1.9;
12
13/// Configuration for RaBitQ quantisation.
14#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
15pub struct RabitqConfig {
16    pub total_bits: usize,
17    /// Precomputed constant scaling factor for faster quantization.
18    /// If None, compute optimal t for each vector (slower but more accurate).
19    /// If Some(t_const), use this constant for all vectors (100-500x faster).
20    pub t_const: Option<f32>,
21}
22
23impl RabitqConfig {
24    pub fn new(total_bits: usize) -> Self {
25        RabitqConfig {
26            total_bits,
27            t_const: None, // Default to precise mode
28        }
29    }
30
31    /// Create a faster config with precomputed scaling factor.
32    /// This trades <1% accuracy for 100-500x faster quantization.
33    pub fn faster(dim: usize, total_bits: usize, seed: u64) -> Self {
34        let ex_bits = total_bits.saturating_sub(1);
35        let t_const = if ex_bits > 0 {
36            Some(compute_const_scaling_factor(dim, ex_bits, seed))
37        } else {
38            None
39        };
40
41        RabitqConfig {
42            total_bits,
43            t_const,
44        }
45    }
46}
47
48impl Default for RabitqConfig {
49    fn default() -> Self {
50        Self::new(7) // Default to 7 bits as per MSTG spec
51    }
52}
53
54/// Quantised representation of a vector using packed format (aligned with C++ implementation).
55///
56/// Binary codes and extended codes are stored in bit-packed format for memory efficiency.
57/// - `binary_code_packed`: 1 bit per dimension
58/// - `ex_code_packed`: ex_bits per dimension
59///
60/// For performance-critical search operations, unpacked codes are cached to avoid
61/// repeated unpacking overhead.
62#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
63pub struct QuantizedVector {
64    /// Full code (binary_code | ex_code), kept for backward compatibility
65    pub code: Vec<u16>,
66    /// Packed binary code (1 bit per element)
67    pub binary_code_packed: Vec<u8>,
68    /// Packed extended code (ex_bits per element)
69    pub ex_code_packed: Vec<u8>,
70    /// Cached unpacked binary code (for fast search)
71    #[serde(skip)]
72    pub binary_code_unpacked: Vec<u8>,
73    /// Cached unpacked ex code (for fast search)
74    #[serde(skip)]
75    pub ex_code_unpacked: Vec<u16>,
76    /// Number of extended bits per element
77    pub ex_bits: u8,
78    /// Original dimension (before padding)
79    pub dim: usize,
80    pub delta: f32,
81    pub vl: f32,
82    pub f_add: f32,
83    pub f_rescale: f32,
84    pub f_error: f32,
85    pub residual_norm: f32,
86    pub f_add_ex: f32,
87    pub f_rescale_ex: f32,
88}
89
90impl QuantizedVector {
91    /// Unpack binary code for computation (legacy method, prefer using cached binary_code_unpacked)
92    #[inline]
93    pub fn unpack_binary_code(&self) -> Vec<u8> {
94        // Return cached version if available
95        if !self.binary_code_unpacked.is_empty() {
96            return self.binary_code_unpacked.clone();
97        }
98        // Fallback to unpacking (for deserialized data without cache)
99        let mut binary_code = vec![0u8; self.dim];
100        simd::unpack_binary_code(&self.binary_code_packed, &mut binary_code, self.dim);
101        binary_code
102    }
103
104    /// Unpack extended code for computation (legacy method, prefer using cached ex_code_unpacked)
105    #[inline]
106    pub fn unpack_ex_code(&self) -> Vec<u16> {
107        // Return cached version if available
108        if !self.ex_code_unpacked.is_empty() {
109            return self.ex_code_unpacked.clone();
110        }
111        // Fallback to unpacking (for deserialized data without cache)
112        let mut ex_code = vec![0u16; self.dim];
113        simd::unpack_ex_code(&self.ex_code_packed, &mut ex_code, self.dim, self.ex_bits);
114        ex_code
115    }
116
117    /// Ensure unpacked caches are populated (call after deserialization)
118    pub fn ensure_unpacked_cache(&mut self) {
119        if self.binary_code_unpacked.is_empty() {
120            self.binary_code_unpacked = vec![0u8; self.dim];
121            simd::unpack_binary_code(
122                &self.binary_code_packed,
123                &mut self.binary_code_unpacked,
124                self.dim,
125            );
126        }
127        if self.ex_code_unpacked.is_empty() {
128            self.ex_code_unpacked = vec![0u16; self.dim];
129            simd::unpack_ex_code(
130                &self.ex_code_packed,
131                &mut self.ex_code_unpacked,
132                self.dim,
133                self.ex_bits,
134            );
135        }
136    }
137}
138
139/// Quantise a vector relative to a centroid with custom configuration.
140pub fn quantize_with_centroid(
141    data: &[f32],
142    centroid: &[f32],
143    config: &RabitqConfig,
144    metric: Metric,
145) -> QuantizedVector {
146    assert_eq!(data.len(), centroid.len());
147    assert!((1..=16).contains(&config.total_bits));
148    let dim = data.len();
149    let ex_bits = config.total_bits.saturating_sub(1);
150
151    let residual = subtract(data, centroid);
152    let mut binary_code = vec![0u8; dim];
153    for (idx, &value) in residual.iter().enumerate() {
154        if value >= 0.0 {
155            binary_code[idx] = 1u8;
156        }
157    }
158
159    let (ex_code, ipnorm_inv) = if ex_bits > 0 {
160        ex_bits_code_with_inv(&residual, ex_bits, config.t_const)
161    } else {
162        (vec![0u16; dim], 1.0f32)
163    };
164
165    let mut total_code = vec![0u16; dim];
166    for i in 0..dim {
167        total_code[i] = ex_code[i] + ((binary_code[i] as u16) << ex_bits);
168    }
169
170    let (f_add, f_rescale, f_error, residual_norm) =
171        compute_one_bit_factors(&residual, centroid, &binary_code, metric);
172    let cb = -((1 << ex_bits) as f32 - 0.5);
173    let quantized_shifted: Vec<f32> = total_code.iter().map(|&code| code as f32 + cb).collect();
174    let norm_quan_sqr = l2_norm_sqr(&quantized_shifted);
175    let dot_residual_quant = dot(&residual, &quantized_shifted);
176
177    let norm_residual_sqr = l2_norm_sqr(&residual);
178    let norm_residual = norm_residual_sqr.sqrt();
179    let norm_quant = norm_quan_sqr.sqrt();
180    let denom = (norm_residual * norm_quant).max(f32::EPSILON);
181    let cos_similarity = (dot_residual_quant / denom).clamp(-1.0, 1.0);
182    let delta = if norm_quant <= f32::EPSILON {
183        0.0
184    } else {
185        (norm_residual / norm_quant) * cos_similarity
186    };
187    let vl = delta * cb;
188
189    let mut f_add_ex = 0.0f32;
190    let mut f_rescale_ex = 0.0f32;
191    if ex_bits > 0 {
192        let factors = compute_extended_factors(
193            &residual,
194            centroid,
195            &binary_code,
196            &ex_code,
197            ipnorm_inv,
198            metric,
199            ex_bits,
200        );
201        f_add_ex = factors.0;
202        f_rescale_ex = factors.1;
203    }
204
205    // Pack binary code and ex code into bit-packed format
206    let binary_code_packed_size = (dim + 7) / 8;
207    let mut binary_code_packed = vec![0u8; binary_code_packed_size];
208    simd::pack_binary_code(&binary_code, &mut binary_code_packed, dim);
209
210    // Use C++-compatible packing format for ex_code (Phase 4 optimization)
211    // This enables direct SIMD operations without unpacking overhead
212    let ex_code_packed_size = match ex_bits {
213        0 => dim / 16 * 2,  // 1-bit total (binary only), but allocate for consistency
214        1 => dim / 16 * 2,  // 2-bit total (1-bit ex-code) - not commonly used
215        2 => dim / 16 * 4,  // 3-bit total (2-bit ex-code)
216        6 => dim / 16 * 12, // 7-bit total (6-bit ex-code)
217        _ => ((dim * ex_bits) + 7) / 8, // Fallback for other bit configs
218    };
219    let mut ex_code_packed = vec![0u8; ex_code_packed_size];
220
221    // Pack using C++-compatible format based on ex_bits
222    match ex_bits {
223        0 => {
224            // Binary-only: no ex-code to pack (all zeros)
225            // Keep packed array as zeros
226        }
227        1 => {
228            // 1-bit ex-code (2-bit total RaBitQ)
229            simd::pack_ex_code_1bit_cpp_compat(&ex_code, &mut ex_code_packed, dim);
230        }
231        2 => {
232            // 2-bit ex-code (3-bit total RaBitQ)
233            simd::pack_ex_code_2bit_cpp_compat(&ex_code, &mut ex_code_packed, dim);
234        }
235        6 => {
236            // 6-bit ex-code (7-bit total RaBitQ)
237            simd::pack_ex_code_6bit_cpp_compat(&ex_code, &mut ex_code_packed, dim);
238        }
239        _ => {
240            // Fallback to generic packing for unsupported bit configs
241            simd::pack_ex_code(&ex_code, &mut ex_code_packed, dim, ex_bits as u8);
242        }
243    }
244
245    QuantizedVector {
246        code: total_code,
247        binary_code_packed,
248        ex_code_packed,
249        binary_code_unpacked: binary_code, // Cache unpacked version for fast search
250        ex_code_unpacked: ex_code,         // Cache unpacked version for fast search
251        ex_bits: ex_bits as u8,
252        dim,
253        delta,
254        vl,
255        f_add,
256        f_rescale,
257        f_error,
258        residual_norm,
259        f_add_ex,
260        f_rescale_ex,
261    }
262}
263
264fn compute_one_bit_factors(
265    residual: &[f32],
266    centroid: &[f32],
267    binary_code: &[u8],
268    metric: Metric,
269) -> (f32, f32, f32, f32) {
270    let dim = residual.len();
271    let xu_cb: Vec<f32> = binary_code.iter().map(|&bit| bit as f32 - 0.5f32).collect();
272    let l2_sqr = l2_norm_sqr(residual);
273    let l2_norm = l2_sqr.sqrt();
274    let xu_cb_norm_sqr = l2_norm_sqr(&xu_cb);
275    let ip_resi_xucb = dot(residual, &xu_cb);
276    let ip_cent_xucb = dot(centroid, &xu_cb);
277    let dot_residual_centroid = dot(residual, centroid);
278
279    let mut denom = ip_resi_xucb;
280    if denom.abs() <= f32::EPSILON {
281        denom = f32::INFINITY;
282    }
283
284    let mut tmp_error = 0.0f32;
285    if dim > 1 {
286        let ratio = ((l2_sqr * xu_cb_norm_sqr) / (denom * denom)) - 1.0;
287        if ratio.is_finite() && ratio > 0.0 {
288            tmp_error = l2_norm * K_CONST_EPSILON * ((ratio / ((dim - 1) as f32)).max(0.0)).sqrt();
289        }
290    }
291
292    let (f_add, f_rescale, f_error) = match metric {
293        Metric::L2 => {
294            let f_add = l2_sqr + 2.0 * l2_sqr * ip_cent_xucb / denom;
295            let f_rescale = -2.0 * l2_sqr / denom;
296            let f_error = 2.0 * tmp_error;
297            (f_add, f_rescale, f_error)
298        }
299        Metric::InnerProduct => {
300            let f_add = 1.0 - dot_residual_centroid + l2_sqr * ip_cent_xucb / denom;
301            let f_rescale = -l2_sqr / denom;
302            let f_error = tmp_error;
303            (f_add, f_rescale, f_error)
304        }
305    };
306
307    (f_add, f_rescale, f_error, l2_norm)
308}
309
310fn ex_bits_code_with_inv(
311    residual: &[f32],
312    ex_bits: usize,
313    t_const: Option<f32>,
314) -> (Vec<u16>, f32) {
315    let dim = residual.len();
316    let mut normalized_abs: Vec<f32> = residual.iter().map(|x| x.abs()).collect();
317    let norm = normalized_abs.iter().map(|x| x * x).sum::<f32>().sqrt();
318
319    if norm <= f32::EPSILON {
320        return (vec![0u16; dim], 1.0);
321    }
322
323    for value in normalized_abs.iter_mut() {
324        *value /= norm;
325    }
326
327    // Use precomputed t_const if available, otherwise compute optimal t
328    let t = if let Some(t) = t_const {
329        t as f64
330    } else {
331        best_rescale_factor(&normalized_abs, ex_bits)
332    };
333
334    quantize_ex_with_inv(&normalized_abs, residual, ex_bits, t)
335}
336
337fn best_rescale_factor(o_abs: &[f32], ex_bits: usize) -> f64 {
338    let dim = o_abs.len();
339    let max_o = o_abs.iter().cloned().fold(0.0f32, f32::max) as f64;
340    if max_o <= f64::EPSILON {
341        return 1.0;
342    }
343
344    let table_idx = ex_bits.min(K_TIGHT_START.len() - 1);
345    let t_end = (((1 << ex_bits) - 1) as f64 + K_NENUM) / max_o;
346    let t_start = t_end * K_TIGHT_START[table_idx];
347
348    let mut cur_o_bar = vec![0i32; dim];
349    let mut sqr_denominator = dim as f64 * 0.25;
350    let mut numerator = 0.0f64;
351
352    for (idx, &val) in o_abs.iter().enumerate() {
353        let cur = ((t_start * val as f64) + K_EPS) as i32;
354        cur_o_bar[idx] = cur;
355        sqr_denominator += (cur * cur + cur) as f64;
356        numerator += (cur as f64 + 0.5) * val as f64;
357    }
358
359    #[derive(Copy, Clone, Debug)]
360    struct HeapEntry {
361        t: f64,
362        idx: usize,
363    }
364
365    impl PartialEq for HeapEntry {
366        fn eq(&self, other: &Self) -> bool {
367            self.t.to_bits() == other.t.to_bits() && self.idx == other.idx
368        }
369    }
370
371    impl Eq for HeapEntry {}
372
373    impl PartialOrd for HeapEntry {
374        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
375            Some(self.cmp(other))
376        }
377    }
378
379    impl Ord for HeapEntry {
380        fn cmp(&self, other: &Self) -> Ordering {
381            self.t
382                .total_cmp(&other.t)
383                .then_with(|| self.idx.cmp(&other.idx))
384        }
385    }
386
387    let mut heap: BinaryHeap<Reverse<HeapEntry>> = BinaryHeap::new();
388    for (idx, &val) in o_abs.iter().enumerate() {
389        if val > 0.0 {
390            let next_t = (cur_o_bar[idx] + 1) as f64 / val as f64;
391            heap.push(Reverse(HeapEntry { t: next_t, idx }));
392        }
393    }
394
395    let mut max_ip = 0.0f64;
396    let mut best_t = t_start;
397
398    while let Some(Reverse(HeapEntry { t: cur_t, idx })) = heap.pop() {
399        if cur_t >= t_end {
400            continue;
401        }
402
403        cur_o_bar[idx] += 1;
404        let update = cur_o_bar[idx];
405        sqr_denominator += 2.0 * update as f64;
406        numerator += o_abs[idx] as f64;
407
408        let cur_ip = numerator / sqr_denominator.sqrt();
409        if cur_ip > max_ip {
410            max_ip = cur_ip;
411            best_t = cur_t;
412        }
413
414        if update < (1 << ex_bits) - 1 && o_abs[idx] > 0.0 {
415            let t_next = (update + 1) as f64 / o_abs[idx] as f64;
416            if t_next < t_end {
417                heap.push(Reverse(HeapEntry { t: t_next, idx }));
418            }
419        }
420    }
421
422    if best_t <= 0.0 {
423        t_start.max(f64::EPSILON)
424    } else {
425        best_t
426    }
427}
428
429fn quantize_ex_with_inv(
430    o_abs: &[f32],
431    residual: &[f32],
432    ex_bits: usize,
433    t: f64,
434) -> (Vec<u16>, f32) {
435    let dim = o_abs.len();
436    if dim == 0 {
437        return (Vec::new(), 1.0);
438    }
439
440    let mut code = vec![0u16; dim];
441    let max_val = (1 << ex_bits) - 1;
442    let mut ipnorm = 0.0f64;
443
444    for i in 0..dim {
445        let mut cur = (t * o_abs[i] as f64 + K_EPS) as i32;
446        if cur > max_val {
447            cur = max_val;
448        }
449        code[i] = cur as u16;
450        ipnorm += (cur as f64 + 0.5) * o_abs[i] as f64;
451    }
452
453    let mut ipnorm_inv = if ipnorm.is_finite() && ipnorm > 0.0 {
454        (1.0 / ipnorm) as f32
455    } else {
456        1.0
457    };
458
459    let mask = max_val as u16;
460    if max_val > 0 {
461        for (idx, &res) in residual.iter().enumerate() {
462            if res < 0.0 {
463                code[idx] = (!code[idx]) & mask;
464            }
465        }
466    }
467
468    if !ipnorm_inv.is_finite() {
469        ipnorm_inv = 1.0;
470    }
471
472    (code, ipnorm_inv)
473}
474
475fn compute_extended_factors(
476    residual: &[f32],
477    centroid: &[f32],
478    binary_code: &[u8],
479    ex_code: &[u16],
480    ipnorm_inv: f32,
481    metric: Metric,
482    ex_bits: usize,
483) -> (f32, f32) {
484    let dim = residual.len();
485    let cb = -((1 << ex_bits) as f32 - 0.5);
486    let xu_cb: Vec<f32> = (0..dim)
487        .map(|i| {
488            let total = ex_code[i] as u32 + ((binary_code[i] as u32) << ex_bits);
489            total as f32 + cb
490        })
491        .collect();
492
493    let l2_sqr = l2_norm_sqr(residual);
494    let l2_norm = l2_sqr.sqrt();
495    let xu_cb_norm_sqr = l2_norm_sqr(&xu_cb);
496    let ip_resi_xucb = dot(residual, &xu_cb);
497    let ip_cent_xucb = dot(centroid, &xu_cb);
498    let dot_residual_centroid = dot(residual, centroid);
499
500    let mut denom = ip_resi_xucb * ip_resi_xucb;
501    if denom <= f32::EPSILON {
502        denom = f32::INFINITY;
503    }
504
505    let mut tmp_error = 0.0f32;
506    if dim > 1 {
507        let ratio = ((l2_sqr * xu_cb_norm_sqr) / denom) - 1.0;
508        if ratio > 0.0 {
509            tmp_error = l2_norm * K_CONST_EPSILON * ((ratio / ((dim - 1) as f32)).max(0.0)).sqrt();
510        }
511    }
512
513    let safe_denom = if ip_resi_xucb.abs() <= f32::EPSILON {
514        f32::INFINITY
515    } else {
516        ip_resi_xucb
517    };
518
519    let (f_add_ex, f_rescale_ex) = match metric {
520        Metric::L2 => {
521            let f_add = l2_sqr + 2.0 * l2_sqr * ip_cent_xucb / safe_denom;
522            let f_rescale = -2.0 * l2_norm * ipnorm_inv;
523            (f_add, f_rescale)
524        }
525        Metric::InnerProduct => {
526            let f_add = 1.0 - dot_residual_centroid + l2_sqr * ip_cent_xucb / safe_denom;
527            let f_rescale = -l2_norm * ipnorm_inv;
528            (f_add, f_rescale)
529        }
530    };
531
532    let _ = tmp_error; // retain structure parity; tmp_error may be used in future
533
534    (f_add_ex, f_rescale_ex)
535}
536
537/// Reconstruct a vector from its quantised representation and centroid.
538#[cfg_attr(not(test), allow(dead_code))]
539pub fn reconstruct_into(centroid: &[f32], quantized: &QuantizedVector, output: &mut [f32]) {
540    assert_eq!(centroid.len(), quantized.code.len());
541    assert_eq!(output.len(), centroid.len());
542    for i in 0..centroid.len() {
543        output[i] = centroid[i] + quantized.delta * quantized.code[i] as f32 + quantized.vl;
544    }
545}
546
547/// Compute a constant scaling factor for faster quantization.
548///
549/// This function samples random normalized vectors and computes the average optimal
550/// scaling factor. Using this constant factor for all vectors is 100-500x faster
551/// than computing the optimal factor per-vector, with <1% accuracy loss.
552///
553/// # Arguments
554/// * `dim` - Vector dimensionality
555/// * `ex_bits` - Number of extended bits for quantization
556/// * `seed` - Random seed for reproducibility
557///
558/// # Returns
559/// Average optimal scaling factor across 100 random samples
560pub fn compute_const_scaling_factor(dim: usize, ex_bits: usize, seed: u64) -> f32 {
561    use rand::prelude::*;
562    use rand_distr::{Distribution, Normal};
563
564    const NUM_SAMPLES: usize = 100;
565
566    let mut rng = StdRng::seed_from_u64(seed);
567    let normal = Normal::new(0.0, 1.0).expect("failed to create normal distribution");
568
569    let mut sum_t = 0.0f64;
570
571    for _ in 0..NUM_SAMPLES {
572        // Generate random Gaussian vector
573        let vec: Vec<f32> = (0..dim).map(|_| normal.sample(&mut rng) as f32).collect();
574
575        // Normalize and take absolute value
576        let norm = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
577        if norm <= f32::EPSILON {
578            continue;
579        }
580
581        let normalized_abs: Vec<f32> = vec.iter().map(|x| (x / norm).abs()).collect();
582
583        // Compute optimal scaling factor for this random vector
584        let t = best_rescale_factor(&normalized_abs, ex_bits);
585        sum_t += t;
586    }
587
588    (sum_t / NUM_SAMPLES as f64) as f32
589}