Skip to main content

lance_index/vector/
bq.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Binary Quantization (BQ)
5
6use std::iter::once;
7use std::str::FromStr;
8use std::sync::Arc;
9
10use crate::pb::vector_index_details::RabitQuantization;
11use arrow_array::types::Float32Type;
12use arrow_array::{Array, ArrayRef, UInt8Array, cast::AsArray};
13use lance_core::{Error, Result};
14use num_traits::Float;
15use serde::{Deserialize, Serialize};
16
17use crate::vector::quantizer::QuantizerBuildParams;
18
19pub mod builder;
20pub mod rotation;
21pub mod storage;
22pub mod transform;
23
24#[derive(Clone, Default)]
25pub struct BinaryQuantization {}
26
27impl BinaryQuantization {
28    /// Transform an array of float vectors to binary vectors.
29    pub fn transform(&self, data: &dyn Array) -> Result<ArrayRef> {
30        let fsl = data
31            .as_fixed_size_list_opt()
32            .ok_or(Error::index(format!(
33                "Expect to be a float vector array, got: {:?}",
34                data.data_type()
35            )))?
36            .clone();
37
38        let data = fsl
39            .values()
40            .as_primitive_opt::<Float32Type>()
41            .ok_or(Error::index(format!(
42                "Expect to be a float32 vector array, got: {:?}",
43                fsl.values().data_type()
44            )))?;
45        let dim = fsl.value_length() as usize;
46        let code = data
47            .values()
48            .chunks_exact(dim)
49            .flat_map(binary_quantization)
50            .collect::<Vec<_>>();
51
52        Ok(Arc::new(UInt8Array::from(code)))
53    }
54}
55
56/// Binary quantization.
57///
58/// Use the sign bit of the float vector to represent the binary vector.
59fn binary_quantization<T: Float>(data: &[T]) -> impl Iterator<Item = u8> + '_ {
60    let iter = data.chunks_exact(8);
61    iter.clone()
62        .map(|c| {
63            // Auto vectorized.
64            // Before changing this code, please check the assembly output.
65            let mut bits: u8 = 0;
66            c.iter().enumerate().for_each(|(idx, v)| {
67                bits |= (v.is_sign_positive() as u8) << idx;
68            });
69            bits
70        })
71        .chain(once(0).map(move |_| {
72            let mut bits: u8 = 0;
73            iter.remainder().iter().enumerate().for_each(|(idx, v)| {
74                bits |= (v.is_sign_positive() as u8) << idx;
75            });
76            bits
77        }))
78}
79
80#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
81#[serde(rename_all = "snake_case")]
82pub enum RQRotationType {
83    #[default]
84    Fast,
85    Matrix,
86}
87
88impl FromStr for RQRotationType {
89    type Err = Error;
90
91    fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
92        match value.to_lowercase().as_str() {
93            "fast" | "fht_kac" | "fht-kac" => Ok(Self::Fast),
94            "matrix" | "dense" => Ok(Self::Matrix),
95            _ => Err(Error::invalid_input(format!(
96                "Unknown RQ rotation type: {}. Expected one of: fast, matrix",
97                value
98            ))),
99        }
100    }
101}
102
103#[derive(Clone, Debug, PartialEq, Eq)]
104pub struct RQBuildParams {
105    pub num_bits: u8,
106    pub rotation_type: RQRotationType,
107}
108
109impl RQBuildParams {
110    pub fn new(num_bits: u8) -> Self {
111        Self {
112            num_bits,
113            rotation_type: RQRotationType::default(),
114        }
115    }
116
117    pub fn with_rotation_type(num_bits: u8, rotation_type: RQRotationType) -> Self {
118        Self {
119            num_bits,
120            rotation_type,
121        }
122    }
123}
124
125impl From<&RQBuildParams> for RabitQuantization {
126    fn from(value: &RQBuildParams) -> Self {
127        use crate::pb::vector_index_details::rabit_quantization::RotationType;
128        Self {
129            num_bits: value.num_bits as u32,
130            rotation_type: match value.rotation_type {
131                RQRotationType::Fast => RotationType::Fast as i32,
132                RQRotationType::Matrix => RotationType::Matrix as i32,
133            },
134        }
135    }
136}
137
138impl QuantizerBuildParams for RQBuildParams {
139    fn sample_size(&self) -> usize {
140        0
141    }
142}
143
144impl Default for RQBuildParams {
145    fn default() -> Self {
146        Self {
147            num_bits: 1,
148            rotation_type: RQRotationType::default(),
149        }
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    use half::{bf16, f16};
158
159    fn test_bq<T: Float>() {
160        let data: Vec<T> = [1.0, -1.0, 1.0, -5.0, -7.0, -1.0, 1.0, -1.0, -0.2, 1.2, 3.2]
161            .iter()
162            .map(|&v| T::from(v).unwrap())
163            .collect();
164        let expected = vec![0b01000101, 0b00000110];
165        let result = binary_quantization(&data).collect::<Vec<_>>();
166        assert_eq!(result, expected);
167    }
168
169    #[test]
170    fn test_binary_quantization() {
171        test_bq::<bf16>();
172        test_bq::<f16>();
173        test_bq::<f32>();
174        test_bq::<f64>();
175    }
176
177    #[test]
178    fn test_rotation_type_parse() {
179        assert_eq!(
180            "fast".parse::<RQRotationType>().unwrap(),
181            RQRotationType::Fast
182        );
183        assert_eq!(
184            "matrix".parse::<RQRotationType>().unwrap(),
185            RQRotationType::Matrix
186        );
187        assert!("invalid".parse::<RQRotationType>().is_err());
188    }
189}