use std::any::Any;
use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, LazyLock, Mutex};
use std::time::Instant;
use arrow::array::Float32Builder;
use arrow::datatypes::{Float32Type, UInt32Type, UInt64Type};
use arrow_array::{Array, Float32Array, UInt32Array, UInt64Array};
use arrow_array::{
ArrayRef, BooleanArray, RecordBatch, StringArray,
builder::{ListBuilder, UInt32Builder},
cast::AsArray,
};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion::physical_plan::PlanProperties;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream,
Statistics,
};
use datafusion::{common::ColumnStatistics, physical_plan::metrics::ExecutionPlanMetricsSet};
use datafusion::{
common::stats::Precision,
physical_plan::execution_plan::{Boundedness, EmissionType},
};
use datafusion::{
error::{DataFusionError, Result as DataFusionResult},
physical_plan::metrics::MetricsSet,
};
use datafusion_physical_expr::{Distribution, EquivalenceProperties};
use datafusion_physical_plan::metrics::{BaselineMetrics, Count, Time};
use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt, future, stream};
use itertools::Itertools;
use lance_core::ROW_ID;
use lance_core::utils::futures::FinallyStreamExt;
use lance_core::{
ROW_ID_FIELD,
utils::tokio::{get_num_compute_intensive_cpus, spawn_cpu},
};
use lance_datafusion::utils::{
DELTAS_SEARCHED_METRIC, ExecutionPlanMetricsSetExt, FIND_PARTITIONS_ELAPSED_METRIC,
PARTITIONS_RANKED_METRIC, PARTITIONS_SEARCHED_METRIC,
};
use lance_index::metrics::MetricsCollector;
use lance_index::prefilter::PreFilter;
use lance_index::vector::DIST_Q_C_COLUMN;
use lance_index::vector::{
DIST_COL, INDEX_UUID_COLUMN, PART_ID_COLUMN, PartitionSearchControl, Query, VectorIndex,
flat::compute_distance,
};
use lance_linalg::distance::DistanceType;
use lance_linalg::kernels::normalize_arrow;
use lance_table::format::IndexMetadata;
use tokio::sync::Notify;
use crate::dataset::Dataset;
use crate::index::DatasetIndexInternalExt;
use crate::index::prefilter::{DatasetPreFilter, FilterLoader};
use crate::index::vector::utils::{get_vector_type, validate_distance_type_for};
use crate::{Error, Result};
use lance_arrow::*;
use super::utils::{
FilteredRowIdsToPrefilter, IndexMetrics, InstrumentedRecordBatchStreamAdapter, PreFilterSource,
SelectionVectorToPrefilter,
};
pub struct AnnPartitionMetrics {
index_metrics: IndexMetrics,
partitions_ranked: Count,
deltas_searched: Count,
find_partitions_elapsed: Time,
baseline_metrics: BaselineMetrics,
}
impl AnnPartitionMetrics {
pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
Self {
index_metrics: IndexMetrics::new(metrics, partition),
partitions_ranked: metrics.new_count(PARTITIONS_RANKED_METRIC, partition),
deltas_searched: metrics.new_count(DELTAS_SEARCHED_METRIC, partition),
find_partitions_elapsed: metrics.new_time(FIND_PARTITIONS_ELAPSED_METRIC, partition),
baseline_metrics: BaselineMetrics::new(metrics, partition),
}
}
}
pub struct AnnIndexMetrics {
index_metrics: IndexMetrics,
partitions_searched: Count,
baseline_metrics: BaselineMetrics,
}
impl AnnIndexMetrics {
pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
Self {
index_metrics: IndexMetrics::new(metrics, partition),
partitions_searched: metrics.new_count(PARTITIONS_SEARCHED_METRIC, partition),
baseline_metrics: BaselineMetrics::new(metrics, partition),
}
}
}
async fn find_partitions_on_cpu(
index: Arc<dyn VectorIndex>,
query: Query,
) -> DataFusionResult<(UInt32Array, Float32Array)> {
spawn_cpu(move || index.find_partitions(&query))
.await
.map_err(|e| DataFusionError::Execution(format!("Failed to find partitions: {}", e)))
}
fn normalize_query_for_index(index: &dyn VectorIndex, query: Query) -> DataFusionResult<Query> {
if index.metric_type() != DistanceType::Cosine {
return Ok(query);
}
let mut query = query;
let key = normalize_arrow(&query.key)
.map_err(|e| DataFusionError::Execution(format!("Failed to normalize query: {e}")))?
.0;
query.key = key;
Ok(query)
}
#[derive(Debug)]
pub struct KNNVectorDistanceExec {
pub input: Arc<dyn ExecutionPlan>,
pub query: ArrayRef,
pub column: String,
pub distance_type: DistanceType,
output_schema: SchemaRef,
properties: Arc<PlanProperties>,
metrics: ExecutionPlanMetricsSet,
}
impl DisplayAs for KNNVectorDistanceExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(f, "KNNVectorDistance: metric={}", self.distance_type,)
}
DisplayFormatType::TreeRender => {
write!(f, "KNNVectorDistance\nmetric={}", self.distance_type,)
}
}
}
}
impl KNNVectorDistanceExec {
pub fn try_new(
input: Arc<dyn ExecutionPlan>,
column: &str,
query: ArrayRef,
distance_type: DistanceType,
) -> Result<Self> {
let mut output_schema = input.schema().as_ref().clone();
let (_, element_type) = get_vector_type(&(&output_schema).try_into()?, column)?;
validate_distance_type_for(distance_type, &element_type)?;
if output_schema.column_with_name(DIST_COL).is_some() {
output_schema = output_schema.without_column(DIST_COL);
}
let output_schema = Arc::new(output_schema.try_with_column(Field::new(
DIST_COL,
DataType::Float32,
true,
))?);
let properties = Arc::new(
input
.properties()
.as_ref()
.clone()
.with_eq_properties(EquivalenceProperties::new(output_schema.clone())),
);
Ok(Self {
input,
query,
column: column.to_string(),
distance_type,
output_schema,
properties,
metrics: ExecutionPlanMetricsSet::new(),
})
}
}
impl ExecutionPlan for KNNVectorDistanceExec {
fn name(&self) -> &str {
"KNNVectorDistanceExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> arrow_schema::SchemaRef {
self.output_schema.clone()
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn with_new_children(
self: Arc<Self>,
mut children: Vec<Arc<dyn ExecutionPlan>>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
if children.len() != 1 {
return Err(DataFusionError::Internal(
"KNNVectorDistanceExec node must have exactly one child".to_string(),
));
}
Ok(Arc::new(Self::try_new(
children.pop().expect("length checked"),
&self.column,
self.query.clone(),
self.distance_type,
)?))
}
fn execute(
&self,
partition: usize,
context: Arc<datafusion::execution::context::TaskContext>,
) -> DataFusionResult<SendableRecordBatchStream> {
let input_stream = self.input.execute(partition, context)?;
let key = self.query.clone();
let column = self.column.clone();
let dt = self.distance_type;
let stream = input_stream
.try_filter(|batch| future::ready(batch.num_rows() > 0))
.map(move |batch| {
let key = key.clone();
let column = column.clone();
async move {
let batch = compute_distance(key, dt, &column, batch?)
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;
let distances = batch[DIST_COL].as_primitive::<Float32Type>();
let mask = BooleanArray::from_iter(
distances
.iter()
.map(|v| Some(v.map(|v| !v.is_nan()).unwrap_or(false))),
);
arrow::compute::filter_record_batch(&batch, &mask)
.map_err(|e| DataFusionError::ArrowError(Box::new(e), None))
}
})
.buffer_unordered(get_num_compute_intensive_cpus());
let schema = self.schema();
Ok(Box::pin(InstrumentedRecordBatchStreamAdapter::new(
schema,
stream.boxed(),
partition,
&self.metrics,
)) as SendableRecordBatchStream)
}
fn partition_statistics(&self, partition: Option<usize>) -> DataFusionResult<Statistics> {
let inner_stats = self.input.partition_statistics(partition)?;
let schema = self.input.schema();
let dist_stats = inner_stats
.column_statistics
.iter()
.zip(schema.fields())
.find(|(_, field)| field.name() == &self.column)
.map(|(stats, _)| ColumnStatistics {
null_count: stats.null_count,
..Default::default()
})
.unwrap_or_default();
let column_statistics = inner_stats
.column_statistics
.into_iter()
.zip(schema.fields())
.filter(|(_, field)| field.name() != DIST_COL)
.map(|(stats, _)| stats)
.chain(std::iter::once(dist_stats))
.collect::<Vec<_>>();
Ok(Statistics {
num_rows: inner_stats.num_rows,
column_statistics,
..Statistics::new_unknown(self.schema().as_ref())
})
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.properties
}
fn supports_limit_pushdown(&self) -> bool {
false
}
}
pub static KNN_INDEX_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
Arc::new(Schema::new(vec![
Field::new(DIST_COL, DataType::Float32, true),
ROW_ID_FIELD.clone(),
]))
});
pub static KNN_PARTITION_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
Arc::new(Schema::new(vec![
Field::new(
PART_ID_COLUMN,
DataType::List(Field::new("item", DataType::UInt32, false).into()),
false,
),
Field::new(
DIST_Q_C_COLUMN,
DataType::List(Field::new("item", DataType::Float32, false).into()),
false,
),
Field::new(INDEX_UUID_COLUMN, DataType::Utf8, false),
]))
});
pub fn new_knn_exec(
dataset: Arc<Dataset>,
indices: &[IndexMetadata],
query: &Query,
prefilter_source: PreFilterSource,
) -> Result<Arc<dyn ExecutionPlan>> {
let ivf_node = ANNIvfPartitionExec::try_new(
dataset.clone(),
indices.iter().map(|idx| idx.uuid.to_string()).collect_vec(),
query.clone(),
)?;
let sub_index = ANNIvfSubIndexExec::try_new(
Arc::new(ivf_node),
dataset,
indices.to_vec(),
query.clone(),
prefilter_source,
)?;
Ok(Arc::new(sub_index))
}
#[derive(Debug)]
pub struct ANNIvfPartitionExec {
pub dataset: Arc<Dataset>,
pub query: Query,
pub index_uuids: Vec<String>,
pub properties: Arc<PlanProperties>,
pub metrics: ExecutionPlanMetricsSet,
}
impl ANNIvfPartitionExec {
pub fn try_new(dataset: Arc<Dataset>, index_uuids: Vec<String>, query: Query) -> Result<Self> {
let dataset_schema = dataset.schema();
get_vector_type(dataset_schema, &query.column)?;
if index_uuids.is_empty() {
return Err(Error::execution(
"ANNIVFPartitionExec node: no index found for query".to_string(),
));
}
let schema = KNN_PARTITION_SCHEMA.clone();
let properties = Arc::new(PlanProperties::new(
EquivalenceProperties::new(schema),
Partitioning::RoundRobinBatch(1),
EmissionType::Incremental,
Boundedness::Bounded,
));
Ok(Self {
dataset,
query,
index_uuids,
properties,
metrics: ExecutionPlanMetricsSet::new(),
})
}
}
impl DisplayAs for ANNIvfPartitionExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(
f,
"ANNIvfPartition: uuid={}, minimum_nprobes={}, maximum_nprobes={:?}, deltas={}",
self.index_uuids[0],
self.query.minimum_nprobes,
self.query.maximum_nprobes,
self.index_uuids.len()
)
}
DisplayFormatType::TreeRender => {
write!(
f,
"ANNIvfPartition\nuuid={}\nminimum_nprobes={}\nmaximum_nprobes={:?}\ndeltas={}",
self.index_uuids[0],
self.query.minimum_nprobes,
self.query.maximum_nprobes,
self.index_uuids.len()
)
}
}
}
}
impl ExecutionPlan for ANNIvfPartitionExec {
fn name(&self) -> &str {
"ANNIVFPartitionExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
KNN_PARTITION_SCHEMA.clone()
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.properties
}
fn partition_statistics(&self, _partition: Option<usize>) -> DataFusionResult<Statistics> {
Ok(Statistics {
num_rows: Precision::Exact(self.query.minimum_nprobes),
..Statistics::new_unknown(self.schema().as_ref())
})
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
if !children.is_empty() {
Err(DataFusionError::Internal(
"ANNIVFPartitionExec node does not accept children".to_string(),
))
} else {
Ok(self)
}
}
fn execute(
&self,
partition: usize,
_context: Arc<datafusion::execution::TaskContext>,
) -> DataFusionResult<SendableRecordBatchStream> {
let timer = Instant::now();
let query = self.query.clone();
let ds = self.dataset.clone();
let metrics = Arc::new(AnnPartitionMetrics::new(&self.metrics, partition));
metrics.deltas_searched.add(self.index_uuids.len());
let metrics_clone = metrics.clone();
let stream = stream::iter(self.index_uuids.clone())
.map(move |uuid| {
let query = query.clone();
let ds = ds.clone();
let metrics = metrics.clone();
async move {
let index = ds
.open_vector_index(&query.column, &uuid, &metrics.index_metrics)
.await?;
let query = normalize_query_for_index(index.as_ref(), query.clone())?;
metrics.partitions_ranked.add(index.total_partitions());
let (partitions, dist_q_c) = {
let _timer = metrics.find_partitions_elapsed.timer();
find_partitions_on_cpu(index, query).await?
};
let mut part_list_builder = ListBuilder::new(UInt32Builder::new())
.with_field(Field::new("item", DataType::UInt32, false));
part_list_builder.append_value(partitions.iter());
let partition_col = part_list_builder.finish();
let mut dist_q_c_list_builder = ListBuilder::new(Float32Builder::new())
.with_field(Field::new("item", DataType::Float32, false));
dist_q_c_list_builder.append_value(dist_q_c.iter());
let dist_q_c_col = dist_q_c_list_builder.finish();
let uuid_col = StringArray::from(vec![uuid.as_str()]);
let batch = RecordBatch::try_new(
KNN_PARTITION_SCHEMA.clone(),
vec![
Arc::new(partition_col),
Arc::new(dist_q_c_col),
Arc::new(uuid_col),
],
)?;
metrics.baseline_metrics.record_output(batch.num_rows());
Ok::<_, DataFusionError>(batch)
}
})
.buffered(self.index_uuids.len())
.finally(move || {
metrics_clone.baseline_metrics.done();
metrics_clone
.baseline_metrics
.elapsed_compute()
.add_duration(timer.elapsed());
});
let schema = self.schema();
Ok(
Box::pin(RecordBatchStreamAdapter::new(schema, stream.boxed()))
as SendableRecordBatchStream,
)
}
fn supports_limit_pushdown(&self) -> bool {
false
}
}
#[derive(Debug)]
pub struct ANNIvfSubIndexExec {
input: Arc<dyn ExecutionPlan>,
dataset: Arc<Dataset>,
indices: Vec<IndexMetadata>,
query: Query,
prefilter_source: PreFilterSource,
properties: Arc<PlanProperties>,
metrics: ExecutionPlanMetricsSet,
}
impl ANNIvfSubIndexExec {
pub fn try_new(
input: Arc<dyn ExecutionPlan>,
dataset: Arc<Dataset>,
indices: Vec<IndexMetadata>,
query: Query,
prefilter_source: PreFilterSource,
) -> Result<Self> {
if input.schema().field_with_name(PART_ID_COLUMN).is_err() {
return Err(Error::index(format!(
"ANNSubIndexExec node: input schema does not have \"{}\" column",
PART_ID_COLUMN
)));
}
let properties = Arc::new(PlanProperties::new(
EquivalenceProperties::new(KNN_INDEX_SCHEMA.clone()),
Partitioning::RoundRobinBatch(1),
EmissionType::Final,
Boundedness::Bounded,
));
Ok(Self {
input,
dataset,
indices,
query,
prefilter_source,
properties,
metrics: ExecutionPlanMetricsSet::new(),
})
}
pub fn query(&self) -> &Query {
&self.query
}
pub fn dataset(&self) -> &Arc<Dataset> {
&self.dataset
}
pub fn indices(&self) -> &[IndexMetadata] {
&self.indices
}
pub fn prefilter_source(&self) -> &PreFilterSource {
&self.prefilter_source
}
}
impl DisplayAs for ANNIvfSubIndexExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let metric_str = self
.query
.metric_type
.map(|m| format!("{:?}", m))
.unwrap_or_else(|| "default".to_string());
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(
f,
"ANNSubIndex: name={}, k={}, deltas={}, metric={}",
self.indices[0].name,
self.query.k * self.query.refine_factor.unwrap_or(1) as usize,
self.indices.len(),
metric_str
)
}
DisplayFormatType::TreeRender => {
write!(
f,
"ANNSubIndex\nname={}\nk={}\ndeltas={}\nmetric={}",
self.indices[0].name,
self.query.k * self.query.refine_factor.unwrap_or(1) as usize,
self.indices.len(),
metric_str
)
}
}
}
}
struct ANNIvfEarlySearchResults {
k: usize,
initial_ids: Mutex<Vec<u64>>,
num_results_found: AtomicUsize,
deltas_remaining: AtomicUsize,
all_deltas_done: Notify,
took_no_rows_shortcut: AtomicBool,
}
impl ANNIvfEarlySearchResults {
fn new(deltas_remaining: usize, k: usize) -> Self {
Self {
k,
initial_ids: Mutex::new(Vec::with_capacity(k)),
num_results_found: AtomicUsize::new(0),
deltas_remaining: AtomicUsize::new(deltas_remaining),
all_deltas_done: Notify::new(),
took_no_rows_shortcut: AtomicBool::new(false),
}
}
fn record_batch(&self, batch: &RecordBatch) {
let mut initial_ids = self.initial_ids.lock().unwrap();
let ids_to_record = (self.k - initial_ids.len()).min(batch.num_rows());
initial_ids.extend(
batch
.column(1)
.as_primitive::<UInt64Type>()
.values()
.iter()
.take(ids_to_record),
);
}
fn record_late_batch(&self, num_rows: usize) {
self.num_results_found
.fetch_add(num_rows, Ordering::Relaxed);
}
async fn wait_for_minimum_to_finish(&self) -> usize {
if self.deltas_remaining.fetch_sub(1, Ordering::Relaxed) == 1 {
{
let new_num_results_found = self.initial_ids.lock().unwrap().len();
self.num_results_found
.store(new_num_results_found, Ordering::Relaxed);
}
self.all_deltas_done.notify_waiters();
} else {
self.all_deltas_done.notified().await;
}
self.num_results_found.load(Ordering::Relaxed)
}
}
struct LatePartitionSearchControl {
state: Arc<ANNIvfEarlySearchResults>,
max_results: usize,
}
impl PartitionSearchControl for LatePartitionSearchControl {
fn should_stop(&self) -> bool {
self.state.num_results_found.load(Ordering::Relaxed) >= self.max_results
}
fn record_batch(&self, batch: &RecordBatch) {
self.state.record_late_batch(batch.num_rows());
}
}
fn effective_query_parallelism(query: &Query, index: &dyn VectorIndex) -> usize {
let cpu_pool_size = get_num_compute_intensive_cpus();
effective_query_parallelism_for(
query,
cpu_pool_size,
index.auto_query_parallelism(cpu_pool_size),
)
}
fn effective_query_parallelism_for(
query: &Query,
cpu_pool_size: usize,
auto_parallelism: usize,
) -> usize {
let cpu_pool_size = cpu_pool_size.max(1);
match query.query_parallelism {
-1 => cpu_pool_size,
0 => auto_parallelism.clamp(1, cpu_pool_size),
n if n > 0 => (n as usize).min(cpu_pool_size).max(1),
_ => 1,
}
}
impl ANNIvfSubIndexExec {
async fn search_partition(
index: Arc<dyn VectorIndex>,
query: Query,
part_id: usize,
pre_filter: Arc<DatasetPreFilter>,
metrics: Arc<AnnIndexMetrics>,
) -> DataFusionResult<RecordBatch> {
let batch = index
.search_in_partition(part_id, &query, pre_filter, &metrics.index_metrics)
.map_err(|e| DataFusionError::Execution(format!("Failed to calculate KNN: {}", e)))
.await?;
metrics.baseline_metrics.record_output(batch.num_rows());
Ok(batch)
}
fn instrument_sequential_partition_stream(
stream: stream::BoxStream<'static, DataFusionResult<RecordBatch>>,
metrics: Arc<AnnIndexMetrics>,
state: Arc<ANNIvfEarlySearchResults>,
record_initial: bool,
record_partition_per_batch: bool,
) -> stream::BoxStream<'static, DataFusionResult<RecordBatch>> {
stream
.map(move |batch| {
let metrics = metrics.clone();
let state = state.clone();
batch.inspect(move |batch| {
if record_partition_per_batch {
metrics.partitions_searched.add(1);
}
metrics.baseline_metrics.record_output(batch.num_rows());
if record_initial {
state.record_batch(batch);
}
})
})
.boxed()
}
fn late_search(
index: Arc<dyn VectorIndex>,
query: Query,
partitions: Arc<UInt32Array>,
q_c_dists: Arc<Float32Array>,
prefilter: Arc<DatasetPreFilter>,
metrics: Arc<AnnIndexMetrics>,
state: Arc<ANNIvfEarlySearchResults>,
) -> impl Stream<Item = DataFusionResult<RecordBatch>> {
let stream = futures::stream::once(async move {
let max_nprobes = query
.maximum_nprobes
.unwrap_or(partitions.len())
.min(partitions.len());
let min_nprobes = query.minimum_nprobes.min(max_nprobes);
if max_nprobes <= min_nprobes {
return futures::stream::empty().boxed();
}
let found_so_far = state.wait_for_minimum_to_finish().await;
if found_so_far >= query.k {
return futures::stream::empty().boxed();
}
let prefilter_mask = prefilter.mask();
let max_results = prefilter_mask.max_len().map(|x| x as usize);
if let Some(max_results) = max_results
&& found_so_far < max_results
&& max_results <= query.k
{
if let Some(iter_addrs) = prefilter_mask.iter_addrs() {
if state
.took_no_rows_shortcut
.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
let initial_addrs = state.initial_ids.lock().unwrap();
let found_addrs = HashSet::<_>::from_iter(initial_addrs.iter().copied());
drop(initial_addrs);
let mask_addrs = HashSet::from_iter(iter_addrs.map(u64::from));
let not_found_addrs = mask_addrs.difference(&found_addrs);
let not_found_addrs =
UInt64Array::from_iter_values(not_found_addrs.copied());
let not_found_distance =
Float32Array::from_value(f32::INFINITY, not_found_addrs.len());
let not_found_batch = RecordBatch::try_new(
KNN_INDEX_SCHEMA.clone(),
vec![Arc::new(not_found_distance), Arc::new(not_found_addrs)],
)
.unwrap();
return futures::stream::once(async move { Ok(not_found_batch) }).boxed();
} else {
return futures::stream::empty().boxed();
}
}
}
let max_results = max_results.unwrap_or(usize::MAX).min(query.k);
let state_clone = state.clone();
let query_parallelism = effective_query_parallelism(&query, index.as_ref());
if query_parallelism <= 1 {
return stream::once(async move {
let prefilter: Arc<dyn PreFilter> = prefilter;
let index_metrics: Arc<dyn MetricsCollector> =
Arc::new(metrics.index_metrics.clone());
let stream = index
.search_partitions(
query,
partitions,
q_c_dists,
min_nprobes,
max_nprobes,
prefilter,
Some(Arc::new(LatePartitionSearchControl {
state: state.clone(),
max_results,
})),
index_metrics,
)
.await
.map_err(|e| {
DataFusionError::Execution(format!("Failed to search partitions: {e}"))
})?;
Ok::<stream::BoxStream<'static, DataFusionResult<RecordBatch>>, DataFusionError>(
Self::instrument_sequential_partition_stream(
stream.boxed(),
metrics,
state,
false,
true,
),
)
})
.try_flatten()
.boxed();
}
futures::stream::iter(min_nprobes..max_nprobes)
.map(move |idx| {
let part_id = partitions.value(idx);
let mut query = query.clone();
query.dist_q_c = q_c_dists.value(idx);
let metrics = metrics.clone();
let pre_filter = prefilter.clone();
let state = state.clone();
let index = index.clone();
async move {
metrics.partitions_searched.add(1);
let batch = Self::search_partition(
index,
query,
part_id as usize,
pre_filter,
metrics,
)
.await?;
state.record_late_batch(batch.num_rows());
Ok(batch)
}
})
.take_while(move |_| {
let found_so_far = state_clone.num_results_found.load(Ordering::Relaxed);
std::future::ready(found_so_far < max_results)
})
.buffered(query_parallelism)
.boxed()
});
stream.flatten()
}
fn initial_search(
index: Arc<dyn VectorIndex>,
query: Query,
partitions: Arc<UInt32Array>,
q_c_dists: Arc<Float32Array>,
prefilter: Arc<DatasetPreFilter>,
metrics: Arc<AnnIndexMetrics>,
state: Arc<ANNIvfEarlySearchResults>,
) -> impl Stream<Item = DataFusionResult<RecordBatch>> {
let minimum_nprobes = query.minimum_nprobes.min(partitions.len());
let query_parallelism = effective_query_parallelism(&query, index.as_ref());
if query_parallelism <= 1 {
metrics.partitions_searched.add(minimum_nprobes);
return stream::once(async move {
let prefilter: Arc<dyn PreFilter> = prefilter;
let index_metrics: Arc<dyn MetricsCollector> =
Arc::new(metrics.index_metrics.clone());
let stream = index
.search_partitions(
query,
partitions,
q_c_dists,
0,
minimum_nprobes,
prefilter,
None,
index_metrics,
)
.await
.map_err(|e| {
DataFusionError::Execution(format!("Failed to search partitions: {e}"))
})?;
Ok::<stream::BoxStream<'static, DataFusionResult<RecordBatch>>, DataFusionError>(
Self::instrument_sequential_partition_stream(
stream.boxed(),
metrics,
state,
true,
false,
),
)
})
.try_flatten()
.boxed();
}
metrics.partitions_searched.add(minimum_nprobes);
futures::stream::iter(0..minimum_nprobes)
.map(move |idx| {
let part_id = partitions.value(idx);
let mut query = query.clone();
query.dist_q_c = q_c_dists.value(idx);
let metrics = metrics.clone();
let index = index.clone();
let pre_filter = prefilter.clone();
let state = state.clone();
async move {
let batch =
Self::search_partition(index, query, part_id as usize, pre_filter, metrics)
.await?;
state.record_batch(&batch);
Ok(batch)
}
})
.buffered(query_parallelism)
.boxed()
}
}
impl ExecutionPlan for ANNIvfSubIndexExec {
fn name(&self) -> &str {
"ANNSubIndexExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> arrow_schema::SchemaRef {
KNN_INDEX_SCHEMA.clone()
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
match &self.prefilter_source {
PreFilterSource::None => vec![&self.input],
PreFilterSource::FilteredRowIds(src) => vec![&self.input, &src],
PreFilterSource::ScalarIndexQuery(src) => vec![&self.input, &src],
}
}
fn required_input_distribution(&self) -> Vec<Distribution> {
self.children()
.iter()
.map(|_| Distribution::SinglePartition)
.collect()
}
fn with_new_children(
self: Arc<Self>,
mut children: Vec<Arc<dyn ExecutionPlan>>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
let plan = if children.len() == 1 || children.len() == 2 {
let prefilter_source = if children.len() == 2 {
let prefilter = children.pop().expect("length checked");
match &self.prefilter_source {
PreFilterSource::None => PreFilterSource::None,
PreFilterSource::FilteredRowIds(_) => {
PreFilterSource::FilteredRowIds(prefilter)
}
PreFilterSource::ScalarIndexQuery(_) => {
PreFilterSource::ScalarIndexQuery(prefilter)
}
}
} else {
self.prefilter_source.clone()
};
Self {
input: children.pop().expect("length checked"),
dataset: self.dataset.clone(),
indices: self.indices.clone(),
query: self.query.clone(),
prefilter_source,
properties: self.properties.clone(),
metrics: ExecutionPlanMetricsSet::new(),
}
} else {
return Err(DataFusionError::Internal(
"ANNSubIndexExec node must have exactly one or two (prefilter) child".to_string(),
));
};
Ok(Arc::new(plan))
}
fn execute(
&self,
partition: usize,
context: Arc<datafusion::execution::context::TaskContext>,
) -> DataFusionResult<datafusion::physical_plan::SendableRecordBatchStream> {
let input_stream = self.input.execute(partition, context.clone())?;
let schema = self.schema();
let query = self.query.clone();
let ds = self.dataset.clone();
let column = self.query.column.clone();
let indices = self.indices.clone();
let prefilter_source = self.prefilter_source.clone();
let metrics = Arc::new(AnnIndexMetrics::new(&self.metrics, partition));
let metrics_clone = metrics.clone();
let timer = Instant::now();
let per_index_stream = input_stream
.and_then(move |batch| {
let part_id_col = batch.column_by_name(PART_ID_COLUMN).unwrap_or_else(|| {
panic!("ANNSubIndexExec: input missing {} column", PART_ID_COLUMN)
});
let part_id_arr = part_id_col.as_list::<i32>().clone();
let dist_q_c_col = batch.column_by_name(DIST_Q_C_COLUMN).unwrap_or_else(|| {
panic!("ANNSubIndexExec: input missing {} column", DIST_Q_C_COLUMN)
});
let dist_q_c_arr = dist_q_c_col.as_list::<i32>().clone();
let index_uuid_col = batch.column_by_name(INDEX_UUID_COLUMN).unwrap_or_else(|| {
panic!(
"ANNSubIndexExec: input missing {} column",
INDEX_UUID_COLUMN
)
});
let index_uuid = index_uuid_col.as_string::<i32>().clone();
let plan: Vec<DataFusionResult<(_, _, _)>> = part_id_arr
.iter()
.zip(dist_q_c_arr.iter())
.zip(index_uuid.iter())
.map(|((part_id, dist_q_c), uuid)| {
let partitions =
Arc::new(part_id.unwrap().as_primitive::<UInt32Type>().clone());
let dist_q_c =
Arc::new(dist_q_c.unwrap().as_primitive::<Float32Type>().clone());
let uuid = uuid.unwrap().to_string();
Ok((partitions, dist_q_c, uuid))
})
.collect_vec();
async move { DataFusionResult::Ok(stream::iter(plan)) }
})
.try_flatten();
let prefilter_loader = match &prefilter_source {
PreFilterSource::FilteredRowIds(src_node) => {
let stream = src_node.execute(partition, context)?;
Some(Box::new(FilteredRowIdsToPrefilter(stream)) as Box<dyn FilterLoader>)
}
PreFilterSource::ScalarIndexQuery(src_node) => {
let stream = src_node.execute(partition, context)?;
Some(Box::new(SelectionVectorToPrefilter(stream)) as Box<dyn FilterLoader>)
}
PreFilterSource::None => None,
};
let pre_filter = Arc::new(DatasetPreFilter::new(
ds.clone(),
&indices,
prefilter_loader,
));
let state = Arc::new(ANNIvfEarlySearchResults::new(indices.len(), query.k));
Ok(Box::pin(RecordBatchStreamAdapter::new(
schema,
per_index_stream
.and_then(move |(part_ids, q_c_dists, index_uuid)| {
let ds = ds.clone();
let column = column.clone();
let metrics = metrics.clone();
let pre_filter = pre_filter.clone();
let state = state.clone();
let mut query = query.clone();
let pruned_nprobes = early_pruning(q_c_dists.values(), query.k);
adjust_probes(&mut query, pruned_nprobes);
async move {
let raw_index = ds
.open_vector_index(&column, &index_uuid, &metrics.index_metrics)
.await?;
let query = normalize_query_for_index(raw_index.as_ref(), query)?;
let early_search = Self::initial_search(
raw_index.clone(),
query.clone(),
part_ids.clone(),
q_c_dists.clone(),
pre_filter.clone(),
metrics.clone(),
state.clone(),
);
let late_search = Self::late_search(
raw_index.clone(),
query,
part_ids,
q_c_dists,
pre_filter,
metrics,
state,
);
DataFusionResult::Ok(early_search.chain(late_search).boxed())
}
})
.try_flatten_unordered(None)
.finally(move || {
metrics_clone
.baseline_metrics
.elapsed_compute()
.add_duration(timer.elapsed());
metrics_clone.baseline_metrics.done();
})
.boxed(),
)))
}
fn partition_statistics(
&self,
partition: Option<usize>,
) -> DataFusionResult<datafusion::physical_plan::Statistics> {
Ok(Statistics {
num_rows: Precision::Exact(
self.query.k
* self.query.refine_factor.unwrap_or(1) as usize
* self
.input
.partition_statistics(partition)?
.num_rows
.get_value()
.unwrap_or(&1),
),
..Statistics::new_unknown(self.schema().as_ref())
})
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.properties
}
fn supports_limit_pushdown(&self) -> bool {
false
}
}
fn adjust_probes(query: &mut Query, pruned_nprobes: usize) {
query.minimum_nprobes = query.minimum_nprobes.max(pruned_nprobes);
if let Some(maximum) = query.maximum_nprobes
&& query.minimum_nprobes > maximum
{
query.minimum_nprobes = maximum;
}
}
fn early_pruning(dists: &[f32], k: usize) -> usize {
if dists.is_empty() {
return 0;
}
const PRUNING_FACTORS: [f32; 3] = [0.6, 7.0, 81.0];
let factor = match k {
..=1 => PRUNING_FACTORS[0],
2..=10 => PRUNING_FACTORS[1],
11.. => PRUNING_FACTORS[2],
};
let dist_threshold = dists[0] * factor;
dists.partition_point(|dist| *dist <= dist_threshold)
}
#[derive(Debug)]
pub struct MultivectorScoringExec {
inputs: Vec<Arc<dyn ExecutionPlan>>,
query: Query,
properties: Arc<PlanProperties>,
}
impl MultivectorScoringExec {
pub fn try_new(inputs: Vec<Arc<dyn ExecutionPlan>>, query: Query) -> Result<Self> {
let properties = Arc::new(PlanProperties::new(
EquivalenceProperties::new(KNN_INDEX_SCHEMA.clone()),
Partitioning::RoundRobinBatch(1),
EmissionType::Final,
Boundedness::Bounded,
));
Ok(Self {
inputs,
query,
properties,
})
}
}
impl DisplayAs for MultivectorScoringExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(f, "MultivectorScoring: k={}", self.query.k)
}
DisplayFormatType::TreeRender => {
write!(f, "MultivectorScoring\nk={}", self.query.k)
}
}
}
}
impl ExecutionPlan for MultivectorScoringExec {
fn name(&self) -> &str {
"MultivectorScoringExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> arrow_schema::SchemaRef {
KNN_INDEX_SCHEMA.clone()
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
self.inputs.iter().collect()
}
fn required_input_distribution(&self) -> Vec<Distribution> {
self.children()
.iter()
.map(|_| Distribution::SinglePartition)
.collect()
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
let plan = Self::try_new(children, self.query.clone())?;
Ok(Arc::new(plan))
}
fn execute(
&self,
partition: usize,
context: Arc<datafusion::execution::context::TaskContext>,
) -> DataFusionResult<SendableRecordBatchStream> {
let inputs = self
.inputs
.iter()
.map(|input| input.execute(partition, context.clone()))
.collect::<DataFusionResult<Vec<_>>>()?;
let mut reduced_inputs = stream::select_all(inputs.into_iter().map(|stream| {
stream.map(|batch| {
let batch = batch?;
let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>();
let dists = batch[DIST_COL].as_primitive::<Float32Type>();
debug_assert_eq!(dists.null_count(), 0);
let min_sim = dists
.values()
.last()
.map(|dist| 1.0 - *dist)
.unwrap_or_default();
let mut new_row_ids = Vec::with_capacity(row_ids.len());
let mut new_sims = Vec::with_capacity(row_ids.len());
let mut visited_row_ids = HashSet::with_capacity(row_ids.len());
for (row_id, dist) in row_ids.values().iter().zip(dists.values().iter()) {
if visited_row_ids.contains(row_id) {
continue;
}
visited_row_ids.insert(row_id);
new_row_ids.push(*row_id);
new_sims.push(1.0 - *dist);
}
let new_row_ids = UInt64Array::from(new_row_ids);
let new_dists = Float32Array::from(new_sims);
let batch = RecordBatch::try_new(
KNN_INDEX_SCHEMA.clone(),
vec![Arc::new(new_dists), Arc::new(new_row_ids)],
)?;
Ok::<_, DataFusionError>((min_sim, batch))
})
}));
let k = self.query.k;
let refactor = self.query.refine_factor.unwrap_or(1) as usize;
let num_queries = self.inputs.len() as f32;
let stream = stream::once(async move {
let mut results = HashMap::with_capacity(k * refactor);
let mut missed_sim_sum = 0.0;
while let Some((min_sim, batch)) = reduced_inputs.try_next().await? {
let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>();
let sims = batch[DIST_COL].as_primitive::<Float32Type>();
let query_results = row_ids
.values()
.iter()
.copied()
.zip(sims.values().iter().copied())
.collect::<HashMap<_, _>>();
results.iter_mut().for_each(|(row_id, sim)| {
if let Some(new_dist) = query_results.get(row_id) {
*sim += new_dist;
} else {
*sim += min_sim;
}
});
query_results.into_iter().for_each(|(row_id, sim)| {
results.entry(row_id).or_insert(sim + missed_sim_sum);
});
missed_sim_sum += min_sim;
}
let (row_ids, sims): (Vec<_>, Vec<_>) = results.into_iter().unzip();
let dists = sims
.into_iter()
.map(|sim| num_queries - sim)
.collect::<Vec<_>>();
let row_ids = UInt64Array::from(row_ids);
let dists = Float32Array::from(dists);
let batch = RecordBatch::try_new(
KNN_INDEX_SCHEMA.clone(),
vec![Arc::new(dists), Arc::new(row_ids)],
)?;
Ok::<_, DataFusionError>(batch)
});
Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
stream.boxed(),
)))
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.properties
}
fn supports_limit_pushdown(&self) -> bool {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::index::DatasetIndexExt;
use arrow::compute::{concat_batches, sort_to_indices, take_record_batch};
use arrow::datatypes::Float32Type;
use arrow_array::{
ArrayRef, FixedSizeListArray, Float32Array, Int32Array, RecordBatchIterator, StringArray,
};
use arrow_schema::{Field as ArrowField, Schema as ArrowSchema};
use async_trait::async_trait;
use datafusion::error::Result as DataFusionResult;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use deepsize::DeepSizeOf;
use lance_core::utils::tempfile::TempStrDir;
use lance_datafusion::exec::{ExecutionStatsCallback, ExecutionSummaryCounts};
use lance_datafusion::utils::FIND_PARTITIONS_ELAPSED_METRIC;
use lance_datagen::{BatchCount, RowCount, array};
use lance_index::optimize::OptimizeOptions;
use lance_index::vector::ivf::IvfBuildParams;
use lance_index::vector::pq::PQBuildParams;
use lance_index::vector::{DEFAULT_QUERY_PARALLELISM, PreparedPartitionSearchHandle};
use lance_index::{Index, IndexType};
use lance_io::traits::Reader;
use lance_linalg::distance::MetricType;
use lance_testing::datagen::generate_random_array;
use roaring::RoaringBitmap;
use rstest::rstest;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use crate::dataset::{WriteMode, WriteParams};
use crate::index::vector::VectorIndexParams;
use crate::io::exec::testing::TestingExec;
fn base_query() -> Query {
Query {
column: "vec".to_string(),
key: Arc::new(Float32Array::from(vec![0.0f32])) as ArrayRef,
k: 10,
lower_bound: None,
upper_bound: None,
minimum_nprobes: 1,
maximum_nprobes: None,
ef: None,
refine_factor: None,
metric_type: Some(DistanceType::L2),
use_index: true,
query_parallelism: DEFAULT_QUERY_PARALLELISM,
dist_q_c: 0.0,
}
}
#[test]
fn test_effective_query_parallelism_clamps_to_cpu_pool() {
let mut query = base_query();
query.query_parallelism = -1;
assert_eq!(effective_query_parallelism_for(&query, 16, 1), 16);
query.query_parallelism = 0;
assert_eq!(effective_query_parallelism_for(&query, 16, 1), 1);
assert_eq!(effective_query_parallelism_for(&query, 16, 8), 8);
assert_eq!(effective_query_parallelism_for(&query, 16, 128), 16);
query.query_parallelism = 1;
assert_eq!(effective_query_parallelism_for(&query, 16, 8), 1);
query.query_parallelism = 4;
assert_eq!(effective_query_parallelism_for(&query, 16, 1), 4);
query.query_parallelism = 128;
assert_eq!(effective_query_parallelism_for(&query, 16, 1), 16);
}
#[derive(Debug, DeepSizeOf)]
struct ThreadCapturingIndex {
thread_name: Arc<Mutex<Option<String>>>,
row_ids: Vec<u64>,
}
#[derive(Debug, DeepSizeOf)]
struct PreparedThreadCapturingIndex {
prepared_partitions: Arc<Mutex<Vec<usize>>>,
searched_partitions: Arc<Mutex<Vec<usize>>>,
search_threads: Arc<Mutex<Vec<String>>>,
row_ids: Vec<u64>,
}
#[async_trait]
impl Index for ThreadCapturingIndex {
fn as_any(&self) -> &dyn Any {
self
}
fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
self
}
fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn VectorIndex>> {
Ok(self)
}
fn statistics(&self) -> Result<serde_json::Value> {
Ok(serde_json::json!({}))
}
async fn prewarm(&self) -> Result<()> {
Ok(())
}
fn index_type(&self) -> IndexType {
IndexType::Vector
}
async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
Ok(RoaringBitmap::new())
}
}
#[async_trait]
impl VectorIndex for ThreadCapturingIndex {
async fn search(
&self,
_query: &Query,
_pre_filter: Arc<dyn PreFilter>,
_metrics: &dyn lance_index::metrics::MetricsCollector,
) -> Result<RecordBatch> {
unimplemented!()
}
fn find_partitions(&self, _query: &Query) -> Result<(UInt32Array, Float32Array)> {
let thread_name = std::thread::current()
.name()
.unwrap_or("unknown")
.to_string();
*self.thread_name.lock().unwrap() = Some(thread_name);
Ok((UInt32Array::from(vec![0]), Float32Array::from(vec![0.0])))
}
fn total_partitions(&self) -> usize {
1
}
async fn search_in_partition(
&self,
_partition_id: usize,
_query: &Query,
_pre_filter: Arc<dyn PreFilter>,
_metrics: &dyn lance_index::metrics::MetricsCollector,
) -> Result<RecordBatch> {
unimplemented!()
}
fn is_loadable(&self) -> bool {
false
}
fn use_residual(&self) -> bool {
false
}
async fn load(
&self,
_reader: Arc<dyn Reader>,
_offset: usize,
_length: usize,
) -> Result<Box<dyn VectorIndex>> {
unimplemented!()
}
fn num_rows(&self) -> u64 {
0
}
fn row_ids(&self) -> Box<dyn Iterator<Item = &'_ u64> + '_> {
Box::new(self.row_ids.iter())
}
async fn remap(&mut self, _mapping: &HashMap<u64, Option<u64>>) -> Result<()> {
Ok(())
}
async fn to_batch_stream(&self, _with_vector: bool) -> Result<SendableRecordBatchStream> {
unimplemented!()
}
fn ivf_model(&self) -> &lance_index::vector::ivf::storage::IvfModel {
unimplemented!()
}
fn quantizer(&self) -> lance_index::vector::quantizer::Quantizer {
unimplemented!()
}
fn partition_size(&self, _part_id: usize) -> usize {
unimplemented!()
}
fn sub_index_type(
&self,
) -> (
lance_index::vector::v3::subindex::SubIndexType,
lance_index::vector::quantizer::QuantizationType,
) {
unimplemented!()
}
fn metric_type(&self) -> DistanceType {
DistanceType::L2
}
}
#[async_trait]
impl Index for PreparedThreadCapturingIndex {
fn as_any(&self) -> &dyn Any {
self
}
fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
self
}
fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn VectorIndex>> {
Ok(self)
}
fn statistics(&self) -> Result<serde_json::Value> {
Ok(serde_json::json!({}))
}
async fn prewarm(&self) -> Result<()> {
Ok(())
}
fn index_type(&self) -> IndexType {
IndexType::Vector
}
async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
Ok(RoaringBitmap::new())
}
}
#[async_trait]
impl VectorIndex for PreparedThreadCapturingIndex {
async fn search(
&self,
_query: &Query,
_pre_filter: Arc<dyn PreFilter>,
_metrics: &dyn lance_index::metrics::MetricsCollector,
) -> Result<RecordBatch> {
unimplemented!()
}
fn find_partitions(&self, _query: &Query) -> Result<(UInt32Array, Float32Array)> {
unimplemented!()
}
fn total_partitions(&self) -> usize {
self.row_ids.len()
}
async fn search_in_partition(
&self,
_partition_id: usize,
_query: &Query,
_pre_filter: Arc<dyn PreFilter>,
_metrics: &dyn lance_index::metrics::MetricsCollector,
) -> Result<RecordBatch> {
panic!("sequential prepared path should not call search_in_partition")
}
async fn prepare_partition_search(
&self,
partition_id: usize,
_query: &Query,
_pre_filter: Arc<dyn PreFilter>,
_metrics: &dyn lance_index::metrics::MetricsCollector,
) -> Result<PreparedPartitionSearchHandle> {
self.prepared_partitions.lock().unwrap().push(partition_id);
Ok(Box::new(partition_id))
}
fn search_prepared_partition(
&self,
prepared: PreparedPartitionSearchHandle,
_metrics: &dyn lance_index::metrics::MetricsCollector,
) -> Result<RecordBatch> {
let partition_id = *prepared.downcast::<usize>().unwrap();
self.searched_partitions.lock().unwrap().push(partition_id);
self.search_threads.lock().unwrap().push(
std::thread::current()
.name()
.unwrap_or("unknown")
.to_string(),
);
Ok(RecordBatch::try_new(
KNN_INDEX_SCHEMA.clone(),
vec![
Arc::new(Float32Array::from(vec![partition_id as f32])),
Arc::new(UInt64Array::from(vec![self.row_ids[partition_id]])),
],
)?)
}
fn supports_prepared_partition_search(&self) -> bool {
true
}
#[allow(clippy::too_many_arguments)]
async fn search_partitions(
self: Arc<Self>,
_query: Query,
partitions: Arc<UInt32Array>,
_q_c_dists: Arc<Float32Array>,
start_idx: usize,
end_idx: usize,
_pre_filter: Arc<dyn PreFilter>,
control: Option<Arc<dyn PartitionSearchControl>>,
_metrics: Arc<dyn lance_index::metrics::MetricsCollector>,
) -> Result<SendableRecordBatchStream> {
let (batch_tx, batch_rx) = mpsc::channel(1);
let batch_tx_for_search = batch_tx.clone();
let prepared_partition_ids = (start_idx..end_idx)
.map(|idx| partitions.value(idx) as usize)
.collect::<Vec<_>>();
self.prepared_partitions
.lock()
.unwrap()
.extend(prepared_partition_ids.iter().copied());
tokio::spawn(async move {
let search_result = spawn_cpu(move || -> DataFusionResult<()> {
for partition_id in prepared_partition_ids {
if control
.as_ref()
.is_some_and(|control| control.should_stop())
{
return Ok(());
}
let batch = self
.search_prepared_partition(
Box::new(partition_id),
&lance_index::metrics::NoOpMetricsCollector,
)
.map_err(datafusion::error::DataFusionError::from);
match batch {
Ok(batch) => {
if let Some(control) = control.as_ref() {
control.record_batch(&batch);
}
if batch_tx_for_search.blocking_send(Ok(batch)).is_err() {
return Ok(());
}
}
Err(err) => {
let _ = batch_tx_for_search.blocking_send(Err(err));
return Ok(());
}
}
}
Ok(())
})
.await;
if let Err(err) = search_result {
let _ = batch_tx.send(Err(err)).await;
}
});
Ok(Box::pin(RecordBatchStreamAdapter::new(
KNN_INDEX_SCHEMA.clone(),
ReceiverStream::new(batch_rx),
)))
}
fn is_loadable(&self) -> bool {
false
}
fn use_residual(&self) -> bool {
false
}
async fn load(
&self,
_reader: Arc<dyn Reader>,
_offset: usize,
_length: usize,
) -> Result<Box<dyn VectorIndex>> {
unimplemented!()
}
fn num_rows(&self) -> u64 {
self.row_ids.len() as u64
}
fn row_ids(&self) -> Box<dyn Iterator<Item = &'_ u64> + '_> {
Box::new(self.row_ids.iter())
}
async fn remap(&mut self, _mapping: &HashMap<u64, Option<u64>>) -> Result<()> {
Ok(())
}
async fn to_batch_stream(&self, _with_vector: bool) -> Result<SendableRecordBatchStream> {
unimplemented!()
}
fn ivf_model(&self) -> &lance_index::vector::ivf::storage::IvfModel {
unimplemented!()
}
fn quantizer(&self) -> lance_index::vector::quantizer::Quantizer {
unimplemented!()
}
fn partition_size(&self, _part_id: usize) -> usize {
unimplemented!()
}
fn sub_index_type(
&self,
) -> (
lance_index::vector::v3::subindex::SubIndexType,
lance_index::vector::quantizer::QuantizationType,
) {
unimplemented!()
}
fn metric_type(&self) -> DistanceType {
DistanceType::L2
}
}
async fn empty_prefilter() -> Arc<DatasetPreFilter> {
static NEXT_PREFILTER_DATASET_ID: AtomicUsize = AtomicUsize::new(0);
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"id",
DataType::Int32,
false,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let uri = format!(
"memory://sequential-prefilter-{}",
NEXT_PREFILTER_DATASET_ID.fetch_add(1, Ordering::Relaxed)
);
let dataset = Arc::new(
Dataset::write(
RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema),
&uri,
None,
)
.await
.unwrap(),
);
let mut indexed_fragments = RoaringBitmap::new();
for fragment in dataset.manifest.fragments.iter() {
indexed_fragments.insert(fragment.id as u32);
}
let index = IndexMetadata {
uuid: uuid::Uuid::new_v4(),
fields: vec![],
name: "test".to_string(),
dataset_version: 1,
fragment_bitmap: Some(indexed_fragments),
index_details: None,
index_version: 0,
created_at: None,
base_id: None,
files: None,
};
let prefilter = Arc::new(DatasetPreFilter::new(dataset, &[index], None));
prefilter.wait_for_ready().await.unwrap();
prefilter
}
fn prepared_metrics() -> Arc<AnnIndexMetrics> {
Arc::new(AnnIndexMetrics::new(&ExecutionPlanMetricsSet::new(), 0))
}
type PreparedIndexState = (
Arc<dyn VectorIndex>,
Arc<Mutex<Vec<usize>>>,
Arc<Mutex<Vec<usize>>>,
Arc<Mutex<Vec<String>>>,
);
fn prepared_index(row_ids: Vec<u64>) -> PreparedIndexState {
let prepared_partitions = Arc::new(Mutex::new(Vec::new()));
let searched_partitions = Arc::new(Mutex::new(Vec::new()));
let search_threads = Arc::new(Mutex::new(Vec::new()));
let index: Arc<dyn VectorIndex> = Arc::new(PreparedThreadCapturingIndex {
prepared_partitions: prepared_partitions.clone(),
searched_partitions: searched_partitions.clone(),
search_threads: search_threads.clone(),
row_ids,
});
(
index,
prepared_partitions,
searched_partitions,
search_threads,
)
}
#[test]
fn test_adjust_probes_rules() {
let mut query = base_query();
adjust_probes(&mut query, 10);
assert_eq!(query.minimum_nprobes, 10);
assert_eq!(query.maximum_nprobes, None);
let mut query = base_query();
query.minimum_nprobes = 20;
adjust_probes(&mut query, 10);
assert_eq!(query.minimum_nprobes, 20);
assert_eq!(query.maximum_nprobes, None);
let mut query = base_query();
query.maximum_nprobes = Some(25);
adjust_probes(&mut query, 10);
assert_eq!(query.minimum_nprobes, 10);
assert_eq!(query.maximum_nprobes, Some(25));
let mut query = base_query();
query.maximum_nprobes = Some(5);
adjust_probes(&mut query, 10);
assert_eq!(query.minimum_nprobes, 5);
assert_eq!(query.maximum_nprobes, Some(5));
let mut query = base_query();
query.minimum_nprobes = 30;
query.maximum_nprobes = Some(50);
adjust_probes(&mut query, 10);
assert_eq!(query.minimum_nprobes, 30);
assert_eq!(query.maximum_nprobes, Some(50));
}
#[tokio::test]
async fn test_find_partitions_runs_on_cpu_runtime() {
let thread_name = Arc::new(Mutex::new(None));
let index: Arc<dyn VectorIndex> = Arc::new(ThreadCapturingIndex {
thread_name: thread_name.clone(),
row_ids: Vec::new(),
});
let (_partitions, _distances) = find_partitions_on_cpu(index, base_query()).await.unwrap();
let thread_name = thread_name.lock().unwrap().clone().unwrap();
assert!(
thread_name.contains("lance-cpu"),
"expected find_partitions to run on the dedicated cpu runtime, got thread {thread_name}",
);
}
#[tokio::test]
async fn test_sequential_initial_search_prepares_all_then_searches_on_one_cpu_thread() {
let (index, prepared_partitions, searched_partitions, search_threads) =
prepared_index(vec![10, 11, 12]);
let mut query = base_query();
query.minimum_nprobes = 3;
let state = Arc::new(ANNIvfEarlySearchResults::new(1, query.k));
let batches = ANNIvfSubIndexExec::initial_search(
index,
query,
Arc::new(UInt32Array::from(vec![0, 1, 2])),
Arc::new(Float32Array::from(vec![0.1, 0.2, 0.3])),
empty_prefilter().await,
prepared_metrics(),
state,
)
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(batches.len(), 3);
assert_eq!(*prepared_partitions.lock().unwrap(), vec![0, 1, 2]);
assert_eq!(*searched_partitions.lock().unwrap(), vec![0, 1, 2]);
let search_threads = search_threads.lock().unwrap().clone();
assert_eq!(search_threads.len(), 3);
assert!(
search_threads.iter().all(|name| name.contains("lance-cpu")),
"expected prepared searches to run on the cpu runtime, got threads {search_threads:?}",
);
assert!(
search_threads.iter().all(|name| name == &search_threads[0]),
"expected all prepared searches to reuse one cpu thread, got threads {search_threads:?}",
);
}
#[tokio::test]
async fn test_sequential_late_search_prepares_all_then_stops_search_early() {
let (index, prepared_partitions, searched_partitions, _search_threads) =
prepared_index(vec![21, 22, 23]);
let mut query = base_query();
query.k = 2;
query.minimum_nprobes = 0;
query.maximum_nprobes = Some(3);
let state = Arc::new(ANNIvfEarlySearchResults::new(1, query.k));
state.record_batch(
&RecordBatch::try_new(
KNN_INDEX_SCHEMA.clone(),
vec![
Arc::new(Float32Array::from(vec![0.0])),
Arc::new(UInt64Array::from(vec![999])),
],
)
.unwrap(),
);
let batches = ANNIvfSubIndexExec::late_search(
index,
query,
Arc::new(UInt32Array::from(vec![0, 1, 2])),
Arc::new(Float32Array::from(vec![0.1, 0.2, 0.3])),
empty_prefilter().await,
prepared_metrics(),
state.clone(),
)
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(batches.len(), 1);
assert_eq!(*prepared_partitions.lock().unwrap(), vec![0, 1, 2]);
assert_eq!(*searched_partitions.lock().unwrap(), vec![0]);
assert_eq!(state.num_results_found.load(Ordering::Relaxed), 2);
}
#[tokio::test]
async fn knn_flat_search() {
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("key", DataType::Int32, false),
ArrowField::new(
"vector",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
128,
),
true,
),
ArrowField::new("uri", DataType::Utf8, true),
]));
let batches: Vec<RecordBatch> = (0..20)
.map(|i| {
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(i * 20..(i + 1) * 20)),
Arc::new(
FixedSizeListArray::try_new_from_values(
generate_random_array(128 * 20),
128,
)
.unwrap(),
),
Arc::new(StringArray::from_iter_values(
(i * 20..(i + 1) * 20).map(|i| format!("s3://bucket/file-{}", i)),
)),
],
)
.unwrap()
})
.collect();
let test_dir = TempStrDir::default();
let test_uri = test_dir.as_str();
let write_params = WriteParams {
max_rows_per_file: 40,
max_rows_per_group: 10,
..Default::default()
};
let vector_arr = batches[0].column_by_name("vector").unwrap();
let q = as_fixed_size_list_array(&vector_arr).value(5);
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());
Dataset::write(reader, test_uri, Some(write_params))
.await
.unwrap();
let dataset = Dataset::open(test_uri).await.unwrap();
let stream = dataset
.scan()
.nearest("vector", q.as_primitive::<Float32Type>(), 10)
.unwrap()
.try_into_stream()
.await
.unwrap();
let results = stream.try_collect::<Vec<_>>().await.unwrap();
assert!(results[0].schema().column_with_name(DIST_COL).is_some());
assert_eq!(results.len(), 1);
let stream = dataset.scan().try_into_stream().await.unwrap();
let all_with_distances = stream
.and_then(|batch| compute_distance(q.clone(), DistanceType::L2, "vector", batch))
.try_collect::<Vec<_>>()
.await
.unwrap();
let all_with_distances =
concat_batches(&results[0].schema(), all_with_distances.iter()).unwrap();
let dist_arr = all_with_distances.column_by_name(DIST_COL).unwrap();
let distances = dist_arr.as_primitive::<Float32Type>();
let indices = sort_to_indices(distances, None, Some(10)).unwrap();
let expected = take_record_batch(&all_with_distances, &indices).unwrap();
assert_eq!(expected, results[0]);
}
#[test]
fn test_create_knn_flat() {
let dim: usize = 128;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("key", DataType::Int32, false),
ArrowField::new(
"vector",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
dim as i32,
),
true,
),
ArrowField::new("uri", DataType::Utf8, true),
]));
let batch = RecordBatch::new_empty(schema);
let input: Arc<dyn ExecutionPlan> = Arc::new(TestingExec::new(vec![batch]));
let idx = KNNVectorDistanceExec::try_new(
input,
"vector",
Arc::new(generate_random_array(dim)),
DistanceType::L2,
)
.unwrap();
assert_eq!(
idx.schema().as_ref(),
&ArrowSchema::new(vec![
ArrowField::new("key", DataType::Int32, false),
ArrowField::new(
"vector",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
dim as i32,
),
true,
),
ArrowField::new("uri", DataType::Utf8, true),
ArrowField::new(DIST_COL, DataType::Float32, true),
])
);
}
#[tokio::test]
async fn test_multivector_score() {
let query = Query {
column: "vector".to_string(),
key: Arc::new(generate_random_array(1)),
k: 10,
lower_bound: None,
upper_bound: None,
minimum_nprobes: 1,
maximum_nprobes: None,
ef: None,
refine_factor: None,
metric_type: Some(DistanceType::Cosine),
use_index: true,
query_parallelism: DEFAULT_QUERY_PARALLELISM,
dist_q_c: 0.0,
};
async fn multivector_scoring(
inputs: Vec<Arc<dyn ExecutionPlan>>,
query: Query,
) -> Result<HashMap<u64, f32>> {
let ctx = Arc::new(datafusion::execution::context::TaskContext::default());
let plan = MultivectorScoringExec::try_new(inputs, query.clone())?;
let batches = plan
.execute(0, ctx.clone())
.unwrap()
.try_collect::<Vec<_>>()
.await?;
let mut results = HashMap::new();
for batch in batches {
let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>();
let dists = batch[DIST_COL].as_primitive::<Float32Type>();
for (row_id, dist) in row_ids.values().iter().zip(dists.values().iter()) {
results.insert(*row_id, *dist);
}
}
Ok(results)
}
let batches = (0..3)
.map(|i| {
RecordBatch::try_new(
KNN_INDEX_SCHEMA.clone(),
vec![
Arc::new(Float32Array::from(vec![i as f32 + 1.0, i as f32 + 2.0])),
Arc::new(UInt64Array::from(vec![i + 1, i + 2])),
],
)
.unwrap()
})
.collect::<Vec<_>>();
let mut res: Option<HashMap<_, _>> = None;
for perm in batches.into_iter().permutations(3) {
let inputs = perm
.into_iter()
.map(|batch| {
let input: Arc<dyn ExecutionPlan> = Arc::new(TestingExec::new(vec![batch]));
input
})
.collect::<Vec<_>>();
let new_res = multivector_scoring(inputs, query.clone()).await.unwrap();
assert_eq!(new_res.len(), 4);
if let Some(res) = &res {
for (row_id, dist) in new_res.iter() {
assert_eq!(res.get(row_id).unwrap(), dist)
}
} else {
res = Some(new_res);
}
}
}
struct NprobesTestFixture {
dataset: Dataset,
centroids: Arc<dyn Array>,
_tmp_dir: TempStrDir,
}
impl NprobesTestFixture {
pub async fn new(num_centroids: usize, num_deltas: usize) -> Self {
let tempdir = TempStrDir::default();
let tmppath = tempdir.as_str();
let centroids = array::cycle_unit_circle(num_centroids as u32)
.generate_default(RowCount::from(num_centroids as u64))
.unwrap();
assert!(100 % num_deltas == 0, "num_deltas must divide 100");
let rows_per_frag = 100;
let num_frags = 100;
let frags_per_delta = num_frags / num_deltas;
let batches = lance_datagen::gen_batch()
.col("vector", array::jitter_centroids(centroids.clone(), 0.0001))
.col("label", array::cycle::<UInt32Type>(Vec::from_iter(0..61)))
.col("userid", array::step::<UInt64Type>())
.into_reader_rows(
RowCount::from(rows_per_frag),
BatchCount::from(num_frags as u32),
)
.collect::<Vec<_>>();
let schema = batches[0].as_ref().unwrap().schema();
let mut first = true;
for batches in batches.chunks(frags_per_delta) {
let delta_batches = batches
.iter()
.map(|maybe_batch| Ok(maybe_batch.as_ref().unwrap().clone()))
.collect::<Vec<_>>();
let reader = RecordBatchIterator::new(delta_batches, schema.clone());
let mut dataset = Dataset::write(
reader,
tmppath,
Some(WriteParams {
mode: WriteMode::Append,
..Default::default()
}),
)
.await
.unwrap();
let ivf_params = IvfBuildParams::try_with_centroids(
num_centroids,
Arc::new(centroids.as_fixed_size_list().clone()),
)
.unwrap();
let codebook = array::rand::<Float32Type>()
.generate_default(RowCount::from(256 * 2))
.unwrap();
let pq_params = PQBuildParams::with_codebook(2, 8, codebook);
let index_params =
VectorIndexParams::with_ivf_pq_params(MetricType::L2, ivf_params, pq_params);
if first {
first = false;
dataset
.create_index(&["vector"], IndexType::Vector, None, &index_params, false)
.await
.unwrap();
} else {
dataset
.optimize_indices(&OptimizeOptions::append())
.await
.unwrap();
}
}
let dataset = Dataset::open(tmppath).await.unwrap();
Self {
dataset,
centroids,
_tmp_dir: tempdir,
}
}
pub fn get_centroid(&self, idx: usize) -> Arc<dyn Array> {
let centroids = self.centroids.as_fixed_size_list();
centroids.value(idx).clone()
}
}
#[derive(Default)]
struct StatsHolder {
pub collected_stats: Arc<Mutex<Option<ExecutionSummaryCounts>>>,
}
impl StatsHolder {
fn get_setter(&self) -> ExecutionStatsCallback {
let collected_stats = self.collected_stats.clone();
Arc::new(move |stats| {
*collected_stats.lock().unwrap() = Some(stats.clone());
})
}
fn consume(self) -> ExecutionSummaryCounts {
self.collected_stats.lock().unwrap().take().unwrap()
}
}
fn assert_find_partitions_elapsed_recorded(stats: &ExecutionSummaryCounts) {
assert!(
stats
.all_times
.get(FIND_PARTITIONS_ELAPSED_METRIC)
.copied()
.unwrap_or_default()
> 0
);
}
#[rstest]
#[tokio::test]
async fn test_no_max_nprobes(#[values(1, 20)] num_deltas: usize) {
let fixture = NprobesTestFixture::new(100, num_deltas).await;
let q = fixture.get_centroid(0);
let stats_holder = StatsHolder::default();
let results = fixture
.dataset
.scan()
.nearest("vector", q.as_ref(), 50)
.unwrap()
.minimum_nprobes(10)
.prefilter(true)
.scan_stats_callback(stats_holder.get_setter())
.filter("label = 17")
.unwrap()
.project(&Vec::<String>::new())
.unwrap()
.with_row_id()
.try_into_batch()
.await
.unwrap();
assert_eq!(results.num_rows(), 50);
let stats = stats_holder.consume();
if get_num_compute_intensive_cpus() <= 32 {
assert!(*stats.all_counts.get(PARTITIONS_SEARCHED_METRIC).unwrap() < 100 * num_deltas);
}
assert_find_partitions_elapsed_recorded(&stats);
}
#[rstest]
#[tokio::test]
async fn test_no_prefilter_results(#[values(1, 20)] num_deltas: usize) {
let fixture = NprobesTestFixture::new(100, num_deltas).await;
let q = fixture.get_centroid(0);
let stats_holder = StatsHolder::default();
let results = fixture
.dataset
.scan()
.nearest("vector", q.as_ref(), 50)
.unwrap()
.minimum_nprobes(10)
.prefilter(true)
.scan_stats_callback(stats_holder.get_setter())
.filter("label = 17 AND label = 18")
.unwrap()
.project(&Vec::<String>::new())
.unwrap()
.with_row_id()
.try_into_batch()
.await
.unwrap();
assert_eq!(results.num_rows(), 0);
let stats = stats_holder.consume();
assert_eq!(
stats.all_counts.get(PARTITIONS_SEARCHED_METRIC).unwrap(),
&(10 * num_deltas)
);
assert_find_partitions_elapsed_recorded(&stats);
}
#[rstest]
#[tokio::test]
async fn test_some_max_nprobes(#[values(1, 20)] num_deltas: usize) {
let fixture = NprobesTestFixture::new(100, num_deltas).await;
for (max_nprobes, expected_results) in [(10, 16), (20, 33), (30, 48)] {
let q = fixture.get_centroid(0);
let stats_holder = StatsHolder::default();
let results = fixture
.dataset
.scan()
.nearest("vector", q.as_ref(), 50)
.unwrap()
.minimum_nprobes(max_nprobes)
.maximum_nprobes(max_nprobes)
.prefilter(true)
.filter("label = 17")
.unwrap()
.scan_stats_callback(stats_holder.get_setter())
.project(&Vec::<String>::new())
.unwrap()
.with_row_id()
.try_into_batch()
.await
.unwrap();
let stats = stats_holder.consume();
assert_eq!(results.num_rows(), expected_results);
assert_eq!(
stats.all_counts.get(PARTITIONS_SEARCHED_METRIC).unwrap(),
&(max_nprobes * num_deltas)
);
assert_eq!(
stats.all_counts.get(PARTITIONS_RANKED_METRIC).unwrap(),
&(100 * num_deltas)
);
assert_find_partitions_elapsed_recorded(&stats);
}
}
#[rstest]
#[tokio::test]
async fn test_fewer_than_k_results(#[values(1, 20)] num_deltas: usize) {
let fixture = NprobesTestFixture::new(100, num_deltas).await;
let q = fixture.get_centroid(0);
let stats_holder = StatsHolder::default();
let results = fixture
.dataset
.scan()
.nearest("vector", q.as_ref(), 50)
.unwrap()
.minimum_nprobes(10)
.prefilter(true)
.filter("userid < 20")
.unwrap()
.scan_stats_callback(stats_holder.get_setter())
.project(&Vec::<String>::new())
.unwrap()
.with_row_id()
.try_into_batch()
.await
.unwrap();
let stats = stats_holder.consume();
assert_eq!(
stats.all_counts.get(PARTITIONS_SEARCHED_METRIC).unwrap(),
&(10 * num_deltas)
);
assert_find_partitions_elapsed_recorded(&stats);
assert_eq!(results.num_rows(), 20);
let num_infinite_results = results
.column(0)
.as_primitive::<Float32Type>()
.values()
.iter()
.filter(|val| val.is_infinite())
.count();
assert_eq!(num_infinite_results, 15);
let results = fixture
.dataset
.scan()
.nearest("vector", q.as_ref(), 50)
.unwrap()
.minimum_nprobes(10)
.prefilter(true)
.refine(1)
.filter("userid < 20")
.unwrap()
.project(&Vec::<String>::new())
.unwrap()
.with_row_id()
.try_into_batch()
.await
.unwrap();
assert_eq!(results.num_rows(), 20);
let num_infinite_results = results
.column(0)
.as_primitive::<Float32Type>()
.values()
.iter()
.filter(|val| val.is_infinite())
.count();
assert_eq!(num_infinite_results, 0);
}
#[rstest]
#[tokio::test]
async fn test_dataset_too_small(#[values(1, 20)] num_deltas: usize) {
let fixture = NprobesTestFixture::new(100, num_deltas).await;
let q = fixture.get_centroid(0);
let stats_holder = StatsHolder::default();
let results = fixture
.dataset
.scan()
.nearest("vector", q.as_ref(), 40000)
.unwrap()
.minimum_nprobes(10)
.scan_stats_callback(stats_holder.get_setter())
.project(&Vec::<String>::new())
.unwrap()
.with_row_id()
.try_into_batch()
.await
.unwrap();
let stats = stats_holder.consume();
assert_eq!(
stats.all_counts.get(PARTITIONS_SEARCHED_METRIC).unwrap(),
&(100 * num_deltas)
);
assert_find_partitions_elapsed_recorded(&stats);
assert_eq!(results.num_rows(), 10000);
}
}