Skip to main content

lance_index/vector/
bq.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Binary Quantization (BQ)
5
6use std::iter::once;
7use std::str::FromStr;
8use std::sync::Arc;
9
10use arrow_array::types::Float32Type;
11use arrow_array::{Array, ArrayRef, UInt8Array, cast::AsArray};
12use lance_core::{Error, Result};
13use num_traits::Float;
14use serde::{Deserialize, Serialize};
15
16use crate::vector::quantizer::QuantizerBuildParams;
17
18pub mod builder;
19pub mod rotation;
20pub mod storage;
21pub mod transform;
22
23#[derive(Clone, Default)]
24pub struct BinaryQuantization {}
25
26impl BinaryQuantization {
27    /// Transform an array of float vectors to binary vectors.
28    pub fn transform(&self, data: &dyn Array) -> Result<ArrayRef> {
29        let fsl = data
30            .as_fixed_size_list_opt()
31            .ok_or(Error::index(format!(
32                "Expect to be a float vector array, got: {:?}",
33                data.data_type()
34            )))?
35            .clone();
36
37        let data = fsl
38            .values()
39            .as_primitive_opt::<Float32Type>()
40            .ok_or(Error::index(format!(
41                "Expect to be a float32 vector array, got: {:?}",
42                fsl.values().data_type()
43            )))?;
44        let dim = fsl.value_length() as usize;
45        let code = data
46            .values()
47            .chunks_exact(dim)
48            .flat_map(binary_quantization)
49            .collect::<Vec<_>>();
50
51        Ok(Arc::new(UInt8Array::from(code)))
52    }
53}
54
55/// Binary quantization.
56///
57/// Use the sign bit of the float vector to represent the binary vector.
58fn binary_quantization<T: Float>(data: &[T]) -> impl Iterator<Item = u8> + '_ {
59    let iter = data.chunks_exact(8);
60    iter.clone()
61        .map(|c| {
62            // Auto vectorized.
63            // Before changing this code, please check the assembly output.
64            let mut bits: u8 = 0;
65            c.iter().enumerate().for_each(|(idx, v)| {
66                bits |= (v.is_sign_positive() as u8) << idx;
67            });
68            bits
69        })
70        .chain(once(0).map(move |_| {
71            let mut bits: u8 = 0;
72            iter.remainder().iter().enumerate().for_each(|(idx, v)| {
73                bits |= (v.is_sign_positive() as u8) << idx;
74            });
75            bits
76        }))
77}
78
79#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
80#[serde(rename_all = "snake_case")]
81pub enum RQRotationType {
82    #[default]
83    Fast,
84    Matrix,
85}
86
87impl FromStr for RQRotationType {
88    type Err = Error;
89
90    fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
91        match value.to_lowercase().as_str() {
92            "fast" | "fht_kac" | "fht-kac" => Ok(Self::Fast),
93            "matrix" | "dense" => Ok(Self::Matrix),
94            _ => Err(Error::invalid_input(format!(
95                "Unknown RQ rotation type: {}. Expected one of: fast, matrix",
96                value
97            ))),
98        }
99    }
100}
101
102#[derive(Clone, Debug, PartialEq, Eq)]
103pub struct RQBuildParams {
104    pub num_bits: u8,
105    pub rotation_type: RQRotationType,
106}
107
108impl RQBuildParams {
109    pub fn new(num_bits: u8) -> Self {
110        Self {
111            num_bits,
112            rotation_type: RQRotationType::default(),
113        }
114    }
115
116    pub fn with_rotation_type(num_bits: u8, rotation_type: RQRotationType) -> Self {
117        Self {
118            num_bits,
119            rotation_type,
120        }
121    }
122}
123
124impl QuantizerBuildParams for RQBuildParams {
125    fn sample_size(&self) -> usize {
126        0
127    }
128}
129
130impl Default for RQBuildParams {
131    fn default() -> Self {
132        Self {
133            num_bits: 1,
134            rotation_type: RQRotationType::default(),
135        }
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142
143    use half::{bf16, f16};
144
145    fn test_bq<T: Float>() {
146        let data: Vec<T> = [1.0, -1.0, 1.0, -5.0, -7.0, -1.0, 1.0, -1.0, -0.2, 1.2, 3.2]
147            .iter()
148            .map(|&v| T::from(v).unwrap())
149            .collect();
150        let expected = vec![0b01000101, 0b00000110];
151        let result = binary_quantization(&data).collect::<Vec<_>>();
152        assert_eq!(result, expected);
153    }
154
155    #[test]
156    fn test_binary_quantization() {
157        test_bq::<bf16>();
158        test_bq::<f16>();
159        test_bq::<f32>();
160        test_bq::<f64>();
161    }
162
163    #[test]
164    fn test_rotation_type_parse() {
165        assert_eq!(
166            "fast".parse::<RQRotationType>().unwrap(),
167            RQRotationType::Fast
168        );
169        assert_eq!(
170            "matrix".parse::<RQRotationType>().unwrap(),
171            RQRotationType::Matrix
172        );
173        assert!("invalid".parse::<RQRotationType>().is_err());
174    }
175}