lance_index/vector/
bq.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Binary Quantization (BQ)
5
6use std::iter::once;
7use std::sync::Arc;
8
9use arrow_array::types::Float32Type;
10use arrow_array::{cast::AsArray, Array, ArrayRef, UInt8Array};
11use lance_core::{Error, Result};
12use num_traits::Float;
13use snafu::location;
14
15use crate::vector::quantizer::QuantizerBuildParams;
16
17pub mod builder;
18pub mod storage;
19pub mod transform;
20
21#[derive(Clone, Default)]
22pub struct BinaryQuantization {}
23
24impl BinaryQuantization {
25    /// Transform an array of float vectors to binary vectors.
26    pub fn transform(&self, data: &dyn Array) -> Result<ArrayRef> {
27        let fsl = data
28            .as_fixed_size_list_opt()
29            .ok_or(Error::Index {
30                message: format!(
31                    "Expect to be a float vector array, got: {:?}",
32                    data.data_type()
33                ),
34                location: location!(),
35            })?
36            .clone();
37
38        let data = fsl
39            .values()
40            .as_primitive_opt::<Float32Type>()
41            .ok_or(Error::Index {
42                message: format!(
43                    "Expect to be a float32 vector array, got: {:?}",
44                    fsl.values().data_type()
45                ),
46                location: location!(),
47            })?;
48        let dim = fsl.value_length() as usize;
49        let code = data
50            .values()
51            .chunks_exact(dim)
52            .flat_map(binary_quantization)
53            .collect::<Vec<_>>();
54
55        Ok(Arc::new(UInt8Array::from(code)))
56    }
57}
58
59/// Binary quantization.
60///
61/// Use the sign bit of the float vector to represent the binary vector.
62fn binary_quantization<T: Float>(data: &[T]) -> impl Iterator<Item = u8> + '_ {
63    let iter = data.chunks_exact(8);
64    iter.clone()
65        .map(|c| {
66            // Auto vectorized.
67            // Before changing this code, please check the assembly output.
68            let mut bits: u8 = 0;
69            c.iter().enumerate().for_each(|(idx, v)| {
70                bits |= (v.is_sign_positive() as u8) << idx;
71            });
72            bits
73        })
74        .chain(once(0).map(move |_| {
75            let mut bits: u8 = 0;
76            iter.remainder().iter().enumerate().for_each(|(idx, v)| {
77                bits |= (v.is_sign_positive() as u8) << idx;
78            });
79            bits
80        }))
81}
82
83#[derive(Clone, Debug, PartialEq, Eq)]
84pub struct RQBuildParams {
85    pub num_bits: u8,
86}
87
88impl RQBuildParams {
89    pub fn new(num_bits: u8) -> Self {
90        Self { num_bits }
91    }
92}
93
94impl QuantizerBuildParams for RQBuildParams {
95    fn sample_size(&self) -> usize {
96        0
97    }
98}
99
100impl Default for RQBuildParams {
101    fn default() -> Self {
102        Self { num_bits: 1 }
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    use half::{bf16, f16};
111
112    fn test_bq<T: Float>() {
113        let data: Vec<T> = [1.0, -1.0, 1.0, -5.0, -7.0, -1.0, 1.0, -1.0, -0.2, 1.2, 3.2]
114            .iter()
115            .map(|&v| T::from(v).unwrap())
116            .collect();
117        let expected = vec![0b01000101, 0b00000110];
118        let result = binary_quantization(&data).collect::<Vec<_>>();
119        assert_eq!(result, expected);
120    }
121
122    #[test]
123    fn test_binary_quantization() {
124        test_bq::<bf16>();
125        test_bq::<f16>();
126        test_bq::<f32>();
127        test_bq::<f64>();
128    }
129}