Skip to main content

lance_index/vector/pq/
builder.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Product Quantizer Builder
5//!
6
7use std::sync::Arc;
8
9use crate::vector::quantizer::QuantizerBuildParams;
10use arrow::array::PrimitiveBuilder;
11use arrow_array::types::{Float16Type, Float64Type};
12use arrow_array::{Array, ArrayRef, cast::AsArray, types::Float32Type};
13use arrow_array::{ArrowNumericType, FixedSizeListArray, PrimitiveArray};
14use arrow_schema::DataType;
15use lance_arrow::FixedSizeListArrayExt;
16use lance_core::{Error, Result};
17use lance_linalg::distance::DistanceType;
18use lance_linalg::distance::{Dot, L2, Normalize};
19
20use super::ProductQuantizer;
21use super::utils::divide_to_subvectors;
22use crate::vector::kmeans::{KMeansParams, train_kmeans};
23
24/// Parameters for building product quantizer.
25#[derive(Debug, Clone)]
26pub struct PQBuildParams {
27    /// Number of sub-vectors to build PQ code
28    pub num_sub_vectors: usize,
29
30    /// The number of bits to present one PQ centroid.
31    pub num_bits: usize,
32
33    /// The max number of iterations for kmeans training.
34    pub max_iters: usize,
35
36    /// Run kmeans `REDOS` times and take the best result.
37    /// Default to 1.
38    pub kmeans_redos: usize,
39
40    /// User provided codebook.
41    pub codebook: Option<ArrayRef>,
42
43    /// Sample rate to train PQ codebook.
44    pub sample_rate: usize,
45}
46
47impl From<&PQBuildParams> for crate::pb::vector_index_details::ProductQuantization {
48    fn from(params: &PQBuildParams) -> Self {
49        Self {
50            num_bits: params.num_bits as u32,
51            num_sub_vectors: params.num_sub_vectors as u32,
52        }
53    }
54}
55
56impl Default for PQBuildParams {
57    fn default() -> Self {
58        Self {
59            num_sub_vectors: 16,
60            num_bits: 8,
61            max_iters: 50,
62            kmeans_redos: 1,
63            codebook: None,
64            sample_rate: 256,
65        }
66    }
67}
68
69impl QuantizerBuildParams for PQBuildParams {
70    fn sample_size(&self) -> usize {
71        self.sample_rate * 2_usize.pow(self.num_bits as u32)
72    }
73
74    fn use_residual(distance_type: DistanceType) -> bool {
75        matches!(distance_type, DistanceType::L2 | DistanceType::Cosine)
76    }
77}
78
79impl PQBuildParams {
80    pub fn new(num_sub_vectors: usize, num_bits: usize) -> Self {
81        Self {
82            num_sub_vectors,
83            num_bits,
84            ..Default::default()
85        }
86    }
87
88    pub fn with_codebook(num_sub_vectors: usize, num_bits: usize, codebook: ArrayRef) -> Self {
89        Self {
90            num_sub_vectors,
91            num_bits,
92            codebook: Some(codebook),
93            ..Default::default()
94        }
95    }
96
97    fn build_from_fsl<T: ArrowNumericType>(
98        &self,
99        data: &FixedSizeListArray,
100        distance_type: DistanceType,
101    ) -> Result<ProductQuantizer>
102    where
103        T::Native: Dot + L2 + Normalize,
104        PrimitiveArray<T>: From<Vec<T::Native>>,
105    {
106        assert_ne!(
107            distance_type,
108            DistanceType::Cosine,
109            "PQ code does not support cosine"
110        );
111
112        let sub_vectors = divide_to_subvectors::<T>(data, self.num_sub_vectors)?;
113        let num_centroids = 2_usize.pow(self.num_bits as u32);
114        let dimension = data.value_length() as usize;
115        let sub_vector_dimension = dimension / self.num_sub_vectors;
116
117        let d = sub_vectors
118            .into_iter()
119            .enumerate()
120            .map(|(sub_vec_idx, sub_vec)| {
121                let params = KMeansParams::new(
122                    self.codebook.as_ref().map(|cb| {
123                        let sub_vec_centroids = FixedSizeListArray::try_new_from_values(
124                            cb.as_fixed_size_list().values().as_primitive::<T>().slice(
125                                sub_vec_idx * num_centroids * sub_vector_dimension,
126                                num_centroids * sub_vector_dimension,
127                            ),
128                            sub_vector_dimension as i32,
129                        )
130                        .unwrap();
131                        Arc::new(sub_vec_centroids)
132                    }),
133                    self.max_iters as u32,
134                    self.kmeans_redos,
135                    distance_type,
136                );
137                train_kmeans::<T>(
138                    &sub_vec,
139                    params,
140                    sub_vector_dimension,
141                    num_centroids,
142                    self.sample_rate,
143                )
144                .map(|kmeans| kmeans.centroids)
145            })
146            .collect::<Result<Vec<_>>>()?;
147        let mut codebook_builder = PrimitiveBuilder::<T>::with_capacity(num_centroids * dimension);
148        for centroid in d.iter() {
149            let c = centroid
150                .as_any()
151                .downcast_ref::<PrimitiveArray<T>>()
152                .expect("failed to downcast to PrimitiveArray");
153            codebook_builder.append_slice(c.values());
154        }
155
156        let pd_centroids = codebook_builder.finish();
157
158        Ok(ProductQuantizer::new(
159            self.num_sub_vectors,
160            self.num_bits as u32,
161            dimension,
162            FixedSizeListArray::try_new_from_values(pd_centroids, dimension as i32)?,
163            distance_type,
164        ))
165    }
166
167    /// Build a [ProductQuantizer] from the given data.
168    ///
169    /// If the [`DistanceType`] is [`DistanceType::Cosine`], the input data will be normalized.
170    pub fn build(&self, data: &dyn Array, distance_type: DistanceType) -> Result<ProductQuantizer> {
171        assert_eq!(data.null_count(), 0);
172        let fsl = data.as_fixed_size_list_opt().ok_or(Error::index(format!(
173            "PQ builder: input is not a FixedSizeList: {}",
174            data.data_type()
175        )))?;
176
177        let num_centroids = 2_usize.pow(self.num_bits as u32);
178        if data.len() < num_centroids {
179            return Err(Error::unprocessable(format!(
180                "Not enough rows to train PQ. Requires {num_centroids} rows but only {} available",
181                data.len()
182            )));
183        }
184
185        // TODO: support bf16 later.
186        match fsl.value_type() {
187            DataType::Float16 => self.build_from_fsl::<Float16Type>(fsl, distance_type),
188            DataType::Float32 => self.build_from_fsl::<Float32Type>(fsl, distance_type),
189            DataType::Float64 => self.build_from_fsl::<Float64Type>(fsl, distance_type),
190            _ => Err(Error::index(format!(
191                "PQ builder: unsupported data type: {}",
192                fsl.value_type()
193            ))),
194        }
195    }
196}