lance_index/vector/pq/
builder.rs1use std::sync::Arc;
8
9use crate::vector::quantizer::QuantizerBuildParams;
10use arrow::array::PrimitiveBuilder;
11use arrow_array::types::{Float16Type, Float64Type};
12use arrow_array::{Array, ArrayRef, cast::AsArray, types::Float32Type};
13use arrow_array::{ArrowNumericType, FixedSizeListArray, PrimitiveArray};
14use arrow_schema::DataType;
15use lance_arrow::FixedSizeListArrayExt;
16use lance_core::{Error, Result};
17use lance_linalg::distance::DistanceType;
18use lance_linalg::distance::{Dot, L2, Normalize};
19
20use super::ProductQuantizer;
21use super::utils::divide_to_subvectors;
22use crate::vector::kmeans::{KMeansParams, train_kmeans};
23
24#[derive(Debug, Clone)]
26pub struct PQBuildParams {
27 pub num_sub_vectors: usize,
29
30 pub num_bits: usize,
32
33 pub max_iters: usize,
35
36 pub kmeans_redos: usize,
39
40 pub codebook: Option<ArrayRef>,
42
43 pub sample_rate: usize,
45}
46
47impl From<&PQBuildParams> for crate::pb::vector_index_details::ProductQuantization {
48 fn from(params: &PQBuildParams) -> Self {
49 Self {
50 num_bits: params.num_bits as u32,
51 num_sub_vectors: params.num_sub_vectors as u32,
52 }
53 }
54}
55
56impl Default for PQBuildParams {
57 fn default() -> Self {
58 Self {
59 num_sub_vectors: 16,
60 num_bits: 8,
61 max_iters: 50,
62 kmeans_redos: 1,
63 codebook: None,
64 sample_rate: 256,
65 }
66 }
67}
68
69impl QuantizerBuildParams for PQBuildParams {
70 fn sample_size(&self) -> usize {
71 self.sample_rate * 2_usize.pow(self.num_bits as u32)
72 }
73
74 fn use_residual(distance_type: DistanceType) -> bool {
75 matches!(distance_type, DistanceType::L2 | DistanceType::Cosine)
76 }
77}
78
79impl PQBuildParams {
80 pub fn new(num_sub_vectors: usize, num_bits: usize) -> Self {
81 Self {
82 num_sub_vectors,
83 num_bits,
84 ..Default::default()
85 }
86 }
87
88 pub fn with_codebook(num_sub_vectors: usize, num_bits: usize, codebook: ArrayRef) -> Self {
89 Self {
90 num_sub_vectors,
91 num_bits,
92 codebook: Some(codebook),
93 ..Default::default()
94 }
95 }
96
97 fn build_from_fsl<T: ArrowNumericType>(
98 &self,
99 data: &FixedSizeListArray,
100 distance_type: DistanceType,
101 ) -> Result<ProductQuantizer>
102 where
103 T::Native: Dot + L2 + Normalize,
104 PrimitiveArray<T>: From<Vec<T::Native>>,
105 {
106 assert_ne!(
107 distance_type,
108 DistanceType::Cosine,
109 "PQ code does not support cosine"
110 );
111
112 let sub_vectors = divide_to_subvectors::<T>(data, self.num_sub_vectors)?;
113 let num_centroids = 2_usize.pow(self.num_bits as u32);
114 let dimension = data.value_length() as usize;
115 let sub_vector_dimension = dimension / self.num_sub_vectors;
116
117 let d = sub_vectors
118 .into_iter()
119 .enumerate()
120 .map(|(sub_vec_idx, sub_vec)| {
121 let params = KMeansParams::new(
122 self.codebook.as_ref().map(|cb| {
123 let sub_vec_centroids = FixedSizeListArray::try_new_from_values(
124 cb.as_fixed_size_list().values().as_primitive::<T>().slice(
125 sub_vec_idx * num_centroids * sub_vector_dimension,
126 num_centroids * sub_vector_dimension,
127 ),
128 sub_vector_dimension as i32,
129 )
130 .unwrap();
131 Arc::new(sub_vec_centroids)
132 }),
133 self.max_iters as u32,
134 self.kmeans_redos,
135 distance_type,
136 );
137 train_kmeans::<T>(
138 &sub_vec,
139 params,
140 sub_vector_dimension,
141 num_centroids,
142 self.sample_rate,
143 )
144 .map(|kmeans| kmeans.centroids)
145 })
146 .collect::<Result<Vec<_>>>()?;
147 let mut codebook_builder = PrimitiveBuilder::<T>::with_capacity(num_centroids * dimension);
148 for centroid in d.iter() {
149 let c = centroid
150 .as_any()
151 .downcast_ref::<PrimitiveArray<T>>()
152 .expect("failed to downcast to PrimitiveArray");
153 codebook_builder.append_slice(c.values());
154 }
155
156 let pd_centroids = codebook_builder.finish();
157
158 Ok(ProductQuantizer::new(
159 self.num_sub_vectors,
160 self.num_bits as u32,
161 dimension,
162 FixedSizeListArray::try_new_from_values(pd_centroids, dimension as i32)?,
163 distance_type,
164 ))
165 }
166
167 pub fn build(&self, data: &dyn Array, distance_type: DistanceType) -> Result<ProductQuantizer> {
171 assert_eq!(data.null_count(), 0);
172 let fsl = data.as_fixed_size_list_opt().ok_or(Error::index(format!(
173 "PQ builder: input is not a FixedSizeList: {}",
174 data.data_type()
175 )))?;
176
177 let num_centroids = 2_usize.pow(self.num_bits as u32);
178 if data.len() < num_centroids {
179 return Err(Error::unprocessable(format!(
180 "Not enough rows to train PQ. Requires {num_centroids} rows but only {} available",
181 data.len()
182 )));
183 }
184
185 match fsl.value_type() {
187 DataType::Float16 => self.build_from_fsl::<Float16Type>(fsl, distance_type),
188 DataType::Float32 => self.build_from_fsl::<Float32Type>(fsl, distance_type),
189 DataType::Float64 => self.build_from_fsl::<Float64Type>(fsl, distance_type),
190 _ => Err(Error::index(format!(
191 "PQ builder: unsupported data type: {}",
192 fsl.value_type()
193 ))),
194 }
195 }
196}