Skip to main content

lance_index/vector/pq/
builder.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Product Quantizer Builder
5//!
6
7use std::sync::Arc;
8
9use crate::vector::quantizer::QuantizerBuildParams;
10use arrow::array::PrimitiveBuilder;
11use arrow_array::types::{Float16Type, Float64Type};
12use arrow_array::{Array, ArrayRef, cast::AsArray, types::Float32Type};
13use arrow_array::{ArrowNumericType, FixedSizeListArray, PrimitiveArray};
14use arrow_schema::DataType;
15use lance_arrow::FixedSizeListArrayExt;
16use lance_core::{Error, Result};
17use lance_linalg::distance::DistanceType;
18use lance_linalg::distance::{Dot, L2, Normalize};
19
20use super::ProductQuantizer;
21use super::utils::divide_to_subvectors;
22use crate::vector::kmeans::{KMeansParams, train_kmeans};
23
24/// Parameters for building product quantizer.
25#[derive(Debug, Clone)]
26pub struct PQBuildParams {
27    /// Number of sub-vectors to build PQ code
28    pub num_sub_vectors: usize,
29
30    /// The number of bits to present one PQ centroid.
31    pub num_bits: usize,
32
33    /// The max number of iterations for kmeans training.
34    pub max_iters: usize,
35
36    /// Run kmeans `REDOS` times and take the best result.
37    /// Default to 1.
38    pub kmeans_redos: usize,
39
40    /// User provided codebook.
41    pub codebook: Option<ArrayRef>,
42
43    /// Sample rate to train PQ codebook.
44    pub sample_rate: usize,
45}
46
47impl Default for PQBuildParams {
48    fn default() -> Self {
49        Self {
50            num_sub_vectors: 16,
51            num_bits: 8,
52            max_iters: 50,
53            kmeans_redos: 1,
54            codebook: None,
55            sample_rate: 256,
56        }
57    }
58}
59
60impl QuantizerBuildParams for PQBuildParams {
61    fn sample_size(&self) -> usize {
62        self.sample_rate * 2_usize.pow(self.num_bits as u32)
63    }
64
65    fn use_residual(distance_type: DistanceType) -> bool {
66        matches!(distance_type, DistanceType::L2 | DistanceType::Cosine)
67    }
68}
69
70impl PQBuildParams {
71    pub fn new(num_sub_vectors: usize, num_bits: usize) -> Self {
72        Self {
73            num_sub_vectors,
74            num_bits,
75            ..Default::default()
76        }
77    }
78
79    pub fn with_codebook(num_sub_vectors: usize, num_bits: usize, codebook: ArrayRef) -> Self {
80        Self {
81            num_sub_vectors,
82            num_bits,
83            codebook: Some(codebook),
84            ..Default::default()
85        }
86    }
87
88    fn build_from_fsl<T: ArrowNumericType>(
89        &self,
90        data: &FixedSizeListArray,
91        distance_type: DistanceType,
92    ) -> Result<ProductQuantizer>
93    where
94        T::Native: Dot + L2 + Normalize,
95        PrimitiveArray<T>: From<Vec<T::Native>>,
96    {
97        assert_ne!(
98            distance_type,
99            DistanceType::Cosine,
100            "PQ code does not support cosine"
101        );
102
103        let sub_vectors = divide_to_subvectors::<T>(data, self.num_sub_vectors)?;
104        let num_centroids = 2_usize.pow(self.num_bits as u32);
105        let dimension = data.value_length() as usize;
106        let sub_vector_dimension = dimension / self.num_sub_vectors;
107
108        let d = sub_vectors
109            .into_iter()
110            .enumerate()
111            .map(|(sub_vec_idx, sub_vec)| {
112                let params = KMeansParams::new(
113                    self.codebook.as_ref().map(|cb| {
114                        let sub_vec_centroids = FixedSizeListArray::try_new_from_values(
115                            cb.as_fixed_size_list().values().as_primitive::<T>().slice(
116                                sub_vec_idx * num_centroids * sub_vector_dimension,
117                                num_centroids * sub_vector_dimension,
118                            ),
119                            sub_vector_dimension as i32,
120                        )
121                        .unwrap();
122                        Arc::new(sub_vec_centroids)
123                    }),
124                    self.max_iters as u32,
125                    self.kmeans_redos,
126                    distance_type,
127                );
128                train_kmeans::<T>(
129                    &sub_vec,
130                    params,
131                    sub_vector_dimension,
132                    num_centroids,
133                    self.sample_rate,
134                )
135                .map(|kmeans| kmeans.centroids)
136            })
137            .collect::<Result<Vec<_>>>()?;
138        let mut codebook_builder = PrimitiveBuilder::<T>::with_capacity(num_centroids * dimension);
139        for centroid in d.iter() {
140            let c = centroid
141                .as_any()
142                .downcast_ref::<PrimitiveArray<T>>()
143                .expect("failed to downcast to PrimitiveArray");
144            codebook_builder.append_slice(c.values());
145        }
146
147        let pd_centroids = codebook_builder.finish();
148
149        Ok(ProductQuantizer::new(
150            self.num_sub_vectors,
151            self.num_bits as u32,
152            dimension,
153            FixedSizeListArray::try_new_from_values(pd_centroids, dimension as i32)?,
154            distance_type,
155        ))
156    }
157
158    /// Build a [ProductQuantizer] from the given data.
159    ///
160    /// If the [`DistanceType`] is [`DistanceType::Cosine`], the input data will be normalized.
161    pub fn build(&self, data: &dyn Array, distance_type: DistanceType) -> Result<ProductQuantizer> {
162        assert_eq!(data.null_count(), 0);
163        let fsl = data.as_fixed_size_list_opt().ok_or(Error::index(format!(
164            "PQ builder: input is not a FixedSizeList: {}",
165            data.data_type()
166        )))?;
167
168        let num_centroids = 2_usize.pow(self.num_bits as u32);
169        if data.len() < num_centroids {
170            return Err(Error::unprocessable(format!(
171                "Not enough rows to train PQ. Requires {num_centroids} rows but only {} available",
172                data.len()
173            )));
174        }
175
176        // TODO: support bf16 later.
177        match fsl.value_type() {
178            DataType::Float16 => self.build_from_fsl::<Float16Type>(fsl, distance_type),
179            DataType::Float32 => self.build_from_fsl::<Float32Type>(fsl, distance_type),
180            DataType::Float64 => self.build_from_fsl::<Float64Type>(fsl, distance_type),
181            _ => Err(Error::index(format!(
182                "PQ builder: unsupported data type: {}",
183                fsl.value_type()
184            ))),
185        }
186    }
187}