1use std::collections::{BinaryHeap, HashMap};
8use std::sync::Arc;
9
10use arrow::array::AsArray;
11use arrow_array::{Array, ArrayRef, Float32Array, RecordBatch, UInt64Array};
12use arrow_schema::{DataType, Field, Schema, SchemaRef};
13use deepsize::DeepSizeOf;
14use lance_core::{Error, ROW_ID_FIELD, Result};
15use lance_file::previous::reader::FileReader as PreviousFileReader;
16use lance_linalg::distance::DistanceType;
17use serde::{Deserialize, Serialize};
18
19use crate::{
20 metrics::MetricsCollector,
21 prefilter::PreFilter,
22 vector::{
23 DIST_COL, Query,
24 graph::{OrderedFloat, OrderedNode},
25 quantizer::{Quantization, QuantizationType, Quantizer, QuantizerMetadata},
26 storage::{DistCalculator, VectorStore},
27 v3::subindex::IvfSubIndex,
28 },
29};
30
31use super::storage::{FLAT_COLUMN, FlatBinStorage, FlatFloatStorage};
32
33#[inline(always)]
34fn push_candidate_local(
35 res: &mut BinaryHeap<OrderedNode<u64>>,
36 k: usize,
37 row_id: u64,
38 dist: OrderedFloat,
39) {
40 if k == 0 {
41 return;
42 }
43 if res.len() < k {
44 res.push(OrderedNode::new(row_id, dist));
45 } else if res.peek().is_some_and(|node| node.dist > dist) {
46 res.pop();
47 res.push(OrderedNode::new(row_id, dist));
48 }
49}
50
51#[inline(always)]
52fn push_candidate_global(
53 res: &mut BinaryHeap<OrderedNode<u64>>,
54 k: usize,
55 row_id: u64,
56 dist: OrderedFloat,
57 max_dist: &mut Option<OrderedFloat>,
58) {
59 if k == 0 {
60 return;
61 }
62 if res.len() < k {
63 res.push(OrderedNode::new(row_id, dist));
64 if res.len() == k {
65 *max_dist = res.peek().map(|node| node.dist);
66 }
67 } else if max_dist.is_some_and(|max_dist| max_dist > dist) {
68 res.pop();
69 res.push(OrderedNode::new(row_id, dist));
70 *max_dist = res.peek().map(|node| node.dist);
71 }
72}
73
74#[derive(Debug, Clone, Default, DeepSizeOf)]
77pub struct FlatIndex {}
78
79use std::sync::LazyLock;
80
81static ANN_SEARCH_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
82 Schema::new(vec![
83 Field::new(DIST_COL, DataType::Float32, true),
84 ROW_ID_FIELD.clone(),
85 ])
86 .into()
87});
88
89#[derive(Default)]
90pub struct FlatQueryParams {
91 lower_bound: Option<f32>,
92 upper_bound: Option<f32>,
93 dist_q_c: f32,
94}
95
96impl From<&Query> for FlatQueryParams {
97 fn from(q: &Query) -> Self {
98 Self {
99 lower_bound: q.lower_bound,
100 upper_bound: q.upper_bound,
101 dist_q_c: q.dist_q_c,
102 }
103 }
104}
105
106impl IvfSubIndex for FlatIndex {
107 type QueryParams = FlatQueryParams;
108 type BuildParams = ();
109
110 fn name() -> &'static str {
111 "FLAT"
112 }
113
114 fn metadata_key() -> &'static str {
115 "lance:flat"
116 }
117
118 fn schema() -> arrow_schema::SchemaRef {
119 Schema::new(vec![Field::new("__flat_marker", DataType::UInt64, false)]).into()
120 }
121
122 fn search(
123 &self,
124 query: ArrayRef,
125 k: usize,
126 params: Self::QueryParams,
127 storage: &impl VectorStore,
128 prefilter: Arc<dyn PreFilter>,
129 metrics: &dyn MetricsCollector,
130 ) -> Result<RecordBatch> {
131 let is_range_query = params.lower_bound.is_some() || params.upper_bound.is_some();
132 let row_ids = storage.row_ids();
133 let dist_calc = storage.dist_calculator(query, params.dist_q_c);
134 let mut res = BinaryHeap::with_capacity(k);
135 metrics.record_comparisons(storage.len());
136
137 match prefilter.is_empty() {
138 true => {
139 let dists = dist_calc.distance_all(k);
140
141 if is_range_query {
142 let lower_bound = params.lower_bound.unwrap_or(f32::MIN).into();
143 let upper_bound = params.upper_bound.unwrap_or(f32::MAX).into();
144
145 for (&row_id, dist) in row_ids.zip(dists) {
146 let dist = dist.into();
147 if dist < lower_bound || dist >= upper_bound {
148 continue;
149 }
150 push_candidate_local(&mut res, k, row_id, dist);
151 }
152 } else {
153 for (&row_id, dist) in row_ids.zip(dists) {
154 let dist = dist.into();
155 push_candidate_local(&mut res, k, row_id, dist);
156 }
157 }
158 }
159 false => {
160 let row_addr_mask = prefilter.mask();
161 if is_range_query {
162 let lower_bound = params.lower_bound.unwrap_or(f32::MIN).into();
163 let upper_bound = params.upper_bound.unwrap_or(f32::MAX).into();
164 for (id, &row_addr) in row_ids.enumerate() {
165 if !row_addr_mask.selected(row_addr) {
166 continue;
167 }
168 let dist = dist_calc.distance(id as u32).into();
169 if dist < lower_bound || dist >= upper_bound {
170 continue;
171 }
172
173 push_candidate_local(&mut res, k, row_addr, dist);
174 }
175 } else {
176 for (id, &row_addr) in row_ids.enumerate() {
177 if !row_addr_mask.selected(row_addr) {
178 continue;
179 }
180
181 let dist = dist_calc.distance(id as u32).into();
182 push_candidate_local(&mut res, k, row_addr, dist);
183 }
184 }
185 }
186 };
187
188 let (row_ids, dists): (Vec<_>, Vec<_>) = res.into_iter().map(|r| (r.id, r.dist.0)).unzip();
191 let (row_ids, dists) = (UInt64Array::from(row_ids), Float32Array::from(dists));
192
193 Ok(RecordBatch::try_new(
194 ANN_SEARCH_SCHEMA.clone(),
195 vec![Arc::new(dists), Arc::new(row_ids)],
196 )?)
197 }
198
199 fn supports_global_topk_heap() -> bool {
200 true
201 }
202
203 fn accumulate_topk(
204 &self,
205 query: ArrayRef,
206 k: usize,
207 params: Self::QueryParams,
208 storage: &impl VectorStore,
209 prefilter: Arc<dyn PreFilter>,
210 res: &mut BinaryHeap<OrderedNode<u64>>,
211 metrics: &dyn MetricsCollector,
212 ) -> Result<()> {
213 let mut distance_scratch = Vec::new();
214 let mut u16_scratch = Vec::new();
215 let mut u8_scratch = Vec::new();
216 self.accumulate_topk_with_scratch(
217 query,
218 k,
219 params,
220 storage,
221 prefilter,
222 res,
223 &mut distance_scratch,
224 &mut u16_scratch,
225 &mut u8_scratch,
226 metrics,
227 )
228 }
229
230 fn accumulate_topk_with_scratch(
231 &self,
232 query: ArrayRef,
233 k: usize,
234 params: Self::QueryParams,
235 storage: &impl VectorStore,
236 prefilter: Arc<dyn PreFilter>,
237 res: &mut BinaryHeap<OrderedNode<u64>>,
238 distance_scratch: &mut Vec<f32>,
239 u16_scratch: &mut Vec<u16>,
240 u8_scratch: &mut Vec<u8>,
241 metrics: &dyn MetricsCollector,
242 ) -> Result<()> {
243 let is_range_query = params.lower_bound.is_some() || params.upper_bound.is_some();
244 let row_ids = storage.row_ids();
245 let dist_calc = storage.dist_calculator(query, params.dist_q_c);
246 let mut max_dist = res.peek().map(|node| node.dist);
247 metrics.record_comparisons(storage.len());
248
249 match prefilter.is_empty() {
250 true => {
251 dist_calc.distance_all_with_scratch(k, distance_scratch, u16_scratch, u8_scratch);
252 let dists = distance_scratch.iter().copied();
253
254 if is_range_query {
255 let lower_bound = params.lower_bound.unwrap_or(f32::MIN).into();
256 let upper_bound = params.upper_bound.unwrap_or(f32::MAX).into();
257
258 for (&row_id, dist) in row_ids.zip(dists) {
259 let dist = dist.into();
260 if dist < lower_bound || dist >= upper_bound {
261 continue;
262 }
263 push_candidate_global(res, k, row_id, dist, &mut max_dist);
264 }
265 } else {
266 for (&row_id, dist) in row_ids.zip(dists) {
267 let dist = dist.into();
268 push_candidate_global(res, k, row_id, dist, &mut max_dist);
269 }
270 }
271 }
272 false => {
273 let row_addr_mask = prefilter.mask();
274 if is_range_query {
275 let lower_bound = params.lower_bound.unwrap_or(f32::MIN).into();
276 let upper_bound = params.upper_bound.unwrap_or(f32::MAX).into();
277 for (id, &row_addr) in row_ids.enumerate() {
278 if !row_addr_mask.selected(row_addr) {
279 continue;
280 }
281 let dist = dist_calc.distance(id as u32).into();
282 if dist < lower_bound || dist >= upper_bound {
283 continue;
284 }
285
286 push_candidate_global(res, k, row_addr, dist, &mut max_dist);
287 }
288 } else {
289 for (id, &row_addr) in row_ids.enumerate() {
290 if !row_addr_mask.selected(row_addr) {
291 continue;
292 }
293 let dist = dist_calc.distance(id as u32).into();
294 push_candidate_global(res, k, row_addr, dist, &mut max_dist);
295 }
296 }
297 }
298 };
299 Ok(())
300 }
301
302 fn load(_: RecordBatch) -> Result<Self> {
303 Ok(Self {})
304 }
305
306 fn index_vectors(_: &impl VectorStore, _: Self::BuildParams) -> Result<Self>
307 where
308 Self: Sized,
309 {
310 Ok(Self {})
311 }
312
313 fn remap(&self, _: &HashMap<u64, Option<u64>>, _: &impl VectorStore) -> Result<Self> {
314 Ok(self.clone())
315 }
316
317 fn to_batch(&self) -> Result<RecordBatch> {
318 Ok(RecordBatch::new_empty(Schema::empty().into()))
319 }
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize, DeepSizeOf)]
323pub struct FlatMetadata {
324 pub dim: usize,
325}
326
327#[async_trait::async_trait]
328impl QuantizerMetadata for FlatMetadata {
329 async fn load(_: &PreviousFileReader) -> Result<Self> {
330 unimplemented!("Flat will be used in new index builder which doesn't require this")
331 }
332}
333
334#[derive(Debug, Clone, DeepSizeOf)]
335pub struct FlatQuantizer {
336 dim: usize,
337 distance_type: DistanceType,
338}
339
340impl FlatQuantizer {
341 pub fn new(dim: usize, distance_type: DistanceType) -> Self {
342 Self { dim, distance_type }
343 }
344}
345
346impl Quantization for FlatQuantizer {
347 type BuildParams = ();
348 type Metadata = FlatMetadata;
349 type Storage = FlatFloatStorage;
350
351 fn build(data: &dyn Array, distance_type: DistanceType, _: &Self::BuildParams) -> Result<Self> {
352 let dim = data.as_fixed_size_list().value_length();
353 Ok(Self::new(dim as usize, distance_type))
354 }
355
356 fn retrain(&mut self, _: &dyn Array) -> Result<()> {
357 Ok(())
358 }
359
360 fn code_dim(&self) -> usize {
361 self.dim
362 }
363
364 fn column(&self) -> &'static str {
365 FLAT_COLUMN
366 }
367
368 fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result<Quantizer> {
369 Ok(Quantizer::Flat(Self {
370 dim: metadata.dim,
371 distance_type,
372 }))
373 }
374
375 fn metadata(&self, _: Option<crate::vector::quantizer::QuantizationMetadata>) -> FlatMetadata {
376 FlatMetadata { dim: self.dim }
377 }
378
379 fn metadata_key() -> &'static str {
380 "flat"
381 }
382
383 fn quantization_type() -> QuantizationType {
384 QuantizationType::Flat
385 }
386
387 fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef> {
388 Ok(vectors.slice(0, vectors.len()))
389 }
390
391 fn field(&self) -> Field {
392 Field::new(
393 FLAT_COLUMN,
394 DataType::FixedSizeList(
395 Arc::new(Field::new("item", DataType::Float32, true)),
396 self.dim as i32,
397 ),
398 true,
399 )
400 }
401}
402
403impl From<FlatQuantizer> for Quantizer {
404 fn from(value: FlatQuantizer) -> Self {
405 Self::Flat(value)
406 }
407}
408
409impl TryFrom<Quantizer> for FlatQuantizer {
410 type Error = Error;
411
412 fn try_from(value: Quantizer) -> Result<Self> {
413 match value {
414 Quantizer::Flat(quantizer) => Ok(quantizer),
415 _ => Err(Error::invalid_input("quantizer is not FlatQuantizer")),
416 }
417 }
418}
419
420#[derive(Debug, Clone, DeepSizeOf)]
421pub struct FlatBinQuantizer {
422 dim: usize,
423 distance_type: DistanceType,
424}
425
426impl FlatBinQuantizer {
427 pub fn new(dim: usize, distance_type: DistanceType) -> Self {
428 Self { dim, distance_type }
429 }
430}
431
432impl Quantization for FlatBinQuantizer {
433 type BuildParams = ();
434 type Metadata = FlatMetadata;
435 type Storage = FlatBinStorage;
436
437 fn build(data: &dyn Array, distance_type: DistanceType, _: &Self::BuildParams) -> Result<Self> {
438 let dim = data.as_fixed_size_list().value_length();
439 Ok(Self::new(dim as usize, distance_type))
440 }
441
442 fn retrain(&mut self, _: &dyn Array) -> Result<()> {
443 Ok(())
444 }
445
446 fn code_dim(&self) -> usize {
447 self.dim
448 }
449
450 fn column(&self) -> &'static str {
451 FLAT_COLUMN
452 }
453
454 fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result<Quantizer> {
455 Ok(Quantizer::FlatBin(Self {
456 dim: metadata.dim,
457 distance_type,
458 }))
459 }
460
461 fn metadata(&self, _: Option<crate::vector::quantizer::QuantizationMetadata>) -> FlatMetadata {
462 FlatMetadata { dim: self.dim }
463 }
464
465 fn metadata_key() -> &'static str {
466 "flat"
467 }
468
469 fn quantization_type() -> QuantizationType {
470 QuantizationType::FlatBin
471 }
472
473 fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef> {
474 Ok(vectors.slice(0, vectors.len()))
475 }
476
477 fn field(&self) -> Field {
478 Field::new(
479 FLAT_COLUMN,
480 DataType::FixedSizeList(
481 Arc::new(Field::new("item", DataType::UInt8, true)),
482 self.dim as i32,
483 ),
484 true,
485 )
486 }
487}
488
489impl From<FlatBinQuantizer> for Quantizer {
490 fn from(value: FlatBinQuantizer) -> Self {
491 Self::FlatBin(value)
492 }
493}
494
495impl TryFrom<Quantizer> for FlatBinQuantizer {
496 type Error = Error;
497
498 fn try_from(value: Quantizer) -> Result<Self> {
499 match value {
500 Quantizer::FlatBin(quantizer) => Ok(quantizer),
501 _ => Err(Error::invalid_input("quantizer is not FlatBinQuantizer")),
502 }
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509
510 use arrow_array::FixedSizeListArray;
511 use async_trait::async_trait;
512 use lance_arrow::FixedSizeListArrayExt;
513 use lance_core::utils::mask::{RowAddrMask, RowAddrTreeMap};
514
515 use crate::metrics::NoOpMetricsCollector;
516 use crate::prefilter::NoFilter;
517
518 struct MaskPreFilter {
519 mask: Arc<RowAddrMask>,
520 }
521
522 #[async_trait]
523 impl PreFilter for MaskPreFilter {
524 async fn wait_for_ready(&self) -> Result<()> {
525 Ok(())
526 }
527
528 fn is_empty(&self) -> bool {
529 false
530 }
531
532 fn mask(&self) -> Arc<RowAddrMask> {
533 self.mask.clone()
534 }
535
536 fn filter_row_ids<'a>(&self, row_ids: Box<dyn Iterator<Item = &'a u64> + 'a>) -> Vec<u64> {
537 self.mask.selected_indices(row_ids)
538 }
539 }
540
541 fn test_storage() -> FlatFloatStorage {
542 let values = Float32Array::from(vec![
543 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 3.0, 3.0, 4.0, 4.0, ]);
549 let vectors = FixedSizeListArray::try_new_from_values(values, 2).unwrap();
550 FlatFloatStorage::new(vectors, DistanceType::L2)
551 }
552
553 fn query() -> ArrayRef {
554 Arc::new(Float32Array::from(vec![1.0, 1.0]))
555 }
556
557 fn batch_results(batch: RecordBatch) -> Vec<(u64, f32)> {
558 let dists = batch
559 .column(0)
560 .as_primitive::<arrow_array::types::Float32Type>();
561 let row_ids = batch
562 .column(1)
563 .as_primitive::<arrow_array::types::UInt64Type>();
564 let mut results = row_ids
565 .values()
566 .iter()
567 .zip(dists.values().iter())
568 .map(|(row_id, dist)| (*row_id, *dist))
569 .collect::<Vec<_>>();
570 results.sort_by(|left, right| left.0.cmp(&right.0));
571 results
572 }
573
574 fn heap_results(heap: BinaryHeap<OrderedNode<u64>>) -> Vec<(u64, f32)> {
575 let mut results = heap
576 .into_iter()
577 .map(|node| (node.id, node.dist.0))
578 .collect::<Vec<_>>();
579 results.sort_by(|left, right| left.0.cmp(&right.0));
580 results
581 }
582
583 #[test]
584 fn test_flat_search_matches_accumulate_topk_without_prefilter() {
585 let index = FlatIndex::default();
586 let storage = test_storage();
587 let k = 3;
588 let search_results = batch_results(
589 index
590 .search(
591 query(),
592 k,
593 FlatQueryParams::default(),
594 &storage,
595 Arc::new(NoFilter),
596 &NoOpMetricsCollector,
597 )
598 .unwrap(),
599 );
600
601 let mut heap = BinaryHeap::with_capacity(k);
602 index
603 .accumulate_topk(
604 query(),
605 k,
606 FlatQueryParams::default(),
607 &storage,
608 Arc::new(NoFilter),
609 &mut heap,
610 &NoOpMetricsCollector,
611 )
612 .unwrap();
613
614 assert_eq!(search_results, heap_results(heap));
615 }
616
617 #[test]
618 fn test_flat_search_matches_accumulate_topk_with_prefilter() {
619 let index = FlatIndex::default();
620 let storage = test_storage();
621 let k = 2;
622 let filter = Arc::new(MaskPreFilter {
623 mask: Arc::new(RowAddrMask::from_allowed(RowAddrTreeMap::from_iter([
624 0_u64, 3, 4,
625 ]))),
626 });
627 let search_results = batch_results(
628 index
629 .search(
630 query(),
631 k,
632 FlatQueryParams::default(),
633 &storage,
634 filter.clone(),
635 &NoOpMetricsCollector,
636 )
637 .unwrap(),
638 );
639
640 let mut heap = BinaryHeap::with_capacity(k);
641 index
642 .accumulate_topk(
643 query(),
644 k,
645 FlatQueryParams::default(),
646 &storage,
647 filter,
648 &mut heap,
649 &NoOpMetricsCollector,
650 )
651 .unwrap();
652
653 assert_eq!(search_results, heap_results(heap));
654 assert_eq!(
655 search_results.iter().map(|(id, _)| *id).collect::<Vec<_>>(),
656 vec![0, 3]
657 );
658 }
659}