lance_index/vector/flat/
index.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Flat Vector Index.
5//!
6
7use std::collections::{BinaryHeap, HashMap};
8use std::sync::Arc;
9
10use arrow::array::AsArray;
11use arrow_array::{Array, ArrayRef, Float32Array, RecordBatch, UInt64Array};
12use arrow_schema::{DataType, Field, Schema, SchemaRef};
13use deepsize::DeepSizeOf;
14use lance_core::{Error, Result, ROW_ID_FIELD};
15use lance_file::previous::reader::FileReader as PreviousFileReader;
16use lance_linalg::distance::DistanceType;
17use serde::{Deserialize, Serialize};
18use snafu::location;
19
20use crate::{
21    metrics::MetricsCollector,
22    prefilter::PreFilter,
23    vector::{
24        graph::OrderedNode,
25        quantizer::{Quantization, QuantizationType, Quantizer, QuantizerMetadata},
26        storage::{DistCalculator, VectorStore},
27        v3::subindex::IvfSubIndex,
28        Query, DIST_COL,
29    },
30};
31
32use super::storage::{FlatBinStorage, FlatFloatStorage, FLAT_COLUMN};
33
34/// A Flat index is any index that stores no metadata, and
35/// during query, it simply scans over the storage and returns the top k results
36#[derive(Debug, Clone, Default, DeepSizeOf)]
37pub struct FlatIndex {}
38
39use std::sync::LazyLock;
40
41static ANN_SEARCH_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
42    Schema::new(vec![
43        Field::new(DIST_COL, DataType::Float32, true),
44        ROW_ID_FIELD.clone(),
45    ])
46    .into()
47});
48
49#[derive(Default)]
50pub struct FlatQueryParams {
51    lower_bound: Option<f32>,
52    upper_bound: Option<f32>,
53    dist_q_c: f32,
54}
55
56impl From<&Query> for FlatQueryParams {
57    fn from(q: &Query) -> Self {
58        Self {
59            lower_bound: q.lower_bound,
60            upper_bound: q.upper_bound,
61            dist_q_c: q.dist_q_c,
62        }
63    }
64}
65
66impl IvfSubIndex for FlatIndex {
67    type QueryParams = FlatQueryParams;
68    type BuildParams = ();
69
70    fn name() -> &'static str {
71        "FLAT"
72    }
73
74    fn metadata_key() -> &'static str {
75        "lance:flat"
76    }
77
78    fn schema() -> arrow_schema::SchemaRef {
79        Schema::new(vec![Field::new("__flat_marker", DataType::UInt64, false)]).into()
80    }
81
82    fn search(
83        &self,
84        query: ArrayRef,
85        k: usize,
86        params: Self::QueryParams,
87        storage: &impl VectorStore,
88        prefilter: Arc<dyn PreFilter>,
89        metrics: &dyn MetricsCollector,
90    ) -> Result<RecordBatch> {
91        let is_range_query = params.lower_bound.is_some() || params.upper_bound.is_some();
92        let row_ids = storage.row_ids();
93        let dist_calc = storage.dist_calculator(query, params.dist_q_c);
94        let mut res = BinaryHeap::with_capacity(k);
95        metrics.record_comparisons(storage.len());
96
97        match prefilter.is_empty() {
98            true => {
99                let dists = dist_calc.distance_all(k);
100
101                if is_range_query {
102                    let lower_bound = params.lower_bound.unwrap_or(f32::MIN).into();
103                    let upper_bound = params.upper_bound.unwrap_or(f32::MAX).into();
104
105                    for (&row_id, dist) in row_ids.zip(dists) {
106                        let dist = dist.into();
107                        if dist < lower_bound || dist >= upper_bound {
108                            continue;
109                        }
110                        if res.len() < k {
111                            res.push(OrderedNode::new(row_id, dist));
112                        } else if res.peek().unwrap().dist > dist {
113                            res.pop();
114                            res.push(OrderedNode::new(row_id, dist));
115                        }
116                    }
117                } else {
118                    for (&row_id, dist) in row_ids.zip(dists) {
119                        let dist = dist.into();
120                        if res.len() < k {
121                            res.push(OrderedNode::new(row_id, dist));
122                        } else if res.peek().unwrap().dist > dist {
123                            res.pop();
124                            res.push(OrderedNode::new(row_id, dist));
125                        }
126                    }
127                }
128            }
129            false => {
130                let row_id_mask = prefilter.mask();
131                if is_range_query {
132                    let lower_bound = params.lower_bound.unwrap_or(f32::MIN).into();
133                    let upper_bound = params.upper_bound.unwrap_or(f32::MAX).into();
134                    for (id, &row_id) in row_ids.enumerate() {
135                        if !row_id_mask.selected(row_id) {
136                            continue;
137                        }
138                        let dist = dist_calc.distance(id as u32).into();
139                        if dist < lower_bound || dist >= upper_bound {
140                            continue;
141                        }
142
143                        if res.len() < k {
144                            res.push(OrderedNode::new(row_id, dist));
145                        } else if res.peek().unwrap().dist > dist {
146                            res.pop();
147                            res.push(OrderedNode::new(row_id, dist));
148                        }
149                    }
150                } else {
151                    for (id, &row_id) in row_ids.enumerate() {
152                        if !row_id_mask.selected(row_id) {
153                            continue;
154                        }
155
156                        let dist = dist_calc.distance(id as u32).into();
157                        if res.len() < k {
158                            res.push(OrderedNode::new(row_id, dist));
159                        } else if res.peek().unwrap().dist > dist {
160                            res.pop();
161                            res.push(OrderedNode::new(row_id, dist));
162                        }
163                    }
164                }
165            }
166        };
167
168        // we don't need to sort the results by distances here
169        // because there's a SortExec node in the query plan which sorts the results from all partitions
170        let (row_ids, dists): (Vec<_>, Vec<_>) = res.into_iter().map(|r| (r.id, r.dist.0)).unzip();
171        let (row_ids, dists) = (UInt64Array::from(row_ids), Float32Array::from(dists));
172
173        Ok(RecordBatch::try_new(
174            ANN_SEARCH_SCHEMA.clone(),
175            vec![Arc::new(dists), Arc::new(row_ids)],
176        )?)
177    }
178
179    fn load(_: RecordBatch) -> Result<Self> {
180        Ok(Self {})
181    }
182
183    fn index_vectors(_: &impl VectorStore, _: Self::BuildParams) -> Result<Self>
184    where
185        Self: Sized,
186    {
187        Ok(Self {})
188    }
189
190    fn remap(&self, _: &HashMap<u64, Option<u64>>, _: &impl VectorStore) -> Result<Self> {
191        Ok(self.clone())
192    }
193
194    fn to_batch(&self) -> Result<RecordBatch> {
195        Ok(RecordBatch::new_empty(Schema::empty().into()))
196    }
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize, DeepSizeOf)]
200pub struct FlatMetadata {
201    pub dim: usize,
202}
203
204#[async_trait::async_trait]
205impl QuantizerMetadata for FlatMetadata {
206    async fn load(_: &PreviousFileReader) -> Result<Self> {
207        unimplemented!("Flat will be used in new index builder which doesn't require this")
208    }
209}
210
211#[derive(Debug, Clone, DeepSizeOf)]
212pub struct FlatQuantizer {
213    dim: usize,
214    distance_type: DistanceType,
215}
216
217impl FlatQuantizer {
218    pub fn new(dim: usize, distance_type: DistanceType) -> Self {
219        Self { dim, distance_type }
220    }
221}
222
223impl Quantization for FlatQuantizer {
224    type BuildParams = ();
225    type Metadata = FlatMetadata;
226    type Storage = FlatFloatStorage;
227
228    fn build(data: &dyn Array, distance_type: DistanceType, _: &Self::BuildParams) -> Result<Self> {
229        let dim = data.as_fixed_size_list().value_length();
230        Ok(Self::new(dim as usize, distance_type))
231    }
232
233    fn retrain(&mut self, _: &dyn Array) -> Result<()> {
234        Ok(())
235    }
236
237    fn code_dim(&self) -> usize {
238        self.dim
239    }
240
241    fn column(&self) -> &'static str {
242        FLAT_COLUMN
243    }
244
245    fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result<Quantizer> {
246        Ok(Quantizer::Flat(Self {
247            dim: metadata.dim,
248            distance_type,
249        }))
250    }
251
252    fn metadata(&self, _: Option<crate::vector::quantizer::QuantizationMetadata>) -> FlatMetadata {
253        FlatMetadata { dim: self.dim }
254    }
255
256    fn metadata_key() -> &'static str {
257        "flat"
258    }
259
260    fn quantization_type() -> QuantizationType {
261        QuantizationType::Flat
262    }
263
264    fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef> {
265        Ok(vectors.slice(0, vectors.len()))
266    }
267
268    fn field(&self) -> Field {
269        Field::new(
270            FLAT_COLUMN,
271            DataType::FixedSizeList(
272                Arc::new(Field::new("item", DataType::Float32, true)),
273                self.dim as i32,
274            ),
275            true,
276        )
277    }
278}
279
280impl From<FlatQuantizer> for Quantizer {
281    fn from(value: FlatQuantizer) -> Self {
282        Self::Flat(value)
283    }
284}
285
286impl TryFrom<Quantizer> for FlatQuantizer {
287    type Error = Error;
288
289    fn try_from(value: Quantizer) -> Result<Self> {
290        match value {
291            Quantizer::Flat(quantizer) => Ok(quantizer),
292            _ => Err(Error::invalid_input(
293                "quantizer is not FlatQuantizer",
294                location!(),
295            )),
296        }
297    }
298}
299
300#[derive(Debug, Clone, DeepSizeOf)]
301pub struct FlatBinQuantizer {
302    dim: usize,
303    distance_type: DistanceType,
304}
305
306impl FlatBinQuantizer {
307    pub fn new(dim: usize, distance_type: DistanceType) -> Self {
308        Self { dim, distance_type }
309    }
310}
311
312impl Quantization for FlatBinQuantizer {
313    type BuildParams = ();
314    type Metadata = FlatMetadata;
315    type Storage = FlatBinStorage;
316
317    fn build(data: &dyn Array, distance_type: DistanceType, _: &Self::BuildParams) -> Result<Self> {
318        let dim = data.as_fixed_size_list().value_length();
319        Ok(Self::new(dim as usize, distance_type))
320    }
321
322    fn retrain(&mut self, _: &dyn Array) -> Result<()> {
323        Ok(())
324    }
325
326    fn code_dim(&self) -> usize {
327        self.dim
328    }
329
330    fn column(&self) -> &'static str {
331        FLAT_COLUMN
332    }
333
334    fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result<Quantizer> {
335        Ok(Quantizer::FlatBin(Self {
336            dim: metadata.dim,
337            distance_type,
338        }))
339    }
340
341    fn metadata(&self, _: Option<crate::vector::quantizer::QuantizationMetadata>) -> FlatMetadata {
342        FlatMetadata { dim: self.dim }
343    }
344
345    fn metadata_key() -> &'static str {
346        "flat"
347    }
348
349    fn quantization_type() -> QuantizationType {
350        QuantizationType::Flat
351    }
352
353    fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef> {
354        Ok(vectors.slice(0, vectors.len()))
355    }
356
357    fn field(&self) -> Field {
358        Field::new(
359            FLAT_COLUMN,
360            DataType::FixedSizeList(
361                Arc::new(Field::new("item", DataType::UInt8, true)),
362                self.dim as i32,
363            ),
364            true,
365        )
366    }
367}
368
369impl From<FlatBinQuantizer> for Quantizer {
370    fn from(value: FlatBinQuantizer) -> Self {
371        Self::FlatBin(value)
372    }
373}
374
375impl TryFrom<Quantizer> for FlatBinQuantizer {
376    type Error = Error;
377
378    fn try_from(value: Quantizer) -> Result<Self> {
379        match value {
380            Quantizer::FlatBin(quantizer) => Ok(quantizer),
381            _ => Err(Error::invalid_input(
382                "quantizer is not FlatBinQuantizer",
383                location!(),
384            )),
385        }
386    }
387}