lance_index/vector/
sq.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{ops::Range, sync::Arc};
5
6use arrow::array::AsArray;
7use arrow::datatypes::{Float16Type, Float32Type, Float64Type};
8use arrow_array::{Array, ArrayRef, FixedSizeListArray, UInt8Array};
9
10use arrow_schema::{DataType, Field};
11use builder::SQBuildParams;
12use deepsize::DeepSizeOf;
13use itertools::Itertools;
14use lance_arrow::*;
15use lance_core::{Error, Result};
16use lance_linalg::distance::DistanceType;
17use num_traits::*;
18use snafu::location;
19use storage::{ScalarQuantizationMetadata, ScalarQuantizationStorage, SQ_METADATA_KEY};
20
21use super::quantizer::{Quantization, QuantizationMetadata, QuantizationType, Quantizer};
22use super::SQ_CODE_COLUMN;
23
24pub mod builder;
25pub mod storage;
26pub mod transform;
27
28/// Scalar Quantization, optimized for [Apache Arrow] buffer memory layout.
29///
30//
31// TODO: move this to be pub(crate) once we have a better way to test it.
32#[derive(Debug, Clone)]
33pub struct ScalarQuantizer {
34    metadata: ScalarQuantizationMetadata,
35}
36
37impl DeepSizeOf for ScalarQuantizer {
38    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
39        0
40    }
41}
42
43impl ScalarQuantizer {
44    pub fn new(num_bits: u16, dim: usize) -> Self {
45        Self {
46            metadata: ScalarQuantizationMetadata {
47                num_bits,
48                dim,
49                bounds: Range::<f64> {
50                    start: f64::MAX,
51                    end: f64::MIN,
52                },
53            },
54        }
55    }
56
57    pub fn with_bounds(num_bits: u16, dim: usize, bounds: Range<f64>) -> Self {
58        let mut sq = Self::new(num_bits, dim);
59        sq.metadata.bounds = bounds;
60        sq
61    }
62
63    pub fn num_bits(&self) -> u16 {
64        self.metadata.num_bits
65    }
66
67    pub fn update_bounds<T: ArrowFloatType>(
68        &mut self,
69        vectors: &FixedSizeListArray,
70    ) -> Result<Range<f64>> {
71        let data = vectors
72            .values()
73            .as_any()
74            .downcast_ref::<T::ArrayType>()
75            .ok_or(Error::Index {
76                message: format!(
77                    "Expect to be a float vector array, got: {:?}",
78                    vectors.value_type()
79                ),
80                location: location!(),
81            })?
82            .as_slice();
83
84        self.metadata.bounds = data.iter().fold(self.metadata.bounds.clone(), |f, v| {
85            f.start.min(v.as_())..f.end.max(v.as_())
86        });
87
88        Ok(self.metadata.bounds.clone())
89    }
90
91    pub fn transform<T: ArrowFloatType>(&self, data: &dyn Array) -> Result<ArrayRef> {
92        let fsl = data
93            .as_fixed_size_list_opt()
94            .ok_or(Error::Index {
95                message: format!(
96                    "Expect to be a FixedSizeList<float> vector array, got: {:?} array",
97                    data.data_type()
98                ),
99                location: location!(),
100            })?
101            .clone();
102        let data = fsl
103            .values()
104            .as_any()
105            .downcast_ref::<T::ArrayType>()
106            .ok_or(Error::Index {
107                message: format!(
108                    "Expect to be a float vector array, got: {:?}",
109                    fsl.value_type()
110                ),
111                location: location!(),
112            })?
113            .as_slice();
114
115        // TODO: support SQ4
116        let builder: Vec<u8> = scale_to_u8::<T>(data, &self.metadata.bounds);
117
118        Ok(Arc::new(FixedSizeListArray::try_new_from_values(
119            UInt8Array::from(builder),
120            fsl.value_length(),
121        )?))
122    }
123
124    pub fn bounds(&self) -> Range<f64> {
125        self.metadata.bounds.clone()
126    }
127
128    /// Whether to use residual as input or not.
129    pub fn use_residual(&self) -> bool {
130        false
131    }
132}
133
134impl TryFrom<Quantizer> for ScalarQuantizer {
135    type Error = Error;
136    fn try_from(value: Quantizer) -> Result<Self> {
137        match value {
138            Quantizer::Scalar(sq) => Ok(sq),
139            _ => Err(Error::Index {
140                message: "Expect to be a ScalarQuantizer".to_string(),
141                location: location!(),
142            }),
143        }
144    }
145}
146
147impl Quantization for ScalarQuantizer {
148    type BuildParams = SQBuildParams;
149    type Metadata = ScalarQuantizationMetadata;
150    type Storage = ScalarQuantizationStorage;
151
152    fn build(data: &dyn Array, _: DistanceType, params: &Self::BuildParams) -> Result<Self> {
153        let fsl = data.as_fixed_size_list_opt().ok_or(Error::Index {
154            message: format!(
155                "SQ builder: input is not a FixedSizeList: {}",
156                data.data_type()
157            ),
158            location: location!(),
159        })?;
160
161        let mut quantizer = Self::new(params.num_bits, fsl.value_length() as usize);
162
163        match fsl.value_type() {
164            DataType::Float16 => {
165                quantizer.update_bounds::<Float16Type>(fsl)?;
166            }
167            DataType::Float32 => {
168                quantizer.update_bounds::<Float32Type>(fsl)?;
169            }
170            DataType::Float64 => {
171                quantizer.update_bounds::<Float64Type>(fsl)?;
172            }
173            _ => {
174                return Err(Error::Index {
175                    message: format!("SQ builder: unsupported data type: {}", fsl.value_type()),
176                    location: location!(),
177                })
178            }
179        }
180
181        Ok(quantizer)
182    }
183
184    fn retrain(&mut self, data: &dyn Array) -> Result<()> {
185        let fsl = data.as_fixed_size_list_opt().ok_or(Error::Index {
186            message: format!(
187                "SQ retrain: input is not a FixedSizeList: {}",
188                data.data_type()
189            ),
190            location: location!(),
191        })?;
192
193        match fsl.value_type() {
194            DataType::Float16 => {
195                self.update_bounds::<Float16Type>(fsl)?;
196            }
197            DataType::Float32 => {
198                self.update_bounds::<Float32Type>(fsl)?;
199            }
200            DataType::Float64 => {
201                self.update_bounds::<Float64Type>(fsl)?;
202            }
203            value_type => {
204                return Err(Error::invalid_input(
205                    format!("unsupported data type {} for scalar quantizer", value_type),
206                    location!(),
207                ))
208            }
209        }
210        Ok(())
211    }
212
213    fn code_dim(&self) -> usize {
214        self.metadata.dim
215    }
216
217    fn column(&self) -> &'static str {
218        SQ_CODE_COLUMN
219    }
220
221    fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef> {
222        match vectors.as_fixed_size_list().value_type() {
223            DataType::Float16 => self.transform::<Float16Type>(vectors),
224            DataType::Float32 => self.transform::<Float32Type>(vectors),
225            DataType::Float64 => self.transform::<Float64Type>(vectors),
226            value_type => Err(Error::invalid_input(
227                format!("unsupported data type {} for scalar quantizer", value_type),
228                location!(),
229            )),
230        }
231    }
232
233    fn metadata_key() -> &'static str {
234        SQ_METADATA_KEY
235    }
236
237    fn quantization_type() -> QuantizationType {
238        QuantizationType::Scalar
239    }
240
241    fn metadata(&self, _: Option<QuantizationMetadata>) -> Self::Metadata {
242        self.metadata.clone()
243    }
244
245    fn from_metadata(metadata: &Self::Metadata, _: DistanceType) -> Result<Quantizer> {
246        Ok(Quantizer::Scalar(Self {
247            metadata: metadata.clone(),
248        }))
249    }
250
251    fn field(&self) -> Field {
252        Field::new(
253            SQ_CODE_COLUMN,
254            DataType::FixedSizeList(
255                Arc::new(Field::new("item", DataType::UInt8, true)),
256                self.code_dim() as i32,
257            ),
258            true,
259        )
260    }
261}
262
263pub(crate) fn scale_to_u8<T: ArrowFloatType>(values: &[T::Native], bounds: &Range<f64>) -> Vec<u8> {
264    if bounds.start == bounds.end {
265        return vec![0; values.len()];
266    }
267
268    let range = bounds.end - bounds.start;
269    values
270        .iter()
271        .map(|&v| {
272            let v = v.to_f64().unwrap();
273            let v = (v - bounds.start) * 255.0 / range;
274            v as u8 // rust `as` performs saturating cast when casting float to int, so it's safe and expected here
275        })
276        .collect_vec()
277}
278
279pub(crate) fn inverse_scalar_dist(
280    values: impl Iterator<Item = f32>,
281    bounds: &Range<f64>,
282) -> Vec<f32> {
283    let range = (bounds.end - bounds.start) as f32;
284    values
285        .map(|v| v * range.powi(2) / 255.0.powi(2))
286        .collect_vec()
287}
288#[cfg(test)]
289mod tests {
290    use arrow::datatypes::{Float16Type, Float32Type, Float64Type};
291    use arrow_array::{Float16Array, Float32Array, Float64Array};
292    use half::f16;
293
294    use super::*;
295
296    #[tokio::test]
297    async fn test_f16_sq8() {
298        let float_values = Vec::from_iter((0..16).map(|v| f16::from_usize(v).unwrap()));
299        let float_array = Float16Array::from_iter_values(float_values.clone());
300        let vectors =
301            FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
302                .unwrap();
303        let mut sq = ScalarQuantizer::new(8, float_values.len());
304
305        sq.update_bounds::<Float16Type>(&vectors).unwrap();
306        assert_eq!(sq.bounds().start, float_values[0].to_f64());
307        assert_eq!(
308            sq.bounds().end,
309            float_values.last().cloned().unwrap().to_f64()
310        );
311
312        let sq_code = sq.transform::<Float16Type>(&vectors).unwrap();
313        let sq_values = sq_code
314            .as_fixed_size_list()
315            .values()
316            .as_any()
317            .downcast_ref::<UInt8Array>()
318            .unwrap();
319
320        sq_values.values().iter().enumerate().for_each(|(i, v)| {
321            assert_eq!(*v, (i * 17) as u8);
322        });
323    }
324
325    #[tokio::test]
326    async fn test_f32_sq8() {
327        let float_values = Vec::from_iter((0..16).map(|v| v as f32));
328        let float_array = Float32Array::from_iter_values(float_values.clone());
329        let vectors =
330            FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
331                .unwrap();
332        let mut sq = ScalarQuantizer::new(8, float_values.len());
333
334        sq.update_bounds::<Float32Type>(&vectors).unwrap();
335        assert_eq!(sq.bounds().start, float_values[0].to_f64().unwrap());
336        assert_eq!(
337            sq.bounds().end,
338            float_values.last().cloned().unwrap().to_f64().unwrap()
339        );
340
341        let sq_code = sq.transform::<Float32Type>(&vectors).unwrap();
342        let sq_values = sq_code
343            .as_fixed_size_list()
344            .values()
345            .as_any()
346            .downcast_ref::<UInt8Array>()
347            .unwrap();
348
349        sq_values.values().iter().enumerate().for_each(|(i, v)| {
350            assert_eq!(*v, (i * 17) as u8,);
351        });
352    }
353
354    #[tokio::test]
355    async fn test_f64_sq8() {
356        let float_values = Vec::from_iter((0..16).map(|v| v as f64));
357        let float_array = Float64Array::from_iter_values(float_values.clone());
358        let vectors =
359            FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
360                .unwrap();
361        let mut sq = ScalarQuantizer::new(8, float_values.len());
362
363        sq.update_bounds::<Float64Type>(&vectors).unwrap();
364        assert_eq!(sq.bounds().start, float_values[0]);
365        assert_eq!(sq.bounds().end, float_values.last().cloned().unwrap());
366
367        let sq_code = sq.transform::<Float64Type>(&vectors).unwrap();
368        let sq_values = sq_code
369            .as_fixed_size_list()
370            .values()
371            .as_any()
372            .downcast_ref::<UInt8Array>()
373            .unwrap();
374
375        sq_values.values().iter().enumerate().for_each(|(i, v)| {
376            assert_eq!(*v, (i * 17) as u8,);
377        });
378    }
379
380    #[tokio::test]
381    async fn test_scale_to_u8_with_nan() {
382        let values = vec![0.0, 1.0, 2.0, 3.0, f64::NAN];
383        let bounds = Range::<f64> {
384            start: 0.0,
385            end: 3.0,
386        };
387        let u8_values = scale_to_u8::<Float64Type>(&values, &bounds);
388        assert_eq!(u8_values, vec![0, 85, 170, 255, 0]);
389    }
390}