Skip to main content

lance_index/vector/flat/
index.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Flat Vector Index.
5//!
6
7use std::collections::{BinaryHeap, HashMap};
8use std::sync::Arc;
9
10use arrow::array::AsArray;
11use arrow_array::{Array, ArrayRef, Float32Array, RecordBatch, UInt64Array};
12use arrow_schema::{DataType, Field, Schema, SchemaRef};
13use deepsize::DeepSizeOf;
14use lance_core::{Error, ROW_ID_FIELD, Result};
15use lance_file::previous::reader::FileReader as PreviousFileReader;
16use lance_linalg::distance::DistanceType;
17use serde::{Deserialize, Serialize};
18
19use crate::{
20    metrics::MetricsCollector,
21    prefilter::PreFilter,
22    vector::{
23        DIST_COL, Query,
24        graph::{OrderedFloat, OrderedNode},
25        quantizer::{Quantization, QuantizationType, Quantizer, QuantizerMetadata},
26        storage::{DistCalculator, VectorStore},
27        v3::subindex::IvfSubIndex,
28    },
29};
30
31use super::storage::{FLAT_COLUMN, FlatBinStorage, FlatFloatStorage};
32
33#[inline(always)]
34fn push_candidate_local(
35    res: &mut BinaryHeap<OrderedNode<u64>>,
36    k: usize,
37    row_id: u64,
38    dist: OrderedFloat,
39) {
40    if k == 0 {
41        return;
42    }
43    if res.len() < k {
44        res.push(OrderedNode::new(row_id, dist));
45    } else if res.peek().is_some_and(|node| node.dist > dist) {
46        res.pop();
47        res.push(OrderedNode::new(row_id, dist));
48    }
49}
50
51#[inline(always)]
52fn push_candidate_global(
53    res: &mut BinaryHeap<OrderedNode<u64>>,
54    k: usize,
55    row_id: u64,
56    dist: OrderedFloat,
57    max_dist: &mut Option<OrderedFloat>,
58) {
59    if k == 0 {
60        return;
61    }
62    if res.len() < k {
63        res.push(OrderedNode::new(row_id, dist));
64        if res.len() == k {
65            *max_dist = res.peek().map(|node| node.dist);
66        }
67    } else if max_dist.is_some_and(|max_dist| max_dist > dist) {
68        res.pop();
69        res.push(OrderedNode::new(row_id, dist));
70        *max_dist = res.peek().map(|node| node.dist);
71    }
72}
73
74/// A Flat index is any index that stores no metadata, and
75/// during query, it simply scans over the storage and returns the top k results
76#[derive(Debug, Clone, Default, DeepSizeOf)]
77pub struct FlatIndex {}
78
79use std::sync::LazyLock;
80
81static ANN_SEARCH_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
82    Schema::new(vec![
83        Field::new(DIST_COL, DataType::Float32, true),
84        ROW_ID_FIELD.clone(),
85    ])
86    .into()
87});
88
89#[derive(Default)]
90pub struct FlatQueryParams {
91    lower_bound: Option<f32>,
92    upper_bound: Option<f32>,
93    dist_q_c: f32,
94}
95
96impl From<&Query> for FlatQueryParams {
97    fn from(q: &Query) -> Self {
98        Self {
99            lower_bound: q.lower_bound,
100            upper_bound: q.upper_bound,
101            dist_q_c: q.dist_q_c,
102        }
103    }
104}
105
106impl IvfSubIndex for FlatIndex {
107    type QueryParams = FlatQueryParams;
108    type BuildParams = ();
109
110    fn name() -> &'static str {
111        "FLAT"
112    }
113
114    fn metadata_key() -> &'static str {
115        "lance:flat"
116    }
117
118    fn schema() -> arrow_schema::SchemaRef {
119        Schema::new(vec![Field::new("__flat_marker", DataType::UInt64, false)]).into()
120    }
121
122    fn search(
123        &self,
124        query: ArrayRef,
125        k: usize,
126        params: Self::QueryParams,
127        storage: &impl VectorStore,
128        prefilter: Arc<dyn PreFilter>,
129        metrics: &dyn MetricsCollector,
130    ) -> Result<RecordBatch> {
131        let is_range_query = params.lower_bound.is_some() || params.upper_bound.is_some();
132        let row_ids = storage.row_ids();
133        let dist_calc = storage.dist_calculator(query, params.dist_q_c);
134        let mut res = BinaryHeap::with_capacity(k);
135        metrics.record_comparisons(storage.len());
136
137        match prefilter.is_empty() {
138            true => {
139                let dists = dist_calc.distance_all(k);
140
141                if is_range_query {
142                    let lower_bound = params.lower_bound.unwrap_or(f32::MIN).into();
143                    let upper_bound = params.upper_bound.unwrap_or(f32::MAX).into();
144
145                    for (&row_id, dist) in row_ids.zip(dists) {
146                        let dist = dist.into();
147                        if dist < lower_bound || dist >= upper_bound {
148                            continue;
149                        }
150                        push_candidate_local(&mut res, k, row_id, dist);
151                    }
152                } else {
153                    for (&row_id, dist) in row_ids.zip(dists) {
154                        let dist = dist.into();
155                        push_candidate_local(&mut res, k, row_id, dist);
156                    }
157                }
158            }
159            false => {
160                let row_addr_mask = prefilter.mask();
161                if is_range_query {
162                    let lower_bound = params.lower_bound.unwrap_or(f32::MIN).into();
163                    let upper_bound = params.upper_bound.unwrap_or(f32::MAX).into();
164                    for (id, &row_addr) in row_ids.enumerate() {
165                        if !row_addr_mask.selected(row_addr) {
166                            continue;
167                        }
168                        let dist = dist_calc.distance(id as u32).into();
169                        if dist < lower_bound || dist >= upper_bound {
170                            continue;
171                        }
172
173                        push_candidate_local(&mut res, k, row_addr, dist);
174                    }
175                } else {
176                    for (id, &row_addr) in row_ids.enumerate() {
177                        if !row_addr_mask.selected(row_addr) {
178                            continue;
179                        }
180
181                        let dist = dist_calc.distance(id as u32).into();
182                        push_candidate_local(&mut res, k, row_addr, dist);
183                    }
184                }
185            }
186        };
187
188        // we don't need to sort the results by distances here
189        // because there's a SortExec node in the query plan which sorts the results from all partitions
190        let (row_ids, dists): (Vec<_>, Vec<_>) = res.into_iter().map(|r| (r.id, r.dist.0)).unzip();
191        let (row_ids, dists) = (UInt64Array::from(row_ids), Float32Array::from(dists));
192
193        Ok(RecordBatch::try_new(
194            ANN_SEARCH_SCHEMA.clone(),
195            vec![Arc::new(dists), Arc::new(row_ids)],
196        )?)
197    }
198
199    fn supports_global_topk_heap() -> bool {
200        true
201    }
202
203    fn accumulate_topk(
204        &self,
205        query: ArrayRef,
206        k: usize,
207        params: Self::QueryParams,
208        storage: &impl VectorStore,
209        prefilter: Arc<dyn PreFilter>,
210        res: &mut BinaryHeap<OrderedNode<u64>>,
211        metrics: &dyn MetricsCollector,
212    ) -> Result<()> {
213        let mut distance_scratch = Vec::new();
214        let mut u16_scratch = Vec::new();
215        let mut u8_scratch = Vec::new();
216        self.accumulate_topk_with_scratch(
217            query,
218            k,
219            params,
220            storage,
221            prefilter,
222            res,
223            &mut distance_scratch,
224            &mut u16_scratch,
225            &mut u8_scratch,
226            metrics,
227        )
228    }
229
230    fn accumulate_topk_with_scratch(
231        &self,
232        query: ArrayRef,
233        k: usize,
234        params: Self::QueryParams,
235        storage: &impl VectorStore,
236        prefilter: Arc<dyn PreFilter>,
237        res: &mut BinaryHeap<OrderedNode<u64>>,
238        distance_scratch: &mut Vec<f32>,
239        u16_scratch: &mut Vec<u16>,
240        u8_scratch: &mut Vec<u8>,
241        metrics: &dyn MetricsCollector,
242    ) -> Result<()> {
243        let is_range_query = params.lower_bound.is_some() || params.upper_bound.is_some();
244        let row_ids = storage.row_ids();
245        let dist_calc = storage.dist_calculator(query, params.dist_q_c);
246        let mut max_dist = res.peek().map(|node| node.dist);
247        metrics.record_comparisons(storage.len());
248
249        match prefilter.is_empty() {
250            true => {
251                dist_calc.distance_all_with_scratch(k, distance_scratch, u16_scratch, u8_scratch);
252                let dists = distance_scratch.iter().copied();
253
254                if is_range_query {
255                    let lower_bound = params.lower_bound.unwrap_or(f32::MIN).into();
256                    let upper_bound = params.upper_bound.unwrap_or(f32::MAX).into();
257
258                    for (&row_id, dist) in row_ids.zip(dists) {
259                        let dist = dist.into();
260                        if dist < lower_bound || dist >= upper_bound {
261                            continue;
262                        }
263                        push_candidate_global(res, k, row_id, dist, &mut max_dist);
264                    }
265                } else {
266                    for (&row_id, dist) in row_ids.zip(dists) {
267                        let dist = dist.into();
268                        push_candidate_global(res, k, row_id, dist, &mut max_dist);
269                    }
270                }
271            }
272            false => {
273                let row_addr_mask = prefilter.mask();
274                if is_range_query {
275                    let lower_bound = params.lower_bound.unwrap_or(f32::MIN).into();
276                    let upper_bound = params.upper_bound.unwrap_or(f32::MAX).into();
277                    for (id, &row_addr) in row_ids.enumerate() {
278                        if !row_addr_mask.selected(row_addr) {
279                            continue;
280                        }
281                        let dist = dist_calc.distance(id as u32).into();
282                        if dist < lower_bound || dist >= upper_bound {
283                            continue;
284                        }
285
286                        push_candidate_global(res, k, row_addr, dist, &mut max_dist);
287                    }
288                } else {
289                    for (id, &row_addr) in row_ids.enumerate() {
290                        if !row_addr_mask.selected(row_addr) {
291                            continue;
292                        }
293                        let dist = dist_calc.distance(id as u32).into();
294                        push_candidate_global(res, k, row_addr, dist, &mut max_dist);
295                    }
296                }
297            }
298        };
299        Ok(())
300    }
301
302    fn load(_: RecordBatch) -> Result<Self> {
303        Ok(Self {})
304    }
305
306    fn index_vectors(_: &impl VectorStore, _: Self::BuildParams) -> Result<Self>
307    where
308        Self: Sized,
309    {
310        Ok(Self {})
311    }
312
313    fn remap(&self, _: &HashMap<u64, Option<u64>>, _: &impl VectorStore) -> Result<Self> {
314        Ok(self.clone())
315    }
316
317    fn to_batch(&self) -> Result<RecordBatch> {
318        Ok(RecordBatch::new_empty(Schema::empty().into()))
319    }
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize, DeepSizeOf)]
323pub struct FlatMetadata {
324    pub dim: usize,
325}
326
327#[async_trait::async_trait]
328impl QuantizerMetadata for FlatMetadata {
329    async fn load(_: &PreviousFileReader) -> Result<Self> {
330        unimplemented!("Flat will be used in new index builder which doesn't require this")
331    }
332}
333
334#[derive(Debug, Clone, DeepSizeOf)]
335pub struct FlatQuantizer {
336    dim: usize,
337    distance_type: DistanceType,
338}
339
340impl FlatQuantizer {
341    pub fn new(dim: usize, distance_type: DistanceType) -> Self {
342        Self { dim, distance_type }
343    }
344}
345
346impl Quantization for FlatQuantizer {
347    type BuildParams = ();
348    type Metadata = FlatMetadata;
349    type Storage = FlatFloatStorage;
350
351    fn build(data: &dyn Array, distance_type: DistanceType, _: &Self::BuildParams) -> Result<Self> {
352        let dim = data.as_fixed_size_list().value_length();
353        Ok(Self::new(dim as usize, distance_type))
354    }
355
356    fn retrain(&mut self, _: &dyn Array) -> Result<()> {
357        Ok(())
358    }
359
360    fn code_dim(&self) -> usize {
361        self.dim
362    }
363
364    fn column(&self) -> &'static str {
365        FLAT_COLUMN
366    }
367
368    fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result<Quantizer> {
369        Ok(Quantizer::Flat(Self {
370            dim: metadata.dim,
371            distance_type,
372        }))
373    }
374
375    fn metadata(&self, _: Option<crate::vector::quantizer::QuantizationMetadata>) -> FlatMetadata {
376        FlatMetadata { dim: self.dim }
377    }
378
379    fn metadata_key() -> &'static str {
380        "flat"
381    }
382
383    fn quantization_type() -> QuantizationType {
384        QuantizationType::Flat
385    }
386
387    fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef> {
388        Ok(vectors.slice(0, vectors.len()))
389    }
390
391    fn field(&self) -> Field {
392        Field::new(
393            FLAT_COLUMN,
394            DataType::FixedSizeList(
395                Arc::new(Field::new("item", DataType::Float32, true)),
396                self.dim as i32,
397            ),
398            true,
399        )
400    }
401}
402
403impl From<FlatQuantizer> for Quantizer {
404    fn from(value: FlatQuantizer) -> Self {
405        Self::Flat(value)
406    }
407}
408
409impl TryFrom<Quantizer> for FlatQuantizer {
410    type Error = Error;
411
412    fn try_from(value: Quantizer) -> Result<Self> {
413        match value {
414            Quantizer::Flat(quantizer) => Ok(quantizer),
415            _ => Err(Error::invalid_input("quantizer is not FlatQuantizer")),
416        }
417    }
418}
419
420#[derive(Debug, Clone, DeepSizeOf)]
421pub struct FlatBinQuantizer {
422    dim: usize,
423    distance_type: DistanceType,
424}
425
426impl FlatBinQuantizer {
427    pub fn new(dim: usize, distance_type: DistanceType) -> Self {
428        Self { dim, distance_type }
429    }
430}
431
432impl Quantization for FlatBinQuantizer {
433    type BuildParams = ();
434    type Metadata = FlatMetadata;
435    type Storage = FlatBinStorage;
436
437    fn build(data: &dyn Array, distance_type: DistanceType, _: &Self::BuildParams) -> Result<Self> {
438        let dim = data.as_fixed_size_list().value_length();
439        Ok(Self::new(dim as usize, distance_type))
440    }
441
442    fn retrain(&mut self, _: &dyn Array) -> Result<()> {
443        Ok(())
444    }
445
446    fn code_dim(&self) -> usize {
447        self.dim
448    }
449
450    fn column(&self) -> &'static str {
451        FLAT_COLUMN
452    }
453
454    fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result<Quantizer> {
455        Ok(Quantizer::FlatBin(Self {
456            dim: metadata.dim,
457            distance_type,
458        }))
459    }
460
461    fn metadata(&self, _: Option<crate::vector::quantizer::QuantizationMetadata>) -> FlatMetadata {
462        FlatMetadata { dim: self.dim }
463    }
464
465    fn metadata_key() -> &'static str {
466        "flat"
467    }
468
469    fn quantization_type() -> QuantizationType {
470        QuantizationType::FlatBin
471    }
472
473    fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef> {
474        Ok(vectors.slice(0, vectors.len()))
475    }
476
477    fn field(&self) -> Field {
478        Field::new(
479            FLAT_COLUMN,
480            DataType::FixedSizeList(
481                Arc::new(Field::new("item", DataType::UInt8, true)),
482                self.dim as i32,
483            ),
484            true,
485        )
486    }
487}
488
489impl From<FlatBinQuantizer> for Quantizer {
490    fn from(value: FlatBinQuantizer) -> Self {
491        Self::FlatBin(value)
492    }
493}
494
495impl TryFrom<Quantizer> for FlatBinQuantizer {
496    type Error = Error;
497
498    fn try_from(value: Quantizer) -> Result<Self> {
499        match value {
500            Quantizer::FlatBin(quantizer) => Ok(quantizer),
501            _ => Err(Error::invalid_input("quantizer is not FlatBinQuantizer")),
502        }
503    }
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509
510    use arrow_array::FixedSizeListArray;
511    use async_trait::async_trait;
512    use lance_arrow::FixedSizeListArrayExt;
513    use lance_core::utils::mask::{RowAddrMask, RowAddrTreeMap};
514
515    use crate::metrics::NoOpMetricsCollector;
516    use crate::prefilter::NoFilter;
517
518    struct MaskPreFilter {
519        mask: Arc<RowAddrMask>,
520    }
521
522    #[async_trait]
523    impl PreFilter for MaskPreFilter {
524        async fn wait_for_ready(&self) -> Result<()> {
525            Ok(())
526        }
527
528        fn is_empty(&self) -> bool {
529            false
530        }
531
532        fn mask(&self) -> Arc<RowAddrMask> {
533            self.mask.clone()
534        }
535
536        fn filter_row_ids<'a>(&self, row_ids: Box<dyn Iterator<Item = &'a u64> + 'a>) -> Vec<u64> {
537            self.mask.selected_indices(row_ids)
538        }
539    }
540
541    fn test_storage() -> FlatFloatStorage {
542        let values = Float32Array::from(vec![
543            0.0, 0.0, // row 0
544            1.0, 0.0, // row 1
545            1.0, 1.0, // row 2
546            3.0, 3.0, // row 3
547            4.0, 4.0, // row 4
548        ]);
549        let vectors = FixedSizeListArray::try_new_from_values(values, 2).unwrap();
550        FlatFloatStorage::new(vectors, DistanceType::L2)
551    }
552
553    fn query() -> ArrayRef {
554        Arc::new(Float32Array::from(vec![1.0, 1.0]))
555    }
556
557    fn batch_results(batch: RecordBatch) -> Vec<(u64, f32)> {
558        let dists = batch
559            .column(0)
560            .as_primitive::<arrow_array::types::Float32Type>();
561        let row_ids = batch
562            .column(1)
563            .as_primitive::<arrow_array::types::UInt64Type>();
564        let mut results = row_ids
565            .values()
566            .iter()
567            .zip(dists.values().iter())
568            .map(|(row_id, dist)| (*row_id, *dist))
569            .collect::<Vec<_>>();
570        results.sort_by(|left, right| left.0.cmp(&right.0));
571        results
572    }
573
574    fn heap_results(heap: BinaryHeap<OrderedNode<u64>>) -> Vec<(u64, f32)> {
575        let mut results = heap
576            .into_iter()
577            .map(|node| (node.id, node.dist.0))
578            .collect::<Vec<_>>();
579        results.sort_by(|left, right| left.0.cmp(&right.0));
580        results
581    }
582
583    #[test]
584    fn test_flat_search_matches_accumulate_topk_without_prefilter() {
585        let index = FlatIndex::default();
586        let storage = test_storage();
587        let k = 3;
588        let search_results = batch_results(
589            index
590                .search(
591                    query(),
592                    k,
593                    FlatQueryParams::default(),
594                    &storage,
595                    Arc::new(NoFilter),
596                    &NoOpMetricsCollector,
597                )
598                .unwrap(),
599        );
600
601        let mut heap = BinaryHeap::with_capacity(k);
602        index
603            .accumulate_topk(
604                query(),
605                k,
606                FlatQueryParams::default(),
607                &storage,
608                Arc::new(NoFilter),
609                &mut heap,
610                &NoOpMetricsCollector,
611            )
612            .unwrap();
613
614        assert_eq!(search_results, heap_results(heap));
615    }
616
617    #[test]
618    fn test_flat_search_matches_accumulate_topk_with_prefilter() {
619        let index = FlatIndex::default();
620        let storage = test_storage();
621        let k = 2;
622        let filter = Arc::new(MaskPreFilter {
623            mask: Arc::new(RowAddrMask::from_allowed(RowAddrTreeMap::from_iter([
624                0_u64, 3, 4,
625            ]))),
626        });
627        let search_results = batch_results(
628            index
629                .search(
630                    query(),
631                    k,
632                    FlatQueryParams::default(),
633                    &storage,
634                    filter.clone(),
635                    &NoOpMetricsCollector,
636                )
637                .unwrap(),
638        );
639
640        let mut heap = BinaryHeap::with_capacity(k);
641        index
642            .accumulate_topk(
643                query(),
644                k,
645                FlatQueryParams::default(),
646                &storage,
647                filter,
648                &mut heap,
649                &NoOpMetricsCollector,
650            )
651            .unwrap();
652
653        assert_eq!(search_results, heap_results(heap));
654        assert_eq!(
655            search_results.iter().map(|(id, _)| *id).collect::<Vec<_>>(),
656            vec![0, 3]
657        );
658    }
659}