Skip to main content

lance_index/vector/
sq.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{ops::Range, sync::Arc};
5
6use arrow::array::AsArray;
7use arrow::datatypes::{Float16Type, Float32Type, Float64Type};
8use arrow_array::{Array, ArrayRef, FixedSizeListArray, UInt8Array};
9
10use arrow_schema::{DataType, Field};
11use builder::SQBuildParams;
12use deepsize::DeepSizeOf;
13use itertools::Itertools;
14use lance_arrow::*;
15use lance_core::{Error, Result};
16use lance_linalg::distance::DistanceType;
17use num_traits::*;
18use storage::{SQ_METADATA_KEY, ScalarQuantizationMetadata, ScalarQuantizationStorage};
19
20use super::SQ_CODE_COLUMN;
21use super::quantizer::{Quantization, QuantizationMetadata, QuantizationType, Quantizer};
22
23pub mod builder;
24pub mod storage;
25pub mod transform;
26
27/// Scalar Quantization, optimized for [Apache Arrow] buffer memory layout.
28///
29//
30// TODO: move this to be pub(crate) once we have a better way to test it.
31#[derive(Debug, Clone)]
32pub struct ScalarQuantizer {
33    metadata: ScalarQuantizationMetadata,
34}
35
36impl DeepSizeOf for ScalarQuantizer {
37    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
38        0
39    }
40}
41
42impl ScalarQuantizer {
43    pub fn new(num_bits: u16, dim: usize) -> Self {
44        Self {
45            metadata: ScalarQuantizationMetadata {
46                num_bits,
47                dim,
48                bounds: Range::<f64> {
49                    start: f64::MAX,
50                    end: f64::MIN,
51                },
52            },
53        }
54    }
55
56    pub fn with_bounds(num_bits: u16, dim: usize, bounds: Range<f64>) -> Self {
57        let mut sq = Self::new(num_bits, dim);
58        sq.metadata.bounds = bounds;
59        sq
60    }
61
62    pub fn num_bits(&self) -> u16 {
63        self.metadata.num_bits
64    }
65
66    pub fn update_bounds<T: ArrowFloatType>(
67        &mut self,
68        vectors: &FixedSizeListArray,
69    ) -> Result<Range<f64>> {
70        let data = vectors
71            .values()
72            .as_any()
73            .downcast_ref::<T::ArrayType>()
74            .ok_or(Error::index(format!(
75                "Expect to be a float vector array, got: {:?}",
76                vectors.value_type()
77            )))?
78            .as_slice();
79
80        self.metadata.bounds = data.iter().fold(self.metadata.bounds.clone(), |f, v| {
81            f.start.min(v.as_())..f.end.max(v.as_())
82        });
83
84        Ok(self.metadata.bounds.clone())
85    }
86
87    pub fn transform<T: ArrowFloatType>(&self, data: &dyn Array) -> Result<ArrayRef> {
88        let fsl = data
89            .as_fixed_size_list_opt()
90            .ok_or(Error::index(format!(
91                "Expect to be a FixedSizeList<float> vector array, got: {:?} array",
92                data.data_type()
93            )))?
94            .clone();
95        let data = fsl
96            .values()
97            .as_any()
98            .downcast_ref::<T::ArrayType>()
99            .ok_or(Error::index(format!(
100                "Expect to be a float vector array, got: {:?}",
101                fsl.value_type()
102            )))?
103            .as_slice();
104
105        // TODO: support SQ4
106        let builder: Vec<u8> = scale_to_u8::<T>(data, &self.metadata.bounds);
107
108        Ok(Arc::new(FixedSizeListArray::try_new_from_values(
109            UInt8Array::from(builder),
110            fsl.value_length(),
111        )?))
112    }
113
114    pub fn bounds(&self) -> Range<f64> {
115        self.metadata.bounds.clone()
116    }
117
118    /// Whether to use residual as input or not.
119    pub fn use_residual(&self) -> bool {
120        false
121    }
122}
123
124impl TryFrom<Quantizer> for ScalarQuantizer {
125    type Error = Error;
126    fn try_from(value: Quantizer) -> Result<Self> {
127        match value {
128            Quantizer::Scalar(sq) => Ok(sq),
129            _ => Err(Error::index("Expect to be a ScalarQuantizer".to_string())),
130        }
131    }
132}
133
134impl Quantization for ScalarQuantizer {
135    type BuildParams = SQBuildParams;
136    type Metadata = ScalarQuantizationMetadata;
137    type Storage = ScalarQuantizationStorage;
138
139    fn build(data: &dyn Array, _: DistanceType, params: &Self::BuildParams) -> Result<Self> {
140        let fsl = data.as_fixed_size_list_opt().ok_or(Error::index(format!(
141            "SQ builder: input is not a FixedSizeList: {}",
142            data.data_type()
143        )))?;
144
145        let mut quantizer = Self::new(params.num_bits, fsl.value_length() as usize);
146
147        match fsl.value_type() {
148            DataType::Float16 => {
149                quantizer.update_bounds::<Float16Type>(fsl)?;
150            }
151            DataType::Float32 => {
152                quantizer.update_bounds::<Float32Type>(fsl)?;
153            }
154            DataType::Float64 => {
155                quantizer.update_bounds::<Float64Type>(fsl)?;
156            }
157            _ => {
158                return Err(Error::index(format!(
159                    "SQ builder: unsupported data type: {}",
160                    fsl.value_type()
161                )));
162            }
163        }
164
165        Ok(quantizer)
166    }
167
168    fn retrain(&mut self, data: &dyn Array) -> Result<()> {
169        let fsl = data.as_fixed_size_list_opt().ok_or(Error::index(format!(
170            "SQ retrain: input is not a FixedSizeList: {}",
171            data.data_type()
172        )))?;
173
174        match fsl.value_type() {
175            DataType::Float16 => {
176                self.update_bounds::<Float16Type>(fsl)?;
177            }
178            DataType::Float32 => {
179                self.update_bounds::<Float32Type>(fsl)?;
180            }
181            DataType::Float64 => {
182                self.update_bounds::<Float64Type>(fsl)?;
183            }
184            value_type => {
185                return Err(Error::invalid_input(format!(
186                    "unsupported data type {} for scalar quantizer",
187                    value_type
188                )));
189            }
190        }
191        Ok(())
192    }
193
194    fn code_dim(&self) -> usize {
195        self.metadata.dim
196    }
197
198    fn column(&self) -> &'static str {
199        SQ_CODE_COLUMN
200    }
201
202    fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef> {
203        match vectors.as_fixed_size_list().value_type() {
204            DataType::Float16 => self.transform::<Float16Type>(vectors),
205            DataType::Float32 => self.transform::<Float32Type>(vectors),
206            DataType::Float64 => self.transform::<Float64Type>(vectors),
207            value_type => Err(Error::invalid_input(format!(
208                "unsupported data type {} for scalar quantizer",
209                value_type
210            ))),
211        }
212    }
213
214    fn metadata_key() -> &'static str {
215        SQ_METADATA_KEY
216    }
217
218    fn quantization_type() -> QuantizationType {
219        QuantizationType::Scalar
220    }
221
222    fn metadata(&self, _: Option<QuantizationMetadata>) -> Self::Metadata {
223        self.metadata.clone()
224    }
225
226    fn from_metadata(metadata: &Self::Metadata, _: DistanceType) -> Result<Quantizer> {
227        Ok(Quantizer::Scalar(Self {
228            metadata: metadata.clone(),
229        }))
230    }
231
232    fn field(&self) -> Field {
233        Field::new(
234            SQ_CODE_COLUMN,
235            DataType::FixedSizeList(
236                Arc::new(Field::new("item", DataType::UInt8, true)),
237                self.code_dim() as i32,
238            ),
239            true,
240        )
241    }
242}
243
244pub(crate) fn scale_to_u8<T: ArrowFloatType>(values: &[T::Native], bounds: &Range<f64>) -> Vec<u8> {
245    if bounds.start == bounds.end {
246        return vec![0; values.len()];
247    }
248
249    let range = bounds.end - bounds.start;
250    values
251        .iter()
252        .map(|&v| {
253            let v = v.to_f64().unwrap();
254            let v = (v - bounds.start) * 255.0 / range;
255            v as u8 // rust `as` performs saturating cast when casting float to int, so it's safe and expected here
256        })
257        .collect_vec()
258}
259
260#[cfg(test)]
261mod tests {
262    use arrow::datatypes::{Float16Type, Float32Type, Float64Type};
263    use arrow_array::{Float16Array, Float32Array, Float64Array};
264    use half::f16;
265
266    use super::*;
267
268    #[tokio::test]
269    async fn test_f16_sq8() {
270        let float_values = Vec::from_iter((0..16).map(|v| f16::from_usize(v).unwrap()));
271        let float_array = Float16Array::from_iter_values(float_values.clone());
272        let vectors =
273            FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
274                .unwrap();
275        let mut sq = ScalarQuantizer::new(8, float_values.len());
276
277        sq.update_bounds::<Float16Type>(&vectors).unwrap();
278        assert_eq!(sq.bounds().start, float_values[0].to_f64());
279        assert_eq!(
280            sq.bounds().end,
281            float_values.last().cloned().unwrap().to_f64()
282        );
283
284        let sq_code = sq.transform::<Float16Type>(&vectors).unwrap();
285        let sq_values = sq_code
286            .as_fixed_size_list()
287            .values()
288            .as_any()
289            .downcast_ref::<UInt8Array>()
290            .unwrap();
291
292        sq_values.values().iter().enumerate().for_each(|(i, v)| {
293            assert_eq!(*v, (i * 17) as u8);
294        });
295    }
296
297    #[tokio::test]
298    async fn test_f32_sq8() {
299        let float_values = Vec::from_iter((0..16).map(|v| v as f32));
300        let float_array = Float32Array::from_iter_values(float_values.clone());
301        let vectors =
302            FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
303                .unwrap();
304        let mut sq = ScalarQuantizer::new(8, float_values.len());
305
306        sq.update_bounds::<Float32Type>(&vectors).unwrap();
307        assert_eq!(sq.bounds().start, float_values[0].to_f64().unwrap());
308        assert_eq!(
309            sq.bounds().end,
310            float_values.last().cloned().unwrap().to_f64().unwrap()
311        );
312
313        let sq_code = sq.transform::<Float32Type>(&vectors).unwrap();
314        let sq_values = sq_code
315            .as_fixed_size_list()
316            .values()
317            .as_any()
318            .downcast_ref::<UInt8Array>()
319            .unwrap();
320
321        sq_values.values().iter().enumerate().for_each(|(i, v)| {
322            assert_eq!(*v, (i * 17) as u8,);
323        });
324    }
325
326    #[tokio::test]
327    async fn test_f64_sq8() {
328        let float_values = Vec::from_iter((0..16).map(|v| v as f64));
329        let float_array = Float64Array::from_iter_values(float_values.clone());
330        let vectors =
331            FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
332                .unwrap();
333        let mut sq = ScalarQuantizer::new(8, float_values.len());
334
335        sq.update_bounds::<Float64Type>(&vectors).unwrap();
336        assert_eq!(sq.bounds().start, float_values[0]);
337        assert_eq!(sq.bounds().end, float_values.last().cloned().unwrap());
338
339        let sq_code = sq.transform::<Float64Type>(&vectors).unwrap();
340        let sq_values = sq_code
341            .as_fixed_size_list()
342            .values()
343            .as_any()
344            .downcast_ref::<UInt8Array>()
345            .unwrap();
346
347        sq_values.values().iter().enumerate().for_each(|(i, v)| {
348            assert_eq!(*v, (i * 17) as u8,);
349        });
350    }
351
352    #[tokio::test]
353    async fn test_scale_to_u8_with_nan() {
354        let values = vec![0.0, 1.0, 2.0, 3.0, f64::NAN];
355        let bounds = Range::<f64> {
356            start: 0.0,
357            end: 3.0,
358        };
359        let u8_values = scale_to_u8::<Float64Type>(&values, &bounds);
360        assert_eq!(u8_values, vec![0, 85, 170, 255, 0]);
361    }
362}