lance_index/vector/pq/
builder.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Product Quantizer Builder
5//!
6
7use std::sync::Arc;
8
9use crate::vector::quantizer::QuantizerBuildParams;
10use arrow::array::PrimitiveBuilder;
11use arrow_array::types::{Float16Type, Float64Type};
12use arrow_array::{cast::AsArray, types::Float32Type, Array, ArrayRef};
13use arrow_array::{ArrowNumericType, FixedSizeListArray, PrimitiveArray};
14use arrow_schema::DataType;
15use lance_arrow::FixedSizeListArrayExt;
16use lance_core::{Error, Result};
17use lance_linalg::distance::DistanceType;
18use lance_linalg::distance::{Dot, Normalize, L2};
19use snafu::location;
20
21use super::utils::divide_to_subvectors;
22use super::ProductQuantizer;
23use crate::vector::kmeans::{train_kmeans, KMeansParams};
24
25/// Parameters for building product quantizer.
26#[derive(Debug, Clone)]
27pub struct PQBuildParams {
28    /// Number of sub-vectors to build PQ code
29    pub num_sub_vectors: usize,
30
31    /// The number of bits to present one PQ centroid.
32    pub num_bits: usize,
33
34    /// The max number of iterations for kmeans training.
35    pub max_iters: usize,
36
37    /// Run kmeans `REDOS` times and take the best result.
38    /// Default to 1.
39    pub kmeans_redos: usize,
40
41    /// User provided codebook.
42    pub codebook: Option<ArrayRef>,
43
44    /// Sample rate to train PQ codebook.
45    pub sample_rate: usize,
46}
47
48impl Default for PQBuildParams {
49    fn default() -> Self {
50        Self {
51            num_sub_vectors: 16,
52            num_bits: 8,
53            max_iters: 50,
54            kmeans_redos: 1,
55            codebook: None,
56            sample_rate: 256,
57        }
58    }
59}
60
61impl QuantizerBuildParams for PQBuildParams {
62    fn sample_size(&self) -> usize {
63        self.sample_rate * 2_usize.pow(self.num_bits as u32)
64    }
65
66    fn use_residual(distance_type: DistanceType) -> bool {
67        matches!(distance_type, DistanceType::L2 | DistanceType::Cosine)
68    }
69}
70
71impl PQBuildParams {
72    pub fn new(num_sub_vectors: usize, num_bits: usize) -> Self {
73        Self {
74            num_sub_vectors,
75            num_bits,
76            ..Default::default()
77        }
78    }
79
80    pub fn with_codebook(num_sub_vectors: usize, num_bits: usize, codebook: ArrayRef) -> Self {
81        Self {
82            num_sub_vectors,
83            num_bits,
84            codebook: Some(codebook),
85            ..Default::default()
86        }
87    }
88
89    fn build_from_fsl<T: ArrowNumericType>(
90        &self,
91        data: &FixedSizeListArray,
92        distance_type: DistanceType,
93    ) -> Result<ProductQuantizer>
94    where
95        T::Native: Dot + L2 + Normalize,
96        PrimitiveArray<T>: From<Vec<T::Native>>,
97    {
98        assert_ne!(
99            distance_type,
100            DistanceType::Cosine,
101            "PQ code does not support cosine"
102        );
103
104        let sub_vectors = divide_to_subvectors::<T>(data, self.num_sub_vectors)?;
105        let num_centroids = 2_usize.pow(self.num_bits as u32);
106        let dimension = data.value_length() as usize;
107        let sub_vector_dimension = dimension / self.num_sub_vectors;
108
109        let d = sub_vectors
110            .into_iter()
111            .enumerate()
112            .map(|(sub_vec_idx, sub_vec)| {
113                let params = KMeansParams::new(
114                    self.codebook.as_ref().map(|cb| {
115                        let sub_vec_centroids = FixedSizeListArray::try_new_from_values(
116                            cb.as_fixed_size_list().values().as_primitive::<T>().slice(
117                                sub_vec_idx * num_centroids * sub_vector_dimension,
118                                num_centroids * sub_vector_dimension,
119                            ),
120                            sub_vector_dimension as i32,
121                        )
122                        .unwrap();
123                        Arc::new(sub_vec_centroids)
124                    }),
125                    self.max_iters as u32,
126                    self.kmeans_redos,
127                    distance_type,
128                );
129                train_kmeans::<T>(
130                    &sub_vec,
131                    params,
132                    sub_vector_dimension,
133                    num_centroids,
134                    self.sample_rate,
135                )
136                .map(|kmeans| kmeans.centroids)
137            })
138            .collect::<Result<Vec<_>>>()?;
139        let mut codebook_builder = PrimitiveBuilder::<T>::with_capacity(num_centroids * dimension);
140        for centroid in d.iter() {
141            let c = centroid
142                .as_any()
143                .downcast_ref::<PrimitiveArray<T>>()
144                .expect("failed to downcast to PrimitiveArray");
145            codebook_builder.append_slice(c.values());
146        }
147
148        let pd_centroids = codebook_builder.finish();
149
150        Ok(ProductQuantizer::new(
151            self.num_sub_vectors,
152            self.num_bits as u32,
153            dimension,
154            FixedSizeListArray::try_new_from_values(pd_centroids, dimension as i32)?,
155            distance_type,
156        ))
157    }
158
159    /// Build a [ProductQuantizer] from the given data.
160    ///
161    /// If the [MetricType] is [MetricType::Cosine], the input data will be normalized.
162    pub fn build(&self, data: &dyn Array, distance_type: DistanceType) -> Result<ProductQuantizer> {
163        assert_eq!(data.null_count(), 0);
164        let fsl = data.as_fixed_size_list_opt().ok_or(Error::Index {
165            message: format!(
166                "PQ builder: input is not a FixedSizeList: {}",
167                data.data_type()
168            ),
169            location: location!(),
170        })?;
171
172        let num_centroids = 2_usize.pow(self.num_bits as u32);
173        if data.len() < num_centroids {
174            return Err(Error::Unprocessable {
175                message: format!(
176                    "Not enough rows to train PQ. Requires {num_centroids} rows but only {} available",
177                    data.len()
178                ),
179                location: location!(),
180            });
181        }
182
183        // TODO: support bf16 later.
184        match fsl.value_type() {
185            DataType::Float16 => self.build_from_fsl::<Float16Type>(fsl, distance_type),
186            DataType::Float32 => self.build_from_fsl::<Float32Type>(fsl, distance_type),
187            DataType::Float64 => self.build_from_fsl::<Float64Type>(fsl, distance_type),
188            _ => Err(Error::Index {
189                message: format!("PQ builder: unsupported data type: {}", fsl.value_type()),
190                location: location!(),
191            }),
192        }
193    }
194}