1use std::any::Any;
5use std::collections::{HashMap, HashSet};
6use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
7use std::sync::{Arc, LazyLock, Mutex};
8use std::time::Instant;
9
10use arrow::array::Float32Builder;
11use arrow::datatypes::{Float32Type, UInt32Type, UInt64Type};
12use arrow_array::{
13 builder::{ListBuilder, UInt32Builder},
14 cast::AsArray,
15 ArrayRef, RecordBatch, StringArray,
16};
17use arrow_array::{Array, Float32Array, UInt32Array, UInt64Array};
18use arrow_schema::{DataType, Field, Schema, SchemaRef};
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use datafusion::physical_plan::PlanProperties;
21use datafusion::physical_plan::{
22 DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream,
23 Statistics,
24};
25use datafusion::{
26 common::stats::Precision,
27 physical_plan::execution_plan::{Boundedness, EmissionType},
28};
29use datafusion::{common::ColumnStatistics, physical_plan::metrics::ExecutionPlanMetricsSet};
30use datafusion::{
31 error::{DataFusionError, Result as DataFusionResult},
32 physical_plan::metrics::MetricsSet,
33};
34use datafusion_physical_expr::{Distribution, EquivalenceProperties};
35use datafusion_physical_plan::metrics::{BaselineMetrics, Count};
36use futures::{future, stream, Stream, StreamExt, TryFutureExt, TryStreamExt};
37use itertools::Itertools;
38use lance_core::utils::futures::FinallyStreamExt;
39use lance_core::ROW_ID;
40use lance_core::{utils::tokio::get_num_compute_intensive_cpus, ROW_ID_FIELD};
41use lance_datafusion::utils::{
42 ExecutionPlanMetricsSetExt, DELTAS_SEARCHED_METRIC, PARTITIONS_RANKED_METRIC,
43 PARTITIONS_SEARCHED_METRIC,
44};
45use lance_index::prefilter::PreFilter;
46use lance_index::vector::{
47 flat::compute_distance, Query, DIST_COL, INDEX_UUID_COLUMN, PART_ID_COLUMN,
48};
49use lance_index::vector::{VectorIndex, DIST_Q_C_COLUMN};
50use lance_linalg::distance::DistanceType;
51use lance_linalg::kernels::normalize_arrow;
52use lance_table::format::IndexMetadata;
53use snafu::location;
54use tokio::sync::Notify;
55
56use crate::dataset::Dataset;
57use crate::index::prefilter::{DatasetPreFilter, FilterLoader};
58use crate::index::vector::utils::get_vector_type;
59use crate::index::DatasetIndexInternalExt;
60use crate::{Error, Result};
61use lance_arrow::*;
62
63use super::utils::{
64 FilteredRowIdsToPrefilter, IndexMetrics, InstrumentedRecordBatchStreamAdapter, PreFilterSource,
65 SelectionVectorToPrefilter,
66};
67
68pub struct AnnPartitionMetrics {
69 index_metrics: IndexMetrics,
70 partitions_ranked: Count,
71 deltas_searched: Count,
72 baseline_metrics: BaselineMetrics,
73}
74
75impl AnnPartitionMetrics {
76 pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
77 Self {
78 index_metrics: IndexMetrics::new(metrics, partition),
79 partitions_ranked: metrics.new_count(PARTITIONS_RANKED_METRIC, partition),
80 deltas_searched: metrics.new_count(DELTAS_SEARCHED_METRIC, partition),
81 baseline_metrics: BaselineMetrics::new(metrics, partition),
82 }
83 }
84}
85
86pub struct AnnIndexMetrics {
87 index_metrics: IndexMetrics,
88 partitions_searched: Count,
89 baseline_metrics: BaselineMetrics,
90}
91
92impl AnnIndexMetrics {
93 pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
94 Self {
95 index_metrics: IndexMetrics::new(metrics, partition),
96 partitions_searched: metrics.new_count(PARTITIONS_SEARCHED_METRIC, partition),
97 baseline_metrics: BaselineMetrics::new(metrics, partition),
98 }
99 }
100}
101
102#[derive(Debug)]
110pub struct KNNVectorDistanceExec {
111 pub input: Arc<dyn ExecutionPlan>,
113
114 pub query: ArrayRef,
116 pub column: String,
117 pub distance_type: DistanceType,
118
119 output_schema: SchemaRef,
120 properties: PlanProperties,
121
122 metrics: ExecutionPlanMetricsSet,
123}
124
125impl DisplayAs for KNNVectorDistanceExec {
126 fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
127 match t {
128 DisplayFormatType::Default | DisplayFormatType::Verbose => {
129 write!(f, "KNNVectorDistance: metric={}", self.distance_type,)
130 }
131 DisplayFormatType::TreeRender => {
132 write!(f, "KNNVectorDistance\nmetric={}", self.distance_type,)
133 }
134 }
135 }
136}
137
138impl KNNVectorDistanceExec {
139 pub fn try_new(
143 input: Arc<dyn ExecutionPlan>,
144 column: &str,
145 query: ArrayRef,
146 distance_type: DistanceType,
147 ) -> Result<Self> {
148 let mut output_schema = input.schema().as_ref().clone();
149 get_vector_type(&(&output_schema).try_into()?, column)?;
150
151 if output_schema.column_with_name(DIST_COL).is_some() {
155 output_schema = output_schema.without_column(DIST_COL);
156 }
157 let output_schema = Arc::new(output_schema.try_with_column(Field::new(
158 DIST_COL,
159 DataType::Float32,
160 true,
161 ))?);
162
163 let properties = input
166 .properties()
167 .clone()
168 .with_eq_properties(EquivalenceProperties::new(output_schema.clone()));
169
170 Ok(Self {
171 input,
172 query,
173 column: column.to_string(),
174 distance_type,
175 output_schema,
176 properties,
177 metrics: ExecutionPlanMetricsSet::new(),
178 })
179 }
180}
181
182impl ExecutionPlan for KNNVectorDistanceExec {
183 fn name(&self) -> &str {
184 "KNNVectorDistanceExec"
185 }
186
187 fn as_any(&self) -> &dyn Any {
188 self
189 }
190
191 fn schema(&self) -> arrow_schema::SchemaRef {
193 self.output_schema.clone()
194 }
195
196 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
197 vec![&self.input]
198 }
199
200 fn with_new_children(
201 self: Arc<Self>,
202 mut children: Vec<Arc<dyn ExecutionPlan>>,
203 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
204 if children.len() != 1 {
205 return Err(DataFusionError::Internal(
206 "KNNVectorDistanceExec node must have exactly one child".to_string(),
207 ));
208 }
209
210 Ok(Arc::new(Self::try_new(
211 children.pop().expect("length checked"),
212 &self.column,
213 self.query.clone(),
214 self.distance_type,
215 )?))
216 }
217
218 fn execute(
219 &self,
220 partition: usize,
221 context: Arc<datafusion::execution::context::TaskContext>,
222 ) -> DataFusionResult<SendableRecordBatchStream> {
223 let input_stream = self.input.execute(partition, context)?;
224 let key = self.query.clone();
225 let column = self.column.clone();
226 let dt = self.distance_type;
227 let stream = input_stream
228 .try_filter(|batch| future::ready(batch.num_rows() > 0))
229 .map(move |batch| {
230 let key = key.clone();
231 let column = column.clone();
232 async move {
233 compute_distance(key, dt, &column, batch?)
234 .await
235 .map_err(|e| DataFusionError::Execution(e.to_string()))
236 }
237 })
238 .buffer_unordered(get_num_compute_intensive_cpus());
239 let schema = self.schema();
240 Ok(Box::pin(InstrumentedRecordBatchStreamAdapter::new(
241 schema,
242 stream.boxed(),
243 partition,
244 &self.metrics,
245 )) as SendableRecordBatchStream)
246 }
247
248 fn partition_statistics(&self, partition: Option<usize>) -> DataFusionResult<Statistics> {
249 let inner_stats = self.input.partition_statistics(partition)?;
250 let schema = self.input.schema();
251 let dist_stats = inner_stats
252 .column_statistics
253 .iter()
254 .zip(schema.fields())
255 .find(|(_, field)| field.name() == &self.column)
256 .map(|(stats, _)| ColumnStatistics {
257 null_count: stats.null_count,
258 ..Default::default()
259 })
260 .unwrap_or_default();
261 let column_statistics = inner_stats
262 .column_statistics
263 .into_iter()
264 .zip(schema.fields())
265 .filter(|(_, field)| field.name() != DIST_COL)
266 .map(|(stats, _)| stats)
267 .chain(std::iter::once(dist_stats))
268 .collect::<Vec<_>>();
269 Ok(Statistics {
270 num_rows: inner_stats.num_rows,
271 column_statistics,
272 ..Statistics::new_unknown(self.schema().as_ref())
273 })
274 }
275
276 fn metrics(&self) -> Option<MetricsSet> {
277 Some(self.metrics.clone_inner())
278 }
279
280 fn properties(&self) -> &PlanProperties {
281 &self.properties
282 }
283
284 fn supports_limit_pushdown(&self) -> bool {
285 false
286 }
287}
288
289pub static KNN_INDEX_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
290 Arc::new(Schema::new(vec![
291 Field::new(DIST_COL, DataType::Float32, true),
292 ROW_ID_FIELD.clone(),
293 ]))
294});
295
296pub static KNN_PARTITION_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
297 Arc::new(Schema::new(vec![
298 Field::new(
299 PART_ID_COLUMN,
300 DataType::List(Field::new("item", DataType::UInt32, false).into()),
301 false,
302 ),
303 Field::new(
304 DIST_Q_C_COLUMN,
305 DataType::List(Field::new("item", DataType::Float32, false).into()),
306 false,
307 ),
308 Field::new(INDEX_UUID_COLUMN, DataType::Utf8, false),
309 ]))
310});
311
312pub fn new_knn_exec(
313 dataset: Arc<Dataset>,
314 indices: &[IndexMetadata],
315 query: &Query,
316 prefilter_source: PreFilterSource,
317) -> Result<Arc<dyn ExecutionPlan>> {
318 let ivf_node = ANNIvfPartitionExec::try_new(
319 dataset.clone(),
320 indices.iter().map(|idx| idx.uuid.to_string()).collect_vec(),
321 query.clone(),
322 )?;
323
324 let sub_index = ANNIvfSubIndexExec::try_new(
325 Arc::new(ivf_node),
326 dataset,
327 indices.to_vec(),
328 query.clone(),
329 prefilter_source,
330 )?;
331
332 Ok(Arc::new(sub_index))
333}
334
335#[derive(Debug)]
359pub struct ANNIvfPartitionExec {
360 pub dataset: Arc<Dataset>,
361
362 pub query: Query,
364
365 pub index_uuids: Vec<String>,
367
368 pub properties: PlanProperties,
369
370 pub metrics: ExecutionPlanMetricsSet,
371}
372
373impl ANNIvfPartitionExec {
374 pub fn try_new(dataset: Arc<Dataset>, index_uuids: Vec<String>, query: Query) -> Result<Self> {
375 let dataset_schema = dataset.schema();
376 get_vector_type(dataset_schema, &query.column)?;
377 if index_uuids.is_empty() {
378 return Err(Error::Execution {
379 message: "ANNIVFPartitionExec node: no index found for query".to_string(),
380 location: location!(),
381 });
382 }
383
384 let schema = KNN_PARTITION_SCHEMA.clone();
385 let properties = PlanProperties::new(
386 EquivalenceProperties::new(schema),
387 Partitioning::RoundRobinBatch(1),
388 EmissionType::Incremental,
389 Boundedness::Bounded,
390 );
391
392 Ok(Self {
393 dataset,
394 query,
395 index_uuids,
396 properties,
397 metrics: ExecutionPlanMetricsSet::new(),
398 })
399 }
400}
401
402impl DisplayAs for ANNIvfPartitionExec {
403 fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
404 match t {
405 DisplayFormatType::Default | DisplayFormatType::Verbose => {
406 write!(
407 f,
408 "ANNIvfPartition: uuid={}, minimum_nprobes={}, maximum_nprobes={:?}, deltas={}",
409 self.index_uuids[0],
410 self.query.minimum_nprobes,
411 self.query.maximum_nprobes,
412 self.index_uuids.len()
413 )
414 }
415 DisplayFormatType::TreeRender => {
416 write!(
417 f,
418 "ANNIvfPartition\nuuid={}\nminimum_nprobes={}\nmaximum_nprobes={:?}\ndeltas={}",
419 self.index_uuids[0],
420 self.query.minimum_nprobes,
421 self.query.maximum_nprobes,
422 self.index_uuids.len()
423 )
424 }
425 }
426 }
427}
428
429impl ExecutionPlan for ANNIvfPartitionExec {
430 fn name(&self) -> &str {
431 "ANNIVFPartitionExec"
432 }
433
434 fn as_any(&self) -> &dyn Any {
435 self
436 }
437
438 fn schema(&self) -> SchemaRef {
439 KNN_PARTITION_SCHEMA.clone()
440 }
441
442 fn statistics(&self) -> DataFusionResult<Statistics> {
443 Ok(Statistics {
444 num_rows: Precision::Exact(self.query.minimum_nprobes),
445 ..Statistics::new_unknown(self.schema().as_ref())
446 })
447 }
448
449 fn properties(&self) -> &PlanProperties {
450 &self.properties
451 }
452
453 fn metrics(&self) -> Option<MetricsSet> {
454 Some(self.metrics.clone_inner())
455 }
456
457 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
458 vec![]
459 }
460
461 fn with_new_children(
462 self: Arc<Self>,
463 children: Vec<Arc<dyn ExecutionPlan>>,
464 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
465 if !children.is_empty() {
466 Err(DataFusionError::Internal(
467 "ANNIVFPartitionExec node does not accept children".to_string(),
468 ))
469 } else {
470 Ok(self)
471 }
472 }
473
474 fn execute(
475 &self,
476 partition: usize,
477 _context: Arc<datafusion::execution::TaskContext>,
478 ) -> DataFusionResult<SendableRecordBatchStream> {
479 let timer = Instant::now();
480
481 let query = self.query.clone();
482 let ds = self.dataset.clone();
483 let metrics = Arc::new(AnnPartitionMetrics::new(&self.metrics, partition));
484 metrics.deltas_searched.add(self.index_uuids.len());
485 let metrics_clone = metrics.clone();
486
487 let stream = stream::iter(self.index_uuids.clone())
488 .map(move |uuid| {
489 let query = query.clone();
490 let ds = ds.clone();
491 let metrics = metrics.clone();
492 async move {
493 let index = ds
494 .open_vector_index(&query.column, &uuid, &metrics.index_metrics)
495 .await?;
496 let mut query = query.clone();
497 if index.metric_type() == DistanceType::Cosine {
498 let key = normalize_arrow(&query.key)?.0;
499 query.key = key;
500 };
501
502 metrics.partitions_ranked.add(index.total_partitions());
503
504 let (partitions, dist_q_c) = index.find_partitions(&query).map_err(|e| {
505 DataFusionError::Execution(format!("Failed to find partitions: {}", e))
506 })?;
507
508 let mut part_list_builder = ListBuilder::new(UInt32Builder::new())
509 .with_field(Field::new("item", DataType::UInt32, false));
510 part_list_builder.append_value(partitions.iter());
511 let partition_col = part_list_builder.finish();
512
513 let mut dist_q_c_list_builder = ListBuilder::new(Float32Builder::new())
514 .with_field(Field::new("item", DataType::Float32, false));
515 dist_q_c_list_builder.append_value(dist_q_c.iter());
516 let dist_q_c_col = dist_q_c_list_builder.finish();
517
518 let uuid_col = StringArray::from(vec![uuid.as_str()]);
519 let batch = RecordBatch::try_new(
520 KNN_PARTITION_SCHEMA.clone(),
521 vec![
522 Arc::new(partition_col),
523 Arc::new(dist_q_c_col),
524 Arc::new(uuid_col),
525 ],
526 )?;
527 metrics.baseline_metrics.record_output(batch.num_rows());
528 Ok::<_, DataFusionError>(batch)
529 }
530 })
531 .buffered(self.index_uuids.len())
532 .finally(move || {
533 metrics_clone.baseline_metrics.done();
534 metrics_clone
535 .baseline_metrics
536 .elapsed_compute()
537 .add_duration(timer.elapsed());
538 });
539 let schema = self.schema();
540 Ok(
541 Box::pin(RecordBatchStreamAdapter::new(schema, stream.boxed()))
542 as SendableRecordBatchStream,
543 )
544 }
545
546 fn supports_limit_pushdown(&self) -> bool {
547 false
548 }
549}
550
551#[derive(Debug)]
579pub struct ANNIvfSubIndexExec {
580 input: Arc<dyn ExecutionPlan>,
582
583 dataset: Arc<Dataset>,
584
585 indices: Vec<IndexMetadata>,
586
587 query: Query,
589
590 prefilter_source: PreFilterSource,
592
593 properties: PlanProperties,
595
596 metrics: ExecutionPlanMetricsSet,
597}
598
599impl ANNIvfSubIndexExec {
600 pub fn try_new(
601 input: Arc<dyn ExecutionPlan>,
602 dataset: Arc<Dataset>,
603 indices: Vec<IndexMetadata>,
604 query: Query,
605 prefilter_source: PreFilterSource,
606 ) -> Result<Self> {
607 if input.schema().field_with_name(PART_ID_COLUMN).is_err() {
608 return Err(Error::Index {
609 message: format!(
610 "ANNSubIndexExec node: input schema does not have \"{}\" column",
611 PART_ID_COLUMN
612 ),
613 location: location!(),
614 });
615 }
616 let properties = PlanProperties::new(
617 EquivalenceProperties::new(KNN_INDEX_SCHEMA.clone()),
618 Partitioning::RoundRobinBatch(1),
619 EmissionType::Final,
620 Boundedness::Bounded,
621 );
622 Ok(Self {
623 input,
624 dataset,
625 indices,
626 query,
627 prefilter_source,
628 properties,
629 metrics: ExecutionPlanMetricsSet::new(),
630 })
631 }
632}
633
634impl DisplayAs for ANNIvfSubIndexExec {
635 fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
636 match t {
637 DisplayFormatType::Default | DisplayFormatType::Verbose => {
638 write!(
639 f,
640 "ANNSubIndex: name={}, k={}, deltas={}",
641 self.indices[0].name,
642 self.query.k * self.query.refine_factor.unwrap_or(1) as usize,
643 self.indices.len()
644 )
645 }
646 DisplayFormatType::TreeRender => {
647 write!(
648 f,
649 "ANNSubIndex\nname={}\nk={}\ndeltas={}",
650 self.indices[0].name,
651 self.query.k * self.query.refine_factor.unwrap_or(1) as usize,
652 self.indices.len()
653 )
654 }
655 }
656 }
657}
658
659struct ANNIvfEarlySearchResults {
660 k: usize,
661 initial_ids: Mutex<Vec<u64>>,
662 num_results_found: AtomicUsize,
663 deltas_remaining: AtomicUsize,
664 all_deltas_done: Notify,
665 took_no_rows_shortcut: AtomicBool,
666}
667
668impl ANNIvfEarlySearchResults {
669 fn new(deltas_remaining: usize, k: usize) -> Self {
670 Self {
671 k,
672 initial_ids: Mutex::new(Vec::with_capacity(k)),
673 num_results_found: AtomicUsize::new(0),
674 deltas_remaining: AtomicUsize::new(deltas_remaining),
675 all_deltas_done: Notify::new(),
676 took_no_rows_shortcut: AtomicBool::new(false),
677 }
678 }
679
680 fn record_batch(&self, batch: &RecordBatch) {
681 let mut initial_ids = self.initial_ids.lock().unwrap();
682 let ids_to_record = (self.k - initial_ids.len()).min(batch.num_rows());
683 initial_ids.extend(
684 batch
685 .column(1)
686 .as_primitive::<UInt64Type>()
687 .values()
688 .iter()
689 .take(ids_to_record),
690 );
691 }
692
693 fn record_late_batch(&self, num_rows: usize) {
694 self.num_results_found
695 .fetch_add(num_rows, Ordering::Relaxed);
696 }
697
698 async fn wait_for_minimum_to_finish(&self) -> usize {
699 if self.deltas_remaining.fetch_sub(1, Ordering::Relaxed) == 1 {
700 {
701 let new_num_results_found = self.initial_ids.lock().unwrap().len();
702 self.num_results_found
703 .store(new_num_results_found, Ordering::Relaxed);
704 }
705 self.all_deltas_done.notify_waiters();
706 } else {
707 self.all_deltas_done.notified().await;
708 }
709 self.num_results_found.load(Ordering::Relaxed)
710 }
711}
712
713impl ANNIvfSubIndexExec {
714 fn late_search(
715 index: Arc<dyn VectorIndex>,
716 query: Query,
717 partitions: Arc<UInt32Array>,
718 q_c_dists: Arc<Float32Array>,
719 prefilter: Arc<DatasetPreFilter>,
720 metrics: Arc<AnnIndexMetrics>,
721 state: Arc<ANNIvfEarlySearchResults>,
722 ) -> impl Stream<Item = DataFusionResult<RecordBatch>> {
723 let stream = futures::stream::once(async move {
724 let max_nprobes = query.maximum_nprobes.unwrap_or(partitions.len());
725 if max_nprobes == query.minimum_nprobes {
726 return futures::stream::empty().boxed();
728 }
729
730 let found_so_far = state.wait_for_minimum_to_finish().await;
731 if found_so_far >= query.k {
732 return futures::stream::empty().boxed();
734 }
735
736 let prefilter_mask = prefilter.mask();
739
740 let max_results = prefilter_mask.max_len().map(|x| x as usize);
741
742 if let Some(max_results) = max_results {
743 if found_so_far < max_results && max_results <= query.k {
744 if let Some(iter_ids) = prefilter_mask.iter_ids() {
749 if state
752 .took_no_rows_shortcut
753 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
754 .is_ok()
755 {
756 let initial_ids = state.initial_ids.lock().unwrap();
757 let found_ids = HashSet::<_>::from_iter(initial_ids.iter().copied());
758 drop(initial_ids);
759 let mask_ids = HashSet::from_iter(iter_ids.map(u64::from));
760 let not_found_ids = mask_ids.difference(&found_ids);
761 let not_found_ids =
762 UInt64Array::from_iter_values(not_found_ids.copied());
763 let not_found_distance =
764 Float32Array::from_value(f32::INFINITY, not_found_ids.len());
765 let not_found_batch = RecordBatch::try_new(
766 KNN_INDEX_SCHEMA.clone(),
767 vec![Arc::new(not_found_distance), Arc::new(not_found_ids)],
768 )
769 .unwrap();
770 return futures::stream::once(async move { Ok(not_found_batch) })
771 .boxed();
772 } else {
773 return futures::stream::empty().boxed();
776 }
777 }
778 }
779 }
780
781 let max_results = max_results.unwrap_or(usize::MAX).min(query.k);
784
785 let state_clone = state.clone();
786
787 futures::stream::iter(query.minimum_nprobes..max_nprobes)
788 .map(move |idx| {
789 let part_id = partitions.value(idx);
790 let mut query = query.clone();
791 query.dist_q_c = q_c_dists.value(idx);
792 let metrics = metrics.clone();
793 let pre_filter = prefilter.clone();
794 let state = state.clone();
795 let index = index.clone();
796 async move {
797 let mut query = query.clone();
798 if index.metric_type() == DistanceType::Cosine {
799 let key = normalize_arrow(&query.key)?.0;
800 query.key = key;
801 };
802
803 metrics.partitions_searched.add(1);
804 let batch = index
805 .search_in_partition(
806 part_id as usize,
807 &query,
808 pre_filter,
809 &metrics.index_metrics,
810 )
811 .map_err(|e| {
812 DataFusionError::Execution(format!(
813 "Failed to calculate KNN: {}",
814 e
815 ))
816 })
817 .await?;
818 metrics.baseline_metrics.record_output(batch.num_rows());
819 state.record_late_batch(batch.num_rows());
820 Ok(batch)
821 }
822 })
823 .take_while(move |_| {
824 let found_so_far = state_clone.num_results_found.load(Ordering::Relaxed);
825 std::future::ready(found_so_far < max_results)
826 })
827 .buffered(get_num_compute_intensive_cpus())
828 .boxed()
829 });
830 stream.flatten()
831 }
832
833 fn initial_search(
834 index: Arc<dyn VectorIndex>,
835 query: Query,
836 partitions: Arc<UInt32Array>,
837 q_c_dists: Arc<Float32Array>,
838 prefilter: Arc<DatasetPreFilter>,
839 metrics: Arc<AnnIndexMetrics>,
840 state: Arc<ANNIvfEarlySearchResults>,
841 ) -> impl Stream<Item = DataFusionResult<RecordBatch>> {
842 let minimum_nprobes = query.minimum_nprobes.min(partitions.len());
843 metrics.partitions_searched.add(minimum_nprobes);
844
845 futures::stream::iter(0..minimum_nprobes)
846 .map(move |idx| {
847 let part_id = partitions.value(idx);
848 let mut query = query.clone();
849 query.dist_q_c = q_c_dists.value(idx);
850 let metrics = metrics.clone();
851 let index = index.clone();
852 let pre_filter = prefilter.clone();
853 let state = state.clone();
854 async move {
855 let mut query = query.clone();
856 if index.metric_type() == DistanceType::Cosine {
857 let key = normalize_arrow(&query.key)?.0;
858 query.key = key;
859 };
860
861 let batch = index
862 .search_in_partition(
863 part_id as usize,
864 &query,
865 pre_filter,
866 &metrics.index_metrics,
867 )
868 .map_err(|e| {
869 DataFusionError::Execution(format!("Failed to calculate KNN: {}", e))
870 })
871 .await?;
872 metrics.baseline_metrics.record_output(batch.num_rows());
873 state.record_batch(&batch);
874 Ok(batch)
875 }
876 })
877 .buffered(get_num_compute_intensive_cpus())
878 }
879}
880
881impl ExecutionPlan for ANNIvfSubIndexExec {
882 fn name(&self) -> &str {
883 "ANNSubIndexExec"
884 }
885
886 fn as_any(&self) -> &dyn Any {
887 self
888 }
889
890 fn schema(&self) -> arrow_schema::SchemaRef {
891 KNN_INDEX_SCHEMA.clone()
892 }
893
894 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
895 match &self.prefilter_source {
896 PreFilterSource::None => vec![&self.input],
897 PreFilterSource::FilteredRowIds(src) => vec![&self.input, &src],
898 PreFilterSource::ScalarIndexQuery(src) => vec![&self.input, &src],
899 }
900 }
901
902 fn required_input_distribution(&self) -> Vec<Distribution> {
903 self.children()
905 .iter()
906 .map(|_| Distribution::SinglePartition)
907 .collect()
908 }
909
910 fn with_new_children(
911 self: Arc<Self>,
912 mut children: Vec<Arc<dyn ExecutionPlan>>,
913 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
914 let plan = if children.len() == 1 || children.len() == 2 {
915 let prefilter_source = if children.len() == 2 {
916 let prefilter = children.pop().expect("length checked");
917 match &self.prefilter_source {
918 PreFilterSource::None => PreFilterSource::None,
919 PreFilterSource::FilteredRowIds(_) => {
920 PreFilterSource::FilteredRowIds(prefilter)
921 }
922 PreFilterSource::ScalarIndexQuery(_) => {
923 PreFilterSource::ScalarIndexQuery(prefilter)
924 }
925 }
926 } else {
927 self.prefilter_source.clone()
928 };
929
930 Self {
931 input: children.pop().expect("length checked"),
932 dataset: self.dataset.clone(),
933 indices: self.indices.clone(),
934 query: self.query.clone(),
935 prefilter_source,
936 properties: self.properties.clone(),
937 metrics: ExecutionPlanMetricsSet::new(),
938 }
939 } else {
940 return Err(DataFusionError::Internal(
941 "ANNSubIndexExec node must have exactly one or two (prefilter) child".to_string(),
942 ));
943 };
944 Ok(Arc::new(plan))
945 }
946
947 fn execute(
948 &self,
949 partition: usize,
950 context: Arc<datafusion::execution::context::TaskContext>,
951 ) -> DataFusionResult<datafusion::physical_plan::SendableRecordBatchStream> {
952 let input_stream = self.input.execute(partition, context.clone())?;
953 let schema = self.schema();
954 let query = self.query.clone();
955 let ds = self.dataset.clone();
956 let column = self.query.column.clone();
957 let indices = self.indices.clone();
958 let prefilter_source = self.prefilter_source.clone();
959 let metrics = Arc::new(AnnIndexMetrics::new(&self.metrics, partition));
960 let metrics_clone = metrics.clone();
961 let timer = Instant::now();
962
963 let per_index_stream = input_stream
966 .and_then(move |batch| {
967 let part_id_col = batch.column_by_name(PART_ID_COLUMN).unwrap_or_else(|| {
968 panic!("ANNSubIndexExec: input missing {} column", PART_ID_COLUMN)
969 });
970 let part_id_arr = part_id_col.as_list::<i32>().clone();
971 let dist_q_c_col = batch.column_by_name(DIST_Q_C_COLUMN).unwrap_or_else(|| {
972 panic!("ANNSubIndexExec: input missing {} column", DIST_Q_C_COLUMN)
973 });
974 let dist_q_c_arr = dist_q_c_col.as_list::<i32>().clone();
975 let index_uuid_col = batch.column_by_name(INDEX_UUID_COLUMN).unwrap_or_else(|| {
976 panic!(
977 "ANNSubIndexExec: input missing {} column",
978 INDEX_UUID_COLUMN
979 )
980 });
981 let index_uuid = index_uuid_col.as_string::<i32>().clone();
982
983 let plan: Vec<DataFusionResult<(_, _, _)>> = part_id_arr
984 .iter()
985 .zip(dist_q_c_arr.iter())
986 .zip(index_uuid.iter())
987 .map(|((part_id, dist_q_c), uuid)| {
988 let partitions =
989 Arc::new(part_id.unwrap().as_primitive::<UInt32Type>().clone());
990 let dist_q_c =
991 Arc::new(dist_q_c.unwrap().as_primitive::<Float32Type>().clone());
992 let uuid = uuid.unwrap().to_string();
993 Ok((partitions, dist_q_c, uuid))
994 })
995 .collect_vec();
996 async move { DataFusionResult::Ok(stream::iter(plan)) }
997 })
998 .try_flatten();
999 let prefilter_loader = match &prefilter_source {
1000 PreFilterSource::FilteredRowIds(src_node) => {
1001 let stream = src_node.execute(partition, context)?;
1002 Some(Box::new(FilteredRowIdsToPrefilter(stream)) as Box<dyn FilterLoader>)
1003 }
1004 PreFilterSource::ScalarIndexQuery(src_node) => {
1005 let stream = src_node.execute(partition, context)?;
1006 Some(Box::new(SelectionVectorToPrefilter(stream)) as Box<dyn FilterLoader>)
1007 }
1008 PreFilterSource::None => None,
1009 };
1010
1011 let pre_filter = Arc::new(DatasetPreFilter::new(
1012 ds.clone(),
1013 &indices,
1014 prefilter_loader,
1015 ));
1016
1017 let state = Arc::new(ANNIvfEarlySearchResults::new(indices.len(), query.k));
1018
1019 Ok(Box::pin(RecordBatchStreamAdapter::new(
1020 schema,
1021 per_index_stream
1022 .and_then(move |(part_ids, q_c_dists, index_uuid)| {
1023 let ds = ds.clone();
1024 let column = column.clone();
1025 let metrics = metrics.clone();
1026 let pre_filter = pre_filter.clone();
1027 let state = state.clone();
1028 let query = query.clone();
1029
1030 async move {
1031 let raw_index = ds
1032 .open_vector_index(&column, &index_uuid, &metrics.index_metrics)
1033 .await?;
1034
1035 let early_search = Self::initial_search(
1036 raw_index.clone(),
1037 query.clone(),
1038 part_ids.clone(),
1039 q_c_dists.clone(),
1040 pre_filter.clone(),
1041 metrics.clone(),
1042 state.clone(),
1043 );
1044 let late_search = Self::late_search(
1045 raw_index.clone(),
1046 query,
1047 part_ids,
1048 q_c_dists,
1049 pre_filter,
1050 metrics,
1051 state,
1052 );
1053 DataFusionResult::Ok(early_search.chain(late_search).boxed())
1054 }
1055 })
1056 .try_flatten_unordered(None)
1060 .finally(move || {
1061 metrics_clone
1062 .baseline_metrics
1063 .elapsed_compute()
1064 .add_duration(timer.elapsed());
1065 metrics_clone.baseline_metrics.done();
1066 })
1067 .boxed(),
1068 )))
1069 }
1070
1071 fn partition_statistics(
1072 &self,
1073 partition: Option<usize>,
1074 ) -> DataFusionResult<datafusion::physical_plan::Statistics> {
1075 Ok(Statistics {
1076 num_rows: Precision::Exact(
1077 self.query.k
1078 * self.query.refine_factor.unwrap_or(1) as usize
1079 * self
1080 .input
1081 .partition_statistics(partition)?
1082 .num_rows
1083 .get_value()
1084 .unwrap_or(&1),
1085 ),
1086 ..Statistics::new_unknown(self.schema().as_ref())
1087 })
1088 }
1089
1090 fn metrics(&self) -> Option<MetricsSet> {
1091 Some(self.metrics.clone_inner())
1092 }
1093
1094 fn properties(&self) -> &PlanProperties {
1095 &self.properties
1096 }
1097
1098 fn supports_limit_pushdown(&self) -> bool {
1099 false
1100 }
1101}
1102
1103#[derive(Debug)]
1104pub struct MultivectorScoringExec {
1105 inputs: Vec<Arc<dyn ExecutionPlan>>,
1107 query: Query,
1108 properties: PlanProperties,
1109}
1110
1111impl MultivectorScoringExec {
1112 pub fn try_new(inputs: Vec<Arc<dyn ExecutionPlan>>, query: Query) -> Result<Self> {
1113 let properties = PlanProperties::new(
1114 EquivalenceProperties::new(KNN_INDEX_SCHEMA.clone()),
1115 Partitioning::RoundRobinBatch(1),
1116 EmissionType::Final,
1117 Boundedness::Bounded,
1118 );
1119
1120 Ok(Self {
1121 inputs,
1122 query,
1123 properties,
1124 })
1125 }
1126}
1127
1128impl DisplayAs for MultivectorScoringExec {
1129 fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1130 match t {
1131 DisplayFormatType::Default | DisplayFormatType::Verbose => {
1132 write!(f, "MultivectorScoring: k={}", self.query.k)
1133 }
1134 DisplayFormatType::TreeRender => {
1135 write!(f, "MultivectorScoring\nk={}", self.query.k)
1136 }
1137 }
1138 }
1139}
1140
1141impl ExecutionPlan for MultivectorScoringExec {
1142 fn name(&self) -> &str {
1143 "MultivectorScoringExec"
1144 }
1145
1146 fn as_any(&self) -> &dyn Any {
1147 self
1148 }
1149
1150 fn schema(&self) -> arrow_schema::SchemaRef {
1151 KNN_INDEX_SCHEMA.clone()
1152 }
1153
1154 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1155 self.inputs.iter().collect()
1156 }
1157
1158 fn required_input_distribution(&self) -> Vec<Distribution> {
1159 self.children()
1162 .iter()
1163 .map(|_| Distribution::SinglePartition)
1164 .collect()
1165 }
1166
1167 fn with_new_children(
1168 self: Arc<Self>,
1169 children: Vec<Arc<dyn ExecutionPlan>>,
1170 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
1171 let plan = Self::try_new(children, self.query.clone())?;
1172 Ok(Arc::new(plan))
1173 }
1174
1175 fn execute(
1176 &self,
1177 partition: usize,
1178 context: Arc<datafusion::execution::context::TaskContext>,
1179 ) -> DataFusionResult<SendableRecordBatchStream> {
1180 let inputs = self
1181 .inputs
1182 .iter()
1183 .map(|input| input.execute(partition, context.clone()))
1184 .collect::<DataFusionResult<Vec<_>>>()?;
1185
1186 let mut reduced_inputs = stream::select_all(inputs.into_iter().map(|stream| {
1190 stream.map(|batch| {
1191 let batch = batch?;
1192 let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>();
1193 let dists = batch[DIST_COL].as_primitive::<Float32Type>();
1194 debug_assert_eq!(dists.null_count(), 0);
1195
1196 let min_sim = dists
1198 .values()
1199 .last()
1200 .map(|dist| 1.0 - *dist)
1201 .unwrap_or_default();
1202 let mut new_row_ids = Vec::with_capacity(row_ids.len());
1203 let mut new_sims = Vec::with_capacity(row_ids.len());
1204 let mut visited_row_ids = HashSet::with_capacity(row_ids.len());
1205
1206 for (row_id, dist) in row_ids.values().iter().zip(dists.values().iter()) {
1207 if visited_row_ids.contains(row_id) {
1209 continue;
1210 }
1211 visited_row_ids.insert(row_id);
1212 new_row_ids.push(*row_id);
1213 new_sims.push(1.0 - *dist);
1215 }
1216 let new_row_ids = UInt64Array::from(new_row_ids);
1217 let new_dists = Float32Array::from(new_sims);
1218
1219 let batch = RecordBatch::try_new(
1220 KNN_INDEX_SCHEMA.clone(),
1221 vec![Arc::new(new_dists), Arc::new(new_row_ids)],
1222 )?;
1223
1224 Ok::<_, DataFusionError>((min_sim, batch))
1225 })
1226 }));
1227
1228 let k = self.query.k;
1229 let refactor = self.query.refine_factor.unwrap_or(1) as usize;
1230 let num_queries = self.inputs.len() as f32;
1231 let stream = stream::once(async move {
1232 let mut results = HashMap::with_capacity(k * refactor);
1234 let mut missed_sim_sum = 0.0;
1235 while let Some((min_sim, batch)) = reduced_inputs.try_next().await? {
1236 let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>();
1237 let sims = batch[DIST_COL].as_primitive::<Float32Type>();
1238
1239 let query_results = row_ids
1240 .values()
1241 .iter()
1242 .copied()
1243 .zip(sims.values().iter().copied())
1244 .collect::<HashMap<_, _>>();
1245
1246 results.iter_mut().for_each(|(row_id, sim)| {
1252 if let Some(new_dist) = query_results.get(row_id) {
1253 *sim += new_dist;
1254 } else {
1255 *sim += min_sim;
1256 }
1257 });
1258 query_results.into_iter().for_each(|(row_id, sim)| {
1259 results.entry(row_id).or_insert(sim + missed_sim_sum);
1260 });
1261 missed_sim_sum += min_sim;
1262 }
1263
1264 let (row_ids, sims): (Vec<_>, Vec<_>) = results.into_iter().unzip();
1265 let dists = sims
1266 .into_iter()
1267 .map(|sim| num_queries - sim)
1269 .collect::<Vec<_>>();
1270 let row_ids = UInt64Array::from(row_ids);
1271 let dists = Float32Array::from(dists);
1272 let batch = RecordBatch::try_new(
1273 KNN_INDEX_SCHEMA.clone(),
1274 vec![Arc::new(dists), Arc::new(row_ids)],
1275 )?;
1276 Ok::<_, DataFusionError>(batch)
1277 });
1278 Ok(Box::pin(RecordBatchStreamAdapter::new(
1279 self.schema(),
1280 stream.boxed(),
1281 )))
1282 }
1283
1284 fn statistics(&self) -> DataFusionResult<Statistics> {
1285 Ok(Statistics {
1286 num_rows: Precision::Inexact(
1287 self.query.k * self.query.refine_factor.unwrap_or(1) as usize,
1288 ),
1289 ..Statistics::new_unknown(self.schema().as_ref())
1290 })
1291 }
1292
1293 fn properties(&self) -> &PlanProperties {
1294 &self.properties
1295 }
1296
1297 fn supports_limit_pushdown(&self) -> bool {
1298 false
1299 }
1300}
1301
1302#[cfg(test)]
1303mod tests {
1304 use super::*;
1305
1306 use arrow::compute::{concat_batches, sort_to_indices, take_record_batch};
1307 use arrow::datatypes::Float32Type;
1308 use arrow_array::{FixedSizeListArray, Int32Array, RecordBatchIterator, StringArray};
1309 use arrow_schema::{Field as ArrowField, Schema as ArrowSchema};
1310 use lance_core::utils::tempfile::TempStrDir;
1311 use lance_datafusion::exec::{ExecutionStatsCallback, ExecutionSummaryCounts};
1312 use lance_datagen::{array, BatchCount, RowCount};
1313 use lance_index::optimize::OptimizeOptions;
1314 use lance_index::vector::ivf::IvfBuildParams;
1315 use lance_index::vector::pq::PQBuildParams;
1316 use lance_index::{DatasetIndexExt, IndexType};
1317 use lance_linalg::distance::MetricType;
1318 use lance_testing::datagen::generate_random_array;
1319 use rstest::rstest;
1320
1321 use crate::dataset::{WriteMode, WriteParams};
1322 use crate::index::vector::VectorIndexParams;
1323 use crate::io::exec::testing::TestingExec;
1324
1325 #[tokio::test]
1326 async fn knn_flat_search() {
1327 let schema = Arc::new(ArrowSchema::new(vec![
1328 ArrowField::new("key", DataType::Int32, false),
1329 ArrowField::new(
1330 "vector",
1331 DataType::FixedSizeList(
1332 Arc::new(ArrowField::new("item", DataType::Float32, true)),
1333 128,
1334 ),
1335 true,
1336 ),
1337 ArrowField::new("uri", DataType::Utf8, true),
1338 ]));
1339
1340 let batches: Vec<RecordBatch> = (0..20)
1341 .map(|i| {
1342 RecordBatch::try_new(
1343 schema.clone(),
1344 vec![
1345 Arc::new(Int32Array::from_iter_values(i * 20..(i + 1) * 20)),
1346 Arc::new(
1347 FixedSizeListArray::try_new_from_values(
1348 generate_random_array(128 * 20),
1349 128,
1350 )
1351 .unwrap(),
1352 ),
1353 Arc::new(StringArray::from_iter_values(
1354 (i * 20..(i + 1) * 20).map(|i| format!("s3://bucket/file-{}", i)),
1355 )),
1356 ],
1357 )
1358 .unwrap()
1359 })
1360 .collect();
1361
1362 let test_dir = TempStrDir::default();
1363 let test_uri = test_dir.as_str();
1364
1365 let write_params = WriteParams {
1366 max_rows_per_file: 40,
1367 max_rows_per_group: 10,
1368 ..Default::default()
1369 };
1370 let vector_arr = batches[0].column_by_name("vector").unwrap();
1371 let q = as_fixed_size_list_array(&vector_arr).value(5);
1372
1373 let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());
1374 Dataset::write(reader, test_uri, Some(write_params))
1375 .await
1376 .unwrap();
1377
1378 let dataset = Dataset::open(test_uri).await.unwrap();
1379 let stream = dataset
1380 .scan()
1381 .nearest("vector", q.as_primitive::<Float32Type>(), 10)
1382 .unwrap()
1383 .try_into_stream()
1384 .await
1385 .unwrap();
1386 let results = stream.try_collect::<Vec<_>>().await.unwrap();
1387
1388 assert!(results[0].schema().column_with_name(DIST_COL).is_some());
1389
1390 assert_eq!(results.len(), 1);
1391
1392 let stream = dataset.scan().try_into_stream().await.unwrap();
1393 let all_with_distances = stream
1394 .and_then(|batch| compute_distance(q.clone(), DistanceType::L2, "vector", batch))
1395 .try_collect::<Vec<_>>()
1396 .await
1397 .unwrap();
1398 let all_with_distances =
1399 concat_batches(&results[0].schema(), all_with_distances.iter()).unwrap();
1400 let dist_arr = all_with_distances.column_by_name(DIST_COL).unwrap();
1401 let distances = dist_arr.as_primitive::<Float32Type>();
1402 let indices = sort_to_indices(distances, None, Some(10)).unwrap();
1403 let expected = take_record_batch(&all_with_distances, &indices).unwrap();
1404 assert_eq!(expected, results[0]);
1405 }
1406
1407 #[test]
1408 fn test_create_knn_flat() {
1409 let dim: usize = 128;
1410 let schema = Arc::new(ArrowSchema::new(vec![
1411 ArrowField::new("key", DataType::Int32, false),
1412 ArrowField::new(
1413 "vector",
1414 DataType::FixedSizeList(
1415 Arc::new(ArrowField::new("item", DataType::Float32, true)),
1416 dim as i32,
1417 ),
1418 true,
1419 ),
1420 ArrowField::new("uri", DataType::Utf8, true),
1421 ]));
1422 let batch = RecordBatch::new_empty(schema);
1423
1424 let input: Arc<dyn ExecutionPlan> = Arc::new(TestingExec::new(vec![batch]));
1425
1426 let idx = KNNVectorDistanceExec::try_new(
1427 input,
1428 "vector",
1429 Arc::new(generate_random_array(dim)),
1430 DistanceType::L2,
1431 )
1432 .unwrap();
1433 assert_eq!(
1434 idx.schema().as_ref(),
1435 &ArrowSchema::new(vec![
1436 ArrowField::new("key", DataType::Int32, false),
1437 ArrowField::new(
1438 "vector",
1439 DataType::FixedSizeList(
1440 Arc::new(ArrowField::new("item", DataType::Float32, true)),
1441 dim as i32,
1442 ),
1443 true,
1444 ),
1445 ArrowField::new("uri", DataType::Utf8, true),
1446 ArrowField::new(DIST_COL, DataType::Float32, true),
1447 ])
1448 );
1449 }
1450
1451 #[tokio::test]
1452 async fn test_multivector_score() {
1453 let query = Query {
1454 column: "vector".to_string(),
1455 key: Arc::new(generate_random_array(1)),
1456 k: 10,
1457 lower_bound: None,
1458 upper_bound: None,
1459 minimum_nprobes: 1,
1460 maximum_nprobes: None,
1461 ef: None,
1462 refine_factor: None,
1463 metric_type: DistanceType::Cosine,
1464 use_index: true,
1465 dist_q_c: 0.0,
1466 };
1467
1468 async fn multivector_scoring(
1469 inputs: Vec<Arc<dyn ExecutionPlan>>,
1470 query: Query,
1471 ) -> Result<HashMap<u64, f32>> {
1472 let ctx = Arc::new(datafusion::execution::context::TaskContext::default());
1473 let plan = MultivectorScoringExec::try_new(inputs, query.clone())?;
1474 let batches = plan
1475 .execute(0, ctx.clone())
1476 .unwrap()
1477 .try_collect::<Vec<_>>()
1478 .await?;
1479 let mut results = HashMap::new();
1480 for batch in batches {
1481 let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>();
1482 let dists = batch[DIST_COL].as_primitive::<Float32Type>();
1483 for (row_id, dist) in row_ids.values().iter().zip(dists.values().iter()) {
1484 results.insert(*row_id, *dist);
1485 }
1486 }
1487 Ok(results)
1488 }
1489
1490 let batches = (0..3)
1491 .map(|i| {
1492 RecordBatch::try_new(
1493 KNN_INDEX_SCHEMA.clone(),
1494 vec![
1495 Arc::new(Float32Array::from(vec![i as f32 + 1.0, i as f32 + 2.0])),
1496 Arc::new(UInt64Array::from(vec![i + 1, i + 2])),
1497 ],
1498 )
1499 .unwrap()
1500 })
1501 .collect::<Vec<_>>();
1502
1503 let mut res: Option<HashMap<_, _>> = None;
1504 for perm in batches.into_iter().permutations(3) {
1505 let inputs = perm
1506 .into_iter()
1507 .map(|batch| {
1508 let input: Arc<dyn ExecutionPlan> = Arc::new(TestingExec::new(vec![batch]));
1509 input
1510 })
1511 .collect::<Vec<_>>();
1512 let new_res = multivector_scoring(inputs, query.clone()).await.unwrap();
1513 assert_eq!(new_res.len(), 4);
1514 if let Some(res) = &res {
1515 for (row_id, dist) in new_res.iter() {
1516 assert_eq!(res.get(row_id).unwrap(), dist)
1517 }
1518 } else {
1519 res = Some(new_res);
1520 }
1521 }
1522 }
1523
1524 struct NprobesTestFixture {
1529 dataset: Dataset,
1530 centroids: Arc<dyn Array>,
1531 _tmp_dir: TempStrDir,
1532 }
1533
1534 impl NprobesTestFixture {
1535 pub async fn new(num_centroids: usize, num_deltas: usize) -> Self {
1536 let tempdir = TempStrDir::default();
1537 let tmppath = tempdir.as_str();
1538
1539 let centroids = array::cycle_unit_circle(num_centroids as u32)
1544 .generate_default(RowCount::from(num_centroids as u64))
1545 .unwrap();
1546
1547 assert!(100 % num_deltas == 0, "num_deltas must divide 100");
1549 let rows_per_frag = 100;
1550 let num_frags = 100;
1551 let frags_per_delta = num_frags / num_deltas;
1552
1553 let batches = lance_datagen::gen_batch()
1554 .col("vector", array::jitter_centroids(centroids.clone(), 0.0001))
1555 .col("label", array::cycle::<UInt32Type>(Vec::from_iter(0..61)))
1556 .col("userid", array::step::<UInt64Type>())
1557 .into_reader_rows(
1558 RowCount::from(rows_per_frag),
1559 BatchCount::from(num_frags as u32),
1560 )
1561 .collect::<Vec<_>>();
1562 let schema = batches[0].as_ref().unwrap().schema();
1563
1564 let mut first = true;
1565 for batches in batches.chunks(frags_per_delta) {
1566 let delta_batches = batches
1567 .iter()
1568 .map(|maybe_batch| Ok(maybe_batch.as_ref().unwrap().clone()))
1569 .collect::<Vec<_>>();
1570 let reader = RecordBatchIterator::new(delta_batches, schema.clone());
1571 let mut dataset = Dataset::write(
1572 reader,
1573 tmppath,
1574 Some(WriteParams {
1575 mode: WriteMode::Append,
1576 ..Default::default()
1577 }),
1578 )
1579 .await
1580 .unwrap();
1581
1582 let ivf_params = IvfBuildParams::try_with_centroids(
1583 num_centroids,
1584 Arc::new(centroids.as_fixed_size_list().clone()),
1585 )
1586 .unwrap();
1587
1588 let codebook = array::rand::<Float32Type>()
1589 .generate_default(RowCount::from(256 * 2))
1590 .unwrap();
1591 let pq_params = PQBuildParams::with_codebook(2, 8, codebook);
1592 let index_params =
1593 VectorIndexParams::with_ivf_pq_params(MetricType::L2, ivf_params, pq_params);
1594
1595 if first {
1596 first = false;
1597 dataset
1598 .create_index(&["vector"], IndexType::Vector, None, &index_params, false)
1599 .await
1600 .unwrap();
1601 } else {
1602 dataset
1603 .optimize_indices(&OptimizeOptions::append())
1604 .await
1605 .unwrap();
1606 }
1607 }
1608
1609 let dataset = Dataset::open(tmppath).await.unwrap();
1610 Self {
1611 dataset,
1612 centroids,
1613 _tmp_dir: tempdir,
1614 }
1615 }
1616
1617 pub fn get_centroid(&self, idx: usize) -> Arc<dyn Array> {
1618 let centroids = self.centroids.as_fixed_size_list();
1619 centroids.value(idx).clone()
1620 }
1621 }
1622
1623 #[derive(Default)]
1624 struct StatsHolder {
1625 pub collected_stats: Arc<Mutex<Option<ExecutionSummaryCounts>>>,
1626 }
1627
1628 impl StatsHolder {
1629 fn get_setter(&self) -> ExecutionStatsCallback {
1630 let collected_stats = self.collected_stats.clone();
1631 Arc::new(move |stats| {
1632 *collected_stats.lock().unwrap() = Some(stats.clone());
1633 })
1634 }
1635
1636 fn consume(self) -> ExecutionSummaryCounts {
1637 self.collected_stats.lock().unwrap().take().unwrap()
1638 }
1639 }
1640
1641 #[rstest]
1642 #[tokio::test]
1643 async fn test_no_max_nprobes(#[values(1, 20)] num_deltas: usize) {
1644 let fixture = NprobesTestFixture::new(100, num_deltas).await;
1645
1646 let q = fixture.get_centroid(0);
1647 let stats_holder = StatsHolder::default();
1648
1649 let results = fixture
1650 .dataset
1651 .scan()
1652 .nearest("vector", q.as_ref(), 50)
1653 .unwrap()
1654 .minimum_nprobes(10)
1655 .prefilter(true)
1656 .scan_stats_callback(stats_holder.get_setter())
1657 .filter("label = 17")
1658 .unwrap()
1659 .project(&Vec::<String>::new())
1660 .unwrap()
1661 .with_row_id()
1662 .try_into_batch()
1663 .await
1664 .unwrap();
1665
1666 assert_eq!(results.num_rows(), 50);
1667
1668 let stats = stats_holder.consume();
1669
1670 if get_num_compute_intensive_cpus() <= 32 {
1674 assert!(*stats.all_counts.get(PARTITIONS_SEARCHED_METRIC).unwrap() < 100 * num_deltas);
1675 }
1676 }
1677
1678 #[rstest]
1679 #[tokio::test]
1680 async fn test_no_prefilter_results(#[values(1, 20)] num_deltas: usize) {
1681 let fixture = NprobesTestFixture::new(100, num_deltas).await;
1682
1683 let q = fixture.get_centroid(0);
1684 let stats_holder = StatsHolder::default();
1685
1686 let results = fixture
1687 .dataset
1688 .scan()
1689 .nearest("vector", q.as_ref(), 50)
1690 .unwrap()
1691 .minimum_nprobes(10)
1692 .prefilter(true)
1693 .scan_stats_callback(stats_holder.get_setter())
1694 .filter("label = 17 AND label = 18")
1695 .unwrap()
1696 .project(&Vec::<String>::new())
1697 .unwrap()
1698 .with_row_id()
1699 .try_into_batch()
1700 .await
1701 .unwrap();
1702
1703 assert_eq!(results.num_rows(), 0);
1704
1705 let stats = stats_holder.consume();
1706 assert_eq!(
1709 stats.all_counts.get(PARTITIONS_SEARCHED_METRIC).unwrap(),
1710 &(10 * num_deltas)
1711 );
1712 }
1713
1714 #[rstest]
1715 #[tokio::test]
1716 async fn test_some_max_nprobes(#[values(1, 20)] num_deltas: usize) {
1717 let fixture = NprobesTestFixture::new(100, num_deltas).await;
1718
1719 for (max_nprobes, expected_results) in [(10, 16), (20, 33), (30, 48)] {
1720 let q = fixture.get_centroid(0);
1721 let stats_holder = StatsHolder::default();
1722 let results = fixture
1723 .dataset
1724 .scan()
1725 .nearest("vector", q.as_ref(), 50)
1726 .unwrap()
1727 .minimum_nprobes(10)
1728 .maximum_nprobes(max_nprobes)
1729 .prefilter(true)
1730 .filter("label = 17")
1731 .unwrap()
1732 .scan_stats_callback(stats_holder.get_setter())
1733 .project(&Vec::<String>::new())
1734 .unwrap()
1735 .with_row_id()
1736 .try_into_batch()
1737 .await
1738 .unwrap();
1739
1740 let stats = stats_holder.consume();
1741
1742 assert_eq!(results.num_rows(), expected_results);
1743 assert_eq!(
1744 stats.all_counts.get(PARTITIONS_SEARCHED_METRIC).unwrap(),
1745 &(max_nprobes * num_deltas)
1746 );
1747 assert_eq!(
1748 stats.all_counts.get(PARTITIONS_RANKED_METRIC).unwrap(),
1749 &(100 * num_deltas)
1750 );
1751 }
1752 }
1753
1754 #[rstest]
1755 #[tokio::test]
1756 async fn test_fewer_than_k_results(#[values(1, 20)] num_deltas: usize) {
1757 let fixture = NprobesTestFixture::new(100, num_deltas).await;
1758
1759 let q = fixture.get_centroid(0);
1760 let stats_holder = StatsHolder::default();
1761 let results = fixture
1762 .dataset
1763 .scan()
1764 .nearest("vector", q.as_ref(), 50)
1765 .unwrap()
1766 .minimum_nprobes(10)
1767 .prefilter(true)
1768 .filter("userid < 20")
1769 .unwrap()
1770 .scan_stats_callback(stats_holder.get_setter())
1771 .project(&Vec::<String>::new())
1772 .unwrap()
1773 .with_row_id()
1774 .try_into_batch()
1775 .await
1776 .unwrap();
1777
1778 let stats = stats_holder.consume();
1779
1780 assert_eq!(
1783 stats.all_counts.get(PARTITIONS_SEARCHED_METRIC).unwrap(),
1784 &(10 * num_deltas)
1785 );
1786 assert_eq!(results.num_rows(), 20);
1787
1788 let num_infinite_results = results
1791 .column(0)
1792 .as_primitive::<Float32Type>()
1793 .values()
1794 .iter()
1795 .filter(|val| val.is_infinite())
1796 .count();
1797 assert_eq!(num_infinite_results, 15);
1798
1799 let results = fixture
1801 .dataset
1802 .scan()
1803 .nearest("vector", q.as_ref(), 50)
1804 .unwrap()
1805 .minimum_nprobes(10)
1806 .prefilter(true)
1807 .refine(1)
1808 .filter("userid < 20")
1809 .unwrap()
1810 .project(&Vec::<String>::new())
1811 .unwrap()
1812 .with_row_id()
1813 .try_into_batch()
1814 .await
1815 .unwrap();
1816
1817 assert_eq!(results.num_rows(), 20);
1818 let num_infinite_results = results
1819 .column(0)
1820 .as_primitive::<Float32Type>()
1821 .values()
1822 .iter()
1823 .filter(|val| val.is_infinite())
1824 .count();
1825 assert_eq!(num_infinite_results, 0);
1826 }
1827
1828 #[rstest]
1829 #[tokio::test]
1830 async fn test_dataset_too_small(#[values(1, 20)] num_deltas: usize) {
1831 let fixture = NprobesTestFixture::new(100, num_deltas).await;
1832
1833 let q = fixture.get_centroid(0);
1834 let stats_holder = StatsHolder::default();
1835 let results = fixture
1838 .dataset
1839 .scan()
1840 .nearest("vector", q.as_ref(), 40000)
1841 .unwrap()
1842 .minimum_nprobes(10)
1843 .scan_stats_callback(stats_holder.get_setter())
1844 .project(&Vec::<String>::new())
1845 .unwrap()
1846 .with_row_id()
1847 .try_into_batch()
1848 .await
1849 .unwrap();
1850
1851 let stats = stats_holder.consume();
1852
1853 assert_eq!(
1854 stats.all_counts.get(PARTITIONS_SEARCHED_METRIC).unwrap(),
1855 &(100 * num_deltas)
1856 );
1857 assert_eq!(results.num_rows(), 10000);
1858 }
1859}