Skip to main content

lance_index/vector/bq/
builder.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::sync::Arc;
5
6use arrow::array::AsArray;
7use arrow::datatypes::{Float16Type, Float32Type, Float64Type};
8use arrow_array::{Array, ArrayRef, FixedSizeListArray, UInt8Array};
9use arrow_schema::{DataType, Field};
10use bitvec::prelude::{BitVec, Lsb0};
11use deepsize::DeepSizeOf;
12use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, FloatArray, FloatType};
13use lance_core::{Error, Result};
14use ndarray::{Axis, ShapeBuilder, s};
15use num_traits::{AsPrimitive, FromPrimitive};
16use rand_distr::Distribution;
17use rayon::prelude::*;
18
19use crate::vector::bq::storage::{
20    RABIT_CODE_COLUMN, RABIT_METADATA_KEY, RabitQuantizationMetadata, RabitQuantizationStorage,
21};
22use crate::vector::bq::transform::{ADD_FACTORS_FIELD, SCALE_FACTORS_FIELD};
23use crate::vector::bq::{
24    RQBuildParams, RQRotationType,
25    rotation::{apply_fast_rotation, random_fast_rotation_signs},
26};
27use crate::vector::quantizer::{Quantization, Quantizer, QuantizerBuildParams};
28
29/// Build parameters for RabitQuantizer.
30///
31/// num_bits: the number of bits per dimension.
32pub struct RabitBuildParams {
33    pub num_bits: u8,
34    pub rotation_type: RQRotationType,
35}
36
37impl Default for RabitBuildParams {
38    fn default() -> Self {
39        Self {
40            num_bits: 1,
41            rotation_type: RQRotationType::default(),
42        }
43    }
44}
45
46impl QuantizerBuildParams for RabitBuildParams {
47    fn sample_size(&self) -> usize {
48        // RabitQ doesn't need to sample any data
49        0
50    }
51}
52
53#[derive(Debug, Clone, DeepSizeOf)]
54pub struct RabitQuantizer {
55    metadata: RabitQuantizationMetadata,
56}
57
58#[inline]
59fn pack_sign_bits(codes: &mut [u8], rotated: &[f32]) {
60    codes.fill(0);
61    for (bit_idx, value) in rotated.iter().enumerate() {
62        if value.is_sign_positive() {
63            codes[bit_idx / u8::BITS as usize] |= 1u8 << (bit_idx % u8::BITS as usize);
64        }
65    }
66}
67
68impl RabitQuantizer {
69    pub fn new<T: ArrowFloatType>(num_bits: u8, dim: i32) -> Self {
70        Self::new_with_rotation::<T>(num_bits, dim, RQRotationType::default())
71    }
72
73    pub fn new_with_rotation<T: ArrowFloatType>(
74        num_bits: u8,
75        dim: i32,
76        rotation_type: RQRotationType,
77    ) -> Self {
78        let code_dim = (dim * num_bits as i32) as usize;
79        let metadata = match rotation_type {
80            RQRotationType::Matrix => {
81                // we don't need to calculate the inverse of P, just take generated Q as P^{-1}
82                let rotate_mat = random_orthogonal::<T>(code_dim);
83                let (rotate_mat, _) = rotate_mat.into_raw_vec_and_offset();
84                let rotate_mat = match T::FLOAT_TYPE {
85                    FloatType::Float16 | FloatType::Float32 | FloatType::Float64 => {
86                        let rotate_mat = <T::ArrayType as FloatArray<T>>::from_values(rotate_mat);
87                        FixedSizeListArray::try_new_from_values(rotate_mat, code_dim as i32)
88                            .unwrap()
89                    }
90                    _ => unimplemented!("RabitQ does not support data type: {:?}", T::FLOAT_TYPE),
91                };
92                RabitQuantizationMetadata {
93                    rotate_mat: Some(rotate_mat),
94                    rotate_mat_position: None,
95                    fast_rotation_signs: None,
96                    rotation_type,
97                    code_dim: code_dim as u32,
98                    num_bits,
99                    packed: false,
100                }
101            }
102            RQRotationType::Fast => RabitQuantizationMetadata {
103                rotate_mat: None,
104                rotate_mat_position: None,
105                fast_rotation_signs: Some(random_fast_rotation_signs(code_dim)),
106                rotation_type,
107                code_dim: code_dim as u32,
108                num_bits,
109                packed: false,
110            },
111        };
112        Self { metadata }
113    }
114
115    pub fn num_bits(&self) -> u8 {
116        self.metadata.num_bits
117    }
118
119    pub fn rotation_type(&self) -> RQRotationType {
120        self.metadata.rotation_type
121    }
122
123    #[inline]
124    fn fast_rotation_signs(&self) -> &[u8] {
125        self.metadata
126            .fast_rotation_signs
127            .as_ref()
128            .expect("RabitQ fast rotation signs missing")
129            .as_slice()
130    }
131
132    #[inline]
133    fn rotate_mat_flat<T: ArrowFloatType>(&self) -> &[T::Native] {
134        let rotate_mat = self.metadata.rotate_mat.as_ref().unwrap();
135        rotate_mat
136            .values()
137            .as_any()
138            .downcast_ref::<T::ArrayType>()
139            .unwrap()
140            .as_slice()
141    }
142
143    #[inline]
144    fn rotate_mat<T: ArrowFloatType>(&'_ self) -> ndarray::ArrayView2<'_, T::Native> {
145        let code_dim = self.code_dim();
146        ndarray::ArrayView2::from_shape((code_dim, code_dim), self.rotate_mat_flat::<T>()).unwrap()
147    }
148
149    fn rotate_vectors<T: ArrowFloatType>(
150        &self,
151        vectors: ndarray::ArrayView2<'_, T::Native>,
152    ) -> ndarray::Array2<f32>
153    where
154        T::Native: AsPrimitive<f32>,
155    {
156        let dim = vectors.nrows();
157        let code_dim = self.code_dim();
158        match self.rotation_type() {
159            RQRotationType::Matrix => {
160                let rotate_mat = self.rotate_mat::<T>();
161                let rotate_mat = rotate_mat.slice(s![.., 0..dim]);
162                rotate_mat.dot(&vectors).mapv(|v| v.as_())
163            }
164            RQRotationType::Fast => {
165                let signs = self.fast_rotation_signs();
166                let ncols = vectors.ncols();
167                let mut rotated_data = vec![0.0f32; code_dim * ncols];
168                rotated_data
169                    .par_chunks_mut(code_dim)
170                    .enumerate()
171                    .for_each_init(
172                        || vec![0.0f32; code_dim],
173                        |scratch, (col_idx, dst)| {
174                            let column = vectors.column(col_idx);
175                            let input = column
176                                .as_slice()
177                                .expect("RabitQ input vectors should be contiguous");
178                            apply_fast_rotation(input, scratch, signs);
179                            dst.copy_from_slice(scratch);
180                        },
181                    );
182
183                ndarray::Array2::from_shape_vec((code_dim, ncols).f(), rotated_data).unwrap()
184            }
185        }
186    }
187
188    pub fn dim(&self) -> usize {
189        self.code_dim() / self.metadata.num_bits as usize
190    }
191
192    // compute the dot product of v_q * v_r
193    pub fn codes_res_dot_dists<T: ArrowFloatType>(
194        &self,
195        residual_vectors: &FixedSizeListArray,
196    ) -> Result<Vec<f32>>
197    where
198        T::Native: AsPrimitive<f32> + Sync,
199    {
200        let dim = self.dim();
201        if residual_vectors.value_length() as usize != dim {
202            return Err(Error::invalid_input(format!(
203                "Vector dimension mismatch: {} != {}",
204                residual_vectors.value_length(),
205                dim
206            )));
207        }
208
209        let sqrt_dim = (dim as f32 * self.metadata.num_bits as f32).sqrt();
210        let values = residual_vectors
211            .values()
212            .as_any()
213            .downcast_ref::<T::ArrayType>()
214            .unwrap()
215            .as_slice();
216
217        match self.rotation_type() {
218            RQRotationType::Matrix => {
219                // convert the vector to a dxN matrix
220                let vec_mat =
221                    ndarray::ArrayView2::from_shape((residual_vectors.len(), dim), values)
222                        .map_err(|e| Error::invalid_input(e.to_string()))?;
223                let vec_mat = vec_mat.t();
224                let rotated_vectors = self.rotate_vectors::<T>(vec_mat);
225                let norm_dists = rotated_vectors.mapv(f32::abs).sum_axis(Axis(0)) / sqrt_dim;
226                debug_assert_eq!(norm_dists.len(), residual_vectors.len());
227                Ok(norm_dists.to_vec())
228            }
229            RQRotationType::Fast => {
230                let code_dim = self.code_dim();
231                let signs = self.fast_rotation_signs();
232                let mut norm_dists = vec![0.0f32; residual_vectors.len()];
233                norm_dists
234                    .par_iter_mut()
235                    .zip(values.par_chunks_exact(dim))
236                    .for_each_init(
237                        || vec![0.0f32; code_dim],
238                        |scratch, (dst, input)| {
239                            apply_fast_rotation(input, scratch, signs);
240                            *dst = scratch.iter().map(|v| v.abs()).sum::<f32>() / sqrt_dim;
241                        },
242                    );
243                Ok(norm_dists)
244            }
245        }
246    }
247
248    fn transform<T: ArrowFloatType>(
249        &self,
250        residual_vectors: &FixedSizeListArray,
251    ) -> Result<ArrayRef>
252    where
253        T::Native: AsPrimitive<f32> + Sync,
254    {
255        // we don't need to normalize the residual vectors,
256        // because the sign of P^{-1} * v_r is the same as P^{-1} * v_r / ||v_r||
257        let n = residual_vectors.len();
258        let dim = self.dim();
259        debug_assert_eq!(residual_vectors.values().len(), n * dim);
260        let values = residual_vectors
261            .values()
262            .as_any()
263            .downcast_ref::<T::ArrayType>()
264            .unwrap()
265            .as_slice();
266        let code_dim = self.code_dim();
267        let code_bytes = code_dim / u8::BITS as usize;
268
269        match self.rotation_type() {
270            RQRotationType::Matrix => {
271                let vectors = ndarray::ArrayView2::from_shape((n, dim), values)
272                    .map_err(|e| Error::invalid_input(e.to_string()))?;
273                let vectors = vectors.t();
274                let rotated_vectors = self.rotate_vectors::<T>(vectors);
275
276                let quantized_vectors = rotated_vectors.t().mapv(|v| v.is_sign_positive());
277                let bv: BitVec<u8, Lsb0> = BitVec::from_iter(quantized_vectors);
278
279                let codes = UInt8Array::from(bv.into_vec());
280                debug_assert_eq!(codes.len(), n * code_bytes);
281                Ok(Arc::new(FixedSizeListArray::try_new_from_values(
282                    codes,
283                    code_bytes as i32, // num_bits -> num_bytes
284                )?))
285            }
286            RQRotationType::Fast => {
287                let signs = self.fast_rotation_signs();
288                let mut encoded_codes = vec![0u8; n * code_bytes];
289                encoded_codes
290                    .par_chunks_mut(code_bytes)
291                    .zip(values.par_chunks_exact(dim))
292                    .for_each_init(
293                        || vec![0.0f32; code_dim],
294                        |scratch, (code_dst, input)| {
295                            apply_fast_rotation(input, scratch, signs);
296                            pack_sign_bits(code_dst, scratch);
297                        },
298                    );
299                let codes = UInt8Array::from(encoded_codes);
300                debug_assert_eq!(codes.len(), n * code_bytes);
301                Ok(Arc::new(FixedSizeListArray::try_new_from_values(
302                    codes,
303                    code_bytes as i32,
304                )?))
305            }
306        }
307    }
308}
309
310impl Quantization for RabitQuantizer {
311    type BuildParams = RQBuildParams;
312    type Metadata = RabitQuantizationMetadata;
313    type Storage = RabitQuantizationStorage;
314
315    fn build(
316        data: &dyn Array,
317        _: lance_linalg::distance::DistanceType,
318        params: &Self::BuildParams,
319    ) -> Result<Self> {
320        let dim = data.as_fixed_size_list().value_length() as usize;
321        if !dim.is_multiple_of(u8::BITS as usize) {
322            return Err(Error::invalid_input(
323                "vector dimension must be divisible by 8 for IVF_RQ",
324            ));
325        }
326
327        let q = match data.as_fixed_size_list().value_type() {
328            DataType::Float16 => Self::new_with_rotation::<Float16Type>(
329                params.num_bits,
330                data.as_fixed_size_list().value_length(),
331                params.rotation_type,
332            ),
333            DataType::Float32 => Self::new_with_rotation::<Float32Type>(
334                params.num_bits,
335                data.as_fixed_size_list().value_length(),
336                params.rotation_type,
337            ),
338            DataType::Float64 => Self::new_with_rotation::<Float64Type>(
339                params.num_bits,
340                data.as_fixed_size_list().value_length(),
341                params.rotation_type,
342            ),
343            dt => {
344                return Err(Error::invalid_input(format!(
345                    "Unsupported data type: {:?}",
346                    dt
347                )));
348            }
349        };
350        Ok(q)
351    }
352
353    fn retrain(&mut self, _data: &dyn Array) -> Result<()> {
354        Ok(())
355    }
356
357    fn code_dim(&self) -> usize {
358        if self.metadata.code_dim > 0 {
359            self.metadata.code_dim as usize
360        } else {
361            self.metadata
362                .rotate_mat
363                .as_ref()
364                .map(|rotate_mat| rotate_mat.len())
365                .unwrap_or(0)
366        }
367    }
368
369    fn column(&self) -> &'static str {
370        RABIT_CODE_COLUMN
371    }
372
373    fn use_residual(_: lance_linalg::distance::DistanceType) -> bool {
374        true
375    }
376
377    fn quantize(&self, vectors: &dyn Array) -> Result<arrow_array::ArrayRef> {
378        let vectors = vectors.as_fixed_size_list();
379        match vectors.value_type() {
380            DataType::Float16 => self.transform::<Float16Type>(vectors),
381            DataType::Float32 => self.transform::<Float32Type>(vectors),
382            DataType::Float64 => self.transform::<Float64Type>(vectors),
383            value_type => Err(Error::invalid_input(format!(
384                "Unsupported data type: {:?}",
385                value_type
386            ))),
387        }
388    }
389
390    fn metadata_key() -> &'static str {
391        RABIT_METADATA_KEY
392    }
393
394    fn quantization_type() -> crate::vector::quantizer::QuantizationType {
395        crate::vector::quantizer::QuantizationType::Rabit
396    }
397
398    fn metadata(
399        &self,
400        args: Option<crate::vector::quantizer::QuantizationMetadata>,
401    ) -> Self::Metadata {
402        let mut metadata = self.metadata.clone();
403        metadata.packed = args.map(|args| args.transposed).unwrap_or_default();
404        metadata
405    }
406
407    fn from_metadata(
408        metadata: &Self::Metadata,
409        _: lance_linalg::distance::DistanceType,
410    ) -> Result<Quantizer> {
411        Ok(Quantizer::Rabit(Self {
412            metadata: metadata.clone(),
413        }))
414    }
415
416    fn field(&self) -> Field {
417        Field::new(
418            RABIT_CODE_COLUMN,
419            DataType::FixedSizeList(
420                Arc::new(Field::new("item", DataType::UInt8, true)),
421                self.code_dim() as i32 / u8::BITS as i32, // num_bits -> num_bytes
422            ),
423            true,
424        )
425    }
426
427    fn extra_fields(&self) -> Vec<Field> {
428        vec![ADD_FACTORS_FIELD.clone(), SCALE_FACTORS_FIELD.clone()]
429    }
430}
431
432impl TryFrom<Quantizer> for RabitQuantizer {
433    type Error = Error;
434
435    fn try_from(quantizer: Quantizer) -> Result<Self> {
436        match quantizer {
437            Quantizer::Rabit(quantizer) => Ok(quantizer),
438            _ => Err(Error::invalid_input(
439                "Cannot convert non-RabitQuantizer to RabitQuantizer",
440            )),
441        }
442    }
443}
444
445impl From<RabitQuantizer> for Quantizer {
446    fn from(quantizer: RabitQuantizer) -> Self {
447        Self::Rabit(quantizer)
448    }
449}
450
451fn random_normal_matrix(n: usize) -> ndarray::Array2<f64> {
452    let mut rng = rand::rng();
453    let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
454    ndarray::Array2::from_shape_simple_fn((n, n), || normal.sample(&mut rng))
455}
456
457// implement the householder qr decomposition referenced from https://en.wikipedia.org/wiki/Householder_transformation#QR_decomposition
458fn householder_qr(a: ndarray::Array2<f64>) -> (ndarray::Array2<f64>, ndarray::Array2<f64>) {
459    let (m, n) = a.dim();
460    let mut q = ndarray::Array2::eye(m);
461    let mut r = a;
462
463    for k in 0..n.min(m - 1) {
464        let mut x = r.slice(s![k.., k]).to_owned();
465        let x_norm = x.dot(&x).sqrt();
466
467        if x_norm < f64::EPSILON {
468            continue;
469        }
470
471        // Create Householder vector
472        let sign = if x[0] >= 0.0 { 1.0 } else { -1.0 };
473        x[0] += sign * x_norm;
474        let u = &x / x.dot(&x).sqrt();
475
476        // Apply Householder transformation to R
477        // Compute outer product manually
478        let mut u_outer = ndarray::Array2::zeros((m - k, m - k));
479        for i in 0..(m - k) {
480            for j in 0..(m - k) {
481                u_outer[[i, j]] = u[i] * u[j];
482            }
483        }
484        let h = ndarray::Array2::eye(m - k) - 2.0 * u_outer;
485
486        // Apply transformation to R
487        let r_block = r.slice(s![k.., k..]).to_owned();
488        let h_r = h.dot(&r_block);
489        r.slice_mut(s![k.., k..]).assign(&h_r);
490
491        // Apply transformation to Q
492        let q_block = q.slice(s![.., k..]).to_owned();
493        let q_h = q_block.dot(&h);
494        q.slice_mut(s![.., k..]).assign(&q_h);
495    }
496
497    (q, r)
498}
499
500fn random_orthogonal<T: ArrowFloatType>(n: usize) -> ndarray::Array2<T::Native>
501where
502    T::Native: FromPrimitive,
503{
504    let a = random_normal_matrix(n);
505    let (q, _) = householder_qr(a);
506
507    // cast f64 matrix to T::Native matrix
508    q.mapv(|v| T::Native::from_f64(v).unwrap())
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514    use approx::assert_relative_eq;
515    use arrow::datatypes::Float32Type;
516    use arrow_array::{FixedSizeListArray, Float32Array};
517    use lance_linalg::distance::DistanceType;
518    use rstest::rstest;
519
520    #[rstest]
521    #[case(8)]
522    #[case(16)]
523    #[case(32)]
524    fn test_householder_qr(#[case] n: usize) {
525        let a = random_normal_matrix(n);
526        let (m, n) = a.dim();
527
528        let (q, r) = householder_qr(a.clone());
529
530        // Check Q is orthogonal: Q^T * Q should be identity
531        let q_t_q = q.t().dot(&q);
532        for i in 0..m {
533            for j in 0..m {
534                let expected = if i == j { 1.0 } else { 0.0 };
535                assert_relative_eq!(q_t_q[[i, j]], expected, epsilon = 1e-5);
536            }
537        }
538
539        // Check QR decomposition: Q * R should equal original matrix
540        let qr = q.dot(&r);
541        for i in 0..m {
542            for j in 0..n {
543                assert_relative_eq!(qr[[i, j]], a[[i, j]], epsilon = 1e-5);
544            }
545        }
546
547        // Check R is upper triangular
548        for i in 1..n.min(m) {
549            for j in 0..i {
550                assert_relative_eq!(r[[i, j]], 0.0, epsilon = 1e-5);
551            }
552        }
553
554        // Additional check: Q should have shape (m, m) and R should have shape (m, n)
555        assert_eq!(q.dim(), (m, m));
556        assert_eq!(r.dim(), (m, n));
557    }
558
559    #[test]
560    fn test_rabit_quantizer_rotation_modes() {
561        let fast_q = RabitQuantizer::new_with_rotation::<Float32Type>(1, 128, RQRotationType::Fast);
562        assert_eq!(fast_q.rotation_type(), RQRotationType::Fast);
563        assert_eq!(fast_q.dim(), 128);
564
565        let matrix_q =
566            RabitQuantizer::new_with_rotation::<Float32Type>(1, 128, RQRotationType::Matrix);
567        assert_eq!(matrix_q.rotation_type(), RQRotationType::Matrix);
568        assert_eq!(matrix_q.dim(), 128);
569    }
570
571    #[test]
572    fn test_rabit_quantizer_requires_dim_divisible_by_8() {
573        let vectors = Float32Array::from(vec![0.0f32; 4 * 30]);
574        let fsl = FixedSizeListArray::try_new_from_values(vectors, 30).unwrap();
575        let params = RQBuildParams::new(1);
576
577        let err = RabitQuantizer::build(&fsl, DistanceType::L2, &params).unwrap_err();
578        assert!(
579            err.to_string()
580                .contains("vector dimension must be divisible by 8 for IVF_RQ"),
581            "{}",
582            err
583        );
584    }
585}