lance/io/exec/
knn.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::any::Any;
5use std::collections::{HashMap, HashSet};
6use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
7use std::sync::{Arc, LazyLock, Mutex};
8use std::time::Instant;
9
10use arrow::array::Float32Builder;
11use arrow::datatypes::{Float32Type, UInt32Type, UInt64Type};
12use arrow_array::{
13    builder::{ListBuilder, UInt32Builder},
14    cast::AsArray,
15    ArrayRef, RecordBatch, StringArray,
16};
17use arrow_array::{Array, Float32Array, UInt32Array, UInt64Array};
18use arrow_schema::{DataType, Field, Schema, SchemaRef};
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use datafusion::physical_plan::PlanProperties;
21use datafusion::physical_plan::{
22    DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream,
23    Statistics,
24};
25use datafusion::{
26    common::stats::Precision,
27    physical_plan::execution_plan::{Boundedness, EmissionType},
28};
29use datafusion::{common::ColumnStatistics, physical_plan::metrics::ExecutionPlanMetricsSet};
30use datafusion::{
31    error::{DataFusionError, Result as DataFusionResult},
32    physical_plan::metrics::MetricsSet,
33};
34use datafusion_physical_expr::{Distribution, EquivalenceProperties};
35use datafusion_physical_plan::metrics::{BaselineMetrics, Count};
36use futures::{future, stream, Stream, StreamExt, TryFutureExt, TryStreamExt};
37use itertools::Itertools;
38use lance_core::utils::futures::FinallyStreamExt;
39use lance_core::ROW_ID;
40use lance_core::{utils::tokio::get_num_compute_intensive_cpus, ROW_ID_FIELD};
41use lance_datafusion::utils::{
42    ExecutionPlanMetricsSetExt, DELTAS_SEARCHED_METRIC, PARTITIONS_RANKED_METRIC,
43    PARTITIONS_SEARCHED_METRIC,
44};
45use lance_index::prefilter::PreFilter;
46use lance_index::vector::{
47    flat::compute_distance, Query, DIST_COL, INDEX_UUID_COLUMN, PART_ID_COLUMN,
48};
49use lance_index::vector::{VectorIndex, DIST_Q_C_COLUMN};
50use lance_linalg::distance::DistanceType;
51use lance_linalg::kernels::normalize_arrow;
52use lance_table::format::IndexMetadata;
53use snafu::location;
54use tokio::sync::Notify;
55
56use crate::dataset::Dataset;
57use crate::index::prefilter::{DatasetPreFilter, FilterLoader};
58use crate::index::vector::utils::get_vector_type;
59use crate::index::DatasetIndexInternalExt;
60use crate::{Error, Result};
61use lance_arrow::*;
62
63use super::utils::{
64    FilteredRowIdsToPrefilter, IndexMetrics, InstrumentedRecordBatchStreamAdapter, PreFilterSource,
65    SelectionVectorToPrefilter,
66};
67
68pub struct AnnPartitionMetrics {
69    index_metrics: IndexMetrics,
70    partitions_ranked: Count,
71    deltas_searched: Count,
72    baseline_metrics: BaselineMetrics,
73}
74
75impl AnnPartitionMetrics {
76    pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
77        Self {
78            index_metrics: IndexMetrics::new(metrics, partition),
79            partitions_ranked: metrics.new_count(PARTITIONS_RANKED_METRIC, partition),
80            deltas_searched: metrics.new_count(DELTAS_SEARCHED_METRIC, partition),
81            baseline_metrics: BaselineMetrics::new(metrics, partition),
82        }
83    }
84}
85
86pub struct AnnIndexMetrics {
87    index_metrics: IndexMetrics,
88    partitions_searched: Count,
89    baseline_metrics: BaselineMetrics,
90}
91
92impl AnnIndexMetrics {
93    pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
94        Self {
95            index_metrics: IndexMetrics::new(metrics, partition),
96            partitions_searched: metrics.new_count(PARTITIONS_SEARCHED_METRIC, partition),
97            baseline_metrics: BaselineMetrics::new(metrics, partition),
98        }
99    }
100}
101
102/// [ExecutionPlan] compute vector distance from a query vector.
103///
104/// Preconditions:
105/// - `input` schema must contains `query.column`,
106/// - The column must be a vector column.
107///
108/// WARNING: Internal API with no stability guarantees.
109#[derive(Debug)]
110pub struct KNNVectorDistanceExec {
111    /// Inner input node.
112    pub input: Arc<dyn ExecutionPlan>,
113
114    /// The vector query to execute.
115    pub query: ArrayRef,
116    pub column: String,
117    pub distance_type: DistanceType,
118
119    output_schema: SchemaRef,
120    properties: PlanProperties,
121
122    metrics: ExecutionPlanMetricsSet,
123}
124
125impl DisplayAs for KNNVectorDistanceExec {
126    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
127        match t {
128            DisplayFormatType::Default | DisplayFormatType::Verbose => {
129                write!(f, "KNNVectorDistance: metric={}", self.distance_type,)
130            }
131            DisplayFormatType::TreeRender => {
132                write!(f, "KNNVectorDistance\nmetric={}", self.distance_type,)
133            }
134        }
135    }
136}
137
138impl KNNVectorDistanceExec {
139    /// Create a new [KNNFlatExec] node.
140    ///
141    /// Returns an error if the preconditions are not met.
142    pub fn try_new(
143        input: Arc<dyn ExecutionPlan>,
144        column: &str,
145        query: ArrayRef,
146        distance_type: DistanceType,
147    ) -> Result<Self> {
148        let mut output_schema = input.schema().as_ref().clone();
149        get_vector_type(&(&output_schema).try_into()?, column)?;
150
151        // FlatExec appends a distance column to the input schema. The input
152        // may already have a distance column (possibly in the wrong position), so
153        // we need to remove it before adding a new one.
154        if output_schema.column_with_name(DIST_COL).is_some() {
155            output_schema = output_schema.without_column(DIST_COL);
156        }
157        let output_schema = Arc::new(output_schema.try_with_column(Field::new(
158            DIST_COL,
159            DataType::Float32,
160            true,
161        ))?);
162
163        // This node has the same partitioning & boundedness as the input node
164        // but it destroys any ordering.
165        let properties = input
166            .properties()
167            .clone()
168            .with_eq_properties(EquivalenceProperties::new(output_schema.clone()));
169
170        Ok(Self {
171            input,
172            query,
173            column: column.to_string(),
174            distance_type,
175            output_schema,
176            properties,
177            metrics: ExecutionPlanMetricsSet::new(),
178        })
179    }
180}
181
182impl ExecutionPlan for KNNVectorDistanceExec {
183    fn name(&self) -> &str {
184        "KNNVectorDistanceExec"
185    }
186
187    fn as_any(&self) -> &dyn Any {
188        self
189    }
190
191    /// Flat KNN inherits the schema from input node, and add one distance column.
192    fn schema(&self) -> arrow_schema::SchemaRef {
193        self.output_schema.clone()
194    }
195
196    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
197        vec![&self.input]
198    }
199
200    fn with_new_children(
201        self: Arc<Self>,
202        mut children: Vec<Arc<dyn ExecutionPlan>>,
203    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
204        if children.len() != 1 {
205            return Err(DataFusionError::Internal(
206                "KNNVectorDistanceExec node must have exactly one child".to_string(),
207            ));
208        }
209
210        Ok(Arc::new(Self::try_new(
211            children.pop().expect("length checked"),
212            &self.column,
213            self.query.clone(),
214            self.distance_type,
215        )?))
216    }
217
218    fn execute(
219        &self,
220        partition: usize,
221        context: Arc<datafusion::execution::context::TaskContext>,
222    ) -> DataFusionResult<SendableRecordBatchStream> {
223        let input_stream = self.input.execute(partition, context)?;
224        let key = self.query.clone();
225        let column = self.column.clone();
226        let dt = self.distance_type;
227        let stream = input_stream
228            .try_filter(|batch| future::ready(batch.num_rows() > 0))
229            .map(move |batch| {
230                let key = key.clone();
231                let column = column.clone();
232                async move {
233                    compute_distance(key, dt, &column, batch?)
234                        .await
235                        .map_err(|e| DataFusionError::Execution(e.to_string()))
236                }
237            })
238            .buffer_unordered(get_num_compute_intensive_cpus());
239        let schema = self.schema();
240        Ok(Box::pin(InstrumentedRecordBatchStreamAdapter::new(
241            schema,
242            stream.boxed(),
243            partition,
244            &self.metrics,
245        )) as SendableRecordBatchStream)
246    }
247
248    fn partition_statistics(&self, partition: Option<usize>) -> DataFusionResult<Statistics> {
249        let inner_stats = self.input.partition_statistics(partition)?;
250        let schema = self.input.schema();
251        let dist_stats = inner_stats
252            .column_statistics
253            .iter()
254            .zip(schema.fields())
255            .find(|(_, field)| field.name() == &self.column)
256            .map(|(stats, _)| ColumnStatistics {
257                null_count: stats.null_count,
258                ..Default::default()
259            })
260            .unwrap_or_default();
261        let column_statistics = inner_stats
262            .column_statistics
263            .into_iter()
264            .zip(schema.fields())
265            .filter(|(_, field)| field.name() != DIST_COL)
266            .map(|(stats, _)| stats)
267            .chain(std::iter::once(dist_stats))
268            .collect::<Vec<_>>();
269        Ok(Statistics {
270            num_rows: inner_stats.num_rows,
271            column_statistics,
272            ..Statistics::new_unknown(self.schema().as_ref())
273        })
274    }
275
276    fn metrics(&self) -> Option<MetricsSet> {
277        Some(self.metrics.clone_inner())
278    }
279
280    fn properties(&self) -> &PlanProperties {
281        &self.properties
282    }
283
284    fn supports_limit_pushdown(&self) -> bool {
285        false
286    }
287}
288
289pub static KNN_INDEX_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
290    Arc::new(Schema::new(vec![
291        Field::new(DIST_COL, DataType::Float32, true),
292        ROW_ID_FIELD.clone(),
293    ]))
294});
295
296pub static KNN_PARTITION_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
297    Arc::new(Schema::new(vec![
298        Field::new(
299            PART_ID_COLUMN,
300            DataType::List(Field::new("item", DataType::UInt32, false).into()),
301            false,
302        ),
303        Field::new(
304            DIST_Q_C_COLUMN,
305            DataType::List(Field::new("item", DataType::Float32, false).into()),
306            false,
307        ),
308        Field::new(INDEX_UUID_COLUMN, DataType::Utf8, false),
309    ]))
310});
311
312pub fn new_knn_exec(
313    dataset: Arc<Dataset>,
314    indices: &[IndexMetadata],
315    query: &Query,
316    prefilter_source: PreFilterSource,
317) -> Result<Arc<dyn ExecutionPlan>> {
318    let ivf_node = ANNIvfPartitionExec::try_new(
319        dataset.clone(),
320        indices.iter().map(|idx| idx.uuid.to_string()).collect_vec(),
321        query.clone(),
322    )?;
323
324    let sub_index = ANNIvfSubIndexExec::try_new(
325        Arc::new(ivf_node),
326        dataset,
327        indices.to_vec(),
328        query.clone(),
329        prefilter_source,
330    )?;
331
332    Ok(Arc::new(sub_index))
333}
334
335/// [ExecutionPlan] to execute the find the closest IVF partitions.
336///
337/// It searches the partition IDs using the input query.
338///
339/// It searches all index deltas in parallel.  For each delta it returns a
340/// single batch with the partition IDs and the delta index `uuid`:
341///
342/// The number of partitions returned is at most `maximum_nprobes`.  If
343/// `maximum_nprobes` is not set, it will return all partitions.  The partitions
344/// are returned in sorted order from closest to farthest.
345///
346/// Typically, all partition ids will be identical for each delta index (since delta
347/// indices have identical partitions) but the downstream nodes do not rely on this.
348///
349/// TODO: We may want to search the partitions once instead of once per delta index
350/// since the centroids are the same.
351///
352/// ```text
353/// {
354///    "__ivf_part_id": List<UInt32>,
355///    "__index_uuid": String,
356/// }
357/// ```
358#[derive(Debug)]
359pub struct ANNIvfPartitionExec {
360    pub dataset: Arc<Dataset>,
361
362    /// The vector query to execute.
363    pub query: Query,
364
365    /// The UUIDs of the indices to search.
366    pub index_uuids: Vec<String>,
367
368    pub properties: PlanProperties,
369
370    pub metrics: ExecutionPlanMetricsSet,
371}
372
373impl ANNIvfPartitionExec {
374    pub fn try_new(dataset: Arc<Dataset>, index_uuids: Vec<String>, query: Query) -> Result<Self> {
375        let dataset_schema = dataset.schema();
376        get_vector_type(dataset_schema, &query.column)?;
377        if index_uuids.is_empty() {
378            return Err(Error::Execution {
379                message: "ANNIVFPartitionExec node: no index found for query".to_string(),
380                location: location!(),
381            });
382        }
383
384        let schema = KNN_PARTITION_SCHEMA.clone();
385        let properties = PlanProperties::new(
386            EquivalenceProperties::new(schema),
387            Partitioning::RoundRobinBatch(1),
388            EmissionType::Incremental,
389            Boundedness::Bounded,
390        );
391
392        Ok(Self {
393            dataset,
394            query,
395            index_uuids,
396            properties,
397            metrics: ExecutionPlanMetricsSet::new(),
398        })
399    }
400}
401
402impl DisplayAs for ANNIvfPartitionExec {
403    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
404        match t {
405            DisplayFormatType::Default | DisplayFormatType::Verbose => {
406                write!(
407                    f,
408                    "ANNIvfPartition: uuid={}, minimum_nprobes={}, maximum_nprobes={:?}, deltas={}",
409                    self.index_uuids[0],
410                    self.query.minimum_nprobes,
411                    self.query.maximum_nprobes,
412                    self.index_uuids.len()
413                )
414            }
415            DisplayFormatType::TreeRender => {
416                write!(
417                    f,
418                    "ANNIvfPartition\nuuid={}\nminimum_nprobes={}\nmaximum_nprobes={:?}\ndeltas={}",
419                    self.index_uuids[0],
420                    self.query.minimum_nprobes,
421                    self.query.maximum_nprobes,
422                    self.index_uuids.len()
423                )
424            }
425        }
426    }
427}
428
429impl ExecutionPlan for ANNIvfPartitionExec {
430    fn name(&self) -> &str {
431        "ANNIVFPartitionExec"
432    }
433
434    fn as_any(&self) -> &dyn Any {
435        self
436    }
437
438    fn schema(&self) -> SchemaRef {
439        KNN_PARTITION_SCHEMA.clone()
440    }
441
442    fn statistics(&self) -> DataFusionResult<Statistics> {
443        Ok(Statistics {
444            num_rows: Precision::Exact(self.query.minimum_nprobes),
445            ..Statistics::new_unknown(self.schema().as_ref())
446        })
447    }
448
449    fn properties(&self) -> &PlanProperties {
450        &self.properties
451    }
452
453    fn metrics(&self) -> Option<MetricsSet> {
454        Some(self.metrics.clone_inner())
455    }
456
457    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
458        vec![]
459    }
460
461    fn with_new_children(
462        self: Arc<Self>,
463        children: Vec<Arc<dyn ExecutionPlan>>,
464    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
465        if !children.is_empty() {
466            Err(DataFusionError::Internal(
467                "ANNIVFPartitionExec node does not accept children".to_string(),
468            ))
469        } else {
470            Ok(self)
471        }
472    }
473
474    fn execute(
475        &self,
476        partition: usize,
477        _context: Arc<datafusion::execution::TaskContext>,
478    ) -> DataFusionResult<SendableRecordBatchStream> {
479        let timer = Instant::now();
480
481        let query = self.query.clone();
482        let ds = self.dataset.clone();
483        let metrics = Arc::new(AnnPartitionMetrics::new(&self.metrics, partition));
484        metrics.deltas_searched.add(self.index_uuids.len());
485        let metrics_clone = metrics.clone();
486
487        let stream = stream::iter(self.index_uuids.clone())
488            .map(move |uuid| {
489                let query = query.clone();
490                let ds = ds.clone();
491                let metrics = metrics.clone();
492                async move {
493                    let index = ds
494                        .open_vector_index(&query.column, &uuid, &metrics.index_metrics)
495                        .await?;
496                    let mut query = query.clone();
497                    if index.metric_type() == DistanceType::Cosine {
498                        let key = normalize_arrow(&query.key)?.0;
499                        query.key = key;
500                    };
501
502                    metrics.partitions_ranked.add(index.total_partitions());
503
504                    let (partitions, dist_q_c) = index.find_partitions(&query).map_err(|e| {
505                        DataFusionError::Execution(format!("Failed to find partitions: {}", e))
506                    })?;
507
508                    let mut part_list_builder = ListBuilder::new(UInt32Builder::new())
509                        .with_field(Field::new("item", DataType::UInt32, false));
510                    part_list_builder.append_value(partitions.iter());
511                    let partition_col = part_list_builder.finish();
512
513                    let mut dist_q_c_list_builder = ListBuilder::new(Float32Builder::new())
514                        .with_field(Field::new("item", DataType::Float32, false));
515                    dist_q_c_list_builder.append_value(dist_q_c.iter());
516                    let dist_q_c_col = dist_q_c_list_builder.finish();
517
518                    let uuid_col = StringArray::from(vec![uuid.as_str()]);
519                    let batch = RecordBatch::try_new(
520                        KNN_PARTITION_SCHEMA.clone(),
521                        vec![
522                            Arc::new(partition_col),
523                            Arc::new(dist_q_c_col),
524                            Arc::new(uuid_col),
525                        ],
526                    )?;
527                    metrics.baseline_metrics.record_output(batch.num_rows());
528                    Ok::<_, DataFusionError>(batch)
529                }
530            })
531            .buffered(self.index_uuids.len())
532            .finally(move || {
533                metrics_clone.baseline_metrics.done();
534                metrics_clone
535                    .baseline_metrics
536                    .elapsed_compute()
537                    .add_duration(timer.elapsed());
538            });
539        let schema = self.schema();
540        Ok(
541            Box::pin(RecordBatchStreamAdapter::new(schema, stream.boxed()))
542                as SendableRecordBatchStream,
543        )
544    }
545
546    fn supports_limit_pushdown(&self) -> bool {
547        false
548    }
549}
550
551/// Datafusion [ExecutionPlan] to run search on vector index partitions.
552///
553/// A IVF-{PQ/SQ/HNSW} query plan is:
554///
555/// ```text
556/// AnnSubIndexExec: k=10
557///   AnnPartitionExec: nprobes=20
558/// ```
559///
560/// The partition index returns one batch per delta with `maximum_nprobes` partitions in sorted order.
561///
562/// The sub-index then runs a KNN search on each partition in parallel.
563///
564/// First, the index will search `minimum_probes` partitions on each delta.  If there are enough results
565/// at that point to satisfy K then the results will be sorted and returned.
566///
567/// If there are not enough results then the prefilter will be consulted to determine how many potential
568/// results exist.  If the number is smaller than K then those additional results will be fetched directly
569/// and given maximum distance.
570///
571/// If the number of results is larger then additional partitions will be searched in batches of
572/// `cpu_parallelism` until min(K, num_filtered_results) are found or `maximum_nprobes` partitions
573/// have been searched.
574///
575/// TODO: In the future, if we can know that a filter will be highly selective (through cost estimation or
576/// user-provided hints) we wait for the prefilter results before we load any partitions.  If there are less
577/// than K (or some threshold) results then we can return without search.
578#[derive(Debug)]
579pub struct ANNIvfSubIndexExec {
580    /// Inner input source node.
581    input: Arc<dyn ExecutionPlan>,
582
583    dataset: Arc<Dataset>,
584
585    indices: Vec<IndexMetadata>,
586
587    /// Vector Query.
588    query: Query,
589
590    /// Prefiltering input
591    prefilter_source: PreFilterSource,
592
593    /// Datafusion Plan Properties
594    properties: PlanProperties,
595
596    metrics: ExecutionPlanMetricsSet,
597}
598
599impl ANNIvfSubIndexExec {
600    pub fn try_new(
601        input: Arc<dyn ExecutionPlan>,
602        dataset: Arc<Dataset>,
603        indices: Vec<IndexMetadata>,
604        query: Query,
605        prefilter_source: PreFilterSource,
606    ) -> Result<Self> {
607        if input.schema().field_with_name(PART_ID_COLUMN).is_err() {
608            return Err(Error::Index {
609                message: format!(
610                    "ANNSubIndexExec node: input schema does not have \"{}\" column",
611                    PART_ID_COLUMN
612                ),
613                location: location!(),
614            });
615        }
616        let properties = PlanProperties::new(
617            EquivalenceProperties::new(KNN_INDEX_SCHEMA.clone()),
618            Partitioning::RoundRobinBatch(1),
619            EmissionType::Final,
620            Boundedness::Bounded,
621        );
622        Ok(Self {
623            input,
624            dataset,
625            indices,
626            query,
627            prefilter_source,
628            properties,
629            metrics: ExecutionPlanMetricsSet::new(),
630        })
631    }
632}
633
634impl DisplayAs for ANNIvfSubIndexExec {
635    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
636        match t {
637            DisplayFormatType::Default | DisplayFormatType::Verbose => {
638                write!(
639                    f,
640                    "ANNSubIndex: name={}, k={}, deltas={}",
641                    self.indices[0].name,
642                    self.query.k * self.query.refine_factor.unwrap_or(1) as usize,
643                    self.indices.len()
644                )
645            }
646            DisplayFormatType::TreeRender => {
647                write!(
648                    f,
649                    "ANNSubIndex\nname={}\nk={}\ndeltas={}",
650                    self.indices[0].name,
651                    self.query.k * self.query.refine_factor.unwrap_or(1) as usize,
652                    self.indices.len()
653                )
654            }
655        }
656    }
657}
658
659struct ANNIvfEarlySearchResults {
660    k: usize,
661    initial_ids: Mutex<Vec<u64>>,
662    num_results_found: AtomicUsize,
663    deltas_remaining: AtomicUsize,
664    all_deltas_done: Notify,
665    took_no_rows_shortcut: AtomicBool,
666}
667
668impl ANNIvfEarlySearchResults {
669    fn new(deltas_remaining: usize, k: usize) -> Self {
670        Self {
671            k,
672            initial_ids: Mutex::new(Vec::with_capacity(k)),
673            num_results_found: AtomicUsize::new(0),
674            deltas_remaining: AtomicUsize::new(deltas_remaining),
675            all_deltas_done: Notify::new(),
676            took_no_rows_shortcut: AtomicBool::new(false),
677        }
678    }
679
680    fn record_batch(&self, batch: &RecordBatch) {
681        let mut initial_ids = self.initial_ids.lock().unwrap();
682        let ids_to_record = (self.k - initial_ids.len()).min(batch.num_rows());
683        initial_ids.extend(
684            batch
685                .column(1)
686                .as_primitive::<UInt64Type>()
687                .values()
688                .iter()
689                .take(ids_to_record),
690        );
691    }
692
693    fn record_late_batch(&self, num_rows: usize) {
694        self.num_results_found
695            .fetch_add(num_rows, Ordering::Relaxed);
696    }
697
698    async fn wait_for_minimum_to_finish(&self) -> usize {
699        if self.deltas_remaining.fetch_sub(1, Ordering::Relaxed) == 1 {
700            {
701                let new_num_results_found = self.initial_ids.lock().unwrap().len();
702                self.num_results_found
703                    .store(new_num_results_found, Ordering::Relaxed);
704            }
705            self.all_deltas_done.notify_waiters();
706        } else {
707            self.all_deltas_done.notified().await;
708        }
709        self.num_results_found.load(Ordering::Relaxed)
710    }
711}
712
713impl ANNIvfSubIndexExec {
714    fn late_search(
715        index: Arc<dyn VectorIndex>,
716        query: Query,
717        partitions: Arc<UInt32Array>,
718        q_c_dists: Arc<Float32Array>,
719        prefilter: Arc<DatasetPreFilter>,
720        metrics: Arc<AnnIndexMetrics>,
721        state: Arc<ANNIvfEarlySearchResults>,
722    ) -> impl Stream<Item = DataFusionResult<RecordBatch>> {
723        let stream = futures::stream::once(async move {
724            let max_nprobes = query.maximum_nprobes.unwrap_or(partitions.len());
725            if max_nprobes == query.minimum_nprobes {
726                // We've already searched all partitions, no late search needed
727                return futures::stream::empty().boxed();
728            }
729
730            let found_so_far = state.wait_for_minimum_to_finish().await;
731            if found_so_far >= query.k {
732                // We found enough results, no need for late search
733                return futures::stream::empty().boxed();
734            }
735
736            // We know the prefilter should be ready at this point so we shouldn't
737            // need to call wait_for_ready
738            let prefilter_mask = prefilter.mask();
739
740            let max_results = prefilter_mask.max_len().map(|x| x as usize);
741
742            if let Some(max_results) = max_results {
743                if found_so_far < max_results && max_results <= query.k {
744                    // In this case there are fewer than k results matching the prefilter so
745                    // just return the prefilter ids and don't bother searching any further
746
747                    // This next if check should be true, because we wouldn't get max_results otherwise
748                    if let Some(iter_ids) = prefilter_mask.iter_ids() {
749                        // We only run this on the first delta because the prefilter mask is shared
750                        // by all deltas and we don't want to duplicate the rows.
751                        if state
752                            .took_no_rows_shortcut
753                            .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
754                            .is_ok()
755                        {
756                            let initial_ids = state.initial_ids.lock().unwrap();
757                            let found_ids = HashSet::<_>::from_iter(initial_ids.iter().copied());
758                            drop(initial_ids);
759                            let mask_ids = HashSet::from_iter(iter_ids.map(u64::from));
760                            let not_found_ids = mask_ids.difference(&found_ids);
761                            let not_found_ids =
762                                UInt64Array::from_iter_values(not_found_ids.copied());
763                            let not_found_distance =
764                                Float32Array::from_value(f32::INFINITY, not_found_ids.len());
765                            let not_found_batch = RecordBatch::try_new(
766                                KNN_INDEX_SCHEMA.clone(),
767                                vec![Arc::new(not_found_distance), Arc::new(not_found_ids)],
768                            )
769                            .unwrap();
770                            return futures::stream::once(async move { Ok(not_found_batch) })
771                                .boxed();
772                        } else {
773                            // We meet all the criteria for an early exit, but we aren't first
774                            // delta so we just return an empty stream and skip the late search
775                            return futures::stream::empty().boxed();
776                        }
777                    }
778                }
779            }
780
781            // Stop searching if we have k results or we've found all the results
782            // that could possible match the prefilter
783            let max_results = max_results.unwrap_or(usize::MAX).min(query.k);
784
785            let state_clone = state.clone();
786
787            futures::stream::iter(query.minimum_nprobes..max_nprobes)
788                .map(move |idx| {
789                    let part_id = partitions.value(idx);
790                    let mut query = query.clone();
791                    query.dist_q_c = q_c_dists.value(idx);
792                    let metrics = metrics.clone();
793                    let pre_filter = prefilter.clone();
794                    let state = state.clone();
795                    let index = index.clone();
796                    async move {
797                        let mut query = query.clone();
798                        if index.metric_type() == DistanceType::Cosine {
799                            let key = normalize_arrow(&query.key)?.0;
800                            query.key = key;
801                        };
802
803                        metrics.partitions_searched.add(1);
804                        let batch = index
805                            .search_in_partition(
806                                part_id as usize,
807                                &query,
808                                pre_filter,
809                                &metrics.index_metrics,
810                            )
811                            .map_err(|e| {
812                                DataFusionError::Execution(format!(
813                                    "Failed to calculate KNN: {}",
814                                    e
815                                ))
816                            })
817                            .await?;
818                        metrics.baseline_metrics.record_output(batch.num_rows());
819                        state.record_late_batch(batch.num_rows());
820                        Ok(batch)
821                    }
822                })
823                .take_while(move |_| {
824                    let found_so_far = state_clone.num_results_found.load(Ordering::Relaxed);
825                    std::future::ready(found_so_far < max_results)
826                })
827                .buffered(get_num_compute_intensive_cpus())
828                .boxed()
829        });
830        stream.flatten()
831    }
832
833    fn initial_search(
834        index: Arc<dyn VectorIndex>,
835        query: Query,
836        partitions: Arc<UInt32Array>,
837        q_c_dists: Arc<Float32Array>,
838        prefilter: Arc<DatasetPreFilter>,
839        metrics: Arc<AnnIndexMetrics>,
840        state: Arc<ANNIvfEarlySearchResults>,
841    ) -> impl Stream<Item = DataFusionResult<RecordBatch>> {
842        let minimum_nprobes = query.minimum_nprobes.min(partitions.len());
843        metrics.partitions_searched.add(minimum_nprobes);
844
845        futures::stream::iter(0..minimum_nprobes)
846            .map(move |idx| {
847                let part_id = partitions.value(idx);
848                let mut query = query.clone();
849                query.dist_q_c = q_c_dists.value(idx);
850                let metrics = metrics.clone();
851                let index = index.clone();
852                let pre_filter = prefilter.clone();
853                let state = state.clone();
854                async move {
855                    let mut query = query.clone();
856                    if index.metric_type() == DistanceType::Cosine {
857                        let key = normalize_arrow(&query.key)?.0;
858                        query.key = key;
859                    };
860
861                    let batch = index
862                        .search_in_partition(
863                            part_id as usize,
864                            &query,
865                            pre_filter,
866                            &metrics.index_metrics,
867                        )
868                        .map_err(|e| {
869                            DataFusionError::Execution(format!("Failed to calculate KNN: {}", e))
870                        })
871                        .await?;
872                    metrics.baseline_metrics.record_output(batch.num_rows());
873                    state.record_batch(&batch);
874                    Ok(batch)
875                }
876            })
877            .buffered(get_num_compute_intensive_cpus())
878    }
879}
880
881impl ExecutionPlan for ANNIvfSubIndexExec {
882    fn name(&self) -> &str {
883        "ANNSubIndexExec"
884    }
885
886    fn as_any(&self) -> &dyn Any {
887        self
888    }
889
890    fn schema(&self) -> arrow_schema::SchemaRef {
891        KNN_INDEX_SCHEMA.clone()
892    }
893
894    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
895        match &self.prefilter_source {
896            PreFilterSource::None => vec![&self.input],
897            PreFilterSource::FilteredRowIds(src) => vec![&self.input, &src],
898            PreFilterSource::ScalarIndexQuery(src) => vec![&self.input, &src],
899        }
900    }
901
902    fn required_input_distribution(&self) -> Vec<Distribution> {
903        // Prefilter inputs must be a single partition
904        self.children()
905            .iter()
906            .map(|_| Distribution::SinglePartition)
907            .collect()
908    }
909
910    fn with_new_children(
911        self: Arc<Self>,
912        mut children: Vec<Arc<dyn ExecutionPlan>>,
913    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
914        let plan = if children.len() == 1 || children.len() == 2 {
915            let prefilter_source = if children.len() == 2 {
916                let prefilter = children.pop().expect("length checked");
917                match &self.prefilter_source {
918                    PreFilterSource::None => PreFilterSource::None,
919                    PreFilterSource::FilteredRowIds(_) => {
920                        PreFilterSource::FilteredRowIds(prefilter)
921                    }
922                    PreFilterSource::ScalarIndexQuery(_) => {
923                        PreFilterSource::ScalarIndexQuery(prefilter)
924                    }
925                }
926            } else {
927                self.prefilter_source.clone()
928            };
929
930            Self {
931                input: children.pop().expect("length checked"),
932                dataset: self.dataset.clone(),
933                indices: self.indices.clone(),
934                query: self.query.clone(),
935                prefilter_source,
936                properties: self.properties.clone(),
937                metrics: ExecutionPlanMetricsSet::new(),
938            }
939        } else {
940            return Err(DataFusionError::Internal(
941                "ANNSubIndexExec node must have exactly one or two (prefilter) child".to_string(),
942            ));
943        };
944        Ok(Arc::new(plan))
945    }
946
947    fn execute(
948        &self,
949        partition: usize,
950        context: Arc<datafusion::execution::context::TaskContext>,
951    ) -> DataFusionResult<datafusion::physical_plan::SendableRecordBatchStream> {
952        let input_stream = self.input.execute(partition, context.clone())?;
953        let schema = self.schema();
954        let query = self.query.clone();
955        let ds = self.dataset.clone();
956        let column = self.query.column.clone();
957        let indices = self.indices.clone();
958        let prefilter_source = self.prefilter_source.clone();
959        let metrics = Arc::new(AnnIndexMetrics::new(&self.metrics, partition));
960        let metrics_clone = metrics.clone();
961        let timer = Instant::now();
962
963        // Per-delta-index stream:
964        //   Stream<(parttitions, index uuid)>
965        let per_index_stream = input_stream
966            .and_then(move |batch| {
967                let part_id_col = batch.column_by_name(PART_ID_COLUMN).unwrap_or_else(|| {
968                    panic!("ANNSubIndexExec: input missing {} column", PART_ID_COLUMN)
969                });
970                let part_id_arr = part_id_col.as_list::<i32>().clone();
971                let dist_q_c_col = batch.column_by_name(DIST_Q_C_COLUMN).unwrap_or_else(|| {
972                    panic!("ANNSubIndexExec: input missing {} column", DIST_Q_C_COLUMN)
973                });
974                let dist_q_c_arr = dist_q_c_col.as_list::<i32>().clone();
975                let index_uuid_col = batch.column_by_name(INDEX_UUID_COLUMN).unwrap_or_else(|| {
976                    panic!(
977                        "ANNSubIndexExec: input missing {} column",
978                        INDEX_UUID_COLUMN
979                    )
980                });
981                let index_uuid = index_uuid_col.as_string::<i32>().clone();
982
983                let plan: Vec<DataFusionResult<(_, _, _)>> = part_id_arr
984                    .iter()
985                    .zip(dist_q_c_arr.iter())
986                    .zip(index_uuid.iter())
987                    .map(|((part_id, dist_q_c), uuid)| {
988                        let partitions =
989                            Arc::new(part_id.unwrap().as_primitive::<UInt32Type>().clone());
990                        let dist_q_c =
991                            Arc::new(dist_q_c.unwrap().as_primitive::<Float32Type>().clone());
992                        let uuid = uuid.unwrap().to_string();
993                        Ok((partitions, dist_q_c, uuid))
994                    })
995                    .collect_vec();
996                async move { DataFusionResult::Ok(stream::iter(plan)) }
997            })
998            .try_flatten();
999        let prefilter_loader = match &prefilter_source {
1000            PreFilterSource::FilteredRowIds(src_node) => {
1001                let stream = src_node.execute(partition, context)?;
1002                Some(Box::new(FilteredRowIdsToPrefilter(stream)) as Box<dyn FilterLoader>)
1003            }
1004            PreFilterSource::ScalarIndexQuery(src_node) => {
1005                let stream = src_node.execute(partition, context)?;
1006                Some(Box::new(SelectionVectorToPrefilter(stream)) as Box<dyn FilterLoader>)
1007            }
1008            PreFilterSource::None => None,
1009        };
1010
1011        let pre_filter = Arc::new(DatasetPreFilter::new(
1012            ds.clone(),
1013            &indices,
1014            prefilter_loader,
1015        ));
1016
1017        let state = Arc::new(ANNIvfEarlySearchResults::new(indices.len(), query.k));
1018
1019        Ok(Box::pin(RecordBatchStreamAdapter::new(
1020            schema,
1021            per_index_stream
1022                .and_then(move |(part_ids, q_c_dists, index_uuid)| {
1023                    let ds = ds.clone();
1024                    let column = column.clone();
1025                    let metrics = metrics.clone();
1026                    let pre_filter = pre_filter.clone();
1027                    let state = state.clone();
1028                    let query = query.clone();
1029
1030                    async move {
1031                        let raw_index = ds
1032                            .open_vector_index(&column, &index_uuid, &metrics.index_metrics)
1033                            .await?;
1034
1035                        let early_search = Self::initial_search(
1036                            raw_index.clone(),
1037                            query.clone(),
1038                            part_ids.clone(),
1039                            q_c_dists.clone(),
1040                            pre_filter.clone(),
1041                            metrics.clone(),
1042                            state.clone(),
1043                        );
1044                        let late_search = Self::late_search(
1045                            raw_index.clone(),
1046                            query,
1047                            part_ids,
1048                            q_c_dists,
1049                            pre_filter,
1050                            metrics,
1051                            state,
1052                        );
1053                        DataFusionResult::Ok(early_search.chain(late_search).boxed())
1054                    }
1055                })
1056                // Must use flatten_unordered to avoid deadlock.
1057                // Each delta stream is split into an early and late search.  The late search
1058                // will not start until the early search is complete across all deltas.
1059                .try_flatten_unordered(None)
1060                .finally(move || {
1061                    metrics_clone
1062                        .baseline_metrics
1063                        .elapsed_compute()
1064                        .add_duration(timer.elapsed());
1065                    metrics_clone.baseline_metrics.done();
1066                })
1067                .boxed(),
1068        )))
1069    }
1070
1071    fn partition_statistics(
1072        &self,
1073        partition: Option<usize>,
1074    ) -> DataFusionResult<datafusion::physical_plan::Statistics> {
1075        Ok(Statistics {
1076            num_rows: Precision::Exact(
1077                self.query.k
1078                    * self.query.refine_factor.unwrap_or(1) as usize
1079                    * self
1080                        .input
1081                        .partition_statistics(partition)?
1082                        .num_rows
1083                        .get_value()
1084                        .unwrap_or(&1),
1085            ),
1086            ..Statistics::new_unknown(self.schema().as_ref())
1087        })
1088    }
1089
1090    fn metrics(&self) -> Option<MetricsSet> {
1091        Some(self.metrics.clone_inner())
1092    }
1093
1094    fn properties(&self) -> &PlanProperties {
1095        &self.properties
1096    }
1097
1098    fn supports_limit_pushdown(&self) -> bool {
1099        false
1100    }
1101}
1102
1103#[derive(Debug)]
1104pub struct MultivectorScoringExec {
1105    // the inputs are sorted ANN search results
1106    inputs: Vec<Arc<dyn ExecutionPlan>>,
1107    query: Query,
1108    properties: PlanProperties,
1109}
1110
1111impl MultivectorScoringExec {
1112    pub fn try_new(inputs: Vec<Arc<dyn ExecutionPlan>>, query: Query) -> Result<Self> {
1113        let properties = PlanProperties::new(
1114            EquivalenceProperties::new(KNN_INDEX_SCHEMA.clone()),
1115            Partitioning::RoundRobinBatch(1),
1116            EmissionType::Final,
1117            Boundedness::Bounded,
1118        );
1119
1120        Ok(Self {
1121            inputs,
1122            query,
1123            properties,
1124        })
1125    }
1126}
1127
1128impl DisplayAs for MultivectorScoringExec {
1129    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1130        match t {
1131            DisplayFormatType::Default | DisplayFormatType::Verbose => {
1132                write!(f, "MultivectorScoring: k={}", self.query.k)
1133            }
1134            DisplayFormatType::TreeRender => {
1135                write!(f, "MultivectorScoring\nk={}", self.query.k)
1136            }
1137        }
1138    }
1139}
1140
1141impl ExecutionPlan for MultivectorScoringExec {
1142    fn name(&self) -> &str {
1143        "MultivectorScoringExec"
1144    }
1145
1146    fn as_any(&self) -> &dyn Any {
1147        self
1148    }
1149
1150    fn schema(&self) -> arrow_schema::SchemaRef {
1151        KNN_INDEX_SCHEMA.clone()
1152    }
1153
1154    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1155        self.inputs.iter().collect()
1156    }
1157
1158    fn required_input_distribution(&self) -> Vec<Distribution> {
1159        // This node fully consumes and re-orders the input rows.  It must be
1160        // run on a single partition.
1161        self.children()
1162            .iter()
1163            .map(|_| Distribution::SinglePartition)
1164            .collect()
1165    }
1166
1167    fn with_new_children(
1168        self: Arc<Self>,
1169        children: Vec<Arc<dyn ExecutionPlan>>,
1170    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
1171        let plan = Self::try_new(children, self.query.clone())?;
1172        Ok(Arc::new(plan))
1173    }
1174
1175    fn execute(
1176        &self,
1177        partition: usize,
1178        context: Arc<datafusion::execution::context::TaskContext>,
1179    ) -> DataFusionResult<SendableRecordBatchStream> {
1180        let inputs = self
1181            .inputs
1182            .iter()
1183            .map(|input| input.execute(partition, context.clone()))
1184            .collect::<DataFusionResult<Vec<_>>>()?;
1185
1186        // collect the top k results from each stream,
1187        // and max-reduce for each query,
1188        // records the minimum distance for each query as estimation.
1189        let mut reduced_inputs = stream::select_all(inputs.into_iter().map(|stream| {
1190            stream.map(|batch| {
1191                let batch = batch?;
1192                let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>();
1193                let dists = batch[DIST_COL].as_primitive::<Float32Type>();
1194                debug_assert_eq!(dists.null_count(), 0);
1195
1196                // max-reduce for the same row id
1197                let min_sim = dists
1198                    .values()
1199                    .last()
1200                    .map(|dist| 1.0 - *dist)
1201                    .unwrap_or_default();
1202                let mut new_row_ids = Vec::with_capacity(row_ids.len());
1203                let mut new_sims = Vec::with_capacity(row_ids.len());
1204                let mut visited_row_ids = HashSet::with_capacity(row_ids.len());
1205
1206                for (row_id, dist) in row_ids.values().iter().zip(dists.values().iter()) {
1207                    // the results are sorted by distance, so we can skip if we have seen this row id before
1208                    if visited_row_ids.contains(row_id) {
1209                        continue;
1210                    }
1211                    visited_row_ids.insert(row_id);
1212                    new_row_ids.push(*row_id);
1213                    // it's cosine distance, so we need to convert it to similarity
1214                    new_sims.push(1.0 - *dist);
1215                }
1216                let new_row_ids = UInt64Array::from(new_row_ids);
1217                let new_dists = Float32Array::from(new_sims);
1218
1219                let batch = RecordBatch::try_new(
1220                    KNN_INDEX_SCHEMA.clone(),
1221                    vec![Arc::new(new_dists), Arc::new(new_row_ids)],
1222                )?;
1223
1224                Ok::<_, DataFusionError>((min_sim, batch))
1225            })
1226        }));
1227
1228        let k = self.query.k;
1229        let refactor = self.query.refine_factor.unwrap_or(1) as usize;
1230        let num_queries = self.inputs.len() as f32;
1231        let stream = stream::once(async move {
1232            // at most, we will have k * refine_factor results for each query
1233            let mut results = HashMap::with_capacity(k * refactor);
1234            let mut missed_sim_sum = 0.0;
1235            while let Some((min_sim, batch)) = reduced_inputs.try_next().await? {
1236                let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>();
1237                let sims = batch[DIST_COL].as_primitive::<Float32Type>();
1238
1239                let query_results = row_ids
1240                    .values()
1241                    .iter()
1242                    .copied()
1243                    .zip(sims.values().iter().copied())
1244                    .collect::<HashMap<_, _>>();
1245
1246                // for a row `r`:
1247                // if `r` is in only `results``, then `results[r] += min_sim`
1248                // if `r` is in only `query_results`, then `results[r] = query_results[r] + missed_similarities`,
1249                // here `missed_similarities` is the sum of `min_sim` from previous iterations
1250                // if `r` is in both, then `results[r] += query_results[r]`
1251                results.iter_mut().for_each(|(row_id, sim)| {
1252                    if let Some(new_dist) = query_results.get(row_id) {
1253                        *sim += new_dist;
1254                    } else {
1255                        *sim += min_sim;
1256                    }
1257                });
1258                query_results.into_iter().for_each(|(row_id, sim)| {
1259                    results.entry(row_id).or_insert(sim + missed_sim_sum);
1260                });
1261                missed_sim_sum += min_sim;
1262            }
1263
1264            let (row_ids, sims): (Vec<_>, Vec<_>) = results.into_iter().unzip();
1265            let dists = sims
1266                .into_iter()
1267                // it's similarity, so we need to convert it back to distance
1268                .map(|sim| num_queries - sim)
1269                .collect::<Vec<_>>();
1270            let row_ids = UInt64Array::from(row_ids);
1271            let dists = Float32Array::from(dists);
1272            let batch = RecordBatch::try_new(
1273                KNN_INDEX_SCHEMA.clone(),
1274                vec![Arc::new(dists), Arc::new(row_ids)],
1275            )?;
1276            Ok::<_, DataFusionError>(batch)
1277        });
1278        Ok(Box::pin(RecordBatchStreamAdapter::new(
1279            self.schema(),
1280            stream.boxed(),
1281        )))
1282    }
1283
1284    fn statistics(&self) -> DataFusionResult<Statistics> {
1285        Ok(Statistics {
1286            num_rows: Precision::Inexact(
1287                self.query.k * self.query.refine_factor.unwrap_or(1) as usize,
1288            ),
1289            ..Statistics::new_unknown(self.schema().as_ref())
1290        })
1291    }
1292
1293    fn properties(&self) -> &PlanProperties {
1294        &self.properties
1295    }
1296
1297    fn supports_limit_pushdown(&self) -> bool {
1298        false
1299    }
1300}
1301
1302#[cfg(test)]
1303mod tests {
1304    use super::*;
1305
1306    use arrow::compute::{concat_batches, sort_to_indices, take_record_batch};
1307    use arrow::datatypes::Float32Type;
1308    use arrow_array::{FixedSizeListArray, Int32Array, RecordBatchIterator, StringArray};
1309    use arrow_schema::{Field as ArrowField, Schema as ArrowSchema};
1310    use lance_core::utils::tempfile::TempStrDir;
1311    use lance_datafusion::exec::{ExecutionStatsCallback, ExecutionSummaryCounts};
1312    use lance_datagen::{array, BatchCount, RowCount};
1313    use lance_index::optimize::OptimizeOptions;
1314    use lance_index::vector::ivf::IvfBuildParams;
1315    use lance_index::vector::pq::PQBuildParams;
1316    use lance_index::{DatasetIndexExt, IndexType};
1317    use lance_linalg::distance::MetricType;
1318    use lance_testing::datagen::generate_random_array;
1319    use rstest::rstest;
1320
1321    use crate::dataset::{WriteMode, WriteParams};
1322    use crate::index::vector::VectorIndexParams;
1323    use crate::io::exec::testing::TestingExec;
1324
1325    #[tokio::test]
1326    async fn knn_flat_search() {
1327        let schema = Arc::new(ArrowSchema::new(vec![
1328            ArrowField::new("key", DataType::Int32, false),
1329            ArrowField::new(
1330                "vector",
1331                DataType::FixedSizeList(
1332                    Arc::new(ArrowField::new("item", DataType::Float32, true)),
1333                    128,
1334                ),
1335                true,
1336            ),
1337            ArrowField::new("uri", DataType::Utf8, true),
1338        ]));
1339
1340        let batches: Vec<RecordBatch> = (0..20)
1341            .map(|i| {
1342                RecordBatch::try_new(
1343                    schema.clone(),
1344                    vec![
1345                        Arc::new(Int32Array::from_iter_values(i * 20..(i + 1) * 20)),
1346                        Arc::new(
1347                            FixedSizeListArray::try_new_from_values(
1348                                generate_random_array(128 * 20),
1349                                128,
1350                            )
1351                            .unwrap(),
1352                        ),
1353                        Arc::new(StringArray::from_iter_values(
1354                            (i * 20..(i + 1) * 20).map(|i| format!("s3://bucket/file-{}", i)),
1355                        )),
1356                    ],
1357                )
1358                .unwrap()
1359            })
1360            .collect();
1361
1362        let test_dir = TempStrDir::default();
1363        let test_uri = test_dir.as_str();
1364
1365        let write_params = WriteParams {
1366            max_rows_per_file: 40,
1367            max_rows_per_group: 10,
1368            ..Default::default()
1369        };
1370        let vector_arr = batches[0].column_by_name("vector").unwrap();
1371        let q = as_fixed_size_list_array(&vector_arr).value(5);
1372
1373        let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());
1374        Dataset::write(reader, test_uri, Some(write_params))
1375            .await
1376            .unwrap();
1377
1378        let dataset = Dataset::open(test_uri).await.unwrap();
1379        let stream = dataset
1380            .scan()
1381            .nearest("vector", q.as_primitive::<Float32Type>(), 10)
1382            .unwrap()
1383            .try_into_stream()
1384            .await
1385            .unwrap();
1386        let results = stream.try_collect::<Vec<_>>().await.unwrap();
1387
1388        assert!(results[0].schema().column_with_name(DIST_COL).is_some());
1389
1390        assert_eq!(results.len(), 1);
1391
1392        let stream = dataset.scan().try_into_stream().await.unwrap();
1393        let all_with_distances = stream
1394            .and_then(|batch| compute_distance(q.clone(), DistanceType::L2, "vector", batch))
1395            .try_collect::<Vec<_>>()
1396            .await
1397            .unwrap();
1398        let all_with_distances =
1399            concat_batches(&results[0].schema(), all_with_distances.iter()).unwrap();
1400        let dist_arr = all_with_distances.column_by_name(DIST_COL).unwrap();
1401        let distances = dist_arr.as_primitive::<Float32Type>();
1402        let indices = sort_to_indices(distances, None, Some(10)).unwrap();
1403        let expected = take_record_batch(&all_with_distances, &indices).unwrap();
1404        assert_eq!(expected, results[0]);
1405    }
1406
1407    #[test]
1408    fn test_create_knn_flat() {
1409        let dim: usize = 128;
1410        let schema = Arc::new(ArrowSchema::new(vec![
1411            ArrowField::new("key", DataType::Int32, false),
1412            ArrowField::new(
1413                "vector",
1414                DataType::FixedSizeList(
1415                    Arc::new(ArrowField::new("item", DataType::Float32, true)),
1416                    dim as i32,
1417                ),
1418                true,
1419            ),
1420            ArrowField::new("uri", DataType::Utf8, true),
1421        ]));
1422        let batch = RecordBatch::new_empty(schema);
1423
1424        let input: Arc<dyn ExecutionPlan> = Arc::new(TestingExec::new(vec![batch]));
1425
1426        let idx = KNNVectorDistanceExec::try_new(
1427            input,
1428            "vector",
1429            Arc::new(generate_random_array(dim)),
1430            DistanceType::L2,
1431        )
1432        .unwrap();
1433        assert_eq!(
1434            idx.schema().as_ref(),
1435            &ArrowSchema::new(vec![
1436                ArrowField::new("key", DataType::Int32, false),
1437                ArrowField::new(
1438                    "vector",
1439                    DataType::FixedSizeList(
1440                        Arc::new(ArrowField::new("item", DataType::Float32, true)),
1441                        dim as i32,
1442                    ),
1443                    true,
1444                ),
1445                ArrowField::new("uri", DataType::Utf8, true),
1446                ArrowField::new(DIST_COL, DataType::Float32, true),
1447            ])
1448        );
1449    }
1450
1451    #[tokio::test]
1452    async fn test_multivector_score() {
1453        let query = Query {
1454            column: "vector".to_string(),
1455            key: Arc::new(generate_random_array(1)),
1456            k: 10,
1457            lower_bound: None,
1458            upper_bound: None,
1459            minimum_nprobes: 1,
1460            maximum_nprobes: None,
1461            ef: None,
1462            refine_factor: None,
1463            metric_type: DistanceType::Cosine,
1464            use_index: true,
1465            dist_q_c: 0.0,
1466        };
1467
1468        async fn multivector_scoring(
1469            inputs: Vec<Arc<dyn ExecutionPlan>>,
1470            query: Query,
1471        ) -> Result<HashMap<u64, f32>> {
1472            let ctx = Arc::new(datafusion::execution::context::TaskContext::default());
1473            let plan = MultivectorScoringExec::try_new(inputs, query.clone())?;
1474            let batches = plan
1475                .execute(0, ctx.clone())
1476                .unwrap()
1477                .try_collect::<Vec<_>>()
1478                .await?;
1479            let mut results = HashMap::new();
1480            for batch in batches {
1481                let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>();
1482                let dists = batch[DIST_COL].as_primitive::<Float32Type>();
1483                for (row_id, dist) in row_ids.values().iter().zip(dists.values().iter()) {
1484                    results.insert(*row_id, *dist);
1485                }
1486            }
1487            Ok(results)
1488        }
1489
1490        let batches = (0..3)
1491            .map(|i| {
1492                RecordBatch::try_new(
1493                    KNN_INDEX_SCHEMA.clone(),
1494                    vec![
1495                        Arc::new(Float32Array::from(vec![i as f32 + 1.0, i as f32 + 2.0])),
1496                        Arc::new(UInt64Array::from(vec![i + 1, i + 2])),
1497                    ],
1498                )
1499                .unwrap()
1500            })
1501            .collect::<Vec<_>>();
1502
1503        let mut res: Option<HashMap<_, _>> = None;
1504        for perm in batches.into_iter().permutations(3) {
1505            let inputs = perm
1506                .into_iter()
1507                .map(|batch| {
1508                    let input: Arc<dyn ExecutionPlan> = Arc::new(TestingExec::new(vec![batch]));
1509                    input
1510                })
1511                .collect::<Vec<_>>();
1512            let new_res = multivector_scoring(inputs, query.clone()).await.unwrap();
1513            assert_eq!(new_res.len(), 4);
1514            if let Some(res) = &res {
1515                for (row_id, dist) in new_res.iter() {
1516                    assert_eq!(res.get(row_id).unwrap(), dist)
1517                }
1518            } else {
1519                res = Some(new_res);
1520            }
1521        }
1522    }
1523
1524    /// A test dataset for testing the nprobes parameter.
1525    ///
1526    /// The dataset has 100 partitions and filterable columns setup to easily create
1527    /// filters whose results are spread across the partitions evenly.
1528    struct NprobesTestFixture {
1529        dataset: Dataset,
1530        centroids: Arc<dyn Array>,
1531        _tmp_dir: TempStrDir,
1532    }
1533
1534    impl NprobesTestFixture {
1535        pub async fn new(num_centroids: usize, num_deltas: usize) -> Self {
1536            let tempdir = TempStrDir::default();
1537            let tmppath = tempdir.as_str();
1538
1539            // We create 100 centroids
1540            // We generate 10,000 vectors evenly divided (100 vectors per centroid)
1541            // We assign labels 0..60 to the vectors so each label has ~164 vectors
1542            //   spread out through all of the centroids
1543            let centroids = array::cycle_unit_circle(num_centroids as u32)
1544                .generate_default(RowCount::from(num_centroids as u64))
1545                .unwrap();
1546
1547            // Let's not deal with fractions
1548            assert!(100 % num_deltas == 0, "num_deltas must divide 100");
1549            let rows_per_frag = 100;
1550            let num_frags = 100;
1551            let frags_per_delta = num_frags / num_deltas;
1552
1553            let batches = lance_datagen::gen_batch()
1554                .col("vector", array::jitter_centroids(centroids.clone(), 0.0001))
1555                .col("label", array::cycle::<UInt32Type>(Vec::from_iter(0..61)))
1556                .col("userid", array::step::<UInt64Type>())
1557                .into_reader_rows(
1558                    RowCount::from(rows_per_frag),
1559                    BatchCount::from(num_frags as u32),
1560                )
1561                .collect::<Vec<_>>();
1562            let schema = batches[0].as_ref().unwrap().schema();
1563
1564            let mut first = true;
1565            for batches in batches.chunks(frags_per_delta) {
1566                let delta_batches = batches
1567                    .iter()
1568                    .map(|maybe_batch| Ok(maybe_batch.as_ref().unwrap().clone()))
1569                    .collect::<Vec<_>>();
1570                let reader = RecordBatchIterator::new(delta_batches, schema.clone());
1571                let mut dataset = Dataset::write(
1572                    reader,
1573                    tmppath,
1574                    Some(WriteParams {
1575                        mode: WriteMode::Append,
1576                        ..Default::default()
1577                    }),
1578                )
1579                .await
1580                .unwrap();
1581
1582                let ivf_params = IvfBuildParams::try_with_centroids(
1583                    num_centroids,
1584                    Arc::new(centroids.as_fixed_size_list().clone()),
1585                )
1586                .unwrap();
1587
1588                let codebook = array::rand::<Float32Type>()
1589                    .generate_default(RowCount::from(256 * 2))
1590                    .unwrap();
1591                let pq_params = PQBuildParams::with_codebook(2, 8, codebook);
1592                let index_params =
1593                    VectorIndexParams::with_ivf_pq_params(MetricType::L2, ivf_params, pq_params);
1594
1595                if first {
1596                    first = false;
1597                    dataset
1598                        .create_index(&["vector"], IndexType::Vector, None, &index_params, false)
1599                        .await
1600                        .unwrap();
1601                } else {
1602                    dataset
1603                        .optimize_indices(&OptimizeOptions::append())
1604                        .await
1605                        .unwrap();
1606                }
1607            }
1608
1609            let dataset = Dataset::open(tmppath).await.unwrap();
1610            Self {
1611                dataset,
1612                centroids,
1613                _tmp_dir: tempdir,
1614            }
1615        }
1616
1617        pub fn get_centroid(&self, idx: usize) -> Arc<dyn Array> {
1618            let centroids = self.centroids.as_fixed_size_list();
1619            centroids.value(idx).clone()
1620        }
1621    }
1622
1623    #[derive(Default)]
1624    struct StatsHolder {
1625        pub collected_stats: Arc<Mutex<Option<ExecutionSummaryCounts>>>,
1626    }
1627
1628    impl StatsHolder {
1629        fn get_setter(&self) -> ExecutionStatsCallback {
1630            let collected_stats = self.collected_stats.clone();
1631            Arc::new(move |stats| {
1632                *collected_stats.lock().unwrap() = Some(stats.clone());
1633            })
1634        }
1635
1636        fn consume(self) -> ExecutionSummaryCounts {
1637            self.collected_stats.lock().unwrap().take().unwrap()
1638        }
1639    }
1640
1641    #[rstest]
1642    #[tokio::test]
1643    async fn test_no_max_nprobes(#[values(1, 20)] num_deltas: usize) {
1644        let fixture = NprobesTestFixture::new(100, num_deltas).await;
1645
1646        let q = fixture.get_centroid(0);
1647        let stats_holder = StatsHolder::default();
1648
1649        let results = fixture
1650            .dataset
1651            .scan()
1652            .nearest("vector", q.as_ref(), 50)
1653            .unwrap()
1654            .minimum_nprobes(10)
1655            .prefilter(true)
1656            .scan_stats_callback(stats_holder.get_setter())
1657            .filter("label = 17")
1658            .unwrap()
1659            .project(&Vec::<String>::new())
1660            .unwrap()
1661            .with_row_id()
1662            .try_into_batch()
1663            .await
1664            .unwrap();
1665
1666        assert_eq!(results.num_rows(), 50);
1667
1668        let stats = stats_holder.consume();
1669
1670        // We should not search _all_ partitions because we should hit 50 results partway
1671        // through the late search.
1672        // The exact number here is deterministic but depends on the number of CPUs
1673        if get_num_compute_intensive_cpus() <= 32 {
1674            assert!(*stats.all_counts.get(PARTITIONS_SEARCHED_METRIC).unwrap() < 100 * num_deltas);
1675        }
1676    }
1677
1678    #[rstest]
1679    #[tokio::test]
1680    async fn test_no_prefilter_results(#[values(1, 20)] num_deltas: usize) {
1681        let fixture = NprobesTestFixture::new(100, num_deltas).await;
1682
1683        let q = fixture.get_centroid(0);
1684        let stats_holder = StatsHolder::default();
1685
1686        let results = fixture
1687            .dataset
1688            .scan()
1689            .nearest("vector", q.as_ref(), 50)
1690            .unwrap()
1691            .minimum_nprobes(10)
1692            .prefilter(true)
1693            .scan_stats_callback(stats_holder.get_setter())
1694            .filter("label = 17 AND label = 18")
1695            .unwrap()
1696            .project(&Vec::<String>::new())
1697            .unwrap()
1698            .with_row_id()
1699            .try_into_batch()
1700            .await
1701            .unwrap();
1702
1703        assert_eq!(results.num_rows(), 0);
1704
1705        let stats = stats_holder.consume();
1706        // We still do the early search because we don't wait for the prefilter to execute
1707        // We skip the late search because by then we know there are no results
1708        assert_eq!(
1709            stats.all_counts.get(PARTITIONS_SEARCHED_METRIC).unwrap(),
1710            &(10 * num_deltas)
1711        );
1712    }
1713
1714    #[rstest]
1715    #[tokio::test]
1716    async fn test_some_max_nprobes(#[values(1, 20)] num_deltas: usize) {
1717        let fixture = NprobesTestFixture::new(100, num_deltas).await;
1718
1719        for (max_nprobes, expected_results) in [(10, 16), (20, 33), (30, 48)] {
1720            let q = fixture.get_centroid(0);
1721            let stats_holder = StatsHolder::default();
1722            let results = fixture
1723                .dataset
1724                .scan()
1725                .nearest("vector", q.as_ref(), 50)
1726                .unwrap()
1727                .minimum_nprobes(10)
1728                .maximum_nprobes(max_nprobes)
1729                .prefilter(true)
1730                .filter("label = 17")
1731                .unwrap()
1732                .scan_stats_callback(stats_holder.get_setter())
1733                .project(&Vec::<String>::new())
1734                .unwrap()
1735                .with_row_id()
1736                .try_into_batch()
1737                .await
1738                .unwrap();
1739
1740            let stats = stats_holder.consume();
1741
1742            assert_eq!(results.num_rows(), expected_results);
1743            assert_eq!(
1744                stats.all_counts.get(PARTITIONS_SEARCHED_METRIC).unwrap(),
1745                &(max_nprobes * num_deltas)
1746            );
1747            assert_eq!(
1748                stats.all_counts.get(PARTITIONS_RANKED_METRIC).unwrap(),
1749                &(100 * num_deltas)
1750            );
1751        }
1752    }
1753
1754    #[rstest]
1755    #[tokio::test]
1756    async fn test_fewer_than_k_results(#[values(1, 20)] num_deltas: usize) {
1757        let fixture = NprobesTestFixture::new(100, num_deltas).await;
1758
1759        let q = fixture.get_centroid(0);
1760        let stats_holder = StatsHolder::default();
1761        let results = fixture
1762            .dataset
1763            .scan()
1764            .nearest("vector", q.as_ref(), 50)
1765            .unwrap()
1766            .minimum_nprobes(10)
1767            .prefilter(true)
1768            .filter("userid < 20")
1769            .unwrap()
1770            .scan_stats_callback(stats_holder.get_setter())
1771            .project(&Vec::<String>::new())
1772            .unwrap()
1773            .with_row_id()
1774            .try_into_batch()
1775            .await
1776            .unwrap();
1777
1778        let stats = stats_holder.consume();
1779
1780        // We should only search minimum_nprobes before we look at the prefilter and realize
1781        // we can cheaply stop early.
1782        assert_eq!(
1783            stats.all_counts.get(PARTITIONS_SEARCHED_METRIC).unwrap(),
1784            &(10 * num_deltas)
1785        );
1786        assert_eq!(results.num_rows(), 20);
1787
1788        // 15 of the results come from beyond the closest 10 partitions and these will have infinite
1789        // distance.
1790        let num_infinite_results = results
1791            .column(0)
1792            .as_primitive::<Float32Type>()
1793            .values()
1794            .iter()
1795            .filter(|val| val.is_infinite())
1796            .count();
1797        assert_eq!(num_infinite_results, 15);
1798
1799        // If we set a refine factor then the distance should not be infinite.
1800        let results = fixture
1801            .dataset
1802            .scan()
1803            .nearest("vector", q.as_ref(), 50)
1804            .unwrap()
1805            .minimum_nprobes(10)
1806            .prefilter(true)
1807            .refine(1)
1808            .filter("userid < 20")
1809            .unwrap()
1810            .project(&Vec::<String>::new())
1811            .unwrap()
1812            .with_row_id()
1813            .try_into_batch()
1814            .await
1815            .unwrap();
1816
1817        assert_eq!(results.num_rows(), 20);
1818        let num_infinite_results = results
1819            .column(0)
1820            .as_primitive::<Float32Type>()
1821            .values()
1822            .iter()
1823            .filter(|val| val.is_infinite())
1824            .count();
1825        assert_eq!(num_infinite_results, 0);
1826    }
1827
1828    #[rstest]
1829    #[tokio::test]
1830    async fn test_dataset_too_small(#[values(1, 20)] num_deltas: usize) {
1831        let fixture = NprobesTestFixture::new(100, num_deltas).await;
1832
1833        let q = fixture.get_centroid(0);
1834        let stats_holder = StatsHolder::default();
1835        // There is no filter but we only have 10K rows.  Since maximum_nprobes is not set
1836        // we will search all partitions.
1837        let results = fixture
1838            .dataset
1839            .scan()
1840            .nearest("vector", q.as_ref(), 40000)
1841            .unwrap()
1842            .minimum_nprobes(10)
1843            .scan_stats_callback(stats_holder.get_setter())
1844            .project(&Vec::<String>::new())
1845            .unwrap()
1846            .with_row_id()
1847            .try_into_batch()
1848            .await
1849            .unwrap();
1850
1851        let stats = stats_holder.consume();
1852
1853        assert_eq!(
1854            stats.all_counts.get(PARTITIONS_SEARCHED_METRIC).unwrap(),
1855            &(100 * num_deltas)
1856        );
1857        assert_eq!(results.num_rows(), 10000);
1858    }
1859}