lance_index/vector/
ivf.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! IVF - Inverted File Index
5
6use std::ops::Range;
7use std::sync::Arc;
8
9use arrow_array::{Array, FixedSizeListArray, Float32Array, RecordBatch, UInt32Array};
10
11pub use builder::IvfBuildParams;
12use lance_core::Result;
13use lance_linalg::distance::{DistanceType, MetricType};
14use tracing::instrument;
15
16use crate::vector::bq::builder::RabitQuantizer;
17use crate::vector::bq::transform::RQTransformer;
18use crate::vector::ivf::transform::PartitionTransformer;
19use crate::vector::kmeans::{compute_partitions_arrow_array, kmeans_find_partitions_arrow_array};
20use crate::vector::{pq::ProductQuantizer, transform::Transformer};
21
22use super::flat::transform::FlatTransformer;
23use super::pq::transform::PQTransformer;
24use super::quantizer::Quantization;
25use super::residual::ResidualTransform;
26use super::sq::transform::SQTransformer;
27use super::sq::ScalarQuantizer;
28use super::transform::KeepFiniteVectors;
29use super::{quantizer::Quantizer, residual::compute_residual};
30use super::{PART_ID_COLUMN, PQ_CODE_COLUMN, SQ_CODE_COLUMN};
31
32pub mod builder;
33pub mod shuffler;
34pub mod storage;
35mod transform;
36
37/// Create an IVF from the flatten centroids.
38///
39/// Parameters
40/// ----------
41/// - *centroids*: a flatten floating number array of centroids.
42/// - *dimension*: dimension of the vector.
43/// - *metric_type*: metric type to compute pair-wise vector distance.
44/// - *transforms*: a list of transforms to apply to the vector column.
45/// - *range*: only covers a range of partitions. Default is None
46pub fn new_ivf_transformer(
47    centroids: FixedSizeListArray,
48    metric_type: DistanceType,
49    transforms: Vec<Arc<dyn Transformer>>,
50) -> IvfTransformer {
51    IvfTransformer::new(centroids, metric_type, transforms)
52}
53
54pub fn new_ivf_transformer_with_quantizer(
55    centroids: FixedSizeListArray,
56    metric_type: MetricType,
57    vector_column: &str,
58    quantizer: Quantizer,
59    range: Option<Range<u32>>,
60) -> Result<IvfTransformer> {
61    match quantizer {
62        Quantizer::Flat(_) | Quantizer::FlatBin(_) => Ok(IvfTransformer::new_flat(
63            centroids,
64            metric_type,
65            vector_column,
66            range,
67        )),
68        Quantizer::Product(pq) => Ok(IvfTransformer::with_pq(
69            centroids,
70            metric_type,
71            vector_column,
72            pq,
73            range,
74        )),
75        Quantizer::Scalar(sq) => Ok(IvfTransformer::with_sq(
76            centroids,
77            metric_type,
78            vector_column,
79            sq,
80            range,
81        )),
82        Quantizer::Rabit(rq) => Ok(IvfTransformer::with_rq(
83            centroids,
84            metric_type,
85            vector_column,
86            rq,
87            range,
88        )),
89    }
90}
91
92/// IVF - IVF file partition
93///
94#[derive(Debug)]
95pub struct IvfTransformer {
96    /// Centroids of a cluster algorithm, to run IVF.
97    ///
98    /// It is a 2-D `(num_partitions * dimension)` of floating array.
99    centroids: FixedSizeListArray,
100
101    /// Transform applied to each partition.
102    transforms: Vec<Arc<dyn Transformer>>,
103
104    /// Metric type to compute pair-wise vector distance.
105    distance_type: DistanceType,
106}
107
108impl IvfTransformer {
109    /// Create a new Ivf model.
110    pub fn new(
111        centroids: FixedSizeListArray,
112        metric_type: MetricType,
113        transforms: Vec<Arc<dyn Transformer>>,
114    ) -> Self {
115        Self {
116            centroids,
117            distance_type: metric_type,
118            transforms,
119        }
120    }
121
122    pub fn new_partition_transformer(
123        centroids: FixedSizeListArray,
124        distance_type: DistanceType,
125        vector_column: &str,
126    ) -> Self {
127        let mut transforms: Vec<Arc<dyn Transformer>> =
128            vec![Arc::new(super::transform::Flatten::new(vector_column))];
129
130        let distance_type = if distance_type == MetricType::Cosine {
131            transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
132                vector_column,
133            )));
134            MetricType::L2
135        } else {
136            distance_type
137        };
138        transforms.push(Arc::new(KeepFiniteVectors::new(vector_column)));
139
140        let partition_transform = Arc::new(PartitionTransformer::new(
141            centroids.clone(),
142            distance_type,
143            vector_column,
144        ));
145        transforms.push(partition_transform);
146        Self::new(centroids, distance_type, transforms)
147    }
148
149    pub fn new_flat(
150        centroids: FixedSizeListArray,
151        distance_type: DistanceType,
152        vector_column: &str,
153        range: Option<Range<u32>>,
154    ) -> Self {
155        let mut transforms: Vec<Arc<dyn Transformer>> =
156            vec![Arc::new(super::transform::Flatten::new(vector_column))];
157
158        let dt = if distance_type == DistanceType::Cosine {
159            transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
160                vector_column,
161            )));
162            MetricType::L2
163        } else {
164            distance_type
165        };
166        transforms.push(Arc::new(KeepFiniteVectors::new(vector_column)));
167
168        let ivf_transform = Arc::new(PartitionTransformer::new(
169            centroids.clone(),
170            dt,
171            vector_column,
172        ));
173        transforms.push(ivf_transform);
174
175        if let Some(range) = range {
176            transforms.push(Arc::new(transform::PartitionFilter::new(
177                PART_ID_COLUMN,
178                range,
179            )));
180        }
181
182        transforms.push(Arc::new(FlatTransformer::new(vector_column)));
183
184        Self::new(centroids, distance_type, transforms)
185    }
186
187    /// Create a IVF_PQ struct.
188    pub fn with_pq(
189        centroids: FixedSizeListArray,
190        distance_type: DistanceType,
191        vector_column: &str,
192        pq: ProductQuantizer,
193        range: Option<Range<u32>>,
194    ) -> Self {
195        let mut transforms: Vec<Arc<dyn Transformer>> =
196            vec![Arc::new(super::transform::Flatten::new(vector_column))];
197
198        let distance_type = if distance_type == MetricType::Cosine {
199            transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
200                vector_column,
201            )));
202            MetricType::L2
203        } else {
204            distance_type
205        };
206        transforms.push(Arc::new(KeepFiniteVectors::new(vector_column)));
207
208        let partition_transform = Arc::new(PartitionTransformer::new(
209            centroids.clone(),
210            distance_type,
211            vector_column,
212        ));
213        transforms.push(partition_transform);
214
215        if let Some(range) = range {
216            transforms.push(Arc::new(transform::PartitionFilter::new(
217                PART_ID_COLUMN,
218                range,
219            )));
220        }
221
222        if ProductQuantizer::use_residual(distance_type) {
223            transforms.push(Arc::new(ResidualTransform::new(
224                centroids.clone(),
225                PART_ID_COLUMN,
226                vector_column,
227            )));
228        }
229        transforms.push(Arc::new(PQTransformer::new(
230            pq,
231            vector_column,
232            PQ_CODE_COLUMN,
233        )));
234
235        Self::new(centroids, distance_type, transforms)
236    }
237
238    fn with_sq(
239        centroids: FixedSizeListArray,
240        metric_type: MetricType,
241        vector_column: &str,
242        sq: ScalarQuantizer,
243        range: Option<Range<u32>>,
244    ) -> Self {
245        let mut transforms: Vec<Arc<dyn Transformer>> =
246            vec![Arc::new(super::transform::Flatten::new(vector_column))];
247
248        let distance_type = if metric_type == MetricType::Cosine {
249            transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
250                vector_column,
251            )));
252            MetricType::L2
253        } else {
254            metric_type
255        };
256        transforms.push(Arc::new(KeepFiniteVectors::new(vector_column)));
257
258        let partition_transformer = Arc::new(PartitionTransformer::new(
259            centroids.clone(),
260            distance_type,
261            vector_column,
262        ));
263        transforms.push(partition_transformer);
264
265        if let Some(range) = range {
266            transforms.push(Arc::new(transform::PartitionFilter::new(
267                PART_ID_COLUMN,
268                range,
269            )));
270        }
271
272        transforms.push(Arc::new(SQTransformer::new(
273            sq,
274            vector_column.to_owned(),
275            SQ_CODE_COLUMN.to_owned(),
276        )));
277
278        Self::new(centroids, distance_type, transforms)
279    }
280
281    fn with_rq(
282        centroids: FixedSizeListArray,
283        distance_type: DistanceType,
284        vector_column: &str,
285        rq: RabitQuantizer,
286        range: Option<Range<u32>>,
287    ) -> Self {
288        let mut transforms: Vec<Arc<dyn Transformer>> =
289            vec![Arc::new(super::transform::Flatten::new(vector_column))];
290
291        let distance_type = if distance_type == MetricType::Cosine {
292            transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
293                vector_column,
294            )));
295            MetricType::L2
296        } else {
297            distance_type
298        };
299        transforms.push(Arc::new(KeepFiniteVectors::new(vector_column)));
300
301        let partition_transform = Arc::new(
302            PartitionTransformer::new(centroids.clone(), distance_type, vector_column)
303                .with_distance(true),
304        );
305        transforms.push(partition_transform);
306
307        if let Some(range) = range {
308            transforms.push(Arc::new(transform::PartitionFilter::new(
309                PART_ID_COLUMN,
310                range,
311            )));
312        }
313
314        transforms.push(Arc::new(ResidualTransform::new(
315            centroids.clone(),
316            PART_ID_COLUMN,
317            vector_column,
318        )));
319
320        transforms.push(Arc::new(RQTransformer::new(
321            rq,
322            distance_type,
323            centroids.clone(),
324            vector_column,
325        )));
326
327        Self::new(centroids, distance_type, transforms)
328    }
329
330    #[inline]
331    pub fn compute_residual(&self, data: &FixedSizeListArray) -> Result<FixedSizeListArray> {
332        compute_residual(&self.centroids, data, Some(self.distance_type), None)
333    }
334
335    #[inline]
336    pub fn compute_partitions(&self, data: &FixedSizeListArray) -> Result<UInt32Array> {
337        Ok(
338            compute_partitions_arrow_array(&self.centroids, data, self.distance_type)
339                .map(|(part_ids, _)| part_ids.into())?,
340        )
341    }
342
343    pub fn find_partitions(
344        &self,
345        query: &dyn Array,
346        nprobes: usize,
347    ) -> Result<(UInt32Array, Float32Array)> {
348        Ok(kmeans_find_partitions_arrow_array(
349            &self.centroids,
350            query,
351            nprobes,
352            self.distance_type,
353        )?)
354    }
355}
356
357impl Transformer for IvfTransformer {
358    #[instrument(name = "IvfTransformer::transform", level = "debug", skip_all)]
359    fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
360        let mut batch = batch.clone();
361        for transform in self.transforms.as_slice() {
362            batch = transform.transform(&batch)?;
363        }
364        Ok(batch)
365    }
366}