lance_index/vector/pq/
builder.rs1use std::sync::Arc;
8
9use crate::vector::quantizer::QuantizerBuildParams;
10use arrow::array::PrimitiveBuilder;
11use arrow_array::types::{Float16Type, Float64Type};
12use arrow_array::{cast::AsArray, types::Float32Type, Array, ArrayRef};
13use arrow_array::{ArrowNumericType, FixedSizeListArray, PrimitiveArray};
14use arrow_schema::DataType;
15use lance_arrow::FixedSizeListArrayExt;
16use lance_core::{Error, Result};
17use lance_linalg::distance::DistanceType;
18use lance_linalg::distance::{Dot, Normalize, L2};
19use snafu::location;
20
21use super::utils::divide_to_subvectors;
22use super::ProductQuantizer;
23use crate::vector::kmeans::{train_kmeans, KMeansParams};
24
25#[derive(Debug, Clone)]
27pub struct PQBuildParams {
28 pub num_sub_vectors: usize,
30
31 pub num_bits: usize,
33
34 pub max_iters: usize,
36
37 pub kmeans_redos: usize,
40
41 pub codebook: Option<ArrayRef>,
43
44 pub sample_rate: usize,
46}
47
48impl Default for PQBuildParams {
49 fn default() -> Self {
50 Self {
51 num_sub_vectors: 16,
52 num_bits: 8,
53 max_iters: 50,
54 kmeans_redos: 1,
55 codebook: None,
56 sample_rate: 256,
57 }
58 }
59}
60
61impl QuantizerBuildParams for PQBuildParams {
62 fn sample_size(&self) -> usize {
63 self.sample_rate * 2_usize.pow(self.num_bits as u32)
64 }
65
66 fn use_residual(distance_type: DistanceType) -> bool {
67 matches!(distance_type, DistanceType::L2 | DistanceType::Cosine)
68 }
69}
70
71impl PQBuildParams {
72 pub fn new(num_sub_vectors: usize, num_bits: usize) -> Self {
73 Self {
74 num_sub_vectors,
75 num_bits,
76 ..Default::default()
77 }
78 }
79
80 pub fn with_codebook(num_sub_vectors: usize, num_bits: usize, codebook: ArrayRef) -> Self {
81 Self {
82 num_sub_vectors,
83 num_bits,
84 codebook: Some(codebook),
85 ..Default::default()
86 }
87 }
88
89 fn build_from_fsl<T: ArrowNumericType>(
90 &self,
91 data: &FixedSizeListArray,
92 distance_type: DistanceType,
93 ) -> Result<ProductQuantizer>
94 where
95 T::Native: Dot + L2 + Normalize,
96 PrimitiveArray<T>: From<Vec<T::Native>>,
97 {
98 assert_ne!(
99 distance_type,
100 DistanceType::Cosine,
101 "PQ code does not support cosine"
102 );
103
104 let sub_vectors = divide_to_subvectors::<T>(data, self.num_sub_vectors)?;
105 let num_centroids = 2_usize.pow(self.num_bits as u32);
106 let dimension = data.value_length() as usize;
107 let sub_vector_dimension = dimension / self.num_sub_vectors;
108
109 let d = sub_vectors
110 .into_iter()
111 .enumerate()
112 .map(|(sub_vec_idx, sub_vec)| {
113 let params = KMeansParams::new(
114 self.codebook.as_ref().map(|cb| {
115 let sub_vec_centroids = FixedSizeListArray::try_new_from_values(
116 cb.as_fixed_size_list().values().as_primitive::<T>().slice(
117 sub_vec_idx * num_centroids * sub_vector_dimension,
118 num_centroids * sub_vector_dimension,
119 ),
120 sub_vector_dimension as i32,
121 )
122 .unwrap();
123 Arc::new(sub_vec_centroids)
124 }),
125 self.max_iters as u32,
126 self.kmeans_redos,
127 distance_type,
128 );
129 train_kmeans::<T>(
130 &sub_vec,
131 params,
132 sub_vector_dimension,
133 num_centroids,
134 self.sample_rate,
135 )
136 .map(|kmeans| kmeans.centroids)
137 })
138 .collect::<Result<Vec<_>>>()?;
139 let mut codebook_builder = PrimitiveBuilder::<T>::with_capacity(num_centroids * dimension);
140 for centroid in d.iter() {
141 let c = centroid
142 .as_any()
143 .downcast_ref::<PrimitiveArray<T>>()
144 .expect("failed to downcast to PrimitiveArray");
145 codebook_builder.append_slice(c.values());
146 }
147
148 let pd_centroids = codebook_builder.finish();
149
150 Ok(ProductQuantizer::new(
151 self.num_sub_vectors,
152 self.num_bits as u32,
153 dimension,
154 FixedSizeListArray::try_new_from_values(pd_centroids, dimension as i32)?,
155 distance_type,
156 ))
157 }
158
159 pub fn build(&self, data: &dyn Array, distance_type: DistanceType) -> Result<ProductQuantizer> {
163 assert_eq!(data.null_count(), 0);
164 let fsl = data.as_fixed_size_list_opt().ok_or(Error::Index {
165 message: format!(
166 "PQ builder: input is not a FixedSizeList: {}",
167 data.data_type()
168 ),
169 location: location!(),
170 })?;
171
172 let num_centroids = 2_usize.pow(self.num_bits as u32);
173 if data.len() < num_centroids {
174 return Err(Error::Unprocessable {
175 message: format!(
176 "Not enough rows to train PQ. Requires {num_centroids} rows but only {} available",
177 data.len()
178 ),
179 location: location!(),
180 });
181 }
182
183 match fsl.value_type() {
185 DataType::Float16 => self.build_from_fsl::<Float16Type>(fsl, distance_type),
186 DataType::Float32 => self.build_from_fsl::<Float32Type>(fsl, distance_type),
187 DataType::Float64 => self.build_from_fsl::<Float64Type>(fsl, distance_type),
188 _ => Err(Error::Index {
189 message: format!("PQ builder: unsupported data type: {}", fsl.value_type()),
190 location: location!(),
191 }),
192 }
193 }
194}