lance_index/vector/
pq.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Product Quantization
5//!
6
7use std::sync::Arc;
8
9use arrow::datatypes::{self, ArrowPrimitiveType};
10use arrow_array::{cast::AsArray, Array, FixedSizeListArray, UInt8Array};
11use arrow_array::{ArrayRef, Float32Array, PrimitiveArray};
12use arrow_schema::{DataType, Field};
13use deepsize::DeepSizeOf;
14use distance::build_distance_table_dot;
15use lance_arrow::*;
16use lance_core::{assume_eq, Error, Result};
17use lance_linalg::distance::{DistanceType, Dot, L2};
18use lance_table::utils::LanceIteratorExtension;
19use num_traits::Float;
20use prost::Message;
21use snafu::location;
22use storage::{ProductQuantizationMetadata, ProductQuantizationStorage, PQ_METADATA_KEY};
23use tracing::instrument;
24
25pub mod builder;
26pub mod distance;
27pub mod storage;
28pub mod transform;
29pub(crate) mod utils;
30
31use self::distance::{build_distance_table_l2, compute_pq_distance};
32pub use self::utils::num_centroids;
33use super::quantizer::{
34    Quantization, QuantizationMetadata, QuantizationType, Quantizer, QuantizerBuildParams,
35};
36use super::{pb, PQ_CODE_COLUMN};
37use crate::vector::kmeans::compute_partition;
38pub use builder::PQBuildParams;
39use utils::get_sub_vector_centroids;
40
41#[derive(Debug, Clone)]
42pub struct ProductQuantizer {
43    pub num_sub_vectors: usize,
44    pub num_bits: u32,
45    pub dimension: usize,
46    pub codebook: FixedSizeListArray,
47    pub distance_type: DistanceType,
48}
49
50impl DeepSizeOf for ProductQuantizer {
51    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
52        self.codebook.get_array_memory_size()
53            + self.num_sub_vectors.deep_size_of_children(_context)
54            + self.num_bits.deep_size_of_children(_context)
55            + self.dimension.deep_size_of_children(_context)
56            + self.distance_type.deep_size_of_children(_context)
57    }
58}
59
60impl ProductQuantizer {
61    pub fn new(
62        num_sub_vectors: usize,
63        num_bits: u32,
64        dimension: usize,
65        codebook: FixedSizeListArray,
66        distance_type: DistanceType,
67    ) -> Self {
68        Self {
69            num_bits,
70            num_sub_vectors,
71            dimension,
72            codebook,
73            distance_type,
74        }
75    }
76
77    pub fn from_proto(proto: &pb::Pq, distance_type: DistanceType) -> Result<Self> {
78        let distance_type = match distance_type {
79            DistanceType::Cosine => DistanceType::L2,
80            _ => distance_type,
81        };
82        let codebook = match proto.codebook_tensor.as_ref() {
83            Some(tensor) => FixedSizeListArray::try_from(tensor)?,
84            None => FixedSizeListArray::try_new_from_values(
85                Float32Array::from(proto.codebook.clone()),
86                proto.dimension as i32,
87            )?,
88        };
89        Ok(Self {
90            num_bits: proto.num_bits,
91            num_sub_vectors: proto.num_sub_vectors as usize,
92            dimension: proto.dimension as usize,
93            codebook,
94            distance_type,
95        })
96    }
97
98    #[instrument(name = "ProductQuantizer::transform", level = "debug", skip_all)]
99    fn transform<T: ArrowPrimitiveType>(&self, vectors: &dyn Array) -> Result<ArrayRef>
100    where
101        T::Native: Float + L2 + Dot,
102    {
103        match self.num_bits {
104            4 => self.transform_impl::<4, T>(vectors),
105            8 => self.transform_impl::<8, T>(vectors),
106            _ => Err(Error::Index {
107                message: format!(
108                    "ProductQuantization: num_bits {} not supported",
109                    self.num_bits
110                ),
111                location: location!(),
112            }),
113        }
114    }
115
116    fn transform_impl<const NUM_BITS: u32, T: ArrowPrimitiveType>(
117        &self,
118        vectors: &dyn Array,
119    ) -> Result<ArrayRef>
120    where
121        T::Native: Float + L2 + Dot,
122    {
123        let fsl = vectors.as_fixed_size_list_opt().ok_or(Error::Index {
124            message: format!(
125                "Expect to be a FixedSizeList<float> vector array, got: {:?} array",
126                vectors.data_type()
127            ),
128            location: location!(),
129        })?;
130        let num_sub_vectors = self.num_sub_vectors;
131        let dim = self.dimension;
132        if NUM_BITS == 4 && num_sub_vectors % 2 != 0 {
133            return Err(Error::Index {
134                message: format!(
135                    "PQ: num_sub_vectors must be divisible by 2 for num_bits=4, but got {}",
136                    num_sub_vectors,
137                ),
138                location: location!(),
139            });
140        }
141        let codebook = self.codebook.values().as_primitive::<T>();
142
143        let distance_type = self.distance_type;
144
145        let flatten_data = fsl.values().as_primitive::<T>();
146        let sub_dim = dim / num_sub_vectors;
147        let total_code_length = fsl.len() * num_sub_vectors / (8 / NUM_BITS as usize);
148        let values = flatten_data
149            .values()
150            .chunks_exact(dim)
151            .flat_map(|vector| {
152                let sub_vec_code = vector
153                    .chunks_exact(sub_dim)
154                    .enumerate()
155                    .map(|(sub_idx, sub_vector)| {
156                        let centroids = get_sub_vector_centroids::<NUM_BITS, _>(
157                            codebook.values(),
158                            dim,
159                            num_sub_vectors,
160                            sub_idx,
161                        );
162                        // SAFETY: The must be 2^NUM_BITS centroids, it's safe to unwrap_or(0),
163                        // this could happen if all distances are INFs in the case of vectors are large.
164                        assume_eq!(centroids.len(), 2_usize.pow(NUM_BITS) * sub_dim);
165                        compute_partition(centroids, sub_vector, distance_type).unwrap_or(0) as u8
166                    })
167                    .collect::<Vec<_>>();
168                if NUM_BITS == 4 {
169                    sub_vec_code
170                        .chunks_exact(2)
171                        .map(|v| (v[1] << 4) | v[0])
172                        .collect::<Vec<_>>()
173                } else {
174                    sub_vec_code
175                }
176            })
177            .exact_size(total_code_length)
178            .collect::<Vec<_>>();
179
180        let num_sub_vectors_in_byte = if NUM_BITS == 4 {
181            num_sub_vectors / 2
182        } else {
183            num_sub_vectors
184        };
185
186        debug_assert_eq!(values.len(), fsl.len() * num_sub_vectors_in_byte);
187        Ok(Arc::new(FixedSizeListArray::try_new_from_values(
188            UInt8Array::from(values),
189            num_sub_vectors_in_byte as i32,
190        )?))
191    }
192
193    // the code must be transposed
194    pub fn compute_distances(&self, query: &dyn Array, code: &UInt8Array) -> Result<Float32Array> {
195        if code.is_empty() {
196            return Ok(Float32Array::from(Vec::<f32>::new()));
197        }
198
199        match self.distance_type {
200            DistanceType::L2 => self.l2_distances(query, code),
201            DistanceType::Cosine => {
202                // it seems we implemented cosine distance at some version,
203                // but from now on, we should use normalized L2 distance.
204                debug_assert!(
205                    false,
206                    "cosine distance should be converted to normalized L2 distance"
207                );
208                // L2 over normalized vectors:  ||x - y|| = x^2 + y^2 - 2 * xy = 1 + 1 - 2 * xy = 2 * (1 - xy)
209                // Cosine distance: 1 - |xy| / (||x|| * ||y||) = 1 - xy / (x^2 * y^2) = 1 - xy / (1 * 1) = 1 - xy
210                // Therefore, Cosine = L2 / 2
211                let l2_dists = self.l2_distances(query, code)?;
212                Ok(l2_dists.values().iter().map(|v| *v / 2.0).collect())
213            }
214            DistanceType::Dot => self.dot_distances(query, code),
215            _ => panic!(
216                "ProductQuantization: distance type {} not supported",
217                self.distance_type
218            ),
219        }
220    }
221
222    /// Pre-compute L2 distance from the query to all code.
223    ///
224    /// It returns the squared L2 distance.
225    fn l2_distances(&self, key: &dyn Array, code: &UInt8Array) -> Result<Float32Array> {
226        let distance_table = self.build_l2_distance_table(key)?;
227
228        #[cfg(target_feature = "avx512f")]
229        {
230            Ok(self.compute_l2_distance(&distance_table, code.values()))
231        }
232        #[cfg(not(target_feature = "avx512f"))]
233        {
234            Ok(self.compute_l2_distance(&distance_table, code.values()))
235        }
236    }
237
238    /// Parameters
239    /// ----------
240    ///  - query: the query vector, with shape (dimension, )
241    ///  - code: the PQ code in one partition.
242    ///
243    fn dot_distances(&self, key: &dyn Array, code: &UInt8Array) -> Result<Float32Array> {
244        match key.data_type() {
245            DataType::Float16 => {
246                self.dot_distances_impl::<datatypes::Float16Type>(key.as_primitive(), code)
247            }
248            DataType::Float32 => {
249                self.dot_distances_impl::<datatypes::Float32Type>(key.as_primitive(), code)
250            }
251            DataType::Float64 => {
252                self.dot_distances_impl::<datatypes::Float64Type>(key.as_primitive(), code)
253            }
254            _ => Err(Error::Index {
255                message: format!("unsupported data type: {}", key.data_type()),
256                location: location!(),
257            }),
258        }
259    }
260
261    fn dot_distances_impl<T: ArrowPrimitiveType>(
262        &self,
263        key: &PrimitiveArray<T>,
264        code: &UInt8Array,
265    ) -> Result<Float32Array>
266    where
267        T::Native: Dot,
268    {
269        let distance_table = build_distance_table_dot(
270            self.codebook.values().as_primitive::<T>().values(),
271            self.num_bits,
272            self.num_sub_vectors,
273            key.values(),
274        );
275
276        let distances = compute_pq_distance(
277            &distance_table,
278            self.num_bits,
279            self.num_sub_vectors,
280            code.values(),
281            0,
282        );
283
284        let diff = self.num_sub_vectors as f32 - 1.0;
285        let distances = distances.into_iter().map(|d| d - diff).collect::<Vec<_>>();
286        Ok(distances.into())
287    }
288
289    fn build_l2_distance_table(&self, key: &dyn Array) -> Result<Vec<f32>> {
290        match key.data_type() {
291            DataType::Float16 => {
292                Ok(self.build_l2_distance_table_impl::<datatypes::Float16Type>(key.as_primitive()))
293            }
294            DataType::Float32 => {
295                Ok(self.build_l2_distance_table_impl::<datatypes::Float32Type>(key.as_primitive()))
296            }
297            DataType::Float64 => {
298                Ok(self.build_l2_distance_table_impl::<datatypes::Float64Type>(key.as_primitive()))
299            }
300            _ => Err(Error::Index {
301                message: format!("unsupported data type: {}", key.data_type()),
302                location: location!(),
303            }),
304        }
305    }
306
307    fn build_l2_distance_table_impl<T: ArrowPrimitiveType>(
308        &self,
309        key: &PrimitiveArray<T>,
310    ) -> Vec<f32>
311    where
312        T::Native: L2,
313    {
314        build_distance_table_l2(
315            self.codebook.values().as_primitive::<T>().values(),
316            self.num_bits,
317            self.num_sub_vectors,
318            key.values(),
319        )
320    }
321
322    /// Compute L2 distance from the query to all code.
323    ///
324    /// Type parameters
325    /// ---------------
326    /// - C: the tile size of code-book to run at once.
327    /// - V: the tile size of PQ code to run at once.
328    ///
329    /// Parameters
330    /// ----------
331    /// - distance_table: the pre-computed L2 distance table.
332    ///   It is a flatten array of [num_sub_vectors, num_centroids] f32.
333    /// - code: the PQ code to be used to compute the distances.
334    ///
335    /// Returns
336    /// -------
337    ///  The squared L2 distance.
338    #[inline]
339    fn compute_l2_distance(&self, distance_table: &[f32], code: &[u8]) -> Float32Array {
340        Float32Array::from(compute_pq_distance(
341            distance_table,
342            self.num_bits,
343            self.num_sub_vectors,
344            code,
345            100,
346        ))
347    }
348
349    /// Get the centroids for one sub-vector.
350    ///
351    /// Returns a flatten `num_centroids * sub_vector_width` f32 array.
352    pub fn centroids<T: ArrowPrimitiveType>(&self, sub_vector_idx: usize) -> &[T::Native] {
353        match self.num_bits {
354            4 => get_sub_vector_centroids::<4, _>(
355                self.codebook.values().as_primitive::<T>().values(),
356                self.dimension,
357                self.num_sub_vectors,
358                sub_vector_idx,
359            ),
360            8 => get_sub_vector_centroids::<8, _>(
361                self.codebook.values().as_primitive::<T>().values(),
362                self.dimension,
363                self.num_sub_vectors,
364                sub_vector_idx,
365            ),
366            _ => panic!(
367                "ProductQuantization: num_bits {} not supported",
368                self.num_bits
369            ),
370        }
371    }
372}
373
374impl Quantization for ProductQuantizer {
375    type BuildParams = PQBuildParams;
376    type Metadata = ProductQuantizationMetadata;
377    type Storage = ProductQuantizationStorage;
378
379    fn build(
380        data: &dyn Array,
381        distance_type: DistanceType,
382        params: &Self::BuildParams,
383    ) -> Result<Self> {
384        assert_eq!(data.null_count(), 0);
385        let fsl = data.as_fixed_size_list_opt().ok_or(Error::Index {
386            message: format!(
387                "PQ builder: input is not a FixedSizeList: {}",
388                data.data_type()
389            ),
390            location: location!(),
391        })?;
392
393        if let Some(codebook) = params.codebook.as_ref() {
394            return Ok(Self::new(
395                params.num_sub_vectors,
396                params.num_bits as u32,
397                fsl.value_length() as usize,
398                FixedSizeListArray::try_new_from_values(codebook.clone(), fsl.value_length())?,
399                distance_type,
400            ));
401        }
402
403        params.build(data, distance_type)
404    }
405
406    fn retrain(&mut self, data: &dyn Array) -> Result<()> {
407        assert_eq!(data.null_count(), 0);
408        let params = PQBuildParams::with_codebook(
409            self.num_sub_vectors,
410            self.num_bits as usize,
411            Arc::new(self.codebook.clone()),
412        );
413
414        *self = params.build(data, self.distance_type)?;
415        Ok(())
416    }
417
418    fn code_dim(&self) -> usize {
419        self.num_sub_vectors
420    }
421
422    fn column(&self) -> &'static str {
423        PQ_CODE_COLUMN
424    }
425
426    fn use_residual(distance_type: DistanceType) -> bool {
427        PQBuildParams::use_residual(distance_type)
428    }
429
430    fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef> {
431        let fsl = vectors.as_fixed_size_list_opt().ok_or(Error::Index {
432            message: format!(
433                "Expect to be a FixedSizeList<float> vector array, got: {:?} array",
434                vectors.data_type()
435            ),
436            location: location!(),
437        })?;
438
439        match fsl.value_type() {
440            DataType::Float16 => self.transform::<datatypes::Float16Type>(vectors),
441            DataType::Float32 => self.transform::<datatypes::Float32Type>(vectors),
442            DataType::Float64 => self.transform::<datatypes::Float64Type>(vectors),
443            _ => Err(Error::Index {
444                message: format!("unsupported data type: {}", fsl.value_type()),
445                location: location!(),
446            }),
447        }
448    }
449
450    fn metadata_key() -> &'static str {
451        PQ_METADATA_KEY
452    }
453
454    fn quantization_type() -> QuantizationType {
455        QuantizationType::Product
456    }
457
458    fn metadata(&self, args: Option<QuantizationMetadata>) -> Self::Metadata {
459        let codebook_position = match &args {
460            Some(args) => args.codebook_position,
461            None => Some(0),
462        };
463
464        let codebook_position = codebook_position.expect("codebook position should be set");
465        ProductQuantizationMetadata {
466            codebook_position,
467            nbits: self.num_bits,
468            num_sub_vectors: self.num_sub_vectors,
469            dimension: self.dimension,
470            codebook: Some(self.codebook.clone()),
471            codebook_tensor: Vec::new(),
472            transposed: args.map(|args| args.transposed).unwrap_or_default(),
473        }
474    }
475
476    fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result<Quantizer> {
477        let distance_type = match distance_type {
478            DistanceType::Cosine => DistanceType::L2,
479            _ => distance_type,
480        };
481        let codebook = match metadata.codebook.as_ref() {
482            Some(fsl) => fsl.clone(),
483            None => {
484                let tensor = pb::Tensor::decode(metadata.codebook_tensor.as_ref())?;
485                FixedSizeListArray::try_from(&tensor)?
486            }
487        };
488        Ok(Quantizer::Product(Self::new(
489            metadata.num_sub_vectors,
490            metadata.nbits,
491            metadata.dimension,
492            codebook,
493            distance_type,
494        )))
495    }
496
497    fn field(&self) -> Field {
498        let num_bytes_per_sub_vector = self.num_sub_vectors * self.num_bits as usize / 8;
499        Field::new(
500            PQ_CODE_COLUMN,
501            DataType::FixedSizeList(
502                Arc::new(Field::new("item", DataType::UInt8, true)),
503                num_bytes_per_sub_vector as i32,
504            ),
505            true,
506        )
507    }
508}
509
510impl TryFrom<&ProductQuantizer> for pb::Pq {
511    type Error = Error;
512
513    fn try_from(pq: &ProductQuantizer) -> Result<Self> {
514        let tensor = pb::Tensor::try_from(&pq.codebook)?;
515        Ok(Self {
516            num_bits: pq.num_bits,
517            num_sub_vectors: pq.num_sub_vectors as u32,
518            dimension: pq.dimension as u32,
519            codebook: vec![],
520            codebook_tensor: Some(tensor),
521        })
522    }
523}
524
525impl TryFrom<Quantizer> for ProductQuantizer {
526    type Error = Error;
527    fn try_from(value: Quantizer) -> Result<Self> {
528        match value {
529            Quantizer::Product(pq) => Ok(pq),
530            _ => Err(Error::Index {
531                message: "Expect to be a ProductQuantizer".to_string(),
532                location: location!(),
533            }),
534        }
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541
542    use std::iter::repeat_n;
543
544    use approx::assert_relative_eq;
545    use arrow::datatypes::UInt8Type;
546    use arrow_array::Float16Array;
547    use half::f16;
548    use lance_linalg::distance::l2_distance_batch;
549    use lance_linalg::kernels::argmin;
550    use lance_testing::datagen::generate_random_array;
551    use num_traits::Zero;
552    use storage::transpose;
553
554    #[test]
555    fn test_f16_pq_to_protobuf() {
556        let pq = ProductQuantizer::new(
557            4,
558            8,
559            16,
560            FixedSizeListArray::try_new_from_values(
561                Float16Array::from_iter_values(repeat_n(f16::zero(), 256 * 16)),
562                16,
563            )
564            .unwrap(),
565            DistanceType::L2,
566        );
567        let proto: pb::Pq = pb::Pq::try_from(&pq).unwrap();
568        assert_eq!(proto.num_bits, 8);
569        assert_eq!(proto.num_sub_vectors, 4);
570        assert_eq!(proto.dimension, 16);
571        assert!(proto.codebook.is_empty());
572        assert!(proto.codebook_tensor.is_some());
573
574        let tensor = proto.codebook_tensor.as_ref().unwrap();
575        assert_eq!(tensor.data_type, pb::tensor::DataType::Float16 as i32);
576        assert_eq!(tensor.shape, vec![256, 16]);
577    }
578
579    #[test]
580    fn test_l2_distance() {
581        const DIM: usize = 512;
582        const TOTAL: usize = 66; // 64 + 2 to make sure reminder is handled correctly.
583        let codebook = generate_random_array(256 * DIM);
584        let pq = ProductQuantizer::new(
585            16,
586            8,
587            DIM,
588            FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap(),
589            DistanceType::L2,
590        );
591        let pq_code = UInt8Array::from_iter_values((0..16 * TOTAL).map(|v| v as u8));
592        let query = generate_random_array(DIM);
593
594        let transposed_pq_codes = transpose(&pq_code, TOTAL, 16);
595        let dists = pq.compute_distances(&query, &transposed_pq_codes).unwrap();
596
597        let sub_vec_len = DIM / 16;
598        let expected = pq_code
599            .values()
600            .chunks(16)
601            .map(|code| {
602                code.iter()
603                    .enumerate()
604                    .flat_map(|(sub_idx, c)| {
605                        let subvec_centroids = pq.centroids::<datatypes::Float32Type>(sub_idx);
606                        let subvec =
607                            &query.values()[sub_idx * sub_vec_len..(sub_idx + 1) * sub_vec_len];
608                        l2_distance_batch(
609                            subvec,
610                            &subvec_centroids
611                                [*c as usize * sub_vec_len..(*c as usize + 1) * sub_vec_len],
612                            sub_vec_len,
613                        )
614                    })
615                    .sum::<f32>()
616            })
617            .collect::<Vec<_>>();
618        dists
619            .values()
620            .iter()
621            .zip(expected.iter())
622            .for_each(|(v, e)| {
623                assert_relative_eq!(*v, *e, epsilon = 1e-4);
624            });
625    }
626
627    #[test]
628    fn test_pq_transform() {
629        const DIM: usize = 16;
630        const TOTAL: usize = 64;
631        let codebook = generate_random_array(DIM * 256);
632        let pq = ProductQuantizer::new(
633            4,
634            8,
635            DIM,
636            FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap(),
637            DistanceType::L2,
638        );
639
640        let vectors = generate_random_array(DIM * TOTAL);
641        let fsl = FixedSizeListArray::try_new_from_values(vectors.clone(), DIM as i32).unwrap();
642        let pq_code = pq.quantize(&fsl).unwrap();
643
644        let mut expected = Vec::with_capacity(TOTAL * 4);
645        vectors.values().chunks_exact(DIM).for_each(|vec| {
646            vec.chunks_exact(DIM / 4)
647                .enumerate()
648                .for_each(|(sub_idx, sub_vec)| {
649                    let centroids = pq.centroids::<datatypes::Float32Type>(sub_idx);
650                    let dists = l2_distance_batch(sub_vec, centroids, DIM / 4);
651                    let code = argmin(dists).unwrap() as u8;
652                    expected.push(code);
653                });
654        });
655
656        assert_eq!(pq_code.len(), TOTAL);
657        assert_eq!(
658            &expected,
659            pq_code
660                .as_fixed_size_list()
661                .values()
662                .as_primitive::<UInt8Type>()
663                .values()
664        );
665    }
666}