lance_index/vector/bq/
builder.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::sync::Arc;
5
6use arrow::array::AsArray;
7use arrow::datatypes::{Float16Type, Float32Type, Float64Type};
8use arrow_array::{Array, ArrayRef, FixedSizeListArray, UInt8Array};
9use arrow_schema::{DataType, Field};
10use bitvec::prelude::{BitVec, Lsb0};
11use deepsize::DeepSizeOf;
12use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, FloatArray, FloatType};
13use lance_core::{Error, Result};
14use ndarray::{s, Axis};
15use num_traits::{AsPrimitive, FromPrimitive};
16use rand_distr::Distribution;
17use snafu::location;
18
19use crate::vector::bq::storage::{
20    RabitQuantizationMetadata, RabitQuantizationStorage, RABIT_CODE_COLUMN, RABIT_METADATA_KEY,
21};
22use crate::vector::bq::transform::{ADD_FACTORS_FIELD, SCALE_FACTORS_FIELD};
23use crate::vector::bq::RQBuildParams;
24use crate::vector::quantizer::{Quantization, Quantizer, QuantizerBuildParams};
25
26/// Build parameters for RabitQuantizer.
27///
28/// num_bits: the number of bits per dimension.
29pub struct RabitBuildParams {
30    pub num_bits: u8,
31}
32
33impl Default for RabitBuildParams {
34    fn default() -> Self {
35        Self { num_bits: 1 }
36    }
37}
38
39impl QuantizerBuildParams for RabitBuildParams {
40    fn sample_size(&self) -> usize {
41        // RabitQ doesn't need to sample any data
42        0
43    }
44}
45
46#[derive(Debug, Clone, DeepSizeOf)]
47pub struct RabitQuantizer {
48    metadata: RabitQuantizationMetadata,
49}
50
51impl RabitQuantizer {
52    pub fn new<T: ArrowFloatType>(num_bits: u8, dim: i32) -> Self {
53        // we don't need to calculate the inverse of P,
54        // just take the generated matrix as P^{-1}
55        let code_dim = dim * num_bits as i32;
56        let rotate_mat = random_orthogonal::<T>(code_dim as usize);
57        let (rotate_mat, _) = rotate_mat.into_raw_vec_and_offset();
58
59        let rotate_mat = match T::FLOAT_TYPE {
60            FloatType::Float16 | FloatType::Float32 | FloatType::Float64 => {
61                let rotate_mat = T::ArrayType::from(rotate_mat);
62                FixedSizeListArray::try_new_from_values(rotate_mat, code_dim).unwrap()
63            }
64            _ => unimplemented!("RabitQ does not support data type: {:?}", T::FLOAT_TYPE),
65        };
66
67        let metadata = RabitQuantizationMetadata {
68            rotate_mat: Some(rotate_mat),
69            rotate_mat_position: 0,
70            num_bits,
71            packed: false,
72        };
73        Self { metadata }
74    }
75
76    pub fn num_bits(&self) -> u8 {
77        self.metadata.num_bits
78    }
79
80    #[inline]
81    fn rotate_mat_flat<T: ArrowFloatType>(&self) -> &[T::Native] {
82        let rotate_mat = self.metadata.rotate_mat.as_ref().unwrap();
83        rotate_mat
84            .values()
85            .as_any()
86            .downcast_ref::<T::ArrayType>()
87            .unwrap()
88            .as_slice()
89    }
90
91    #[inline]
92    fn rotate_mat<T: ArrowFloatType>(&'_ self) -> ndarray::ArrayView2<'_, T::Native> {
93        let code_dim = self.code_dim();
94        ndarray::ArrayView2::from_shape((code_dim, code_dim), self.rotate_mat_flat::<T>()).unwrap()
95    }
96
97    pub fn dim(&self) -> usize {
98        self.code_dim() / self.metadata.num_bits as usize
99    }
100
101    // compute the dot product of v_q * v_r
102    pub fn codes_res_dot_dists<T: ArrowFloatType>(
103        &self,
104        residual_vectors: &FixedSizeListArray,
105    ) -> Result<Vec<f32>>
106    where
107        T::Native: AsPrimitive<f32>,
108    {
109        let dim = self.dim();
110        if residual_vectors.value_length() as usize != dim {
111            return Err(Error::invalid_input(
112                format!(
113                    "Vector dimension mismatch: {} != {}",
114                    residual_vectors.value_length(),
115                    dim
116                ),
117                location!(),
118            ));
119        }
120
121        // convert the vector to a dxN matrix
122        let vec_mat = ndarray::ArrayView2::from_shape(
123            (residual_vectors.len(), dim),
124            residual_vectors
125                .values()
126                .as_any()
127                .downcast_ref::<T::ArrayType>()
128                .unwrap()
129                .as_slice(),
130        )
131        .map_err(|e| Error::invalid_input(e.to_string(), location!()))?;
132        let vec_mat = vec_mat.t();
133
134        let rotate_mat = self.rotate_mat::<T>();
135        // slice to (code_dim, dim)
136        let rotate_mat = rotate_mat.slice(s![.., 0..dim]);
137        let rotated_vectors = rotate_mat.dot(&vec_mat);
138        let sqrt_dim = (dim as f32 * self.metadata.num_bits as f32).sqrt();
139        let norm_dists = rotated_vectors.mapv(|v| v.as_().abs()).sum_axis(Axis(0)) / sqrt_dim;
140        debug_assert_eq!(norm_dists.len(), residual_vectors.len());
141        Ok(norm_dists.to_vec())
142    }
143
144    fn transform<T: ArrowFloatType>(
145        &self,
146        residual_vectors: &FixedSizeListArray,
147    ) -> Result<ArrayRef>
148    where
149        T::Native: AsPrimitive<f32>,
150    {
151        // we don't need to normalize the residual vectors,
152        // because the sign of P^{-1} * v_r is the same as P^{-1} * v_r / ||v_r||
153        let n = residual_vectors.len();
154        let dim = self.dim();
155        debug_assert_eq!(residual_vectors.values().len(), n * dim);
156
157        let vectors = ndarray::ArrayView2::from_shape(
158            (n, dim),
159            residual_vectors
160                .values()
161                .as_any()
162                .downcast_ref::<T::ArrayType>()
163                .unwrap()
164                .as_slice(),
165        )
166        .map_err(|e| Error::invalid_input(e.to_string(), location!()))?;
167        let vectors = vectors.t();
168        let rotate_mat = self.rotate_mat::<T>();
169        let rotate_mat = rotate_mat.slice(s![.., 0..dim]);
170        let rotated_vectors = rotate_mat.dot(&vectors);
171
172        let quantized_vectors = rotated_vectors.t().mapv(|v| v.as_().is_sign_positive());
173        let bv: BitVec<u8, Lsb0> = BitVec::from_iter(quantized_vectors);
174
175        let codes = UInt8Array::from(bv.into_vec());
176        debug_assert_eq!(codes.len(), n * self.code_dim() / u8::BITS as usize);
177        Ok(Arc::new(FixedSizeListArray::try_new_from_values(
178            codes,
179            self.code_dim() as i32 / u8::BITS as i32, // num_bits -> num_bytes
180        )?))
181    }
182}
183
184impl Quantization for RabitQuantizer {
185    type BuildParams = RQBuildParams;
186    type Metadata = RabitQuantizationMetadata;
187    type Storage = RabitQuantizationStorage;
188
189    fn build(
190        data: &dyn Array,
191        _: lance_linalg::distance::DistanceType,
192        params: &Self::BuildParams,
193    ) -> Result<Self> {
194        let q = match data.as_fixed_size_list().value_type() {
195            DataType::Float16 => {
196                Self::new::<Float16Type>(params.num_bits, data.as_fixed_size_list().value_length())
197            }
198            DataType::Float32 => {
199                Self::new::<Float32Type>(params.num_bits, data.as_fixed_size_list().value_length())
200            }
201            DataType::Float64 => {
202                Self::new::<Float64Type>(params.num_bits, data.as_fixed_size_list().value_length())
203            }
204            dt => {
205                return Err(Error::invalid_input(
206                    format!("Unsupported data type: {:?}", dt),
207                    location!(),
208                ))
209            }
210        };
211        Ok(q)
212    }
213
214    fn retrain(&mut self, _data: &dyn Array) -> Result<()> {
215        Ok(())
216    }
217
218    fn code_dim(&self) -> usize {
219        self.metadata
220            .rotate_mat
221            .as_ref()
222            .map(|inv_p| inv_p.len())
223            .unwrap_or(0)
224    }
225
226    fn column(&self) -> &'static str {
227        RABIT_CODE_COLUMN
228    }
229
230    fn use_residual(_: lance_linalg::distance::DistanceType) -> bool {
231        true
232    }
233
234    fn quantize(&self, vectors: &dyn Array) -> Result<arrow_array::ArrayRef> {
235        let vectors = vectors.as_fixed_size_list();
236        match vectors.value_type() {
237            DataType::Float16 => self.transform::<Float16Type>(vectors),
238            DataType::Float32 => self.transform::<Float32Type>(vectors),
239            DataType::Float64 => self.transform::<Float64Type>(vectors),
240            value_type => Err(Error::invalid_input(
241                format!("Unsupported data type: {:?}", value_type),
242                location!(),
243            )),
244        }
245    }
246
247    fn metadata_key() -> &'static str {
248        RABIT_METADATA_KEY
249    }
250
251    fn quantization_type() -> crate::vector::quantizer::QuantizationType {
252        crate::vector::quantizer::QuantizationType::Rabit
253    }
254
255    fn metadata(
256        &self,
257        args: Option<crate::vector::quantizer::QuantizationMetadata>,
258    ) -> Self::Metadata {
259        let mut metadata = self.metadata.clone();
260        metadata.packed = args.map(|args| args.transposed).unwrap_or_default();
261        metadata
262    }
263
264    fn from_metadata(
265        metadata: &Self::Metadata,
266        _: lance_linalg::distance::DistanceType,
267    ) -> Result<Quantizer> {
268        Ok(Quantizer::Rabit(Self {
269            metadata: metadata.clone(),
270        }))
271    }
272
273    fn field(&self) -> Field {
274        Field::new(
275            RABIT_CODE_COLUMN,
276            DataType::FixedSizeList(
277                Arc::new(Field::new("item", DataType::UInt8, true)),
278                self.code_dim() as i32 / u8::BITS as i32, // num_bits -> num_bytes
279            ),
280            true,
281        )
282    }
283
284    fn extra_fields(&self) -> Vec<Field> {
285        vec![ADD_FACTORS_FIELD.clone(), SCALE_FACTORS_FIELD.clone()]
286    }
287}
288
289impl TryFrom<Quantizer> for RabitQuantizer {
290    type Error = Error;
291
292    fn try_from(quantizer: Quantizer) -> Result<Self> {
293        match quantizer {
294            Quantizer::Rabit(quantizer) => Ok(quantizer),
295            _ => Err(Error::invalid_input(
296                "Cannot convert non-RabitQuantizer to RabitQuantizer",
297                location!(),
298            )),
299        }
300    }
301}
302
303impl From<RabitQuantizer> for Quantizer {
304    fn from(quantizer: RabitQuantizer) -> Self {
305        Self::Rabit(quantizer)
306    }
307}
308
309fn random_normal_matrix(n: usize) -> ndarray::Array2<f64> {
310    let mut rng = rand::rng();
311    let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
312    ndarray::Array2::from_shape_simple_fn((n, n), || normal.sample(&mut rng))
313}
314
315// implement the householder qr decomposition referenced from https://en.wikipedia.org/wiki/Householder_transformation#QR_decomposition
316fn householder_qr(a: ndarray::Array2<f64>) -> (ndarray::Array2<f64>, ndarray::Array2<f64>) {
317    let (m, n) = a.dim();
318    let mut q = ndarray::Array2::eye(m);
319    let mut r = a;
320
321    for k in 0..n.min(m - 1) {
322        let mut x = r.slice(s![k.., k]).to_owned();
323        let x_norm = x.dot(&x).sqrt();
324
325        if x_norm < f64::EPSILON {
326            continue;
327        }
328
329        // Create Householder vector
330        let sign = if x[0] >= 0.0 { 1.0 } else { -1.0 };
331        x[0] += sign * x_norm;
332        let u = &x / x.dot(&x).sqrt();
333
334        // Apply Householder transformation to R
335        // Compute outer product manually
336        let mut u_outer = ndarray::Array2::zeros((m - k, m - k));
337        for i in 0..(m - k) {
338            for j in 0..(m - k) {
339                u_outer[[i, j]] = u[i] * u[j];
340            }
341        }
342        let h = ndarray::Array2::eye(m - k) - 2.0 * u_outer;
343
344        // Apply transformation to R
345        let r_block = r.slice(s![k.., k..]).to_owned();
346        let h_r = h.dot(&r_block);
347        r.slice_mut(s![k.., k..]).assign(&h_r);
348
349        // Apply transformation to Q
350        let q_block = q.slice(s![.., k..]).to_owned();
351        let q_h = q_block.dot(&h);
352        q.slice_mut(s![.., k..]).assign(&q_h);
353    }
354
355    (q, r)
356}
357
358fn random_orthogonal<T: ArrowFloatType>(n: usize) -> ndarray::Array2<T::Native>
359where
360    T::Native: FromPrimitive,
361{
362    let a = random_normal_matrix(n);
363    let (q, _) = householder_qr(a);
364
365    // cast f64 matrix to T::Native matrix
366    q.mapv(|v| T::Native::from_f64(v).unwrap())
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372    use approx::assert_relative_eq;
373    use rstest::rstest;
374
375    #[rstest]
376    #[case(8)]
377    #[case(16)]
378    #[case(32)]
379    fn test_householder_qr(#[case] n: usize) {
380        let a = random_normal_matrix(n);
381        let (m, n) = a.dim();
382
383        let (q, r) = householder_qr(a.clone());
384
385        // Check Q is orthogonal: Q^T * Q should be identity
386        let q_t_q = q.t().dot(&q);
387        for i in 0..m {
388            for j in 0..m {
389                let expected = if i == j { 1.0 } else { 0.0 };
390                assert_relative_eq!(q_t_q[[i, j]], expected, epsilon = 1e-5);
391            }
392        }
393
394        // Check QR decomposition: Q * R should equal original matrix
395        let qr = q.dot(&r);
396        for i in 0..m {
397            for j in 0..n {
398                assert_relative_eq!(qr[[i, j]], a[[i, j]], epsilon = 1e-5);
399            }
400        }
401
402        // Check R is upper triangular
403        for i in 1..n.min(m) {
404            for j in 0..i {
405                assert_relative_eq!(r[[i, j]], 0.0, epsilon = 1e-5);
406            }
407        }
408
409        // Additional check: Q should have shape (m, m) and R should have shape (m, n)
410        assert_eq!(q.dim(), (m, m));
411        assert_eq!(r.dim(), (m, n));
412    }
413}