1use std::ops::Range;
7use std::sync::Arc;
8
9use arrow_array::{Array, FixedSizeListArray, Float32Array, RecordBatch, UInt32Array};
10
11pub use builder::IvfBuildParams;
12use lance_core::Result;
13use lance_linalg::distance::{DistanceType, MetricType};
14use tracing::instrument;
15
16use crate::vector::bq::builder::RabitQuantizer;
17use crate::vector::bq::transform::RQTransformer;
18use crate::vector::ivf::transform::PartitionTransformer;
19use crate::vector::kmeans::{compute_partitions_arrow_array, kmeans_find_partitions_arrow_array};
20use crate::vector::{pq::ProductQuantizer, transform::Transformer};
21
22use super::flat::transform::FlatTransformer;
23use super::pq::transform::PQTransformer;
24use super::quantizer::Quantization;
25use super::residual::ResidualTransform;
26use super::sq::transform::SQTransformer;
27use super::sq::ScalarQuantizer;
28use super::transform::KeepFiniteVectors;
29use super::{quantizer::Quantizer, residual::compute_residual};
30use super::{PART_ID_COLUMN, PQ_CODE_COLUMN, SQ_CODE_COLUMN};
31
32pub mod builder;
33pub mod shuffler;
34pub mod storage;
35mod transform;
36
37pub fn new_ivf_transformer(
47 centroids: FixedSizeListArray,
48 metric_type: DistanceType,
49 transforms: Vec<Arc<dyn Transformer>>,
50) -> IvfTransformer {
51 IvfTransformer::new(centroids, metric_type, transforms)
52}
53
54pub fn new_ivf_transformer_with_quantizer(
55 centroids: FixedSizeListArray,
56 metric_type: MetricType,
57 vector_column: &str,
58 quantizer: Quantizer,
59 range: Option<Range<u32>>,
60) -> Result<IvfTransformer> {
61 match quantizer {
62 Quantizer::Flat(_) | Quantizer::FlatBin(_) => Ok(IvfTransformer::new_flat(
63 centroids,
64 metric_type,
65 vector_column,
66 range,
67 )),
68 Quantizer::Product(pq) => Ok(IvfTransformer::with_pq(
69 centroids,
70 metric_type,
71 vector_column,
72 pq,
73 range,
74 )),
75 Quantizer::Scalar(sq) => Ok(IvfTransformer::with_sq(
76 centroids,
77 metric_type,
78 vector_column,
79 sq,
80 range,
81 )),
82 Quantizer::Rabit(rq) => Ok(IvfTransformer::with_rq(
83 centroids,
84 metric_type,
85 vector_column,
86 rq,
87 range,
88 )),
89 }
90}
91
92#[derive(Debug)]
95pub struct IvfTransformer {
96 centroids: FixedSizeListArray,
100
101 transforms: Vec<Arc<dyn Transformer>>,
103
104 distance_type: DistanceType,
106}
107
108impl IvfTransformer {
109 pub fn new(
111 centroids: FixedSizeListArray,
112 metric_type: MetricType,
113 transforms: Vec<Arc<dyn Transformer>>,
114 ) -> Self {
115 Self {
116 centroids,
117 distance_type: metric_type,
118 transforms,
119 }
120 }
121
122 pub fn new_partition_transformer(
123 centroids: FixedSizeListArray,
124 distance_type: DistanceType,
125 vector_column: &str,
126 ) -> Self {
127 let mut transforms: Vec<Arc<dyn Transformer>> =
128 vec![Arc::new(super::transform::Flatten::new(vector_column))];
129
130 let distance_type = if distance_type == MetricType::Cosine {
131 transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
132 vector_column,
133 )));
134 MetricType::L2
135 } else {
136 distance_type
137 };
138 transforms.push(Arc::new(KeepFiniteVectors::new(vector_column)));
139
140 let partition_transform = Arc::new(PartitionTransformer::new(
141 centroids.clone(),
142 distance_type,
143 vector_column,
144 ));
145 transforms.push(partition_transform);
146 Self::new(centroids, distance_type, transforms)
147 }
148
149 pub fn new_flat(
150 centroids: FixedSizeListArray,
151 distance_type: DistanceType,
152 vector_column: &str,
153 range: Option<Range<u32>>,
154 ) -> Self {
155 let mut transforms: Vec<Arc<dyn Transformer>> =
156 vec![Arc::new(super::transform::Flatten::new(vector_column))];
157
158 let dt = if distance_type == DistanceType::Cosine {
159 transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
160 vector_column,
161 )));
162 MetricType::L2
163 } else {
164 distance_type
165 };
166 transforms.push(Arc::new(KeepFiniteVectors::new(vector_column)));
167
168 let ivf_transform = Arc::new(PartitionTransformer::new(
169 centroids.clone(),
170 dt,
171 vector_column,
172 ));
173 transforms.push(ivf_transform);
174
175 if let Some(range) = range {
176 transforms.push(Arc::new(transform::PartitionFilter::new(
177 PART_ID_COLUMN,
178 range,
179 )));
180 }
181
182 transforms.push(Arc::new(FlatTransformer::new(vector_column)));
183
184 Self::new(centroids, distance_type, transforms)
185 }
186
187 pub fn with_pq(
189 centroids: FixedSizeListArray,
190 distance_type: DistanceType,
191 vector_column: &str,
192 pq: ProductQuantizer,
193 range: Option<Range<u32>>,
194 ) -> Self {
195 let mut transforms: Vec<Arc<dyn Transformer>> =
196 vec![Arc::new(super::transform::Flatten::new(vector_column))];
197
198 let distance_type = if distance_type == MetricType::Cosine {
199 transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
200 vector_column,
201 )));
202 MetricType::L2
203 } else {
204 distance_type
205 };
206 transforms.push(Arc::new(KeepFiniteVectors::new(vector_column)));
207
208 let partition_transform = Arc::new(PartitionTransformer::new(
209 centroids.clone(),
210 distance_type,
211 vector_column,
212 ));
213 transforms.push(partition_transform);
214
215 if let Some(range) = range {
216 transforms.push(Arc::new(transform::PartitionFilter::new(
217 PART_ID_COLUMN,
218 range,
219 )));
220 }
221
222 if ProductQuantizer::use_residual(distance_type) {
223 transforms.push(Arc::new(ResidualTransform::new(
224 centroids.clone(),
225 PART_ID_COLUMN,
226 vector_column,
227 )));
228 }
229 transforms.push(Arc::new(PQTransformer::new(
230 pq,
231 vector_column,
232 PQ_CODE_COLUMN,
233 )));
234
235 Self::new(centroids, distance_type, transforms)
236 }
237
238 fn with_sq(
239 centroids: FixedSizeListArray,
240 metric_type: MetricType,
241 vector_column: &str,
242 sq: ScalarQuantizer,
243 range: Option<Range<u32>>,
244 ) -> Self {
245 let mut transforms: Vec<Arc<dyn Transformer>> =
246 vec![Arc::new(super::transform::Flatten::new(vector_column))];
247
248 let distance_type = if metric_type == MetricType::Cosine {
249 transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
250 vector_column,
251 )));
252 MetricType::L2
253 } else {
254 metric_type
255 };
256 transforms.push(Arc::new(KeepFiniteVectors::new(vector_column)));
257
258 let partition_transformer = Arc::new(PartitionTransformer::new(
259 centroids.clone(),
260 distance_type,
261 vector_column,
262 ));
263 transforms.push(partition_transformer);
264
265 if let Some(range) = range {
266 transforms.push(Arc::new(transform::PartitionFilter::new(
267 PART_ID_COLUMN,
268 range,
269 )));
270 }
271
272 transforms.push(Arc::new(SQTransformer::new(
273 sq,
274 vector_column.to_owned(),
275 SQ_CODE_COLUMN.to_owned(),
276 )));
277
278 Self::new(centroids, distance_type, transforms)
279 }
280
281 fn with_rq(
282 centroids: FixedSizeListArray,
283 distance_type: DistanceType,
284 vector_column: &str,
285 rq: RabitQuantizer,
286 range: Option<Range<u32>>,
287 ) -> Self {
288 let mut transforms: Vec<Arc<dyn Transformer>> =
289 vec![Arc::new(super::transform::Flatten::new(vector_column))];
290
291 let distance_type = if distance_type == MetricType::Cosine {
292 transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
293 vector_column,
294 )));
295 MetricType::L2
296 } else {
297 distance_type
298 };
299 transforms.push(Arc::new(KeepFiniteVectors::new(vector_column)));
300
301 let partition_transform = Arc::new(
302 PartitionTransformer::new(centroids.clone(), distance_type, vector_column)
303 .with_distance(true),
304 );
305 transforms.push(partition_transform);
306
307 if let Some(range) = range {
308 transforms.push(Arc::new(transform::PartitionFilter::new(
309 PART_ID_COLUMN,
310 range,
311 )));
312 }
313
314 transforms.push(Arc::new(ResidualTransform::new(
315 centroids.clone(),
316 PART_ID_COLUMN,
317 vector_column,
318 )));
319
320 transforms.push(Arc::new(RQTransformer::new(
321 rq,
322 distance_type,
323 centroids.clone(),
324 vector_column,
325 )));
326
327 Self::new(centroids, distance_type, transforms)
328 }
329
330 #[inline]
331 pub fn compute_residual(&self, data: &FixedSizeListArray) -> Result<FixedSizeListArray> {
332 compute_residual(&self.centroids, data, Some(self.distance_type), None)
333 }
334
335 #[inline]
336 pub fn compute_partitions(&self, data: &FixedSizeListArray) -> Result<UInt32Array> {
337 Ok(
338 compute_partitions_arrow_array(&self.centroids, data, self.distance_type)
339 .map(|(part_ids, _)| part_ids.into())?,
340 )
341 }
342
343 pub fn find_partitions(
344 &self,
345 query: &dyn Array,
346 nprobes: usize,
347 ) -> Result<(UInt32Array, Float32Array)> {
348 Ok(kmeans_find_partitions_arrow_array(
349 &self.centroids,
350 query,
351 nprobes,
352 self.distance_type,
353 )?)
354 }
355}
356
357impl Transformer for IvfTransformer {
358 #[instrument(name = "IvfTransformer::transform", level = "debug", skip_all)]
359 fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
360 let mut batch = batch.clone();
361 for transform in self.transforms.as_slice() {
362 batch = transform.transform(&batch)?;
363 }
364 Ok(batch)
365 }
366}