lance_index/vector/pq/
builder.rs1use std::sync::Arc;
8
9use crate::vector::quantizer::QuantizerBuildParams;
10use arrow::array::PrimitiveBuilder;
11use arrow_array::types::{Float16Type, Float64Type};
12use arrow_array::{Array, ArrayRef, cast::AsArray, types::Float32Type};
13use arrow_array::{ArrowNumericType, FixedSizeListArray, PrimitiveArray};
14use arrow_schema::DataType;
15use lance_arrow::FixedSizeListArrayExt;
16use lance_core::{Error, Result};
17use lance_linalg::distance::DistanceType;
18use lance_linalg::distance::{Dot, L2, Normalize};
19
20use super::ProductQuantizer;
21use super::utils::divide_to_subvectors;
22use crate::vector::kmeans::{KMeansParams, train_kmeans};
23
24#[derive(Debug, Clone)]
26pub struct PQBuildParams {
27 pub num_sub_vectors: usize,
29
30 pub num_bits: usize,
32
33 pub max_iters: usize,
35
36 pub kmeans_redos: usize,
39
40 pub codebook: Option<ArrayRef>,
42
43 pub sample_rate: usize,
45}
46
47impl Default for PQBuildParams {
48 fn default() -> Self {
49 Self {
50 num_sub_vectors: 16,
51 num_bits: 8,
52 max_iters: 50,
53 kmeans_redos: 1,
54 codebook: None,
55 sample_rate: 256,
56 }
57 }
58}
59
60impl QuantizerBuildParams for PQBuildParams {
61 fn sample_size(&self) -> usize {
62 self.sample_rate * 2_usize.pow(self.num_bits as u32)
63 }
64
65 fn use_residual(distance_type: DistanceType) -> bool {
66 matches!(distance_type, DistanceType::L2 | DistanceType::Cosine)
67 }
68}
69
70impl PQBuildParams {
71 pub fn new(num_sub_vectors: usize, num_bits: usize) -> Self {
72 Self {
73 num_sub_vectors,
74 num_bits,
75 ..Default::default()
76 }
77 }
78
79 pub fn with_codebook(num_sub_vectors: usize, num_bits: usize, codebook: ArrayRef) -> Self {
80 Self {
81 num_sub_vectors,
82 num_bits,
83 codebook: Some(codebook),
84 ..Default::default()
85 }
86 }
87
88 fn build_from_fsl<T: ArrowNumericType>(
89 &self,
90 data: &FixedSizeListArray,
91 distance_type: DistanceType,
92 ) -> Result<ProductQuantizer>
93 where
94 T::Native: Dot + L2 + Normalize,
95 PrimitiveArray<T>: From<Vec<T::Native>>,
96 {
97 assert_ne!(
98 distance_type,
99 DistanceType::Cosine,
100 "PQ code does not support cosine"
101 );
102
103 let sub_vectors = divide_to_subvectors::<T>(data, self.num_sub_vectors)?;
104 let num_centroids = 2_usize.pow(self.num_bits as u32);
105 let dimension = data.value_length() as usize;
106 let sub_vector_dimension = dimension / self.num_sub_vectors;
107
108 let d = sub_vectors
109 .into_iter()
110 .enumerate()
111 .map(|(sub_vec_idx, sub_vec)| {
112 let params = KMeansParams::new(
113 self.codebook.as_ref().map(|cb| {
114 let sub_vec_centroids = FixedSizeListArray::try_new_from_values(
115 cb.as_fixed_size_list().values().as_primitive::<T>().slice(
116 sub_vec_idx * num_centroids * sub_vector_dimension,
117 num_centroids * sub_vector_dimension,
118 ),
119 sub_vector_dimension as i32,
120 )
121 .unwrap();
122 Arc::new(sub_vec_centroids)
123 }),
124 self.max_iters as u32,
125 self.kmeans_redos,
126 distance_type,
127 );
128 train_kmeans::<T>(
129 &sub_vec,
130 params,
131 sub_vector_dimension,
132 num_centroids,
133 self.sample_rate,
134 )
135 .map(|kmeans| kmeans.centroids)
136 })
137 .collect::<Result<Vec<_>>>()?;
138 let mut codebook_builder = PrimitiveBuilder::<T>::with_capacity(num_centroids * dimension);
139 for centroid in d.iter() {
140 let c = centroid
141 .as_any()
142 .downcast_ref::<PrimitiveArray<T>>()
143 .expect("failed to downcast to PrimitiveArray");
144 codebook_builder.append_slice(c.values());
145 }
146
147 let pd_centroids = codebook_builder.finish();
148
149 Ok(ProductQuantizer::new(
150 self.num_sub_vectors,
151 self.num_bits as u32,
152 dimension,
153 FixedSizeListArray::try_new_from_values(pd_centroids, dimension as i32)?,
154 distance_type,
155 ))
156 }
157
158 pub fn build(&self, data: &dyn Array, distance_type: DistanceType) -> Result<ProductQuantizer> {
162 assert_eq!(data.null_count(), 0);
163 let fsl = data.as_fixed_size_list_opt().ok_or(Error::index(format!(
164 "PQ builder: input is not a FixedSizeList: {}",
165 data.data_type()
166 )))?;
167
168 let num_centroids = 2_usize.pow(self.num_bits as u32);
169 if data.len() < num_centroids {
170 return Err(Error::unprocessable(format!(
171 "Not enough rows to train PQ. Requires {num_centroids} rows but only {} available",
172 data.len()
173 )));
174 }
175
176 match fsl.value_type() {
178 DataType::Float16 => self.build_from_fsl::<Float16Type>(fsl, distance_type),
179 DataType::Float32 => self.build_from_fsl::<Float32Type>(fsl, distance_type),
180 DataType::Float64 => self.build_from_fsl::<Float64Type>(fsl, distance_type),
181 _ => Err(Error::index(format!(
182 "PQ builder: unsupported data type: {}",
183 fsl.value_type()
184 ))),
185 }
186 }
187}