use std::ops::Range;
use std::pin::Pin;
use std::sync::{Arc, LazyLock};
use std::task::{Context, Poll};
use arrow::array::AsArray;
use arrow_array::{Array, Float32Array, Int64Array, RecordBatch};
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema, SchemaRef, SortOptions};
use arrow_select::concat::concat_batches;
use async_recursion::async_recursion;
use chrono::Utc;
use datafusion::common::{DFSchema, JoinType, NullEquality, SchemaExt, exec_datafusion_err};
use datafusion::functions_aggregate;
use datafusion::logical_expr::{Expr, ScalarUDF, col, lit};
use datafusion::physical_expr::PhysicalSortExpr;
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion::physical_plan::expressions;
use datafusion::physical_plan::projection::ProjectionExec as DFProjectionExec;
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::{
ExecutionPlan, SendableRecordBatchStream,
aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy},
display::DisplayableExecutionPlan,
limit::GlobalLimitExec,
repartition::RepartitionExec,
union::UnionExec,
};
use datafusion::scalar::ScalarValue;
use datafusion_expr::ExprSchemable;
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_functions::core::getfield::GetFieldFunc;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::{LexOrdering, Partitioning, PhysicalExpr, create_physical_expr};
use datafusion_physical_plan::joins::PartitionMode;
use datafusion_physical_plan::projection::ProjectionExec;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{empty::EmptyExec, joins::HashJoinExec};
use futures::future::BoxFuture;
use futures::stream::{Stream, StreamExt};
use futures::{FutureExt, TryStreamExt};
use lance_arrow::floats::{FloatType, coerce_float_vector};
use lance_arrow::{DataTypeExt, SchemaExt as ArrowSchemaExt};
use lance_core::datatypes::{
BlobHandling, Field, OnMissing, Projection, escape_field_path_for_project, format_field_path,
};
use lance_core::error::LanceOptionExt;
use lance_core::utils::address::RowAddress;
use lance_core::utils::mask::{RowAddrMask, RowAddrTreeMap};
use lance_core::utils::tokio::get_num_compute_intensive_cpus;
use lance_core::{ROW_ADDR, ROW_ID, ROW_OFFSET};
use lance_datafusion::aggregate::Aggregate;
use lance_datafusion::exec::{
LanceExecutionOptions, OneShotExec, StrictBatchSizeExec, analyze_plan, execute_plan,
};
use lance_datafusion::expr::safe_coerce_scalar;
use lance_datafusion::projection::ProjectionPlan;
use lance_file::reader::FileReaderOptions;
use lance_index::IndexCriteria;
use lance_index::scalar::FullTextSearchQuery;
use lance_index::scalar::expression::{INDEX_EXPR_RESULT_SCHEMA, IndexExprResult, PlannerIndexExt};
use lance_index::scalar::inverted::query::{
FtsQuery, FtsQueryNode, FtsSearchParams, MatchQuery, PhraseQuery, fill_fts_query_column,
};
use lance_index::scalar::inverted::{SCORE_COL, SCORE_FIELD};
use lance_index::vector::{DIST_COL, Query};
use lance_index::{DatasetIndexExt, scalar::expression::ScalarIndexExpr};
use lance_index::{metrics::NoOpMetricsCollector, scalar::inverted::FTS_SCHEMA};
use lance_io::stream::RecordBatchStream;
use lance_linalg::distance::MetricType;
use lance_table::format::{Fragment, IndexMetadata};
use roaring::RoaringBitmap;
use tracing::{Span, info_span, instrument};
use super::Dataset;
use crate::dataset::row_offsets_to_row_addresses;
use crate::dataset::utils::SchemaAdapter;
use crate::index::DatasetIndexInternalExt;
use crate::index::vector::utils::{
default_distance_type_for, get_vector_dim, get_vector_type, validate_distance_type_for,
};
use crate::io::exec::filtered_read::{FilteredReadExec, FilteredReadOptions};
use crate::io::exec::fts::{
BoostQueryExec, FlatMatchFilterExec, FlatMatchQueryExec, MatchQueryExec, PhraseQueryExec,
};
use crate::io::exec::knn::MultivectorScoringExec;
use crate::io::exec::scalar_index::{MaterializeIndexExec, ScalarIndexExec};
use crate::io::exec::{
AddRowAddrExec, FilterPlan as ExprFilterPlan, KNNVectorDistanceExec, LancePushdownScanExec,
LanceScanExec, Planner, PreFilterSource, ScanConfig, TakeExec,
knn::{KNN_INDEX_SCHEMA, new_knn_exec},
project,
};
use crate::io::exec::{AddRowOffsetExec, LanceFilterExec, LanceScanConfig, get_physical_optimizer};
use crate::{Error, Result};
use crate::{datatypes::Schema, io::exec::fts::BooleanQueryExec};
pub use lance_datafusion::exec::{ExecutionStatsCallback, ExecutionSummaryCounts};
#[cfg(feature = "substrait")]
use lance_datafusion::substrait::parse_substrait;
pub(crate) const BATCH_SIZE_FALLBACK: usize = 8192;
fn parse_env_var<T: std::str::FromStr>(env_var_name: &str, default_val: &str) -> Option<T>
where
T::Err: std::fmt::Display,
{
std::env::var(env_var_name)
.ok()
.and_then(|val| match val.parse() {
Ok(value) => Some(value),
Err(e) => {
log::warn!(
"Failed to parse the environment variable {}='{}': {}, the default value is: {}.",
env_var_name,
val,
e,
default_val
);
None
}
})
}
pub fn get_default_batch_size() -> Option<usize> {
parse_env_var("LANCE_DEFAULT_BATCH_SIZE", &BATCH_SIZE_FALLBACK.to_string())
}
pub const LEGACY_DEFAULT_FRAGMENT_READAHEAD: usize = 4;
pub static DEFAULT_FRAGMENT_READAHEAD: LazyLock<Option<usize>> = LazyLock::new(|| {
parse_env_var(
"LANCE_DEFAULT_FRAGMENT_READAHEAD",
&LEGACY_DEFAULT_FRAGMENT_READAHEAD.to_string(),
)
});
const DEFAULT_XTR_OVERFETCH_VALUE: u32 = 10;
pub static DEFAULT_XTR_OVERFETCH: LazyLock<u32> = LazyLock::new(|| {
parse_env_var(
"LANCE_XTR_OVERFETCH",
&DEFAULT_XTR_OVERFETCH_VALUE.to_string(),
)
.unwrap_or(DEFAULT_XTR_OVERFETCH_VALUE)
});
const DEFAULT_IO_BUFFER_SIZE_VALUE: u64 = 2 * 1024 * 1024 * 1024;
pub static DEFAULT_IO_BUFFER_SIZE: LazyLock<u64> = LazyLock::new(|| {
parse_env_var(
"LANCE_DEFAULT_IO_BUFFER_SIZE",
&DEFAULT_IO_BUFFER_SIZE_VALUE.to_string(),
)
.unwrap_or(DEFAULT_IO_BUFFER_SIZE_VALUE)
});
#[derive(Debug, Clone)]
pub struct ColumnOrdering {
pub ascending: bool,
pub nulls_first: bool,
pub column_name: String,
}
impl ColumnOrdering {
pub fn asc_nulls_first(column_name: String) -> Self {
Self {
ascending: true,
nulls_first: true,
column_name,
}
}
pub fn asc_nulls_last(column_name: String) -> Self {
Self {
ascending: true,
nulls_first: false,
column_name,
}
}
pub fn desc_nulls_first(column_name: String) -> Self {
Self {
ascending: false,
nulls_first: true,
column_name,
}
}
pub fn desc_nulls_last(column_name: String) -> Self {
Self {
ascending: false,
nulls_first: false,
column_name,
}
}
}
#[derive(Clone)]
pub enum MaterializationStyle {
Heuristic,
AllLate,
AllEarly,
AllEarlyExcept(Vec<u32>),
}
impl MaterializationStyle {
pub fn all_early_except(columns: &[impl AsRef<str>], schema: &Schema) -> Result<Self> {
let field_ids = schema
.project(columns)?
.field_ids()
.into_iter()
.map(|id| id as u32)
.collect();
Ok(Self::AllEarlyExcept(field_ids))
}
}
#[derive(Debug)]
struct PlannedFilteredScan {
plan: Arc<dyn ExecutionPlan>,
limit_pushed_down: bool,
filter_pushed_down: bool,
}
pub struct FilterPlan {
query_filter: Option<QueryFilter>,
refine_query_filter: bool,
expr_filter_plan: ExprFilterPlan,
}
impl FilterPlan {
pub fn new(query_filter: Option<QueryFilter>, expr_filter_plan: ExprFilterPlan) -> Self {
Self {
query_filter,
refine_query_filter: false,
expr_filter_plan,
}
}
pub fn disable_refine(&mut self) {
self.expr_filter_plan = ExprFilterPlan::default();
self.refine_query_filter = false;
}
pub fn make_refine_only(&mut self) {
self.expr_filter_plan.make_refine_only();
self.refine_query_filter = true;
}
pub fn fts_filter(&self) -> Option<FullTextSearchQuery> {
match &self.query_filter {
Some(QueryFilter::Fts(query)) => Some(query.clone()),
_ => None,
}
}
pub fn vector_filter(&self) -> Option<Query> {
match &self.query_filter {
Some(QueryFilter::Vector(query)) => Some(query.clone()),
_ => None,
}
}
pub fn has_refine(&self) -> bool {
self.expr_filter_plan.has_refine() || self.refine_query_filter
}
pub async fn refine_columns(&self, dataset: &Arc<Dataset>) -> Result<Vec<String>> {
let mut columns = vec![];
if self.expr_filter_plan.has_refine() {
columns.extend(self.expr_filter_plan.refine_columns());
}
if self.refine_query_filter {
match &self.query_filter {
Some(QueryFilter::Fts(fts_query)) => {
let cols = if fts_query.columns().is_empty() {
let indexed_columns = fts_indexed_columns(dataset.clone()).await?;
let q = fill_fts_query_column(&fts_query.query, &indexed_columns, false)?;
q.columns()
} else {
fts_query.columns()
};
if let FtsQuery::Match(_) = &fts_query.query {
columns.extend(cols.iter().cloned().collect::<Vec<_>>());
}
}
Some(QueryFilter::Vector(vector_query)) => {
columns.push(vector_query.column.clone());
}
None => {}
}
}
Ok(columns)
}
pub async fn refine_filter(
&self,
input: Arc<dyn ExecutionPlan>,
scanner: &Scanner,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut plan = input;
if self.refine_query_filter {
match &self.query_filter {
Some(QueryFilter::Fts(fts_query)) => {
plan = scanner.flat_fts_filter(plan, fts_query).await?;
}
Some(QueryFilter::Vector(vector_query)) => {
plan = scanner.flat_knn(plan, vector_query)?;
}
None => {}
}
}
if let Some(refine_expr) = &self.expr_filter_plan.refine_expr {
plan = Arc::new(LanceFilterExec::try_new(refine_expr.clone(), plan)?);
}
Ok(plan)
}
}
#[derive(Debug, Clone, Default)]
pub struct LanceFilter {
query_filter: Option<QueryFilter>,
expr_filter: Option<ExprFilter>,
}
impl LanceFilter {
pub fn is_none(&self) -> bool {
self.query_filter.is_none() && self.expr_filter.is_none()
}
}
#[derive(Debug, Clone)]
pub enum QueryFilter {
Fts(FullTextSearchQuery),
Vector(Query),
}
#[derive(Debug, Clone)]
pub enum ExprFilter {
Sql(String),
Substrait(Vec<u8>),
Datafusion(Expr),
}
impl ExprFilter {
#[allow(unused)]
#[instrument(level = "trace", name = "filter_to_df", skip_all)]
pub fn to_datafusion(&self, dataset_schema: &Schema, full_schema: &Schema) -> Result<Expr> {
match self {
Self::Sql(sql) => {
let schema = Arc::new(ArrowSchema::from(full_schema));
let planner = Planner::new(schema.clone());
let filter = planner.parse_filter(sql)?;
let df_schema = DFSchema::try_from(schema)?;
let ret_field = filter.to_field(&df_schema)?.1;
let ret_type = ret_field.data_type();
if ret_type != &DataType::Boolean {
return Err(Error::invalid_input_source(
format!("The filter {} does not return a boolean", filter).into(),
));
}
let optimized = planner.optimize_expr(filter).map_err(|e| {
Error::invalid_input(format!("Error optimizing sql filter: {sql} ({e})"))
})?;
Ok(optimized)
}
#[cfg(feature = "substrait")]
Self::Substrait(expr) => {
use lance_datafusion::exec::{LanceExecutionOptions, get_session_context};
let ctx = get_session_context(&LanceExecutionOptions::default());
let state = ctx.state();
let schema = Arc::new(ArrowSchema::from(dataset_schema));
let expr = parse_substrait(expr, schema.clone(), &ctx.state())
.now_or_never()
.expect("could not parse the Substrait filter in a synchronous fashion")?;
let planner = Planner::new(schema);
planner.optimize_expr(expr.clone()).map_err(|e| {
Error::invalid_input(format!(
"Error optimizing substrait filter: {expr:?} ({e})"
))
})
}
#[cfg(not(feature = "substrait"))]
Self::Substrait(_) => Err(Error::not_supported_source(
"Substrait filter is not supported in this build".into(),
)),
Self::Datafusion(expr) => Ok(expr.clone()),
}
}
}
#[derive(Debug, Clone)]
pub enum AggregateExpr {
#[cfg(feature = "substrait")]
Substrait(Vec<u8>),
Datafusion {
group_by: Vec<Expr>,
aggregates: Vec<Expr>,
},
}
impl AggregateExpr {
pub fn builder() -> AggregateExprBuilder<false> {
AggregateExprBuilder::new()
}
#[cfg(feature = "substrait")]
pub fn substrait(bytes: impl Into<Vec<u8>>) -> Self {
Self::Substrait(bytes.into())
}
pub fn datafusion(group_by: Vec<Expr>, aggregates: Vec<Expr>) -> Self {
Self::Datafusion {
group_by,
aggregates,
}
}
fn parse(self, #[allow(unused_variables)] schema: Arc<ArrowSchema>) -> Result<Aggregate> {
match self {
#[cfg(feature = "substrait")]
Self::Substrait(bytes) => {
use lance_datafusion::exec::{LanceExecutionOptions, get_session_context};
use lance_datafusion::substrait::parse_substrait_aggregate;
let ctx = get_session_context(&LanceExecutionOptions::default());
parse_substrait_aggregate(&bytes, schema, &ctx.state())
.now_or_never()
.expect("could not parse the Substrait aggregate in a synchronous fashion")
}
Self::Datafusion {
group_by,
aggregates,
} => Ok(Aggregate::new(group_by, aggregates)),
}
}
}
#[derive(Debug, Clone)]
pub struct AggregateExprBuilder<const HAS_PENDING: bool> {
group_by: Vec<Expr>,
aggregates: Vec<Expr>,
}
impl Default for AggregateExprBuilder<false> {
fn default() -> Self {
Self {
group_by: Vec::new(),
aggregates: Vec::new(),
}
}
}
impl AggregateExprBuilder<false> {
pub fn new() -> Self {
Self::default()
}
pub fn build(self) -> AggregateExpr {
AggregateExpr::Datafusion {
group_by: self.group_by,
aggregates: self.aggregates,
}
}
}
impl<const HAS_PENDING: bool> AggregateExprBuilder<HAS_PENDING> {
pub fn group_by(mut self, column: impl Into<String>) -> AggregateExprBuilder<false> {
self.group_by.push(col(column.into()));
AggregateExprBuilder {
group_by: self.group_by,
aggregates: self.aggregates,
}
}
pub fn group_by_columns(
mut self,
columns: impl IntoIterator<Item = impl Into<String>>,
) -> AggregateExprBuilder<false> {
for column in columns {
self.group_by.push(col(column.into()));
}
AggregateExprBuilder {
group_by: self.group_by,
aggregates: self.aggregates,
}
}
pub fn count_star(mut self) -> AggregateExprBuilder<true> {
self.aggregates
.push(functions_aggregate::count::count(lit(1)));
AggregateExprBuilder {
group_by: self.group_by,
aggregates: self.aggregates,
}
}
pub fn count(mut self, column: impl Into<String>) -> AggregateExprBuilder<true> {
self.aggregates
.push(functions_aggregate::count::count(col(column.into())));
AggregateExprBuilder {
group_by: self.group_by,
aggregates: self.aggregates,
}
}
pub fn sum(mut self, column: impl Into<String>) -> AggregateExprBuilder<true> {
self.aggregates
.push(functions_aggregate::sum::sum(col(column.into())));
AggregateExprBuilder {
group_by: self.group_by,
aggregates: self.aggregates,
}
}
pub fn avg(mut self, column: impl Into<String>) -> AggregateExprBuilder<true> {
self.aggregates
.push(functions_aggregate::average::avg(col(column.into())));
AggregateExprBuilder {
group_by: self.group_by,
aggregates: self.aggregates,
}
}
pub fn min(mut self, column: impl Into<String>) -> AggregateExprBuilder<true> {
self.aggregates
.push(functions_aggregate::min_max::min(col(column.into())));
AggregateExprBuilder {
group_by: self.group_by,
aggregates: self.aggregates,
}
}
pub fn max(mut self, column: impl Into<String>) -> AggregateExprBuilder<true> {
self.aggregates
.push(functions_aggregate::min_max::max(col(column.into())));
AggregateExprBuilder {
group_by: self.group_by,
aggregates: self.aggregates,
}
}
}
impl AggregateExprBuilder<true> {
pub fn alias(mut self, name: impl Into<String>) -> AggregateExprBuilder<false> {
let pending = self.aggregates.pop().expect("pending aggregate must exist");
self.aggregates.push(pending.alias(name.into()));
AggregateExprBuilder {
group_by: self.group_by,
aggregates: self.aggregates,
}
}
pub fn build(self) -> AggregateExpr {
AggregateExpr::Datafusion {
group_by: self.group_by,
aggregates: self.aggregates,
}
}
}
#[derive(Clone)]
pub struct Scanner {
dataset: Arc<Dataset>,
projection_plan: ProjectionPlan,
blob_handling: BlobHandling,
prefilter: bool,
materialization_style: MaterializationStyle,
filter: LanceFilter,
full_text_query: Option<FullTextSearchQuery>,
batch_size: Option<usize>,
batch_readahead: usize,
fragment_readahead: Option<usize>,
io_buffer_size: Option<u64>,
limit: Option<i64>,
offset: Option<i64>,
ordering: Option<Vec<ColumnOrdering>>,
nearest: Option<Query>,
use_scalar_index: bool,
use_stats: bool,
ordered: bool,
fragments: Option<Vec<Fragment>>,
fast_search: bool,
include_deleted_rows: bool,
scan_stats_callback: Option<ExecutionStatsCallback>,
strict_batch_size: bool,
file_reader_options: Option<FileReaderOptions>,
aggregate: Option<Aggregate>,
legacy_with_row_id: bool,
legacy_with_row_addr: bool,
explicit_projection: bool,
autoproject_scoring_columns: bool,
}
#[derive(Debug, Clone)]
pub enum TakeOperation {
RowIds(Vec<u64>),
RowAddrs(Vec<u64>),
RowOffsets(Vec<u64>),
}
impl TakeOperation {
fn extract_u64_list(list: &[Expr]) -> Option<Vec<u64>> {
let mut u64s = Vec::with_capacity(list.len());
for expr in list {
if let Expr::Literal(lit, _) = expr {
if let Some(ScalarValue::UInt64(Some(val))) =
safe_coerce_scalar(lit, &DataType::UInt64)
{
u64s.push(val);
} else {
return None;
}
} else {
return None;
}
}
Some(u64s)
}
fn merge(self, other: Self) -> Option<Self> {
match (self, other) {
(Self::RowIds(mut left), Self::RowIds(right)) => {
left.extend(right);
Some(Self::RowIds(left))
}
(Self::RowAddrs(mut left), Self::RowAddrs(right)) => {
left.extend(right);
Some(Self::RowAddrs(left))
}
(Self::RowOffsets(mut left), Self::RowOffsets(right)) => {
left.extend(right);
Some(Self::RowOffsets(left))
}
_ => None,
}
}
fn try_from_expr(expr: &Expr) -> Option<(Self, Option<Expr>)> {
if let Expr::BinaryExpr(binary) = expr {
match binary.op {
datafusion_expr::Operator::And => {
let left_take = Self::try_from_expr(&binary.left);
let right_take = Self::try_from_expr(&binary.right);
match (left_take, right_take) {
(Some(_), Some(_)) => {
return None;
}
(Some((left_op, left_rem)), None) => {
let remainder = match left_rem {
Some(expr) => Expr::and(expr, binary.right.as_ref().clone()),
None => binary.right.as_ref().clone(),
};
return Some((left_op, Some(remainder)));
}
(None, Some((right_op, right_rem))) => {
let remainder = match right_rem {
Some(expr) => Expr::and(expr, binary.left.as_ref().clone()),
None => binary.left.as_ref().clone(),
};
return Some((right_op, Some(remainder)));
}
(None, None) => {
return None;
}
}
}
datafusion_expr::Operator::Eq => {
if let (Expr::Column(col), Expr::Literal(lit, _)) =
(binary.left.as_ref(), binary.right.as_ref())
&& let Some(ScalarValue::UInt64(Some(val))) =
safe_coerce_scalar(lit, &DataType::UInt64)
{
if col.name == ROW_ID {
return Some((Self::RowIds(vec![val]), None));
} else if col.name == ROW_ADDR {
return Some((Self::RowAddrs(vec![val]), None));
} else if col.name == ROW_OFFSET {
return Some((Self::RowOffsets(vec![val]), None));
}
}
}
datafusion_expr::Operator::Or => {
let left_take = Self::try_from_expr(&binary.left);
let right_take = Self::try_from_expr(&binary.right);
if let (Some(left), Some(right)) = (left_take, right_take) {
if left.1.is_some() || right.1.is_some() {
return None;
}
return left.0.merge(right.0).map(|op| (op, None));
}
}
_ => {}
}
} else if let Expr::InList(in_expr) = expr
&& let Expr::Column(col) = in_expr.expr.as_ref()
&& let Some(u64s) = Self::extract_u64_list(&in_expr.list)
{
if col.name == ROW_ID {
return Some((Self::RowIds(u64s), None));
} else if col.name == ROW_ADDR {
return Some((Self::RowAddrs(u64s), None));
} else if col.name == ROW_OFFSET {
return Some((Self::RowOffsets(u64s), None));
}
}
None
}
}
impl Scanner {
pub fn new(dataset: Arc<Dataset>) -> Self {
let projection_plan = ProjectionPlan::full(dataset.clone()).unwrap();
let file_reader_options = dataset.file_reader_options.clone();
let mut scanner = Self {
dataset,
projection_plan,
blob_handling: BlobHandling::default(),
prefilter: false,
materialization_style: MaterializationStyle::Heuristic,
filter: LanceFilter::default(),
full_text_query: None,
batch_size: None,
batch_readahead: get_num_compute_intensive_cpus(),
fragment_readahead: None,
io_buffer_size: None,
limit: None,
offset: None,
ordering: None,
nearest: None,
use_stats: true,
ordered: true,
fragments: None,
fast_search: false,
use_scalar_index: true,
include_deleted_rows: false,
scan_stats_callback: None,
strict_batch_size: false,
file_reader_options,
aggregate: None,
legacy_with_row_addr: false,
legacy_with_row_id: false,
explicit_projection: false,
autoproject_scoring_columns: true,
};
scanner.apply_blob_handling();
scanner
}
fn apply_blob_handling(&mut self) {
let projection = self
.projection_plan
.physical_projection
.clone()
.with_blob_handling(self.blob_handling.clone());
self.projection_plan.physical_projection = projection;
}
pub fn blob_handling(&mut self, blob_handling: BlobHandling) -> &mut Self {
self.blob_handling = blob_handling;
self.apply_blob_handling();
self
}
pub fn from_fragment(dataset: Arc<Dataset>, fragment: Fragment) -> Self {
Self {
fragments: Some(vec![fragment]),
..Self::new(dataset)
}
}
pub fn with_fragments(&mut self, fragments: Vec<Fragment>) -> &mut Self {
self.fragments = Some(fragments);
self
}
fn get_batch_size(&self) -> usize {
get_default_batch_size().unwrap_or_else(|| {
self.batch_size.unwrap_or_else(|| {
std::cmp::max(
self.dataset.object_store().block_size() / 4,
BATCH_SIZE_FALLBACK,
)
})
})
}
fn ensure_not_fragment_scan(&self) -> Result<()> {
if self.is_fragment_scan() {
Err(Error::not_supported(
"This operation is not supported for fragment scan".to_string(),
))
} else {
Ok(())
}
}
fn is_fragment_scan(&self) -> bool {
self.fragments.is_some()
}
pub fn empty_project(&mut self) -> Result<&mut Self> {
self.project(&[] as &[&str])
}
pub fn project<T: AsRef<str>>(&mut self, columns: &[T]) -> Result<&mut Self> {
let transformed_columns: Vec<(&str, String)> = columns
.iter()
.map(|c| (c.as_ref(), escape_field_path_for_project(c.as_ref())))
.collect();
self.project_with_transform(&transformed_columns)
}
pub fn project_with_transform(
&mut self,
columns: &[(impl AsRef<str>, impl AsRef<str>)],
) -> Result<&mut Self> {
self.explicit_projection = true;
self.projection_plan = ProjectionPlan::from_expressions(self.dataset.clone(), columns)?;
if self.legacy_with_row_id {
self.projection_plan.include_row_id();
}
if self.legacy_with_row_addr {
self.projection_plan.include_row_addr();
}
self.apply_blob_handling();
Ok(self)
}
pub fn prefilter(&mut self, should_prefilter: bool) -> &mut Self {
self.prefilter = should_prefilter;
self
}
pub fn scan_stats_callback(&mut self, callback: ExecutionStatsCallback) -> &mut Self {
self.scan_stats_callback = Some(callback);
self
}
pub fn materialization_style(&mut self, style: MaterializationStyle) -> &mut Self {
self.materialization_style = style;
self
}
pub fn filter(&mut self, filter: &str) -> Result<&mut Self> {
self.filter.expr_filter = Some(ExprFilter::Sql(filter.to_string()));
Ok(self)
}
pub fn filter_query(&mut self, filter: QueryFilter) -> Result<&mut Self> {
self.filter.query_filter = Some(filter);
Ok(self)
}
pub fn full_text_search(&mut self, query: FullTextSearchQuery) -> Result<&mut Self> {
let fields = query.columns();
if !fields.is_empty() {
for field in fields.iter() {
if self.dataset.schema().field(field).is_none() {
return Err(Error::invalid_input(format!("Column {} not found", field)));
}
}
}
self.full_text_query = Some(query);
Ok(self)
}
pub fn filter_substrait(&mut self, filter: &[u8]) -> Result<&mut Self> {
self.filter.expr_filter = Some(ExprFilter::Substrait(filter.to_vec()));
Ok(self)
}
pub fn filter_expr(&mut self, filter: Expr) -> &mut Self {
self.filter.expr_filter = Some(ExprFilter::Datafusion(filter));
self
}
pub fn aggregate(&mut self, aggregate: AggregateExpr) -> Result<&mut Self> {
let schema: Arc<ArrowSchema> = Arc::new(self.dataset.schema().into());
let parsed = aggregate.parse(schema)?;
self.aggregate = Some(parsed);
Ok(self)
}
pub fn batch_size(&mut self, batch_size: usize) -> &mut Self {
self.batch_size = Some(batch_size);
self
}
pub fn include_deleted_rows(&mut self) -> &mut Self {
self.include_deleted_rows = true;
self
}
pub fn io_buffer_size(&mut self, size: u64) -> &mut Self {
self.io_buffer_size = Some(size);
self
}
pub fn batch_readahead(&mut self, nbatches: usize) -> &mut Self {
self.batch_readahead = nbatches;
self
}
pub fn fragment_readahead(&mut self, nfragments: usize) -> &mut Self {
self.fragment_readahead = Some(nfragments);
self
}
pub fn scan_in_order(&mut self, ordered: bool) -> &mut Self {
self.ordered = ordered;
self
}
pub fn use_scalar_index(&mut self, use_scalar_index: bool) -> &mut Self {
self.use_scalar_index = use_scalar_index;
self
}
pub fn strict_batch_size(&mut self, strict_batch_size: bool) -> &mut Self {
self.strict_batch_size = strict_batch_size;
self
}
pub fn limit(&mut self, limit: Option<i64>, offset: Option<i64>) -> Result<&mut Self> {
if limit.unwrap_or_default() < 0 {
return Err(Error::invalid_input(
"Limit must be non-negative".to_string(),
));
}
if let Some(off) = offset
&& off < 0
{
return Err(Error::invalid_input(
"Offset must be non-negative".to_string(),
));
}
self.limit = limit;
self.offset = offset;
Ok(self)
}
pub fn nearest(&mut self, column: &str, q: &dyn Array, k: usize) -> Result<&mut Self> {
if !self.prefilter {
self.ensure_not_fragment_scan()?;
}
if k == 0 {
return Err(Error::invalid_input("k must be positive".to_string()));
}
if q.is_empty() {
return Err(Error::invalid_input(
"Query vector must have non-zero length".to_string(),
));
}
let (vector_type, element_type) = get_vector_type(self.dataset.schema(), column)?;
let dim = get_vector_dim(self.dataset.schema(), column)?;
let q = match q.data_type() {
DataType::List(_) | DataType::FixedSizeList(_, _) => {
if !matches!(vector_type, DataType::List(_)) {
return Err(Error::invalid_input(format!(
"Query is multivector but column {}({})is not multivector",
column, vector_type,
)));
}
if let Some(list_array) = q.as_list_opt::<i32>() {
for i in 0..list_array.len() {
let vec = list_array.value(i);
if vec.len() != dim {
return Err(Error::invalid_input(format!(
"query dim({}) doesn't match the column {} vector dim({})",
vec.len(),
column,
dim,
)));
}
}
list_array.values().clone()
} else {
let fsl = q.as_fixed_size_list();
if fsl.value_length() as usize != dim {
return Err(Error::invalid_input(format!(
"query dim({}) doesn't match the column {} vector dim({})",
fsl.value_length(),
column,
dim,
)));
}
fsl.values().clone()
}
}
_ => {
if q.len() != dim {
return Err(Error::invalid_input(format!(
"query dim({}) doesn't match the column {} vector dim({})",
q.len(),
column,
dim,
)));
}
q.slice(0, q.len())
}
};
let key = match &element_type {
dt if dt == q.data_type() => q,
dt if dt.is_floating() => coerce_float_vector(
q.as_any().downcast_ref::<Float32Array>().unwrap(),
FloatType::try_from(dt)?,
)?,
_ => {
return Err(Error::invalid_input(format!(
"Column {} has element type {} and the query vector is {}",
column,
element_type,
q.data_type(),
)));
}
};
self.nearest = Some(Query {
column: column.to_string(),
key,
k,
lower_bound: None,
upper_bound: None,
minimum_nprobes: 1,
maximum_nprobes: None,
ef: None,
refine_factor: None,
metric_type: None,
use_index: true,
dist_q_c: 0.0,
});
Ok(self)
}
#[cfg(test)]
fn nearest_mut(&mut self) -> Option<&mut Query> {
self.nearest.as_mut()
}
pub fn distance_range(
&mut self,
lower_bound: Option<f32>,
upper_bound: Option<f32>,
) -> &mut Self {
if let Some(q) = self.nearest.as_mut() {
q.lower_bound = lower_bound;
q.upper_bound = upper_bound;
}
self
}
pub fn nprobes(&mut self, n: usize) -> &mut Self {
if let Some(q) = self.nearest.as_mut() {
q.minimum_nprobes = n;
q.maximum_nprobes = Some(n);
} else {
log::warn!("nprobes is not set because nearest has not been called yet");
}
self
}
#[deprecated(note = "Use nprobes instead")]
pub fn nprobs(&mut self, n: usize) -> &mut Self {
if let Some(q) = self.nearest.as_mut() {
q.minimum_nprobes = n;
q.maximum_nprobes = Some(n);
} else {
log::warn!("nprobes is not set because nearest has not been called yet");
}
self
}
pub fn minimum_nprobes(&mut self, n: usize) -> &mut Self {
if let Some(q) = self.nearest.as_mut() {
q.minimum_nprobes = n;
} else {
log::warn!("minimum_nprobes is not set because nearest has not been called yet");
}
self
}
pub fn maximum_nprobes(&mut self, n: usize) -> &mut Self {
if let Some(q) = self.nearest.as_mut() {
q.maximum_nprobes = Some(n);
} else {
log::warn!("maximum_nprobes is not set because nearest has not been called yet");
}
self
}
pub fn ef(&mut self, ef: usize) -> &mut Self {
if let Some(q) = self.nearest.as_mut() {
q.ef = Some(ef);
}
self
}
pub fn fast_search(&mut self) -> &mut Self {
if let Some(q) = self.nearest.as_mut() {
q.use_index = true;
}
self.fast_search = true;
self.projection_plan.include_row_id(); self
}
pub fn refine(&mut self, factor: u32) -> &mut Self {
if let Some(q) = self.nearest.as_mut() {
q.refine_factor = Some(factor)
};
self
}
pub fn distance_metric(&mut self, metric_type: MetricType) -> &mut Self {
if let Some(q) = self.nearest.as_mut() {
q.metric_type = Some(metric_type)
}
self
}
pub fn order_by(&mut self, ordering: Option<Vec<ColumnOrdering>>) -> Result<&mut Self> {
if let Some(ordering) = &ordering {
if ordering.is_empty() {
self.ordering = None;
return Ok(self);
}
for column in ordering {
self.dataset
.schema()
.field(&column.column_name)
.ok_or(Error::invalid_input(format!(
"Column {} not found",
&column.column_name
)))?;
}
}
self.ordering = ordering;
Ok(self)
}
pub fn use_index(&mut self, use_index: bool) -> &mut Self {
if let Some(q) = self.nearest.as_mut() {
q.use_index = use_index
}
self
}
pub fn with_row_id(&mut self) -> &mut Self {
self.legacy_with_row_id = true;
self.projection_plan.include_row_id();
self
}
pub fn with_row_address(&mut self) -> &mut Self {
self.legacy_with_row_addr = true;
self.projection_plan.include_row_addr();
self
}
pub fn disable_scoring_autoprojection(&mut self) -> &mut Self {
self.autoproject_scoring_columns = false;
self
}
pub fn with_file_reader_options(&mut self, options: FileReaderOptions) -> &mut Self {
self.file_reader_options = Some(options);
self
}
fn create_column_expr(
column_name: &str,
dataset: &Dataset,
arrow_schema: &ArrowSchema,
) -> Result<Arc<dyn PhysicalExpr>> {
let lance_schema = dataset.schema();
let field_path = lance_schema
.resolve_case_insensitive(column_name)
.ok_or_else(|| {
Error::invalid_input(format!("Field '{}' not found in schema", column_name))
})?;
if field_path.len() == 1 {
expressions::col(&field_path[0].name, arrow_schema).map_err(|e| {
Error::internal(format!(
"Failed to create column expression for '{}': {}",
column_name, e
))
})
} else {
let get_field_func = ScalarUDF::from(GetFieldFunc::default());
let mut expr = Expr::Column(datafusion::common::Column::new_unqualified(
&field_path[0].name,
));
for nested_field in &field_path[1..] {
expr = get_field_func.call(vec![expr, lit(&nested_field.name)]);
}
let df_schema = Arc::new(DFSchema::try_from(arrow_schema.clone())?);
let execution_props = ExecutionProps::new().with_query_execution_start_time(Utc::now());
create_physical_expr(&expr, &df_schema, &execution_props).map_err(|e| {
Error::internal(format!(
"Failed to create physical expression for nested field '{}': {}",
column_name, e
))
})
}
}
pub fn use_stats(&mut self, use_stats: bool) -> &mut Self {
self.use_stats = use_stats;
self
}
pub async fn schema(&self) -> Result<SchemaRef> {
let plan = self.create_plan().await?;
Ok(plan.schema())
}
pub fn get_expr_filter(&self) -> Result<Option<Expr>> {
if let Some(filter) = &self.filter.expr_filter {
let filter_schema = self.filterable_schema()?;
Ok(Some(filter.to_datafusion(
self.dataset.schema(),
filter_schema.as_ref(),
)?))
} else {
Ok(None)
}
}
fn add_extra_columns(&self, schema: Schema) -> Result<Schema> {
let mut extra_columns = vec![ArrowField::new(ROW_OFFSET, DataType::UInt64, true)];
if self.nearest.as_ref().is_some() {
extra_columns.push(ArrowField::new(DIST_COL, DataType::Float32, true));
};
if self.full_text_query.is_some() {
extra_columns.push(ArrowField::new(SCORE_COL, DataType::Float32, true));
}
schema.merge(&ArrowSchema::new(extra_columns))
}
fn filterable_schema(&self) -> Result<Arc<Schema>> {
let base_schema = Projection::full(self.dataset.clone())
.with_row_id()
.with_row_addr()
.with_row_last_updated_at_version()
.with_row_created_at_version()
.to_schema();
Ok(Arc::new(self.add_extra_columns(base_schema)?))
}
pub(crate) fn calculate_final_projection(
&self,
current_schema: &ArrowSchema,
) -> Result<Vec<(Arc<dyn PhysicalExpr>, String)>> {
let mut output_expr = self.projection_plan.to_physical_exprs(current_schema)?;
if self.autoproject_scoring_columns {
if self.nearest.is_some() && output_expr.iter().all(|(_, name)| name != DIST_COL) {
if self.explicit_projection {
log::warn!(
"Deprecation warning, this behavior will change in the future. This search specified output columns but did not include `_distance`. Currently the `_distance` column will be included. In the future it will not. Call `disable_scoring_autoprojection` to adopt the future behavior and avoid this warning"
);
}
let vector_expr = expressions::col(DIST_COL, current_schema)?;
output_expr.push((vector_expr, DIST_COL.to_string()));
}
if self.full_text_query.is_some()
&& output_expr.iter().all(|(_, name)| name != SCORE_COL)
{
if self.explicit_projection {
log::warn!(
"Deprecation warning, this behavior will change in the future. This search specified output columns but did not include `_score`. Currently the `_score` column will be included. In the future it will not. Call `disable_scoring_autoprojection` to adopt the future behavior and avoid this warning"
);
}
let score_expr = expressions::col(SCORE_COL, current_schema)?;
output_expr.push((score_expr, SCORE_COL.to_string()));
}
}
if self.legacy_with_row_id {
let row_id_pos = output_expr
.iter()
.position(|(_, name)| name == ROW_ID)
.ok_or_else(|| {
Error::internal(
"user specified with_row_id but the _rowid column was not in the output"
.to_string(),
)
})?;
if row_id_pos != output_expr.len() - 1 {
let row_id_expr = output_expr.remove(row_id_pos);
output_expr.push(row_id_expr);
}
}
if self.legacy_with_row_addr {
let row_addr_pos = output_expr.iter().position(|(_, name)| name == ROW_ADDR).ok_or_else(|| {
Error::internal("user specified with_row_address but the _rowaddr column was not in the output".to_string())
})?;
if row_addr_pos != output_expr.len() - 1 {
let row_addr_expr = output_expr.remove(row_addr_pos);
output_expr.push(row_addr_expr);
}
}
Ok(output_expr)
}
#[instrument(skip_all)]
pub fn try_into_stream(&self) -> BoxFuture<'_, Result<DatasetRecordBatchStream>> {
async move {
let plan = self.create_plan().await?;
Ok(DatasetRecordBatchStream::new(execute_plan(
plan,
LanceExecutionOptions {
batch_size: self.batch_size,
execution_stats_callback: self.scan_stats_callback.clone(),
..Default::default()
},
)?))
}
.boxed()
}
pub(crate) async fn try_into_dfstream(
&self,
mut options: LanceExecutionOptions,
) -> Result<SendableRecordBatchStream> {
let plan = self.create_plan().await?;
if options.execution_stats_callback.is_none() {
options.execution_stats_callback = self.scan_stats_callback.clone();
}
execute_plan(plan, options)
}
pub async fn try_into_batch(&self) -> Result<RecordBatch> {
let stream = self.try_into_stream().await?;
let schema = stream.schema();
let batches = stream.try_collect::<Vec<_>>().await?;
Ok(concat_batches(&schema, &batches)?)
}
#[instrument(skip_all)]
pub fn count_rows(&self) -> BoxFuture<'_, Result<u64>> {
async move {
let mut scanner = self.clone();
scanner.aggregate(AggregateExpr::builder().count_star().build())?;
let plan = scanner.create_plan().await?;
let mut stream = execute_plan(plan, LanceExecutionOptions::default())?;
if let Some(first_batch) = stream.next().await {
let batch = first_batch?;
let array = batch
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.ok_or(Error::invalid_input(
"Count plan did not return an Int64Array".to_string(),
))?;
Ok(array.value(0) as u64)
} else {
Ok(0)
}
}
.boxed()
}
#[deprecated(note = "Use create_plan() instead, which now applies aggregate automatically")]
pub fn create_aggregate_plan(&self) -> BoxFuture<'_, Result<Arc<dyn ExecutionPlan>>> {
async move {
if self.aggregate.is_none() {
return Err(Error::invalid_input(
"create_aggregate_plan called but no aggregate was set",
));
}
self.create_plan().await
}
.boxed()
}
async fn apply_aggregate(
&self,
plan: Arc<dyn ExecutionPlan>,
agg: &Aggregate,
) -> Result<Arc<dyn ExecutionPlan>> {
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
let schema = plan.schema();
let df_schema = DFSchema::try_from(schema.as_ref().clone())?;
let group_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> = agg
.group_by
.iter()
.map(|expr| {
let name = expr.schema_name().to_string();
let physical_expr =
create_physical_expr(expr, &df_schema, &ExecutionProps::default())?;
Ok((physical_expr, name))
})
.collect::<Result<_>>()?;
#[allow(clippy::type_complexity)]
let aggr_results: Vec<(Arc<AggregateFunctionExpr>, Option<Arc<dyn PhysicalExpr>>)> = agg
.aggregates
.iter()
.map(|expr| self.build_physical_aggregate_expr(expr, &df_schema, &schema))
.collect::<Result<_>>()?;
let (aggr_exprs, filters): (Vec<_>, Vec<_>) = aggr_results.into_iter().unzip();
Ok(Arc::new(AggregateExec::try_new(
AggregateMode::Single,
PhysicalGroupBy::new_single(group_exprs),
aggr_exprs,
filters,
plan,
schema,
)?) as Arc<dyn ExecutionPlan>)
}
#[allow(clippy::type_complexity)]
fn build_physical_aggregate_expr(
&self,
expr: &Expr,
df_schema: &DFSchema,
input_schema: &SchemaRef,
) -> Result<(
Arc<datafusion_physical_expr::aggregate::AggregateFunctionExpr>,
Option<Arc<dyn PhysicalExpr>>,
)> {
use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
let coerced_expr = self.coerce_aggregate_expr(expr, df_schema)?;
let (agg_expr, filter, _order_by) = create_aggregate_expr_and_maybe_filter(
&coerced_expr,
df_schema,
input_schema.as_ref(),
&ExecutionProps::default(),
)?;
Ok((agg_expr, filter))
}
fn coerce_aggregate_expr(&self, expr: &Expr, schema: &DFSchema) -> Result<Expr> {
Self::coerce_aggregate_expr_impl(expr, schema)
}
fn coerce_aggregate_expr_impl(expr: &Expr, schema: &DFSchema) -> Result<Expr> {
use datafusion::logical_expr::Expr;
use datafusion::logical_expr::expr::AggregateFunction;
use datafusion::logical_expr::type_coercion::functions::fields_with_udf;
match expr {
Expr::AggregateFunction(agg_func) => {
let func = &agg_func.func;
let args = &agg_func.params.args;
if args.is_empty() {
return Ok(expr.clone());
}
let current_fields: Vec<arrow_schema::FieldRef> = args
.iter()
.enumerate()
.map(|(i, e)| {
let dt = e.get_type(schema)?;
Ok(Arc::new(arrow_schema::Field::new(
format!("arg_{i}"),
dt,
true,
)))
})
.collect::<std::result::Result<_, datafusion::common::DataFusionError>>()?;
let coerced_fields = fields_with_udf(¤t_fields, func.as_ref())?;
let coerced_args: Vec<Expr> = args
.iter()
.zip(coerced_fields.iter())
.map(|(arg, target_field)| {
let arg_type = arg.get_type(schema)?;
let target_type = target_field.data_type();
if arg_type == *target_type {
Ok(arg.clone())
} else {
arg.clone().cast_to(target_type, schema)
}
})
.collect::<std::result::Result<_, _>>()?;
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
func.clone(),
coerced_args,
agg_func.params.distinct,
agg_func.params.filter.clone(),
agg_func.params.order_by.clone(),
agg_func.params.null_treatment,
)))
}
Expr::Alias(alias) => {
let coerced_inner = Self::coerce_aggregate_expr_impl(&alias.expr, schema)?;
Ok(coerced_inner.alias(&alias.name))
}
other => Err(Error::invalid_input(format!(
"Expected aggregate function expression, got {:?}",
other.variant_name()
))),
}
}
fn is_early_field(&self, field: &Field) -> bool {
match self.materialization_style {
MaterializationStyle::AllEarly => true,
MaterializationStyle::AllLate => false,
MaterializationStyle::AllEarlyExcept(ref cols) => !cols.contains(&(field.id as u32)),
MaterializationStyle::Heuristic => {
if field.is_blob() {
return true;
}
let byte_width = field.data_type().byte_width_opt();
let is_cloud = self.dataset.object_store().is_cloud();
if is_cloud {
byte_width.is_some_and(|bw| bw < 1000)
} else {
byte_width.is_some_and(|bw| bw < 10)
}
}
}
}
fn calc_eager_projection(
&self,
filter_plan: &ExprFilterPlan,
desired_projection: &Projection,
) -> Result<Projection> {
let filter_columns = filter_plan.all_columns();
let filter_schema = self
.dataset
.empty_projection()
.union_columns(filter_columns, OnMissing::Error)?
.into_schema();
Ok(desired_projection
.clone()
.subtract_predicate(|f| !self.is_early_field(f))
.union_schema(&filter_schema))
}
fn validate_options(&self) -> Result<()> {
if self.include_deleted_rows && !self.projection_plan.physical_projection.with_row_id {
return Err(Error::invalid_input_source(
"include_deleted_rows is set but with_row_id is false".into(),
));
}
if self.aggregate.is_some() {
if self.limit.is_some() || self.offset.is_some() {
return Err(Error::invalid_input_source(
"Cannot use limit/offset with aggregate. Apply limit to the result instead."
.into(),
));
}
if self.ordering.is_some() {
return Err(Error::invalid_input_source(
"Cannot use order_by with aggregate. Apply ordering to the result instead."
.into(),
));
}
}
Ok(())
}
async fn create_filter_plan(&self, use_scalar_index: bool) -> Result<FilterPlan> {
let filter_schema = self.filterable_schema()?;
let planner = Planner::new(Arc::new(filter_schema.as_ref().into()));
let filter_plan = if let Some(filter) = self.filter.expr_filter.as_ref() {
let expr = filter.to_datafusion(self.dataset.schema(), filter_schema.as_ref())?;
let index_info = self.dataset.scalar_index_info().await?;
let filter_plan =
planner.create_filter_plan(expr.clone(), &index_info, use_scalar_index)?;
if filter_plan.index_query.is_some() {
let fragments = if let Some(fragments) = self.fragments.as_ref() {
fragments
} else {
self.dataset.fragments()
};
let mut has_missing_row_count = false;
for frag in fragments {
if frag.physical_rows.is_none() {
has_missing_row_count = true;
break;
}
}
if has_missing_row_count {
let filter_plan =
planner.create_filter_plan(expr.clone(), &index_info, false)?;
FilterPlan::new(self.filter.query_filter.clone(), filter_plan)
} else {
FilterPlan::new(self.filter.query_filter.clone(), filter_plan)
}
} else {
FilterPlan::new(self.filter.query_filter.clone(), filter_plan)
}
} else {
FilterPlan::new(self.filter.query_filter.clone(), ExprFilterPlan::default())
};
if filter_plan.query_filter.is_some()
&& self.nearest.is_none()
&& self.full_text_query.is_none()
{
return Err(Error::invalid_input_source(
"Query filter can only be used with full text search or vector search".into(),
));
}
if self.nearest.is_some() && filter_plan.vector_filter().is_some() {
return Err(Error::invalid_input_source(
"Query filter can't be used with vector search".into(),
));
}
if self.full_text_query.is_some() && filter_plan.fts_filter().is_some() {
return Err(Error::invalid_input_source(
"Fts filter can't be used with fts search".into(),
));
}
Ok(filter_plan)
}
async fn get_scan_range(&self, filter_plan: &ExprFilterPlan) -> Result<Option<Range<u64>>> {
if filter_plan.has_any_filter() {
Ok(None)
} else if self.ordering.is_some() {
Ok(None)
} else {
match (self.limit, self.offset) {
(None, None) => Ok(None),
(Some(limit), None) => {
let num_rows = self.dataset.count_all_rows().await? as i64;
Ok(Some(0..limit.min(num_rows) as u64))
}
(None, Some(offset)) => {
let num_rows = self.dataset.count_all_rows().await? as i64;
Ok(Some(offset.min(num_rows) as u64..num_rows as u64))
}
(Some(limit), Some(offset)) => {
let num_rows = self.dataset.count_all_rows().await? as i64;
Ok(Some(
offset.min(num_rows) as u64..(offset + limit).min(num_rows) as u64,
))
}
}
}
}
#[instrument(level = "debug", skip_all)]
pub async fn create_plan(&self) -> Result<Arc<dyn ExecutionPlan>> {
log::trace!("creating scanner plan");
self.validate_options()?;
let use_scalar_index = self.use_scalar_index && (self.prefilter || self.nearest.is_none());
let mut filter_plan = self.create_filter_plan(use_scalar_index).await?;
let mut use_limit_node = true;
let mut plan: Arc<dyn ExecutionPlan> = match (&self.nearest, &self.full_text_query) {
(Some(_), None) => self.vector_search_source(&mut filter_plan).await?,
(None, Some(query)) => self.fts_search_source(&mut filter_plan, query).await?,
(None, None) => {
if self.projection_plan.has_output_cols()
&& self.projection_plan.physical_projection.is_empty()
{
let output_expr = self.calculate_final_projection(&ArrowSchema::empty())?;
return Err(Error::not_supported_source(format!("Scans must request at least one column. Received only dynamic expressions: {:?}", output_expr).into()));
}
let take_op = filter_plan
.expr_filter_plan
.full_expr
.as_ref()
.and_then(TakeOperation::try_from_expr);
if let Some((take_op, remainder)) = take_op {
filter_plan.expr_filter_plan = remainder
.map(ExprFilterPlan::new_refine_only)
.unwrap_or(ExprFilterPlan::default());
self.take_source(take_op).await?
} else {
let planned_read = self
.filtered_read_source(&mut filter_plan.expr_filter_plan)
.await?;
if planned_read.limit_pushed_down {
use_limit_node = false;
}
if planned_read.filter_pushed_down {
filter_plan.disable_refine();
}
planned_read.plan
}
}
_ => {
return Err(Error::invalid_input_source(
"Cannot have both nearest and full text search".into(),
));
}
};
let mut pre_filter_projection = self.dataset.empty_projection();
if filter_plan.has_refine() {
pre_filter_projection = pre_filter_projection.union_columns(
filter_plan.refine_columns(&self.dataset).await?,
OnMissing::Ignore,
)?;
}
if let Some(ordering) = &self.ordering {
pre_filter_projection = pre_filter_projection.union_columns(
ordering.iter().map(|col| &col.column_name),
OnMissing::Error,
)?;
}
plan = self.take(plan, pre_filter_projection)?;
plan = filter_plan.refine_filter(plan, self).await?;
if let Some(agg) = &self.aggregate {
let required_columns = agg.required_columns();
let agg_projection = if required_columns.is_empty() {
self.dataset.empty_projection()
} else {
self.dataset
.empty_projection()
.union_columns(&required_columns, OnMissing::Error)?
};
plan = self.take(plan, agg_projection)?;
plan = self.apply_aggregate(plan, agg).await?;
let optimizer = get_physical_optimizer();
let options = Default::default();
for rule in optimizer.rules {
plan = rule.optimize(plan, &options)?;
}
return Ok(plan);
}
if let Some(ordering) = &self.ordering {
let ordering_columns = ordering.iter().map(|col| &col.column_name);
let projection_with_ordering = self
.dataset
.empty_projection()
.union_columns(ordering_columns, OnMissing::Error)?;
plan = self.take(plan, projection_with_ordering)?;
let col_exprs = ordering
.iter()
.map(|col| {
Ok(PhysicalSortExpr {
expr: Self::create_column_expr(
&col.column_name,
&self.dataset,
plan.schema().as_ref(),
)?,
options: SortOptions {
descending: !col.ascending,
nulls_first: col.nulls_first,
},
})
})
.collect::<Result<Vec<_>>>()?;
plan = Arc::new(SortExec::new(
LexOrdering::new(col_exprs)
.ok_or(exec_datafusion_err!("Unexpected empty sort expressions"))?,
plan,
));
}
if use_limit_node && (self.limit.unwrap_or(0) > 0 || self.offset.is_some()) {
plan = self.limit_node(plan);
}
plan = self.take(plan, self.projection_plan.physical_projection.clone())?;
if self.projection_plan.must_add_row_offset {
plan = Arc::new(AddRowOffsetExec::try_new(plan, self.dataset.clone()).await?);
}
let final_projection = self.calculate_final_projection(plan.schema().as_ref())?;
plan = Arc::new(DFProjectionExec::try_new(final_projection, plan)?);
if self.strict_batch_size {
plan = Arc::new(StrictBatchSizeExec::new(plan, self.get_batch_size()));
}
let optimizer = get_physical_optimizer();
let options = Default::default();
for rule in optimizer.rules {
plan = rule.optimize(plan, &options)?;
}
Ok(plan)
}
fn filter_references_version_columns(&self, filter_plan: &ExprFilterPlan) -> bool {
use lance_core::{ROW_CREATED_AT_VERSION, ROW_LAST_UPDATED_AT_VERSION};
if let Some(refine_expr) = &filter_plan.refine_expr {
let column_names = Planner::column_names_in_expr(refine_expr);
for col_name in column_names {
if col_name == ROW_CREATED_AT_VERSION || col_name == ROW_LAST_UPDATED_AT_VERSION {
return true;
}
}
}
false
}
async fn legacy_filtered_read(
&self,
filter_plan: &ExprFilterPlan,
projection: Projection,
make_deletions_null: bool,
fragments: Option<Arc<Vec<Fragment>>>,
scan_range: Option<Range<u64>>,
is_prefilter: bool,
) -> Result<PlannedFilteredScan> {
let fragments = fragments.unwrap_or(self.dataset.fragments().clone());
let mut filter_pushed_down = false;
let plan: Arc<dyn ExecutionPlan> = if filter_plan.has_index_query() {
if self.include_deleted_rows {
return Err(Error::invalid_input_source(
"Cannot include deleted rows in a scalar indexed scan".into(),
));
}
self.scalar_indexed_scan(projection, filter_plan, fragments)
.await
} else if !is_prefilter
&& filter_plan.has_refine()
&& self.batch_size.is_none()
&& self.use_stats
&& !self.filter_references_version_columns(filter_plan)
{
filter_pushed_down = true;
self.pushdown_scan(false, filter_plan)
} else {
let ordered = if self.ordering.is_some() || self.nearest.is_some() {
false
} else if projection.with_row_last_updated_at_version
|| projection.with_row_created_at_version
{
true
} else {
self.ordered
};
let projection = if let Some(refine_expr) = filter_plan.refine_expr.as_ref() {
if is_prefilter {
let refine_cols = Planner::column_names_in_expr(refine_expr);
projection.union_columns(refine_cols, OnMissing::Error)?
} else {
projection
}
} else {
projection
};
let scan_range = if filter_plan.has_refine() {
None
} else {
scan_range
};
let scan = self.scan_fragments(
projection.with_row_id,
self.projection_plan.physical_projection.with_row_addr,
self.projection_plan
.physical_projection
.with_row_last_updated_at_version,
self.projection_plan
.physical_projection
.with_row_created_at_version,
make_deletions_null,
Arc::new(projection.to_bare_schema()),
fragments,
scan_range,
ordered,
);
if filter_plan.has_refine() && is_prefilter {
Ok(Arc::new(LanceFilterExec::try_new(
filter_plan.refine_expr.clone().unwrap(),
scan,
)?) as Arc<dyn ExecutionPlan>)
} else {
Ok(scan)
}
}?;
Ok(PlannedFilteredScan {
plan,
limit_pushed_down: false,
filter_pushed_down,
})
}
async fn new_filtered_read(
&self,
filter_plan: &ExprFilterPlan,
projection: Projection,
make_deletions_null: bool,
fragments: Option<Arc<Vec<Fragment>>>,
scan_range: Option<Range<u64>>,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut read_options = FilteredReadOptions::basic_full_read(&self.dataset)
.with_filter_plan(filter_plan.clone())
.with_projection(projection);
if let Some(fragments) = fragments {
read_options = read_options.with_fragments(fragments);
}
if let Some(scan_range) = scan_range {
read_options = read_options.with_scan_range_before_filter(scan_range)?;
}
if let Some(batch_size) = self.batch_size {
read_options = read_options.with_batch_size(batch_size as u32);
}
if let Some(fragment_readahead) = self.fragment_readahead {
read_options = read_options.with_fragment_readahead(fragment_readahead);
}
if make_deletions_null {
read_options = read_options.with_deleted_rows()?;
}
if let Some(io_buffer_size_bytes) = self.io_buffer_size {
read_options = read_options.with_io_buffer_size(io_buffer_size_bytes);
}
let index_input = filter_plan.index_query.clone().map(|index_query| {
Arc::new(ScalarIndexExec::new(self.dataset.clone(), index_query))
as Arc<dyn ExecutionPlan>
});
Ok(Arc::new(FilteredReadExec::try_new(
self.dataset.clone(),
read_options,
index_input,
)?))
}
async fn filtered_read(
&self,
filter_plan: &ExprFilterPlan,
projection: Projection,
make_deletions_null: bool,
fragments: Option<Arc<Vec<Fragment>>>,
scan_range: Option<Range<u64>>,
is_prefilter: bool,
) -> Result<PlannedFilteredScan> {
if self.dataset.is_legacy_storage() {
self.legacy_filtered_read(
filter_plan,
projection,
make_deletions_null,
fragments,
scan_range,
is_prefilter,
)
.await
} else {
let limit_pushed_down = scan_range.is_some();
let plan = self
.new_filtered_read(
filter_plan,
projection,
make_deletions_null,
fragments,
scan_range,
)
.await?;
Ok(PlannedFilteredScan {
filter_pushed_down: true,
limit_pushed_down,
plan,
})
}
}
fn u64s_as_take_input(&self, u64s: Vec<u64>) -> Result<Arc<dyn ExecutionPlan>> {
let row_addrs = RowAddrTreeMap::from_iter(u64s);
let row_addr_mask = RowAddrMask::from_allowed(row_addrs);
let index_result = IndexExprResult::Exact(row_addr_mask);
let fragments_covered = self.dataset.fragment_bitmap.as_ref().clone();
let batch = index_result.serialize_to_arrow(&fragments_covered)?;
let stream = futures::stream::once(async move { Ok(batch) });
let stream = Box::pin(RecordBatchStreamAdapter::new(
INDEX_EXPR_RESULT_SCHEMA.clone(),
stream,
));
Ok(Arc::new(OneShotExec::new(stream)))
}
async fn take_source(&self, take_op: TakeOperation) -> Result<Arc<dyn ExecutionPlan>> {
let projection = self.projection_plan.physical_projection.clone();
let input = match take_op {
TakeOperation::RowIds(ids) => self.u64s_as_take_input(ids),
TakeOperation::RowAddrs(addrs) => self.u64s_as_take_input(addrs),
TakeOperation::RowOffsets(offsets) => {
let mut addrs =
row_offsets_to_row_addresses(self.dataset.as_ref(), &offsets).await?;
addrs.retain(|addr| *addr != RowAddress::TOMBSTONE_ROW);
self.u64s_as_take_input(addrs)
}
}?;
let mut filtered_read_options = FilteredReadOptions::new(projection);
if let Some(fragment) = self.fragments.as_ref() {
filtered_read_options =
filtered_read_options.with_fragments(Arc::new(fragment.clone()));
}
Ok(Arc::new(FilteredReadExec::try_new(
self.dataset.clone(),
filtered_read_options,
Some(input),
)?))
}
async fn filtered_read_source(
&self,
filter_plan: &mut ExprFilterPlan,
) -> Result<PlannedFilteredScan> {
log::trace!("source is a filtered read");
let effective_projection = if let Some(agg) = &self.aggregate {
let required_columns = agg.required_columns();
if required_columns.is_empty() {
self.dataset.empty_projection()
} else {
self.dataset
.empty_projection()
.union_columns(&required_columns, OnMissing::Error)?
}
} else {
self.projection_plan.physical_projection.clone()
};
let mut projection = if filter_plan.has_refine() {
self.calc_eager_projection(filter_plan, &effective_projection)?
.with_row_id()
} else {
effective_projection
};
if projection.is_empty() {
projection.with_row_addr = true;
}
let scan_range = if filter_plan.is_empty() {
log::trace!("pushing scan_range into filtered_read");
self.get_scan_range(filter_plan).await?
} else {
None
};
self.filtered_read(
filter_plan,
projection,
self.include_deleted_rows,
self.fragments.clone().map(Arc::new),
scan_range,
false,
)
.await
}
async fn fts_search_source(
&self,
filter_plan: &mut FilterPlan,
query: &FullTextSearchQuery,
) -> Result<Arc<dyn ExecutionPlan>> {
log::trace!("source is an fts search");
if self.include_deleted_rows {
return Err(Error::invalid_input_source(
"Cannot include deleted rows in an FTS search".into(),
));
}
if self.prefilter {
let source: Arc<dyn ExecutionPlan> = match &filter_plan.vector_filter() {
Some(vector_query) => {
let vector_plan = self
.vector_search(&filter_plan.expr_filter_plan, vector_query)
.await?;
self.fts_rerank(vector_plan, query).await?
}
None => self.fts(&filter_plan.expr_filter_plan, query).await?,
};
filter_plan.disable_refine();
Ok(source)
} else {
filter_plan.make_refine_only();
self.fts(&ExprFilterPlan::default(), query).await
}
}
async fn vector_search_source(
&self,
filter_plan: &mut FilterPlan,
) -> Result<Arc<dyn ExecutionPlan>> {
if self.include_deleted_rows {
return Err(Error::invalid_input_source(
"Cannot include deleted rows in a nearest neighbor search".into(),
));
}
let Some(query) = self.nearest.as_ref() else {
return Err(Error::invalid_input("No nearest query".to_string()));
};
if self.prefilter {
log::trace!("source is a vector search (prefilter)");
let source: Arc<dyn ExecutionPlan> = match &filter_plan.fts_filter() {
Some(fts_query) => {
let fts_plan = self.fts(&filter_plan.expr_filter_plan, fts_query).await?;
let projection = self
.dataset
.empty_projection()
.union_column(&query.column, OnMissing::Error)?;
let plan = self.take(fts_plan, projection)?;
self.flat_knn(plan, query)?
}
None => {
self.vector_search(&filter_plan.expr_filter_plan, query)
.await?
}
};
filter_plan.disable_refine();
Ok(source)
} else {
log::trace!("source is a vector search (postfilter)");
filter_plan.make_refine_only();
self.vector_search(&ExprFilterPlan::default(), query).await
}
}
async fn fragments_covered_by_fts_leaf(
&self,
column: &str,
accum: &mut RoaringBitmap,
) -> Result<bool> {
let index = self
.dataset
.load_scalar_index(IndexCriteria::default().for_column(column).supports_fts())
.await?;
match index {
Some(index) => match &index.fragment_bitmap {
Some(fragmap) => {
*accum |= fragmap;
Ok(true)
}
None => Ok(false),
},
None => Ok(false),
}
}
#[async_recursion]
async fn fragments_covered_by_fts_query_helper(
&self,
query: &FtsQuery,
accum: &mut RoaringBitmap,
) -> Result<bool> {
match query {
FtsQuery::Match(match_query) => {
self.fragments_covered_by_fts_leaf(
match_query.column.as_ref().ok_or(Error::invalid_input(
"the column must be specified in the query".to_string(),
))?,
accum,
)
.await
}
FtsQuery::Boost(boost) => Ok(self
.fragments_covered_by_fts_query_helper(&boost.negative, accum)
.await?
& self
.fragments_covered_by_fts_query_helper(&boost.positive, accum)
.await?),
FtsQuery::MultiMatch(multi_match) => {
for mq in &multi_match.match_queries {
if !self
.fragments_covered_by_fts_leaf(
mq.column.as_ref().ok_or(Error::invalid_input(
"the column must be specified in the query".to_string(),
))?,
accum,
)
.await?
{
return Ok(false);
}
}
Ok(true)
}
FtsQuery::Phrase(phrase_query) => {
self.fragments_covered_by_fts_leaf(
phrase_query.column.as_ref().ok_or(Error::invalid_input(
"the column must be specified in the query".to_string(),
))?,
accum,
)
.await
}
FtsQuery::Boolean(bool_query) => {
for query in bool_query.must.iter() {
if !self
.fragments_covered_by_fts_query_helper(query, accum)
.await?
{
return Ok(false);
}
}
for query in &bool_query.should {
if !self
.fragments_covered_by_fts_query_helper(query, accum)
.await?
{
return Ok(false);
}
}
Ok(true)
}
}
}
async fn fragments_covered_by_fts_query(&self, query: &FtsQuery) -> Result<RoaringBitmap> {
let all_fragments = self.get_fragments_as_bitmap();
let mut referenced_fragments = RoaringBitmap::new();
if !self
.fragments_covered_by_fts_query_helper(query, &mut referenced_fragments)
.await?
{
Ok(all_fragments)
} else {
Ok(all_fragments & referenced_fragments)
}
}
async fn fts(
&self,
filter_plan: &ExprFilterPlan,
query: &FullTextSearchQuery,
) -> Result<Arc<dyn ExecutionPlan>> {
let columns = query.columns();
let mut params = query.params();
if params.limit.is_none() {
let search_limit = match (self.limit, self.offset) {
(Some(limit), Some(offset)) => Some((limit + offset) as usize),
(Some(limit), None) => Some(limit as usize),
(None, Some(_)) => None, (None, None) => None,
};
params = params.with_limit(search_limit);
}
let query = if columns.is_empty() {
let indexed_columns = fts_indexed_columns(self.dataset.clone()).await?;
fill_fts_query_column(&query.query, &indexed_columns, false)?
} else {
query.query.clone()
};
let prefilter_source = self
.prefilter_source(
filter_plan,
self.fragments_covered_by_fts_query(&query).await?,
)
.await?;
let fts_exec = self
.plan_fts(&query, ¶ms, filter_plan, &prefilter_source)
.await?;
Ok(fts_exec)
}
async fn plan_fts(
&self,
query: &FtsQuery,
params: &FtsSearchParams,
filter_plan: &ExprFilterPlan,
prefilter_source: &PreFilterSource,
) -> Result<Arc<dyn ExecutionPlan>> {
let plan: Arc<dyn ExecutionPlan> = match query {
FtsQuery::Match(query) => {
self.plan_match_query(query, params, filter_plan, prefilter_source)
.await?
}
FtsQuery::Phrase(query) => {
self.plan_phrase_query(query, params, prefilter_source)
.await?
}
FtsQuery::Boost(query) => {
let unlimited_params = params.clone().with_limit(None);
let positive_exec = Box::pin(self.plan_fts(
&query.positive,
&unlimited_params,
filter_plan,
prefilter_source,
));
let negative_exec = Box::pin(self.plan_fts(
&query.negative,
&unlimited_params,
filter_plan,
prefilter_source,
));
let (positive_exec, negative_exec) =
futures::future::try_join(positive_exec, negative_exec).await?;
Arc::new(BoostQueryExec::new(
query.clone(),
params.clone(),
positive_exec,
negative_exec,
))
}
FtsQuery::MultiMatch(query) => {
let mut children = Vec::with_capacity(query.match_queries.len());
for match_query in &query.match_queries {
let child =
self.plan_match_query(match_query, params, filter_plan, prefilter_source);
children.push(child);
}
let children = futures::future::try_join_all(children).await?;
let schema = children[0].schema();
let group_expr = vec![(
expressions::col(ROW_ID, schema.as_ref())?,
ROW_ID.to_string(),
)];
let fts_node = UnionExec::try_new(children)?;
let fts_node = Arc::new(RepartitionExec::try_new(
fts_node,
Partitioning::RoundRobinBatch(1),
)?);
let fts_node = Arc::new(AggregateExec::try_new(
AggregateMode::Single,
PhysicalGroupBy::new_single(group_expr),
vec![Arc::new(
datafusion_physical_expr::aggregate::AggregateExprBuilder::new(
functions_aggregate::min_max::max_udaf(),
vec![expressions::col(SCORE_COL, &schema)?],
)
.schema(schema.clone())
.alias(SCORE_COL)
.build()?,
)],
vec![None],
fts_node,
schema,
)?);
let sort_expr = PhysicalSortExpr {
expr: expressions::col(SCORE_COL, fts_node.schema().as_ref())?,
options: SortOptions {
descending: true,
nulls_first: false,
},
};
Arc::new(
SortExec::new([sort_expr].into(), fts_node)
.with_fetch(self.limit.map(|l| l as usize)),
)
}
FtsQuery::Boolean(query) => {
let unlimited_params = params.clone().with_limit(None);
let mut should = Vec::with_capacity(query.should.len());
for subquery in &query.should {
let plan = Box::pin(self.plan_fts(
subquery,
&unlimited_params,
filter_plan,
prefilter_source,
))
.await?;
should.push(plan);
}
let should = if should.is_empty() {
Arc::new(EmptyExec::new(FTS_SCHEMA.clone()))
} else if should.len() == 1 {
should.pop().unwrap()
} else {
let unioned = UnionExec::try_new(should)?;
Arc::new(RepartitionExec::try_new(
unioned,
Partitioning::RoundRobinBatch(1),
)?)
};
let mut must = None;
for query in &query.must {
let plan = Box::pin(self.plan_fts(
query,
&unlimited_params,
filter_plan,
prefilter_source,
))
.await?;
if let Some(joined_plan) = must {
must = Some(Arc::new(HashJoinExec::try_new(
joined_plan,
plan,
vec![(
Arc::new(Column::new_with_schema(ROW_ID, &FTS_SCHEMA)?),
Arc::new(Column::new_with_schema(ROW_ID, &FTS_SCHEMA)?),
)],
None,
&datafusion_expr::JoinType::Inner,
None,
datafusion_physical_plan::joins::PartitionMode::CollectLeft,
NullEquality::NullEqualsNothing,
)?) as _);
} else {
must = Some(plan);
}
}
let mut must_not = Vec::with_capacity(query.must_not.len());
for query in &query.must_not {
let plan = Box::pin(self.plan_fts(
query,
&unlimited_params,
filter_plan,
prefilter_source,
))
.await?;
must_not.push(plan);
}
let must_not = if must_not.is_empty() {
Arc::new(EmptyExec::new(FTS_SCHEMA.clone()))
} else if must_not.len() == 1 {
must_not.pop().unwrap()
} else {
let unioned = UnionExec::try_new(must_not)?;
Arc::new(RepartitionExec::try_new(
unioned,
Partitioning::RoundRobinBatch(1),
)?)
};
if query.should.is_empty() && must.is_none() {
return Err(Error::invalid_input(
"boolean query must have at least one should/must query".to_string(),
));
}
Arc::new(BooleanQueryExec::new(
query.clone(),
params.clone(),
should,
must,
must_not,
))
}
};
Ok(plan)
}
async fn plan_phrase_query(
&self,
query: &PhraseQuery,
params: &FtsSearchParams,
prefilter_source: &PreFilterSource,
) -> Result<Arc<dyn ExecutionPlan>> {
let column = query.column.clone().ok_or(Error::invalid_input(
"the column must be specified in the query".to_string(),
))?;
let index_meta = self
.dataset
.load_scalar_index(IndexCriteria::default().for_column(&column).supports_fts())
.await?
.ok_or(Error::invalid_input(format!(
"No Inverted index found for column {}",
column
)))?;
let details_any =
crate::index::scalar::fetch_index_details(&self.dataset, &column, &index_meta).await?;
let details = details_any
.as_ref()
.to_msg::<lance_index::pbold::InvertedIndexDetails>()?;
if !details.with_position {
return Err(Error::invalid_input("position is not found but required for phrase queries, try recreating the index with position"
.to_string()));
}
Ok(Arc::new(PhraseQueryExec::new(
self.dataset.clone(),
query.clone(),
params.clone(),
prefilter_source.clone(),
)))
}
async fn plan_match_query(
&self,
query: &MatchQuery,
params: &FtsSearchParams,
filter_plan: &ExprFilterPlan,
prefilter_source: &PreFilterSource,
) -> Result<Arc<dyn ExecutionPlan>> {
let column = query
.column
.as_ref()
.ok_or(Error::invalid_input(
"the column must be specified in the query".to_string(),
))?
.clone();
let index = self
.dataset
.load_scalar_index(IndexCriteria::default().for_column(&column).supports_fts())
.await?;
let target_fragments = self
.fragments
.clone()
.unwrap_or_else(|| self.dataset.fragments().to_vec());
let (match_plan, flat_match_plan) = match &index {
Some(index) => {
let unindexed_fragments = self
.retain_target_fragments(self.dataset.unindexed_fragments(&index.name).await?);
if unindexed_fragments.len() == target_fragments.len() {
if self.fast_search {
return Ok(Arc::new(EmptyExec::new(FTS_SCHEMA.clone())));
}
let flat_match_plan = self
.plan_flat_match_query(unindexed_fragments, query, params, filter_plan)
.await?;
return Ok(flat_match_plan);
}
let match_plan: Arc<dyn ExecutionPlan> = Arc::new(MatchQueryExec::new(
self.dataset.clone(),
query.clone(),
params.clone(),
prefilter_source.clone(),
));
if self.fast_search || unindexed_fragments.is_empty() {
(Some(match_plan), None)
} else {
let flat_match_plan = self
.plan_flat_match_query(unindexed_fragments, query, params, filter_plan)
.await?;
(Some(match_plan), Some(flat_match_plan))
}
}
None => {
if self.fast_search {
return Ok(Arc::new(EmptyExec::new(FTS_SCHEMA.clone())));
}
let flat_match_plan = self
.plan_flat_match_query(target_fragments.clone(), query, params, filter_plan)
.await?;
(None, Some(flat_match_plan))
}
};
let plan = match (match_plan, flat_match_plan) {
(Some(match_plan), Some(flat_match_plan)) => {
let match_plan = UnionExec::try_new(vec![match_plan, flat_match_plan])?;
let match_plan = Arc::new(RepartitionExec::try_new(
match_plan,
Partitioning::RoundRobinBatch(1),
)?);
let sort_expr = PhysicalSortExpr {
expr: expressions::col(SCORE_COL, match_plan.schema().as_ref())?,
options: SortOptions {
descending: true,
nulls_first: false,
},
};
Arc::new(SortExec::new([sort_expr].into(), match_plan).with_fetch(params.limit))
}
(Some(match_plan), None) => match_plan,
(None, Some(flat_match_plan)) => flat_match_plan,
(None, None) => unreachable!(),
};
Ok(plan)
}
async fn plan_flat_match_query(
&self,
fragments: Vec<Fragment>,
query: &MatchQuery,
params: &FtsSearchParams,
filter_plan: &ExprFilterPlan,
) -> Result<Arc<dyn ExecutionPlan>> {
let column = query
.column
.as_ref()
.ok_or(Error::invalid_input(
"the column must be specified in the query".to_string(),
))?
.clone();
let mut columns = vec![column];
if let Some(expr) = filter_plan.full_expr.as_ref() {
let filter_columns = Planner::column_names_in_expr(expr);
columns.extend(filter_columns);
}
let flat_fts_scan_schema = Arc::new(self.dataset.schema().project(&columns).unwrap());
let mut scan_node = self.scan_fragments(
true,
false,
false,
false,
false,
flat_fts_scan_schema,
Arc::new(fragments),
None,
false,
);
if let Some(expr) = filter_plan.full_expr.as_ref() {
scan_node = Arc::new(LanceFilterExec::try_new(expr.clone(), scan_node)?);
}
let flat_match_plan = Arc::new(FlatMatchQueryExec::new(
self.dataset.clone(),
query.clone(),
params.clone(),
scan_node,
));
Ok(flat_match_plan)
}
async fn vector_search(
&self,
filter_plan: &ExprFilterPlan,
q: &Query,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut q = q.clone();
let (vector_type, element_type) = get_vector_type(self.dataset.schema(), &q.column)?;
let column_id = self.dataset.schema().field_id(q.column.as_str())?;
let use_index = q.use_index;
let indices = if use_index {
self.dataset.load_indices().await?
} else {
Arc::new(vec![])
};
let matching_index = if let Some(index) =
indices.iter().find(|i| i.fields.contains(&column_id))
{
let idx = self
.dataset
.open_vector_index(
q.column.as_str(),
&index.uuid.to_string(),
&NoOpMetricsCollector,
)
.await?;
let index_metric = idx.metric_type();
let use_this_index = match q.metric_type {
Some(user_metric) => {
if user_metric == index_metric {
true
} else {
log::warn!(
"Requested metric {:?} is incompatible with index metric {:?}, falling back to brute-force search",
user_metric,
index_metric
);
false
}
}
None => true, };
if use_this_index {
Some((index, idx, index_metric))
} else {
None
}
} else {
None
};
let index_and_deltas = if let Some((index, _idx, index_metric)) = matching_index {
let deltas = self.dataset.load_indices_by_name(&index.name).await?;
let index_frags = self.get_indexed_frags(&deltas);
if !index_frags.is_empty() {
Some((index, deltas, index_metric))
} else {
None
}
} else {
None
};
if let Some((index, deltas, index_metric)) = index_and_deltas {
log::trace!("index found for vector search");
q.metric_type = Some(index_metric);
validate_distance_type_for(index_metric, &element_type)?;
if matches!(q.refine_factor, Some(0)) {
return Err(Error::invalid_input(
"Refine factor cannot be zero".to_string(),
));
}
let ann_node = match vector_type {
DataType::FixedSizeList(_, _) => self.ann(&q, &deltas, filter_plan).await?,
DataType::List(_) => self.multivec_ann(&q, &deltas, filter_plan).await?,
_ => unreachable!(),
};
let mut knn_node = if q.refine_factor.is_some() {
let vector_projection = self
.dataset
.empty_projection()
.union_column(&q.column, OnMissing::Error)
.unwrap();
let knn_node_with_vector = self.take(ann_node, vector_projection)?;
self.flat_knn(knn_node_with_vector, &q)?
} else {
ann_node
};
if !self.fast_search {
knn_node = self.knn_combined(&q, index, knn_node, filter_plan).await?;
}
Ok(knn_node)
} else {
if self.fast_search {
return Ok(Arc::new(EmptyExec::new(KNN_INDEX_SCHEMA.clone())));
}
let metric = q
.metric_type
.unwrap_or_else(|| default_distance_type_for(&element_type));
q.metric_type = Some(metric);
validate_distance_type_for(metric, &element_type)?;
let mut columns = vec![q.column.clone()];
if let Some(refine_expr) = filter_plan.refine_expr.as_ref() {
columns.extend(Planner::column_names_in_expr(refine_expr));
}
let mut vector_scan_projection = self
.dataset
.empty_projection()
.with_row_id()
.union_columns(&columns, OnMissing::Error)?;
vector_scan_projection.with_row_addr =
self.projection_plan.physical_projection.with_row_addr;
let PlannedFilteredScan { mut plan, .. } = self
.filtered_read(
filter_plan,
vector_scan_projection,
true,
self.fragments.clone().map(Arc::new),
None,
true,
)
.await?;
if let Some(refine_expr) = &filter_plan.refine_expr {
plan = Arc::new(LanceFilterExec::try_new(refine_expr.clone(), plan)?);
}
Ok(self.flat_knn(plan, &q)?)
}
}
async fn knn_combined(
&self,
q: &Query,
index: &IndexMetadata,
mut knn_node: Arc<dyn ExecutionPlan>,
filter_plan: &ExprFilterPlan,
) -> Result<Arc<dyn ExecutionPlan>> {
let unindexed_fragments =
self.retain_target_fragments(self.dataset.unindexed_fragments(&index.name).await?);
if !unindexed_fragments.is_empty() {
let idx = self
.dataset
.open_vector_index(
q.column.as_str(),
&index.uuid.to_string(),
&NoOpMetricsCollector,
)
.await?;
let mut q = q.clone();
q.metric_type = Some(idx.metric_type());
if knn_node.schema().column_with_name(&q.column).is_none() {
let vector_projection = self
.dataset
.empty_projection()
.union_column(&q.column, OnMissing::Error)
.unwrap();
knn_node = self.take(knn_node, vector_projection)?;
}
let mut columns = vec![q.column.clone()];
if let Some(expr) = filter_plan.full_expr.as_ref() {
let filter_columns = Planner::column_names_in_expr(expr);
columns.extend(filter_columns);
}
let vector_scan_projection = Arc::new(self.dataset.schema().project(&columns).unwrap());
let mut scan_node = self.scan_fragments(
true,
false,
false,
false,
false,
vector_scan_projection,
Arc::new(unindexed_fragments),
None,
false,
);
if let Some(expr) = filter_plan.full_expr.as_ref() {
scan_node = Arc::new(LanceFilterExec::try_new(expr.clone(), scan_node)?);
}
let topk_appended = self.flat_knn(scan_node, &q)?;
let topk_appended = project(topk_appended, knn_node.schema().as_ref())?;
assert!(
topk_appended
.schema()
.equivalent_names_and_types(&knn_node.schema())
);
let unioned = UnionExec::try_new(vec![Arc::new(topk_appended), knn_node])?;
let unioned = RepartitionExec::try_new(
unioned,
datafusion::physical_plan::Partitioning::RoundRobinBatch(1),
)?;
return self.flat_knn(Arc::new(unioned), &q);
}
Ok(knn_node)
}
#[async_recursion]
async fn fragments_covered_by_index_query(
&self,
index_expr: &ScalarIndexExpr,
) -> Result<RoaringBitmap> {
match index_expr {
ScalarIndexExpr::And(lhs, rhs) => {
Ok(self.fragments_covered_by_index_query(lhs).await?
& self.fragments_covered_by_index_query(rhs).await?)
}
ScalarIndexExpr::Or(lhs, rhs) => Ok(self.fragments_covered_by_index_query(lhs).await?
& self.fragments_covered_by_index_query(rhs).await?),
ScalarIndexExpr::Not(expr) => self.fragments_covered_by_index_query(expr).await,
ScalarIndexExpr::Query(search) => {
let idx = self
.dataset
.load_scalar_index(IndexCriteria::default().with_name(&search.index_name))
.await?
.expect("Index not found even though it must have been found earlier");
Ok(idx
.fragment_bitmap
.expect("scalar indices should always have a fragment bitmap"))
}
}
}
async fn partition_frags_by_coverage(
&self,
index_expr: &ScalarIndexExpr,
fragments: Arc<Vec<Fragment>>,
) -> Result<(Vec<Fragment>, Vec<Fragment>)> {
let covered_frags = self.fragments_covered_by_index_query(index_expr).await?;
let mut relevant_frags = Vec::with_capacity(fragments.len());
let mut missing_frags = Vec::with_capacity(fragments.len());
for fragment in fragments.iter() {
if covered_frags.contains(fragment.id as u32) {
relevant_frags.push(fragment.clone());
} else {
missing_frags.push(fragment.clone());
}
}
Ok((relevant_frags, missing_frags))
}
async fn scalar_indexed_scan(
&self,
projection: Projection,
filter_plan: &ExprFilterPlan,
fragments: Arc<Vec<Fragment>>,
) -> Result<Arc<dyn ExecutionPlan>> {
log::trace!("scalar indexed scan");
let index_expr = filter_plan.index_query.as_ref().unwrap();
let needs_recheck = index_expr.needs_recheck();
let (relevant_frags, missing_frags) = self
.partition_frags_by_coverage(index_expr, fragments)
.await?;
let mut plan: Arc<dyn ExecutionPlan> = Arc::new(MaterializeIndexExec::new(
self.dataset.clone(),
index_expr.clone(),
Arc::new(relevant_frags),
));
let refine_expr = filter_plan.refine_expr.as_ref();
let needs_take =
needs_recheck || projection.has_data_fields() || filter_plan.refine_expr.is_some();
if needs_take {
let mut take_projection = projection.clone();
if needs_recheck {
let filter_expr = index_expr.to_expr();
let filter_cols = Planner::column_names_in_expr(&filter_expr);
take_projection = take_projection.union_columns(filter_cols, OnMissing::Error)?;
}
if let Some(refine_expr) = refine_expr {
let refine_cols = Planner::column_names_in_expr(refine_expr);
take_projection = take_projection.union_columns(refine_cols, OnMissing::Error)?;
}
log::trace!("need to take additional columns for scalar_indexed_scan");
plan = self.take(plan, take_projection)?;
}
let post_take_filter = match (needs_recheck, refine_expr) {
(false, None) => None,
(true, None) => {
Some(index_expr.to_expr())
}
(true, Some(_)) => Some(filter_plan.full_expr.as_ref().unwrap().clone()),
(false, Some(refine_expr)) => Some(refine_expr.clone()),
};
if let Some(post_take_filter) = post_take_filter {
let planner = Planner::new(plan.schema());
let optimized_filter = planner.optimize_expr(post_take_filter)?;
log::trace!("applying post-take filter to indexed scan");
plan = Arc::new(LanceFilterExec::try_new(optimized_filter, plan)?);
}
if self.projection_plan.physical_projection.with_row_addr {
plan = Arc::new(AddRowAddrExec::try_new(plan, self.dataset.clone(), 0)?);
}
let new_data_path: Option<Arc<dyn ExecutionPlan>> = if !missing_frags.is_empty() {
log::trace!(
"scalar_indexed_scan will need full scan of {} missing fragments",
missing_frags.len()
);
let filter = filter_plan.full_expr.as_ref().unwrap();
let filter_cols = Planner::column_names_in_expr(filter);
let scan_projection = projection.union_columns(filter_cols, OnMissing::Error)?;
let scan_schema = Arc::new(scan_projection.to_bare_schema());
let scan_arrow_schema = Arc::new(scan_schema.as_ref().into());
let planner = Planner::new(scan_arrow_schema);
let optimized_filter = planner.optimize_expr(filter.clone())?;
let new_data_scan = self.scan_fragments(
true,
self.projection_plan.physical_projection.with_row_addr,
self.projection_plan
.physical_projection
.with_row_last_updated_at_version,
self.projection_plan
.physical_projection
.with_row_created_at_version,
false,
scan_schema,
missing_frags.into(),
None,
false,
);
let filtered = Arc::new(LanceFilterExec::try_new(optimized_filter, new_data_scan)?);
Some(Arc::new(project(filtered, plan.schema().as_ref())?))
} else {
log::trace!("scalar_indexed_scan will not need full scan of any missing fragments");
None
};
if let Some(new_data_path) = new_data_path {
let unioned = UnionExec::try_new(vec![plan, new_data_path])?;
let unioned = Arc::new(RepartitionExec::try_new(
unioned,
datafusion::physical_plan::Partitioning::RoundRobinBatch(1),
)?);
Ok(unioned)
} else {
Ok(plan)
}
}
fn get_io_buffer_size(&self) -> u64 {
self.io_buffer_size.unwrap_or(*DEFAULT_IO_BUFFER_SIZE)
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn scan(
&self,
with_row_id: bool,
with_row_address: bool,
with_row_last_updated_at_version: bool,
with_row_created_at_version: bool,
with_make_deletions_null: bool,
range: Option<Range<u64>>,
projection: Arc<Schema>,
) -> Arc<dyn ExecutionPlan> {
let fragments = if let Some(fragment) = self.fragments.as_ref() {
Arc::new(fragment.clone())
} else {
self.dataset.fragments().clone()
};
let ordered = if self.ordering.is_some() || self.nearest.is_some() {
false
} else {
self.ordered
};
self.scan_fragments(
with_row_id,
with_row_address,
with_row_last_updated_at_version,
with_row_created_at_version,
with_make_deletions_null,
projection,
fragments,
range,
ordered,
)
}
#[allow(clippy::too_many_arguments)]
fn scan_fragments(
&self,
with_row_id: bool,
with_row_address: bool,
with_row_last_updated_at_version: bool,
with_row_created_at_version: bool,
with_make_deletions_null: bool,
projection: Arc<Schema>,
fragments: Arc<Vec<Fragment>>,
range: Option<Range<u64>>,
ordered: bool,
) -> Arc<dyn ExecutionPlan> {
log::trace!("scan_fragments covered {} fragments", fragments.len());
let config = LanceScanConfig {
batch_size: self.get_batch_size(),
batch_readahead: self.batch_readahead,
fragment_readahead: self.fragment_readahead,
io_buffer_size: self.get_io_buffer_size(),
with_row_id,
with_row_address,
with_row_last_updated_at_version,
with_row_created_at_version,
with_make_deletions_null,
ordered_output: ordered,
};
Arc::new(LanceScanExec::new(
self.dataset.clone(),
fragments,
range,
projection,
config,
))
}
fn pushdown_scan(
&self,
make_deletions_null: bool,
filter_plan: &ExprFilterPlan,
) -> Result<Arc<dyn ExecutionPlan>> {
log::trace!("pushdown_scan");
let config = ScanConfig {
batch_readahead: self.batch_readahead,
fragment_readahead: self
.fragment_readahead
.unwrap_or(LEGACY_DEFAULT_FRAGMENT_READAHEAD),
with_row_id: self.projection_plan.physical_projection.with_row_id,
with_row_address: self.projection_plan.physical_projection.with_row_addr,
make_deletions_null,
ordered_output: self.ordered,
file_reader_options: self
.file_reader_options
.clone()
.or_else(|| self.dataset.file_reader_options.clone()),
};
let fragments = if let Some(fragment) = self.fragments.as_ref() {
Arc::new(fragment.clone())
} else {
self.dataset.fragments().clone()
};
Ok(Arc::new(LancePushdownScanExec::try_new(
self.dataset.clone(),
fragments,
Arc::new(self.projection_plan.physical_projection.to_bare_schema()),
filter_plan.refine_expr.clone().unwrap(),
config,
)?))
}
async fn flat_fts_filter(
&self,
input: Arc<dyn ExecutionPlan>,
q: &FullTextSearchQuery,
) -> Result<Arc<dyn ExecutionPlan>> {
let fts_query = if q.columns().is_empty() {
let indexed_columns = fts_indexed_columns(self.dataset.clone()).await?;
fill_fts_query_column(&q.query, &indexed_columns, false)?
} else {
q.query.clone()
};
match &fts_query {
FtsQuery::Match(match_query) => {
let schema = Arc::new((input.schema()).try_with_column(SCORE_FIELD.clone())?);
let column = match_query
.column
.as_ref()
.ok_or(Error::invalid_input(
"the column must be specified in the query".to_string(),
))?
.clone();
let input = if schema.column_with_name(&column).is_none() {
let projection = self
.dataset
.empty_projection()
.union_column(&column, OnMissing::Error)?;
self.take(input, projection)?
} else {
input
};
Ok(Arc::new(FlatMatchFilterExec::new(
input,
self.dataset.clone(),
match_query.clone(),
q.params(),
)))
}
_ => Err(Error::not_supported(
"Only Match queries are supported currently when using FTS as a post-filter",
)),
}
}
async fn fts_rerank(
&self,
input: Arc<dyn ExecutionPlan>,
q: &FullTextSearchQuery,
) -> Result<Arc<dyn ExecutionPlan>> {
let fts_query = if q.columns().is_empty() {
let indexed_columns = fts_indexed_columns(self.dataset.clone()).await?;
fill_fts_query_column(&q.query, &indexed_columns, false)?
} else {
q.query.clone()
};
match &fts_query {
FtsQuery::Match(match_query) => {
let schema = Arc::new((input.schema()).try_with_column(SCORE_FIELD.clone())?);
let column = match_query
.column
.as_ref()
.ok_or(Error::invalid_input(
"the column must be specified in the query".to_string(),
))?
.clone();
let input = if schema.column_with_name(&column).is_none() {
let projection = self
.dataset
.empty_projection()
.union_column(&column, OnMissing::Error)?;
self.take(input, projection)?
} else {
input
};
Ok(Arc::new(FlatMatchQueryExec::new(
self.dataset.clone(),
match_query.clone(),
q.params(),
input,
)))
}
_ => {
let default_filter = ExprFilterPlan::default();
let fts_plan = self.fts(&default_filter, q).await?;
let vector_row_id = Column::new_with_schema(ROW_ID, input.schema().as_ref())?;
let fts_row_id = Column::new_with_schema(ROW_ID, fts_plan.schema().as_ref())?;
let join = HashJoinExec::try_new(
input,
fts_plan,
vec![(Arc::new(vector_row_id), Arc::new(fts_row_id))],
None,
&JoinType::Inner,
None,
PartitionMode::CollectLeft,
NullEquality::NullEqualsNull,
)?;
let schema = join.schema();
let mut projection_exprs = Vec::new();
let mut contain_rowid = false;
for field in schema.fields() {
if field.name() == ROW_ID {
if contain_rowid {
continue;
}
contain_rowid = true;
}
projection_exprs.push((
Arc::new(Column::new_with_schema(field.name(), schema.as_ref())?)
as Arc<dyn PhysicalExpr>,
field.name().clone(),
));
}
let projection_exec = ProjectionExec::try_new(projection_exprs, Arc::new(join))?;
Ok(Arc::new(projection_exec))
}
}
}
fn flat_knn(&self, input: Arc<dyn ExecutionPlan>, q: &Query) -> Result<Arc<dyn ExecutionPlan>> {
let metric_type = match q.metric_type {
Some(m) => m,
None => {
let (_, element_type) = get_vector_type(self.dataset.schema(), &q.column)?;
default_distance_type_for(&element_type)
}
};
let flat_dist = Arc::new(KNNVectorDistanceExec::try_new(
input,
&q.column,
q.key.clone(),
metric_type,
)?);
let lower: Option<(Expr, Arc<dyn PhysicalExpr>)> = q
.lower_bound
.map(|v| -> Result<(Expr, Arc<dyn PhysicalExpr>)> {
let logical = col(DIST_COL).gt_eq(lit(v));
let schema = flat_dist.schema();
let df_schema = DFSchema::try_from(schema)?;
let physical = create_physical_expr(&logical, &df_schema, &ExecutionProps::new())?;
Ok::<(Expr, Arc<dyn PhysicalExpr>), _>((logical, physical))
})
.transpose()?;
let upper = q
.upper_bound
.map(|v| -> Result<(Expr, Arc<dyn PhysicalExpr>)> {
let logical = col(DIST_COL).lt(lit(v));
let schema = flat_dist.schema();
let df_schema = DFSchema::try_from(schema)?;
let physical = create_physical_expr(&logical, &df_schema, &ExecutionProps::new())?;
Ok::<(Expr, Arc<dyn PhysicalExpr>), _>((logical, physical))
})
.transpose()?;
let filter_expr = match (lower, upper) {
(Some((llog, _)), Some((ulog, _))) => {
let logical = llog.and(ulog);
let schema = flat_dist.schema();
let df_schema = DFSchema::try_from(schema)?;
let physical = create_physical_expr(&logical, &df_schema, &ExecutionProps::new())?;
Some((logical, physical))
}
(Some((llog, lphys)), None) => Some((llog, lphys)),
(None, Some((ulog, uphys))) => Some((ulog, uphys)),
(None, None) => None,
};
let knn_plan: Arc<dyn ExecutionPlan> = if let Some(filter_expr) = filter_expr {
Arc::new(LanceFilterExec::try_new(filter_expr.0, flat_dist)?)
} else {
flat_dist
};
let sort = SortExec::new(
[
PhysicalSortExpr {
expr: expressions::col(DIST_COL, knn_plan.schema().as_ref())?,
options: SortOptions {
descending: false,
nulls_first: false,
},
},
PhysicalSortExpr {
expr: expressions::col(ROW_ID, knn_plan.schema().as_ref())?,
options: SortOptions {
descending: false,
nulls_first: false,
},
},
]
.into(),
knn_plan,
)
.with_fetch(Some(q.k));
let logical_not_null = col(DIST_COL).is_not_null();
let not_nulls = Arc::new(LanceFilterExec::try_new(logical_not_null, Arc::new(sort))?);
Ok(not_nulls)
}
fn get_fragments_as_bitmap(&self) -> RoaringBitmap {
if let Some(fragments) = &self.fragments {
RoaringBitmap::from_iter(fragments.iter().map(|f| f.id as u32))
} else {
self.dataset.fragment_bitmap.as_ref().clone()
}
}
fn retain_target_fragments(&self, mut fragments: Vec<Fragment>) -> Vec<Fragment> {
if let Some(target) = &self.fragments {
let bitmap = RoaringBitmap::from_iter(target.iter().map(|f| f.id as u32));
fragments.retain(|f| bitmap.contains(f.id as u32));
}
fragments
}
fn get_indexed_frags(&self, index: &[IndexMetadata]) -> RoaringBitmap {
let all_fragments = self.get_fragments_as_bitmap();
let mut all_indexed_frags = RoaringBitmap::new();
for idx in index {
if let Some(fragmap) = idx.fragment_bitmap.as_ref() {
all_indexed_frags |= fragmap;
} else {
return all_fragments;
}
}
all_indexed_frags & all_fragments
}
async fn ann(
&self,
q: &Query,
index: &[IndexMetadata],
filter_plan: &ExprFilterPlan,
) -> Result<Arc<dyn ExecutionPlan>> {
let prefilter_source = self
.prefilter_source(filter_plan, self.get_indexed_frags(index))
.await?;
let inner_fanout_search = new_knn_exec(self.dataset.clone(), index, q, prefilter_source)?;
let sort_expr = PhysicalSortExpr {
expr: expressions::col(DIST_COL, inner_fanout_search.schema().as_ref())?,
options: SortOptions {
descending: false,
nulls_first: false,
},
};
let sort_expr_row_id = PhysicalSortExpr {
expr: expressions::col(ROW_ID, inner_fanout_search.schema().as_ref())?,
options: SortOptions {
descending: false,
nulls_first: false,
},
};
Ok(Arc::new(
SortExec::new([sort_expr, sort_expr_row_id].into(), inner_fanout_search)
.with_fetch(Some(q.k * q.refine_factor.unwrap_or(1) as usize)),
))
}
async fn multivec_ann(
&self,
q: &Query,
index: &[IndexMetadata],
filter_plan: &ExprFilterPlan,
) -> Result<Arc<dyn ExecutionPlan>> {
let over_fetch_factor = *DEFAULT_XTR_OVERFETCH;
let prefilter_source = self
.prefilter_source(filter_plan, self.get_indexed_frags(index))
.await?;
let dim = get_vector_dim(self.dataset.schema(), &q.column)?;
let num_queries = q.key.len() / dim;
let new_queries = (0..num_queries)
.map(|i| q.key.slice(i * dim, dim))
.map(|query_vec| {
let mut new_query = q.clone();
new_query.key = query_vec;
new_query.refine_factor = Some(over_fetch_factor);
new_query
});
let mut ann_nodes = Vec::with_capacity(new_queries.len());
for query in new_queries {
let ann_node = new_knn_exec(
self.dataset.clone(),
index,
&query,
prefilter_source.clone(),
)?;
let sort_expr = PhysicalSortExpr {
expr: expressions::col(DIST_COL, ann_node.schema().as_ref())?,
options: SortOptions {
descending: false,
nulls_first: false,
},
};
let sort_expr_row_id = PhysicalSortExpr {
expr: expressions::col(ROW_ID, ann_node.schema().as_ref())?,
options: SortOptions {
descending: false,
nulls_first: false,
},
};
let ann_node = Arc::new(
SortExec::new([sort_expr, sort_expr_row_id].into(), ann_node)
.with_fetch(Some(q.k * over_fetch_factor as usize)),
);
ann_nodes.push(ann_node as Arc<dyn ExecutionPlan>);
}
let ann_node = Arc::new(MultivectorScoringExec::try_new(ann_nodes, q.clone())?);
let sort_expr = PhysicalSortExpr {
expr: expressions::col(DIST_COL, ann_node.schema().as_ref())?,
options: SortOptions {
descending: false,
nulls_first: false,
},
};
let sort_expr_row_id = PhysicalSortExpr {
expr: expressions::col(ROW_ID, ann_node.schema().as_ref())?,
options: SortOptions {
descending: false,
nulls_first: false,
},
};
let ann_node = Arc::new(
SortExec::new([sort_expr, sort_expr_row_id].into(), ann_node)
.with_fetch(Some(q.k * q.refine_factor.unwrap_or(1) as usize)),
);
Ok(ann_node)
}
async fn prefilter_source(
&self,
filter_plan: &ExprFilterPlan,
required_frags: RoaringBitmap,
) -> Result<PreFilterSource> {
if filter_plan.is_empty() && self.fragments.is_none() {
log::trace!("no filter plan, no prefilter");
return Ok(PreFilterSource::None);
}
let fragments: Vec<Fragment> = self
.dataset
.manifest
.fragments
.iter()
.filter(|f| required_frags.contains(f.id as u32))
.cloned()
.collect();
let fragments = Arc::new(self.retain_target_fragments(fragments));
if filter_plan.is_exact_index_search() && self.fragments.is_none() {
let index_query = filter_plan.index_query.as_ref().expect_ok()?;
let (_, missing_frags) = self
.partition_frags_by_coverage(index_query, fragments.clone())
.await?;
if missing_frags.is_empty() {
log::trace!("prefilter entirely satisfied by exact index search");
return Ok(PreFilterSource::ScalarIndexQuery(Arc::new(
ScalarIndexExec::new(self.dataset.clone(), index_query.clone()),
)));
} else {
log::trace!("exact index search did not cover all fragments");
}
}
log::trace!(
"prefilter is a filtered read of {} fragments",
fragments.len()
);
let PlannedFilteredScan { plan, .. } = self
.filtered_read(
filter_plan,
self.dataset.empty_projection().with_row_id(),
false,
Some(fragments),
None,
true,
)
.await?;
Ok(PreFilterSource::FilteredRowIds(plan))
}
fn take(
&self,
input: Arc<dyn ExecutionPlan>,
output_projection: Projection,
) -> Result<Arc<dyn ExecutionPlan>> {
let coalesced = Arc::new(CoalesceBatchesExec::new(
input.clone(),
self.get_batch_size(),
));
if let Some(take_plan) =
TakeExec::try_new(self.dataset.clone(), coalesced, output_projection)?
{
Ok(Arc::new(take_plan))
} else {
Ok(input)
}
}
fn limit_node(&self, plan: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
Arc::new(GlobalLimitExec::new(
plan,
*self.offset.as_ref().unwrap_or(&0) as usize,
self.limit.map(|l| l as usize),
))
}
#[instrument(level = "info", skip(self))]
pub async fn analyze_plan(&self) -> Result<String> {
let plan = self.create_plan().await?;
analyze_plan(
plan,
LanceExecutionOptions {
batch_size: self.batch_size,
..Default::default()
},
)
.await
}
#[instrument(level = "info", skip(self))]
pub async fn explain_plan(&self, verbose: bool) -> Result<String> {
let plan = self.create_plan().await?;
let display = DisplayableExecutionPlan::new(plan.as_ref());
Ok(format!("{}", display.indent(verbose)))
}
}
async fn fts_indexed_columns(dataset: Arc<Dataset>) -> Result<Vec<String>> {
let mut indexed_columns = Vec::new();
for field in dataset.schema().fields_pre_order() {
let is_string_field = match field.data_type() {
DataType::Utf8 | DataType::LargeUtf8 => true,
DataType::List(inner_field) | DataType::LargeList(inner_field) => {
matches!(
inner_field.data_type(),
DataType::Utf8 | DataType::LargeUtf8
)
}
_ => false,
};
if is_string_field {
let column_path =
if let Some(ancestors) = dataset.schema().field_ancestry_by_id(field.id) {
let field_refs: Vec<&str> = ancestors.iter().map(|f| f.name.as_str()).collect();
format_field_path(&field_refs)
} else {
continue; };
let has_fts_index = dataset
.load_scalar_index(
IndexCriteria::default()
.for_column(&column_path)
.supports_fts(),
)
.await?
.is_some();
if has_fts_index {
indexed_columns.push(column_path);
}
}
}
Ok(indexed_columns)
}
#[pin_project::pin_project]
pub struct DatasetRecordBatchStream {
#[pin]
exec_node: SendableRecordBatchStream,
span: Span,
}
impl DatasetRecordBatchStream {
pub fn new(exec_node: SendableRecordBatchStream) -> Self {
let schema = exec_node.schema();
let adapter = SchemaAdapter::new(schema.clone());
let exec_node = if SchemaAdapter::requires_logical_conversion(&schema) {
adapter.to_logical_stream(exec_node)
} else {
exec_node
};
let span = info_span!("DatasetRecordBatchStream");
Self { exec_node, span }
}
}
impl RecordBatchStream for DatasetRecordBatchStream {
fn schema(&self) -> SchemaRef {
self.exec_node.schema()
}
}
impl Stream for DatasetRecordBatchStream {
type Item = Result<RecordBatch>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
let _guard = this.span.enter();
match this.exec_node.poll_next_unpin(cx) {
Poll::Ready(result) => Poll::Ready(result.map(|r| Ok(r?))),
Poll::Pending => Poll::Pending,
}
}
}
impl From<DatasetRecordBatchStream> for SendableRecordBatchStream {
fn from(stream: DatasetRecordBatchStream) -> Self {
stream.exec_node
}
}
#[cfg(test)]
pub mod test_dataset {
use super::*;
use std::{collections::HashMap, vec};
use arrow_array::{
ArrayRef, FixedSizeListArray, Int32Array, RecordBatch, RecordBatchIterator, StringArray,
};
use arrow_schema::{ArrowError, DataType};
use lance_arrow::FixedSizeListArrayExt;
use lance_core::utils::tempfile::TempStrDir;
use lance_file::version::LanceFileVersion;
use lance_index::{
IndexType,
scalar::{ScalarIndexParams, inverted::tokenizer::InvertedIndexParams},
};
use crate::dataset::WriteParams;
use crate::index::vector::VectorIndexParams;
pub struct TestVectorDataset {
pub tmp_dir: TempStrDir,
pub schema: Arc<ArrowSchema>,
pub dataset: Dataset,
dimension: u32,
}
impl TestVectorDataset {
pub async fn new(
data_storage_version: LanceFileVersion,
stable_row_ids: bool,
) -> Result<Self> {
Self::new_with_dimension(data_storage_version, stable_row_ids, 32).await
}
pub async fn new_with_dimension(
data_storage_version: LanceFileVersion,
stable_row_ids: bool,
dimension: u32,
) -> Result<Self> {
let path = TempStrDir::default();
let metadata: HashMap<String, String> =
vec![("dataset".to_string(), "vector".to_string())]
.into_iter()
.collect();
let schema = Arc::new(ArrowSchema::new_with_metadata(
vec![
ArrowField::new("i", DataType::Int32, true),
ArrowField::new("s", DataType::Utf8, true),
ArrowField::new(
"vec",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
dimension as i32,
),
true,
),
],
metadata,
));
let batches: Vec<RecordBatch> = (0..5)
.map(|i| {
let vector_values: Float32Array =
(0..dimension * 80).map(|v| v as f32).collect();
let vectors =
FixedSizeListArray::try_new_from_values(vector_values, dimension as i32)
.unwrap();
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(i * 80..(i + 1) * 80)),
Arc::new(StringArray::from_iter_values(
(i * 80..(i + 1) * 80).map(|v| format!("s-{}", v)),
)),
Arc::new(vectors),
],
)
})
.collect::<std::result::Result<Vec<_>, ArrowError>>()?;
let params = WriteParams {
max_rows_per_group: 10,
max_rows_per_file: 200,
data_storage_version: Some(data_storage_version),
enable_stable_row_ids: stable_row_ids,
..Default::default()
};
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());
let dataset = Dataset::write(reader, &path, Some(params)).await?;
Ok(Self {
tmp_dir: path,
schema,
dataset,
dimension,
})
}
pub async fn make_vector_index(&mut self) -> Result<()> {
let params = VectorIndexParams::ivf_pq(2, 8, 2, MetricType::L2, 2);
self.dataset
.create_index(
&["vec"],
IndexType::Vector,
Some("idx".to_string()),
¶ms,
true,
)
.await?;
Ok(())
}
pub async fn make_scalar_index(&mut self) -> Result<()> {
self.dataset
.create_index(
&["i"],
IndexType::Scalar,
None,
&ScalarIndexParams::default(),
true,
)
.await?;
Ok(())
}
pub async fn make_fts_index(&mut self) -> Result<()> {
let params = InvertedIndexParams::default().with_position(true);
self.dataset
.create_index(&["s"], IndexType::Inverted, None, ¶ms, true)
.await?;
Ok(())
}
pub async fn append_new_data(&mut self) -> Result<()> {
self.append_data_with_range(400, 410).await
}
pub async fn append_data_with_range(&mut self, start: i32, end: i32) -> Result<()> {
let count = (end - start) as usize;
let vector_values: Float32Array = (0..count)
.flat_map(|i| vec![i as f32; self.dimension as usize].into_iter())
.collect();
let new_vectors =
FixedSizeListArray::try_new_from_values(vector_values, self.dimension as i32)
.unwrap();
let new_data: Vec<ArrayRef> = vec![
Arc::new(Int32Array::from_iter_values(start..end)),
Arc::new(StringArray::from_iter_values(
(start..end).map(|v| format!("s-{}", v)),
)),
Arc::new(new_vectors),
];
let reader = RecordBatchIterator::new(
vec![RecordBatch::try_new(self.schema.clone(), new_data).unwrap()]
.into_iter()
.map(Ok),
self.schema.clone(),
);
self.dataset.append(reader, None).await?;
Ok(())
}
}
}
#[cfg(test)]
mod test {
use std::collections::BTreeSet;
use std::time::{Duration, Instant};
use std::vec;
use arrow::array::as_primitive_array;
use arrow::datatypes::{Float64Type, Int32Type, Int64Type};
use arrow_array::cast::AsArray;
use arrow_array::types::{Float32Type, UInt64Type};
use arrow_array::{
ArrayRef, FixedSizeListArray, Float16Array, Int32Array, LargeStringArray, PrimitiveArray,
RecordBatchIterator, StringArray, StructArray, UInt8Array,
};
use arrow_ord::sort::sort_to_indices;
use arrow_schema::Fields;
use arrow_select::take;
use datafusion::logical_expr::{col, lit};
use half::f16;
use lance_arrow::{FixedSizeListArrayExt, SchemaExt};
use lance_core::utils::tempfile::TempStrDir;
use lance_core::{ROW_CREATED_AT_VERSION, ROW_LAST_UPDATED_AT_VERSION};
use lance_datagen::{
ArrayGeneratorExt, BatchCount, ByteCount, Dimension, RowCount, array, gen_batch,
};
use lance_file::version::LanceFileVersion;
use lance_index::optimize::OptimizeOptions;
use lance_index::scalar::inverted::query::{MatchQuery, PhraseQuery};
use lance_index::vector::hnsw::builder::HnswBuildParams;
use lance_index::vector::ivf::IvfBuildParams;
use lance_index::vector::pq::PQBuildParams;
use lance_index::vector::sq::builder::SQBuildParams;
use lance_index::{IndexType, scalar::ScalarIndexParams};
use lance_io::assert_io_gt;
use lance_io::object_store::ObjectStoreParams;
use lance_linalg::distance::DistanceType;
use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector};
use object_store::throttle::ThrottleConfig;
use rstest::rstest;
use super::*;
use crate::dataset::WriteMode;
use crate::dataset::WriteParams;
use crate::dataset::optimize::{CompactionOptions, compact_files};
use crate::dataset::scanner::test_dataset::TestVectorDataset;
use crate::index::vector::{StageParams, VectorIndexParams};
use crate::utils::test::{
DatagenExt, FragmentCount, FragmentRowCount, ThrottledStoreWrapper, assert_plan_node_equals,
};
#[test]
fn test_env_var_parsing() {
unsafe {
std::env::set_var("LANCE_DEFAULT_BATCH_SIZE", "not_a_number");
}
let result = get_default_batch_size();
assert_eq!(result, None, "Should return None for invalid batch size");
unsafe {
std::env::set_var("LANCE_DEFAULT_BATCH_SIZE", "2048");
}
let result = get_default_batch_size();
assert_eq!(result, Some(2048), "Should parse valid batch size");
unsafe {
std::env::remove_var("LANCE_DEFAULT_BATCH_SIZE");
}
let result = get_default_batch_size();
assert_eq!(result, None, "Should return None when env var is not set");
}
#[test]
fn test_parse_env_var() {
let test_var = "LANCE_TEST_PARSE_ENV_VAR_USIZE";
unsafe {
std::env::set_var(test_var, "12345");
}
let result: Option<usize> = parse_env_var(test_var, "Using default.");
assert_eq!(result, Some(12345));
unsafe {
std::env::set_var(test_var, "not_a_number");
}
let result: Option<usize> = parse_env_var(test_var, "Using default.");
assert_eq!(result, None);
unsafe {
std::env::remove_var(test_var);
}
let result: Option<usize> = parse_env_var(test_var, "Using default.");
assert_eq!(result, None);
let test_var_u32 = "LANCE_TEST_PARSE_ENV_VAR_U32";
unsafe {
std::env::set_var(test_var_u32, "42");
}
let result: Option<u32> = parse_env_var(test_var_u32, "Using default value.");
assert_eq!(result, Some(42));
unsafe {
std::env::set_var(test_var_u32, "invalid");
}
let result: Option<u32> = parse_env_var(test_var_u32, "Using default value.");
assert_eq!(result, None);
unsafe {
std::env::remove_var(test_var_u32);
}
let test_var_u64 = "LANCE_TEST_PARSE_ENV_VAR_U64";
unsafe {
std::env::set_var(test_var_u64, "9999999999");
}
let result: Option<u64> = parse_env_var(test_var_u64, "Using default value.");
assert_eq!(result, Some(9999999999));
unsafe {
std::env::set_var(test_var_u64, "-1");
}
let result: Option<u64> = parse_env_var(test_var_u64, "Using default value.");
assert_eq!(result, None);
unsafe {
std::env::remove_var(test_var_u64);
}
}
async fn make_binary_vector_dataset() -> Result<(TempStrDir, Dataset)> {
let tmp_dir = TempStrDir::default();
let dim = 4;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("id", DataType::Int32, false),
ArrowField::new(
"bin",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::UInt8, true)),
dim,
),
false,
),
]));
let vectors = FixedSizeListArray::try_new_from_values(
UInt8Array::from(vec![
0b0000_1111u8,
0,
0,
0, 0b0000_0011u8,
0,
0,
0, 0u8,
0,
0,
0,
]),
dim,
)?;
let ids = Int32Array::from(vec![0, 1, 2]);
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vectors)])?;
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
Dataset::write(reader, &tmp_dir, None).await?;
let dataset = Dataset::open(&tmp_dir).await?;
Ok((tmp_dir, dataset))
}
#[tokio::test]
async fn test_batch_size() {
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("i", DataType::Int32, true),
ArrowField::new("s", DataType::Utf8, true),
]));
let batches: Vec<RecordBatch> = (0..5)
.map(|i| {
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(i * 20..(i + 1) * 20)),
Arc::new(StringArray::from_iter_values(
(i * 20..(i + 1) * 20).map(|v| format!("s-{}", v)),
)),
],
)
.unwrap()
})
.collect();
for use_filter in [false, true] {
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let write_params = WriteParams {
max_rows_per_file: 40,
max_rows_per_group: 10,
..Default::default()
};
let batches =
RecordBatchIterator::new(batches.clone().into_iter().map(Ok), schema.clone());
Dataset::write(batches, test_uri, Some(write_params))
.await
.unwrap();
let dataset = Dataset::open(test_uri).await.unwrap();
let mut builder = dataset.scan();
builder.batch_size(8);
if use_filter {
builder.filter("i IS NOT NULL").unwrap();
}
let mut stream = builder.try_into_stream().await.unwrap();
let mut rows_read = 0;
while let Some(next) = stream.next().await {
let next = next.unwrap();
let expected = 8.min(100 - rows_read);
assert_eq!(next.num_rows(), expected);
rows_read += next.num_rows();
}
}
}
#[tokio::test]
async fn test_strict_batch_size() {
let dataset = lance_datagen::gen_batch()
.col("x", array::step::<Int32Type>())
.anon_col(array::step::<Int64Type>())
.into_ram_dataset(FragmentCount::from(7), FragmentRowCount::from(6))
.await
.unwrap();
let mut scan = dataset.scan();
scan.batch_size(10)
.strict_batch_size(true)
.filter("x % 2 == 0")
.unwrap();
let batches = scan
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch_sizes = batches.iter().map(|b| b.num_rows()).collect::<Vec<_>>();
assert_eq!(batch_sizes, vec![10, 10, 1]);
}
#[tokio::test]
async fn test_column_not_exist() {
let dataset = lance_datagen::gen_batch()
.col("x", array::step::<Int32Type>())
.into_ram_dataset(FragmentCount::from(7), FragmentRowCount::from(6))
.await
.unwrap();
let check_err_msg = |r: Result<DatasetRecordBatchStream>| {
let Err(err) = r else {
panic!(
"Expected an error to be raised saying column y is not found but got no error"
)
};
assert!(
err.to_string().contains("No field named y"),
"Expected error to contain 'No field named y' but got {}",
err
);
};
let mut scan = dataset.scan();
scan.project(&["x", "y"]).unwrap();
check_err_msg(scan.try_into_stream().await);
let mut scan = dataset.scan();
scan.project(&["y"]).unwrap();
check_err_msg(scan.try_into_stream().await);
let mut scan = dataset.scan();
scan.project_with_transform(&[("foo", "1")]).unwrap();
match scan.try_into_stream().await {
Ok(_) => panic!("Expected an error to be raised saying not supported"),
Err(e) => {
assert!(
e.to_string().contains("Received only dynamic expressions"),
"Expected error to contain 'Received only dynamic expressions' but got {}",
e
);
}
}
}
#[cfg(not(windows))]
#[tokio::test]
async fn test_local_object_store() {
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("i", DataType::Int32, true),
ArrowField::new("s", DataType::Utf8, true),
]));
let batches: Vec<RecordBatch> = (0..5)
.map(|i| {
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(i * 20..(i + 1) * 20)),
Arc::new(StringArray::from_iter_values(
(i * 20..(i + 1) * 20).map(|v| format!("s-{}", v)),
)),
],
)
.unwrap()
})
.collect();
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let write_params = WriteParams {
max_rows_per_file: 40,
max_rows_per_group: 10,
..Default::default()
};
let batches = RecordBatchIterator::new(batches.clone().into_iter().map(Ok), schema.clone());
Dataset::write(batches, test_uri, Some(write_params))
.await
.unwrap();
let dataset = Dataset::open(&format!("file-object-store://{}", test_uri))
.await
.unwrap();
let mut builder = dataset.scan();
builder.batch_size(8);
let mut stream = builder.try_into_stream().await.unwrap();
let mut rows_read = 0;
while let Some(next) = stream.next().await {
let next = next.unwrap();
let expected = 8.min(100 - rows_read);
assert_eq!(next.num_rows(), expected);
rows_read += next.num_rows();
}
}
#[tokio::test]
async fn test_filter_parsing() -> Result<()> {
let test_ds = TestVectorDataset::new(LanceFileVersion::Stable, false).await?;
let dataset = &test_ds.dataset;
let mut scan = dataset.scan();
assert!(scan.filter.is_none());
scan.filter("i > 50")?;
assert_eq!(scan.get_expr_filter().unwrap(), Some(col("i").gt(lit(50))));
for use_stats in [false, true] {
let batches = scan
.project(&["s"])?
.use_stats(use_stats)
.try_into_stream()
.await?
.try_collect::<Vec<_>>()
.await?;
let batch = concat_batches(&batches[0].schema(), &batches)?;
let expected_batch = RecordBatch::try_new(
Arc::new(test_ds.schema.project(&[1])?),
vec![Arc::new(StringArray::from_iter_values(
(51..400).map(|v| format!("s-{}", v)),
))],
)?;
assert_eq!(batch, expected_batch);
}
Ok(())
}
#[tokio::test]
async fn test_scan_regexp_match_and_non_empty_captions() {
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("keywords", DataType::Utf8, true),
ArrowField::new("natural_caption", DataType::Utf8, true),
ArrowField::new("poetic_caption", DataType::Utf8, true),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec![
Some("Liberty for all"),
Some("peace"),
Some("revolution now"),
Some("Liberty"),
Some("revolutionary"),
Some("none"),
])) as ArrayRef,
Arc::new(StringArray::from(vec![
Some("a"),
Some("b"),
None,
Some(""),
Some("c"),
Some("d"),
])) as ArrayRef,
Arc::new(StringArray::from(vec![
Some("x"),
Some(""),
Some("y"),
Some("z"),
None,
Some("w"),
])) as ArrayRef,
],
)
.unwrap();
let reader = RecordBatchIterator::new(vec![Ok(batch.clone())], schema.clone());
let dataset = Dataset::write(reader, "memory://", None).await.unwrap();
let mut scan = dataset.scan();
scan.filter(
"regexp_match(keywords, 'Liberty|revolution') AND \
(natural_caption IS NOT NULL AND natural_caption <> '' AND \
poetic_caption IS NOT NULL AND poetic_caption <> '')",
)
.unwrap();
let out = scan.try_into_batch().await.unwrap();
assert_eq!(out.num_rows(), 1);
let out_keywords = out
.column_by_name("keywords")
.unwrap()
.as_string::<i32>()
.value(0);
let out_nat = out
.column_by_name("natural_caption")
.unwrap()
.as_string::<i32>()
.value(0);
let out_poetic = out
.column_by_name("poetic_caption")
.unwrap()
.as_string::<i32>()
.value(0);
assert_eq!(out_keywords, "Liberty for all");
assert_eq!(out_nat, "a");
assert_eq!(out_poetic, "x");
}
#[tokio::test]
async fn test_nested_projection() {
let point_fields: Fields = vec![
ArrowField::new("x", DataType::Float32, true),
ArrowField::new("y", DataType::Float32, true),
]
.into();
let metadata_fields: Fields = vec![
ArrowField::new("location", DataType::Struct(point_fields), true),
ArrowField::new("age", DataType::Int32, true),
]
.into();
let metadata_field = ArrowField::new("metadata", DataType::Struct(metadata_fields), true);
let schema = Arc::new(ArrowSchema::new(vec![
metadata_field,
ArrowField::new("idx", DataType::Int32, true),
]));
let data = lance_datagen::rand(&schema)
.into_ram_dataset(FragmentCount::from(7), FragmentRowCount::from(6))
.await
.unwrap();
let mut scan = data.scan();
scan.project(&["metadata.location.x", "metadata.age"])
.unwrap();
let batch = scan.try_into_batch().await.unwrap();
assert_eq!(
batch.schema().as_ref(),
&ArrowSchema::new(vec![
ArrowField::new("metadata.location.x", DataType::Float32, true),
ArrowField::new("metadata.age", DataType::Int32, true),
])
);
let take_schema = data.schema().project_by_ids(&[0, 2, 4], false);
let taken = data.take_rows(&[0, 5], take_schema).await.unwrap();
let part_point_fields = Fields::from(vec![ArrowField::new("x", DataType::Float32, true)]);
let part_metadata_fields = Fields::from(vec![
ArrowField::new("location", DataType::Struct(part_point_fields), true),
ArrowField::new("age", DataType::Int32, true),
]);
let part_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"metadata",
DataType::Struct(part_metadata_fields),
true,
)]));
assert_eq!(taken.schema(), part_schema);
}
#[rstest]
#[tokio::test]
async fn test_limit(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) -> Result<()> {
let test_ds = TestVectorDataset::new(data_storage_version, false).await?;
let dataset = &test_ds.dataset;
let full_data = dataset.scan().try_into_batch().await?.slice(19, 2);
let actual = dataset
.scan()
.limit(Some(2), Some(19))?
.try_into_batch()
.await?;
assert_eq!(actual.num_rows(), 2);
assert_eq!(actual, full_data);
Ok(())
}
#[test_log::test(tokio::test)]
async fn test_limit_cancel() {
let throttled = Arc::new(ThrottledStoreWrapper {
config: ThrottleConfig {
wait_get_per_call: Duration::from_secs(1),
..Default::default()
},
});
let write_params = WriteParams {
store_params: Some(ObjectStoreParams {
object_store_wrapper: Some(throttled.clone()),
..Default::default()
}),
max_rows_per_file: 1,
..Default::default()
};
let dataset = gen_batch()
.col("i", array::step::<Int32Type>().with_random_nulls(0.1))
.into_ram_dataset_with_params(
FragmentCount::from(2000),
FragmentRowCount::from(1),
Some(write_params),
)
.await
.unwrap();
let mut scan = dataset.scan();
scan.filter("i IS NOT NULL").unwrap();
scan.limit(Some(10), None).unwrap();
let start = Instant::now();
scan.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let duration = start.elapsed();
assert!(duration < Duration::from_secs(10));
}
#[rstest]
#[tokio::test]
async fn test_knn_nodes(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
#[values(false, true)] stable_row_ids: bool,
#[values(false, true)] build_index: bool,
) {
let mut test_ds = TestVectorDataset::new(data_storage_version, stable_row_ids)
.await
.unwrap();
if build_index {
test_ds.make_vector_index().await.unwrap();
}
let dataset = &test_ds.dataset;
let mut scan = dataset.scan();
let key: Float32Array = (32..64).map(|v| v as f32).collect();
scan.nearest("vec", &key, 5).unwrap();
scan.refine(5);
let batch = scan.try_into_batch().await.unwrap();
assert_eq!(batch.num_rows(), 5);
assert_eq!(
batch.schema().as_ref(),
&ArrowSchema::new(vec![
ArrowField::new("i", DataType::Int32, true),
ArrowField::new("s", DataType::Utf8, true),
ArrowField::new(
"vec",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
32,
),
true,
),
ArrowField::new(DIST_COL, DataType::Float32, true),
])
.with_metadata([("dataset".into(), "vector".into())].into())
);
let expected_i = BTreeSet::from_iter(vec![1, 81, 161, 241, 321]);
let column_i = batch.column_by_name("i").unwrap();
let actual_i: BTreeSet<i32> = as_primitive_array::<Int32Type>(column_i.as_ref())
.values()
.iter()
.copied()
.collect();
assert_eq!(expected_i, actual_i);
}
#[rstest]
#[tokio::test]
async fn test_can_project_distance() {
let test_ds = TestVectorDataset::new(LanceFileVersion::Stable, true)
.await
.unwrap();
let dataset = &test_ds.dataset;
let mut scan = dataset.scan();
let key: Float32Array = (32..64).map(|v| v as f32).collect();
scan.nearest("vec", &key, 5).unwrap();
scan.refine(5);
scan.project(&["_distance"]).unwrap();
let batch = scan.try_into_batch().await.unwrap();
assert_eq!(batch.num_rows(), 5);
assert_eq!(batch.num_columns(), 1);
}
#[rstest]
#[tokio::test]
async fn test_knn_with_new_data(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
#[values(false, true)] stable_row_ids: bool,
) {
let mut test_ds = TestVectorDataset::new(data_storage_version, stable_row_ids)
.await
.unwrap();
test_ds.make_vector_index().await.unwrap();
test_ds.append_new_data().await.unwrap();
let dataset = &test_ds.dataset;
let key: Float32Array = [0f32; 32].into_iter().collect();
let k = 20;
#[derive(Debug)]
struct TestCase {
filter: Option<&'static str>,
limit: Option<i64>,
use_index: bool,
}
let mut cases = vec![];
for filter in [Some("i > 100"), None] {
for limit in [None, Some(10)] {
for use_index in [true, false] {
cases.push(TestCase {
filter,
limit,
use_index,
});
}
}
}
for case in cases {
let mut scanner = dataset.scan();
scanner
.nearest("vec", &key, k)
.unwrap()
.limit(case.limit, None)
.unwrap()
.refine(3)
.use_index(case.use_index);
if let Some(filter) = case.filter {
scanner.filter(filter).unwrap();
}
let result = scanner
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert!(!result.is_empty());
let result = concat_batches(&result[0].schema(), result.iter()).unwrap();
if case.filter.is_some() {
let result_rows = result.num_rows();
let expected_rows = case.limit.unwrap_or(k as i64) as usize;
assert!(
result_rows <= expected_rows,
"Expected less than {} rows, got {}",
expected_rows,
result_rows
);
} else {
assert_eq!(result.num_rows(), case.limit.unwrap_or(k as i64) as usize);
}
assert_eq!(
as_primitive_array::<Int32Type>(result.column(0).as_ref()).value(0),
400
);
}
}
#[rstest]
#[tokio::test]
async fn test_knn_with_prefilter(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
#[values(false, true)] stable_row_ids: bool,
) {
let mut test_ds = TestVectorDataset::new(data_storage_version, stable_row_ids)
.await
.unwrap();
test_ds.make_vector_index().await.unwrap();
let dataset = &test_ds.dataset;
let mut scan = dataset.scan();
let key: Float32Array = (32..64).map(|v| v as f32).collect();
scan.filter("i > 100").unwrap();
scan.prefilter(true);
scan.project(&["i", "vec"]).unwrap();
scan.nearest("vec", &key, 5).unwrap();
scan.use_index(false);
let results = scan
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_rows(), 5);
assert_eq!(
batch.schema().as_ref(),
&ArrowSchema::new(vec![
ArrowField::new("i", DataType::Int32, true),
ArrowField::new(
"vec",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
32,
),
true,
),
ArrowField::new(DIST_COL, DataType::Float32, true),
])
.with_metadata([("dataset".into(), "vector".into())].into())
);
let exact_i = BTreeSet::from_iter(vec![161, 241, 321]);
let close_i = BTreeSet::from_iter(vec![161, 241, 321, 160, 162, 240, 242, 320, 322]);
let column_i = batch.column_by_name("i").unwrap();
let actual_i: BTreeSet<i32> = as_primitive_array::<Int32Type>(column_i.as_ref())
.values()
.iter()
.copied()
.collect();
assert!(exact_i.is_subset(&actual_i));
assert!(actual_i.is_subset(&close_i));
}
#[rstest]
#[tokio::test]
async fn test_knn_filter_new_data(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
#[values(false, true)] stable_row_ids: bool,
) {
let mut test_ds = TestVectorDataset::new(data_storage_version, stable_row_ids)
.await
.unwrap();
test_ds.make_vector_index().await.unwrap();
test_ds.append_new_data().await.unwrap();
let dataset = &test_ds.dataset;
let key: Float32Array = [0f32; 32].into_iter().collect();
let mut query = dataset.scan();
query.nearest("vec", &key, 20).unwrap();
let results = query
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let results_i = results[0]["i"]
.as_primitive::<Int32Type>()
.values()
.iter()
.copied()
.collect::<BTreeSet<_>>();
assert!(results_i.contains(&400));
for prefilter in [false, true] {
let mut query = dataset.scan();
query
.filter("i != 400")
.unwrap()
.prefilter(prefilter)
.nearest("vec", &key, 20)
.unwrap();
let results = query
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let results_i = results[0]["i"]
.as_primitive::<Int32Type>()
.values()
.iter()
.copied()
.collect::<BTreeSet<_>>();
assert!(!results_i.contains(&400));
}
}
#[rstest]
#[tokio::test]
async fn test_knn_with_filter(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
#[values(false, true)] stable_row_ids: bool,
) {
let test_ds = TestVectorDataset::new(data_storage_version, stable_row_ids)
.await
.unwrap();
let dataset = &test_ds.dataset;
let mut scan = dataset.scan();
let key: Float32Array = (32..64).map(|v| v as f32).collect();
scan.nearest("vec", &key, 5).unwrap();
scan.filter("i > 100").unwrap();
scan.project(&["i", "vec"]).unwrap();
scan.refine(5);
let results = scan
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_rows(), 3);
assert_eq!(
batch.schema().as_ref(),
&ArrowSchema::new(vec![
ArrowField::new("i", DataType::Int32, true),
ArrowField::new(
"vec",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
32,
),
true,
),
ArrowField::new(DIST_COL, DataType::Float32, true),
])
.with_metadata([("dataset".into(), "vector".into())].into())
);
let expected_i = BTreeSet::from_iter(vec![161, 241, 321]);
let column_i = batch.column_by_name("i").unwrap();
let actual_i: BTreeSet<i32> = as_primitive_array::<Int32Type>(column_i.as_ref())
.values()
.iter()
.copied()
.collect();
assert_eq!(expected_i, actual_i);
}
#[rstest]
#[tokio::test]
async fn test_refine_factor(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
#[values(false, true)] stable_row_ids: bool,
) {
let test_ds = TestVectorDataset::new(data_storage_version, stable_row_ids)
.await
.unwrap();
let dataset = &test_ds.dataset;
let mut scan = dataset.scan();
let key: Float32Array = (32..64).map(|v| v as f32).collect();
scan.nearest("vec", &key, 5).unwrap();
scan.refine(5);
let results = scan
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_rows(), 5);
assert_eq!(
batch.schema().as_ref(),
&ArrowSchema::new(vec![
ArrowField::new("i", DataType::Int32, true),
ArrowField::new("s", DataType::Utf8, true),
ArrowField::new(
"vec",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
32,
),
true,
),
ArrowField::new(DIST_COL, DataType::Float32, true),
])
.with_metadata([("dataset".into(), "vector".into())].into())
);
let expected_i = BTreeSet::from_iter(vec![1, 81, 161, 241, 321]);
let column_i = batch.column_by_name("i").unwrap();
let actual_i: BTreeSet<i32> = as_primitive_array::<Int32Type>(column_i.as_ref())
.values()
.iter()
.copied()
.collect();
assert_eq!(expected_i, actual_i);
}
#[tokio::test]
async fn test_binary_vectors_default_to_hamming() {
let (_tmp_dir, dataset) = make_binary_vector_dataset().await.unwrap();
let query = UInt8Array::from(vec![0b0000_1111u8, 0, 0, 0]);
let mut scan = dataset.scan();
scan.nearest("bin", &query, 3).unwrap();
assert_eq!(scan.nearest.as_ref().unwrap().metric_type, None);
let batch = scan.try_into_batch().await.unwrap();
let ids = batch
.column_by_name("id")
.unwrap()
.as_primitive::<Int32Type>()
.values();
assert_eq!(ids, &[0, 1, 2]);
let distances = batch
.column_by_name(DIST_COL)
.unwrap()
.as_primitive::<Float32Type>()
.values();
assert_eq!(distances, &[0.0, 2.0, 4.0]);
}
#[tokio::test]
async fn test_binary_vectors_invalid_distance_error() {
let (_tmp_dir, dataset) = make_binary_vector_dataset().await.unwrap();
let query = UInt8Array::from(vec![0b0000_1111u8, 0, 0, 0]);
let mut scan = dataset.scan();
scan.nearest("bin", &query, 1).unwrap();
scan.distance_metric(DistanceType::L2);
let err = scan.try_into_batch().await.unwrap_err();
assert!(matches!(err, Error::InvalidInput { .. }));
let message = err.to_string();
assert!(
message.contains("l2") && message.contains("UInt8"),
"unexpected message: {message}"
);
}
#[tokio::test]
async fn test_knn_metric_mismatch_falls_back_to_flat_search() {
let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, true)
.await
.unwrap();
test_ds.make_vector_index().await.unwrap();
let dataset = &test_ds.dataset;
let key: Float32Array = (32..64).map(|v| v as f32).collect();
let mut scan = dataset.scan();
scan.nearest("vec", &key, 5).unwrap();
scan.distance_metric(DistanceType::Dot);
let plan = scan.explain_plan(false).await.unwrap();
assert!(
!plan.contains("ANNSubIndex"),
"Expected flat search, but got ANN index in plan:\n{}",
plan
);
assert!(
plan.contains("KNNVectorDistance") && plan.to_lowercase().contains("dot"),
"Expected flat KNN with Dot metric in plan:\n{}",
plan
);
let dot_batch = dataset
.scan()
.nearest("vec", &key, 5)
.unwrap()
.distance_metric(DistanceType::Dot)
.try_into_batch()
.await
.unwrap();
let l2_batch = dataset
.scan()
.nearest("vec", &key, 5)
.unwrap()
.distance_metric(DistanceType::L2)
.try_into_batch()
.await
.unwrap();
let dot_distances: Vec<f32> = dot_batch
.column_by_name(DIST_COL)
.unwrap()
.as_primitive::<Float32Type>()
.values()
.to_vec();
let l2_distances: Vec<f32> = l2_batch
.column_by_name(DIST_COL)
.unwrap()
.as_primitive::<Float32Type>()
.values()
.to_vec();
assert_ne!(dot_distances, l2_distances);
}
#[tokio::test]
async fn test_knn_no_metric_uses_index_metric() {
let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, true)
.await
.unwrap();
test_ds.make_vector_index().await.unwrap();
let dataset = &test_ds.dataset;
let key: Float32Array = (32..64).map(|v| v as f32).collect();
let mut scan = dataset.scan();
scan.nearest("vec", &key, 5).unwrap();
let plan = scan.explain_plan(false).await.unwrap();
assert!(
plan.contains("ANNSubIndex") && plan.to_lowercase().contains("l2"),
"Expected ANN index with L2 metric in plan:\n{}",
plan
);
}
#[rstest]
#[tokio::test]
async fn test_only_row_id(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) {
let test_ds = TestVectorDataset::new(data_storage_version, false)
.await
.unwrap();
let dataset = &test_ds.dataset;
let mut scan = dataset.scan();
scan.project::<&str>(&[]).unwrap().with_row_id();
let batch = scan.try_into_batch().await.unwrap();
assert_eq!(batch.num_columns(), 1);
assert_eq!(batch.num_rows(), 400);
let expected_schema =
ArrowSchema::new(vec![ArrowField::new(ROW_ID, DataType::UInt64, true)])
.with_metadata(dataset.schema().metadata.clone());
assert_eq!(batch.schema().as_ref(), &expected_schema,);
let expected_row_ids: Vec<u64> = (0..200_u64).chain((1 << 32)..((1 << 32) + 200)).collect();
let actual_row_ids: Vec<u64> = as_primitive_array::<UInt64Type>(batch.column(0).as_ref())
.values()
.iter()
.copied()
.collect();
assert_eq!(expected_row_ids, actual_row_ids);
}
#[tokio::test]
async fn test_scan_unordered_with_row_id() {
let test_ds = TestVectorDataset::new(LanceFileVersion::Legacy, false)
.await
.unwrap();
let dataset = &test_ds.dataset;
let mut scan = dataset.scan();
scan.with_row_id();
let ordered_batches = scan
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<RecordBatch>>()
.await
.unwrap();
assert!(ordered_batches.len() > 2);
let ordered_batch =
concat_batches(&ordered_batches[0].schema(), ordered_batches.iter()).unwrap();
scan.scan_in_order(false);
for _ in 0..10 {
let unordered_batches = scan
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<RecordBatch>>()
.await
.unwrap();
let unordered_batch =
concat_batches(&unordered_batches[0].schema(), unordered_batches.iter()).unwrap();
assert_eq!(ordered_batch.num_rows(), unordered_batch.num_rows());
if ordered_batch != unordered_batch {
let sort_indices = sort_to_indices(&unordered_batch[ROW_ID], None, None).unwrap();
let ordered_i = ordered_batch["i"].clone();
let sorted_i = take::take(&unordered_batch["i"], &sort_indices, None).unwrap();
assert_eq!(&ordered_i, &sorted_i);
break;
}
}
}
#[tokio::test]
async fn test_scan_with_wildcard() {
let data = gen_batch()
.col("x", array::step::<Float64Type>())
.col("y", array::step::<Float64Type>())
.into_ram_dataset(FragmentCount::from(1), FragmentRowCount::from(100))
.await
.unwrap();
let check_cols = async |projection: &[&str], expected_cols: &[&str]| {
let mut scan = data.scan();
scan.project(projection).unwrap();
let stream = scan.try_into_stream().await.unwrap();
let schema = stream.schema();
let field_names = schema.field_names();
assert_eq!(field_names, expected_cols);
};
check_cols(&["*"], &["x", "y"]).await;
check_cols(&["x", "y"], &["x", "y"]).await;
check_cols(&["x"], &["x"]).await;
check_cols(&["_rowid", "*"], &["_rowid", "x", "y"]).await;
check_cols(&["*", "_rowid"], &["x", "y", "_rowid"]).await;
check_cols(
&["_rowid", "*", "_rowoffset"],
&["_rowid", "x", "y", "_rowoffset"],
)
.await;
let check_exprs = async |exprs: &[&str], expected_cols: &[&str]| {
let mut scan = data.scan();
let projection = exprs
.iter()
.map(|e| (e.to_string(), e.to_string()))
.collect::<Vec<_>>();
scan.project_with_transform(&projection).unwrap();
let stream = scan.try_into_stream().await.unwrap();
let schema = stream.schema();
let field_names = schema.field_names();
assert_eq!(field_names, expected_cols);
};
check_exprs(&["_rowid", "*", "x * 2"], &["_rowid", "x", "y", "x * 2"]).await;
let check_fails = |projection: &[&str]| {
let mut scan = data.scan();
assert!(scan.project(projection).is_err());
};
check_fails(&["x", "*"]);
check_fails(&["_rowid", "_rowid"]);
}
#[rstest]
#[tokio::test]
async fn test_scan_order(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) {
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"i",
DataType::Int32,
true,
)]));
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))],
)
.unwrap();
let batch2 = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![6, 7, 8]))],
)
.unwrap();
let params = WriteParams {
mode: WriteMode::Append,
data_storage_version: Some(data_storage_version),
..Default::default()
};
let write_batch = |batch: RecordBatch| async {
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
Dataset::write(reader, test_uri, Some(params)).await
};
write_batch.clone()(batch1.clone()).await.unwrap();
write_batch(batch2.clone()).await.unwrap();
let dataset = Arc::new(Dataset::open(test_uri).await.unwrap());
let fragment1 = dataset.get_fragment(0).unwrap().metadata().clone();
let fragment2 = dataset.get_fragment(1).unwrap().metadata().clone();
let mut scanner = dataset.scan();
scanner.with_fragments(vec![fragment1.clone(), fragment2.clone()]);
let output = scanner
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(output.len(), 2);
assert_eq!(output[0], batch1);
assert_eq!(output[1], batch2);
let mut scanner = dataset.scan();
scanner.with_fragments(vec![fragment2, fragment1]);
let output = scanner
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(output.len(), 2);
assert_eq!(output[0], batch2);
assert_eq!(output[1], batch1);
}
#[rstest]
#[tokio::test]
async fn test_scan_sort(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) {
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let data = gen_batch()
.col("int", array::cycle::<Int32Type>(vec![5, 4, 1, 2, 3]))
.col(
"str",
array::cycle_utf8_literals(&["a", "b", "c", "e", "d"]),
);
let sorted_by_int = gen_batch()
.col("int", array::cycle::<Int32Type>(vec![1, 2, 3, 4, 5]))
.col(
"str",
array::cycle_utf8_literals(&["c", "e", "d", "b", "a"]),
)
.into_batch_rows(RowCount::from(5))
.unwrap();
let sorted_by_str = gen_batch()
.col("int", array::cycle::<Int32Type>(vec![5, 4, 1, 3, 2]))
.col(
"str",
array::cycle_utf8_literals(&["a", "b", "c", "d", "e"]),
)
.into_batch_rows(RowCount::from(5))
.unwrap();
Dataset::write(
data.into_reader_rows(RowCount::from(5), BatchCount::from(1)),
test_uri,
Some(WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await
.unwrap();
let dataset = Arc::new(Dataset::open(test_uri).await.unwrap());
let batches_by_int = dataset
.scan()
.order_by(Some(vec![ColumnOrdering::asc_nulls_first(
"int".to_string(),
)]))
.unwrap()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(batches_by_int[0], sorted_by_int);
let batches_by_str = dataset
.scan()
.order_by(Some(vec![ColumnOrdering::asc_nulls_first(
"str".to_string(),
)]))
.unwrap()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(batches_by_str[0], sorted_by_str);
dataset
.scan()
.order_by(Some(vec![]))
.unwrap()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_sort_multi_columns(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) {
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let data = gen_batch()
.col("int", array::cycle::<Int32Type>(vec![5, 5, 1, 1, 3]))
.col(
"float",
array::cycle::<Float32Type>(vec![7.3, -f32::NAN, f32::NAN, 4.3, f32::INFINITY]),
);
let sorted_by_int_then_float = gen_batch()
.col("int", array::cycle::<Int32Type>(vec![1, 1, 3, 5, 5]))
.col(
"float",
array::cycle::<Float32Type>(vec![4.3, f32::NAN, f32::INFINITY, -f32::NAN, 7.3]),
)
.into_batch_rows(RowCount::from(5))
.unwrap();
Dataset::write(
data.into_reader_rows(RowCount::from(5), BatchCount::from(1)),
test_uri,
Some(WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await
.unwrap();
let dataset = Arc::new(Dataset::open(test_uri).await.unwrap());
let batches_by_int_then_float = dataset
.scan()
.order_by(Some(vec![
ColumnOrdering::asc_nulls_first("int".to_string()),
ColumnOrdering::asc_nulls_first("float".to_string()),
]))
.unwrap()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(batches_by_int_then_float[0], sorted_by_int_then_float);
}
#[rstest]
#[tokio::test]
async fn test_ann_prefilter(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
#[values(false, true)] stable_row_ids: bool,
#[values(
VectorIndexParams::ivf_pq(2, 8, 2, MetricType::L2, 2),
VectorIndexParams::with_ivf_hnsw_sq_params(
MetricType::L2,
IvfBuildParams::new(2),
HnswBuildParams::default(),
SQBuildParams::default()
)
)]
index_params: VectorIndexParams,
) {
use lance_arrow::{FixedSizeListArrayExt, fixed_size_list_type};
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("filterable", DataType::Int32, true),
ArrowField::new("vector", fixed_size_list_type(2, DataType::Float32), true),
]));
let vector_values = Float32Array::from_iter_values((0..600).map(|x| x as f32));
let batches = vec![
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..300)),
Arc::new(FixedSizeListArray::try_new_from_values(vector_values, 2).unwrap()),
],
)
.unwrap(),
];
let write_params = WriteParams {
data_storage_version: Some(data_storage_version),
max_rows_per_file: 300, enable_stable_row_ids: stable_row_ids,
..Default::default()
};
let batches = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());
let mut dataset = Dataset::write(batches, test_uri, Some(write_params))
.await
.unwrap();
dataset
.create_index(&["vector"], IndexType::Vector, None, &index_params, false)
.await
.unwrap();
let query_key = Arc::new(Float32Array::from_iter_values((0..2).map(|x| x as f32)));
let mut scan = dataset.scan();
scan.filter("filterable > 5").unwrap();
scan.nearest("vector", query_key.as_ref(), 1).unwrap();
scan.minimum_nprobes(100);
scan.ef(100);
scan.with_row_id();
let batches = scan
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(batches.len(), 0);
scan.prefilter(true);
let batches = scan
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(batches.len(), 1);
let first_match = batches[0][ROW_ID].as_primitive::<UInt64Type>().values()[0];
assert_eq!(6, first_match);
}
#[rstest]
#[tokio::test]
async fn test_filter_on_large_utf8(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) {
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"ls",
DataType::LargeUtf8,
true,
)]));
let batches = vec![
RecordBatch::try_new(
schema.clone(),
vec![Arc::new(LargeStringArray::from_iter_values(
(0..10).map(|v| format!("s-{}", v)),
))],
)
.unwrap(),
];
let write_params = WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
};
let batches = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());
Dataset::write(batches, test_uri, Some(write_params))
.await
.unwrap();
let dataset = Dataset::open(test_uri).await.unwrap();
let mut scan = dataset.scan();
scan.filter("ls = 's-8'").unwrap();
let batches = scan
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = &batches[0];
let expected = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(LargeStringArray::from_iter_values(
(8..9).map(|v| format!("s-{}", v)),
))],
)
.unwrap();
assert_eq!(batch, &expected);
}
#[rstest]
#[tokio::test]
async fn test_filter_with_regex(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) {
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"ls",
DataType::Utf8,
true,
)]));
let batches = vec![
RecordBatch::try_new(
schema.clone(),
vec![Arc::new(StringArray::from_iter_values(
(0..20).map(|v| format!("s-{}", v)),
))],
)
.unwrap(),
];
let write_params = WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
};
let batches = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());
Dataset::write(batches, test_uri, Some(write_params))
.await
.unwrap();
let dataset = Dataset::open(test_uri).await.unwrap();
let mut scan = dataset.scan();
scan.filter("regexp_match(ls, 's-1.')").unwrap();
let stream = scan.try_into_stream().await.unwrap();
let batches = stream.try_collect::<Vec<_>>().await.unwrap();
let batch = &batches[0];
let expected = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(StringArray::from_iter_values(
(10..=19).map(|v| format!("s-{}", v)),
))],
)
.unwrap();
assert_eq!(batch, &expected);
}
#[tokio::test]
async fn test_filter_proj_bug() {
let struct_i_field = ArrowField::new("i", DataType::Int32, true);
let struct_o_field = ArrowField::new("o", DataType::Utf8, true);
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new(
"struct",
DataType::Struct(vec![struct_i_field.clone(), struct_o_field.clone()].into()),
true,
),
ArrowField::new("s", DataType::Utf8, true),
]));
let input_batches: Vec<RecordBatch> = (0..5)
.map(|i| {
let struct_i_arr: Arc<Int32Array> =
Arc::new(Int32Array::from_iter_values(i * 20..(i + 1) * 20));
let struct_o_arr: Arc<StringArray> = Arc::new(StringArray::from_iter_values(
(i * 20..(i + 1) * 20).map(|v| format!("o-{:02}", v)),
));
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StructArray::from(vec![
(Arc::new(struct_i_field.clone()), struct_i_arr as ArrayRef),
(Arc::new(struct_o_field.clone()), struct_o_arr as ArrayRef),
])),
Arc::new(StringArray::from_iter_values(
(i * 20..(i + 1) * 20).map(|v| format!("s-{}", v)),
)),
],
)
.unwrap()
})
.collect();
let batches =
RecordBatchIterator::new(input_batches.clone().into_iter().map(Ok), schema.clone());
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let write_params = WriteParams {
max_rows_per_file: 40,
max_rows_per_group: 10,
data_storage_version: Some(LanceFileVersion::Legacy),
..Default::default()
};
Dataset::write(batches, test_uri, Some(write_params))
.await
.unwrap();
let dataset = Dataset::open(test_uri).await.unwrap();
let batches = dataset
.scan()
.filter("struct.i >= 20")
.unwrap()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = concat_batches(&batches[0].schema(), &batches).unwrap();
let expected_batch = concat_batches(&schema, &input_batches.as_slice()[1..]).unwrap();
assert_eq!(batch, expected_batch);
let batches = dataset
.scan()
.filter("struct.o >= 'o-20'")
.unwrap()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = concat_batches(&batches[0].schema(), &batches).unwrap();
assert_eq!(batch, expected_batch);
let batches = dataset
.scan()
.project(vec!["struct"].as_slice())
.unwrap()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
concat_batches(&batches[0].schema(), &batches).unwrap();
}
#[rstest]
#[tokio::test]
async fn test_ann_with_deletion(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
#[values(false, true)] stable_row_ids: bool,
) {
let vec_params = vec![
VectorIndexParams::ivf_pq(4, 8, 2, MetricType::L2, 2),
];
for params in vec_params {
use lance_arrow::FixedSizeListArrayExt;
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("i", DataType::Int32, true),
ArrowField::new(
"vec",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
32,
),
true,
),
]));
let vector_values: Float32Array =
(0..32 * 512).map(|v| (v / 32) as f32 + 1.0).collect();
let vectors = FixedSizeListArray::try_new_from_values(vector_values, 32).unwrap();
let batches = vec![
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..512)),
Arc::new(vectors.clone()),
],
)
.unwrap(),
];
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());
let mut dataset = Dataset::write(
reader,
test_uri,
Some(WriteParams {
data_storage_version: Some(data_storage_version),
enable_stable_row_ids: stable_row_ids,
..Default::default()
}),
)
.await
.unwrap();
assert_eq!(dataset.index_cache_entry_count().await, 0);
dataset
.create_index(
&["vec"],
IndexType::Vector,
Some("idx".to_string()),
¶ms,
true,
)
.await
.unwrap();
let mut scan = dataset.scan();
let key: Float32Array = (0..32).map(|_v| 1.0_f32).collect();
scan.nearest("vec", &key, 5).unwrap();
scan.refine(100);
scan.minimum_nprobes(100);
assert_eq!(
dataset.index_cache_entry_count().await,
2, );
let results = scan
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(
dataset.index_cache_entry_count().await,
5 + dataset.versions().await.unwrap().len()
);
assert_eq!(results.len(), 1);
let batch = &results[0];
let expected_i = BTreeSet::from_iter(vec![0, 1, 2, 3, 4]);
let column_i = batch.column_by_name("i").unwrap();
let actual_i: BTreeSet<i32> = as_primitive_array::<Int32Type>(column_i.as_ref())
.values()
.iter()
.copied()
.collect();
assert_eq!(expected_i, actual_i);
dataset.delete("i = 1").await.unwrap();
let mut scan = dataset.scan();
scan.nearest("vec", &key, 5).unwrap();
scan.refine(100);
scan.minimum_nprobes(100);
let results = scan
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
let expected_i = BTreeSet::from_iter(vec![0, 2, 3, 4, 5]);
let column_i = batch.column_by_name("i").unwrap();
let actual_i: BTreeSet<i32> = as_primitive_array::<Int32Type>(column_i.as_ref())
.values()
.iter()
.copied()
.collect();
assert_eq!(expected_i, actual_i);
let batches = vec![
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(512..1024)),
Arc::new(vectors),
],
)
.unwrap(),
];
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());
let mut dataset = Dataset::write(
reader,
test_uri,
Some(WriteParams {
mode: WriteMode::Append,
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await
.unwrap();
dataset
.create_index(
&["vec"],
IndexType::Vector,
Some("idx".to_string()),
¶ms,
true,
)
.await
.unwrap();
dataset.delete("i < 512").await.unwrap();
let mut scan = dataset.scan();
scan.nearest("vec", &key, 5).unwrap();
scan.refine(100);
scan.minimum_nprobes(100);
let results = scan
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(results.len(), 1);
let batch = &results[0];
let expected_i = BTreeSet::from_iter(vec![512, 513, 514, 515, 516]);
let column_i = batch.column_by_name("i").unwrap();
let actual_i: BTreeSet<i32> = as_primitive_array::<Int32Type>(column_i.as_ref())
.values()
.iter()
.copied()
.collect();
assert_eq!(expected_i, actual_i);
}
}
#[tokio::test]
async fn test_projection_order() {
let vec_params = VectorIndexParams::ivf_pq(4, 8, 2, MetricType::L2, 2);
let mut data = gen_batch()
.col("vec", array::rand_vec::<Float32Type>(Dimension::from(4)))
.col("text", array::rand_utf8(ByteCount::from(10), false))
.into_ram_dataset(FragmentCount::from(3), FragmentRowCount::from(100))
.await
.unwrap();
data.create_index(&["vec"], IndexType::Vector, None, &vec_params, true)
.await
.unwrap();
let mut scan = data.scan();
scan.nearest("vec", &Float32Array::from(vec![1.0, 1.0, 1.0, 1.0]), 5)
.unwrap();
scan.with_row_id().project(&["text"]).unwrap();
let results = scan
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(
results[0].schema().field_names(),
vec!["text", "_distance", "_rowid"]
);
}
#[rstest]
#[tokio::test]
async fn test_count_rows_with_filter(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) {
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let mut data_gen = BatchGenerator::new().col(Box::new(
IncrementingInt32::new().named("Filter_me".to_owned()),
));
Dataset::write(
data_gen.batch(32),
test_uri,
Some(WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await
.unwrap();
let dataset = Dataset::open(test_uri).await.unwrap();
assert_eq!(32, dataset.count_rows(None).await.unwrap());
assert_eq!(
16,
dataset
.count_rows(Some("`Filter_me` > 15".to_string()))
.await
.unwrap()
);
}
#[rstest]
#[tokio::test]
async fn test_dynamic_projection(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) {
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let mut data_gen =
BatchGenerator::new().col(Box::new(IncrementingInt32::new().named("i".to_owned())));
Dataset::write(
data_gen.batch(32),
test_uri,
Some(WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await
.unwrap();
let dataset = Dataset::open(test_uri).await.unwrap();
assert_eq!(dataset.count_rows(None).await.unwrap(), 32);
let mut scanner = dataset.scan();
let scan_res = scanner
.project_with_transform(&[("bool", "i > 15")])
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(1, scan_res.num_columns());
let bool_col = scan_res
.column_by_name("bool")
.expect("bool column should exist");
let bool_arr = bool_col.as_boolean();
for i in 0..32 {
if i > 15 {
assert!(bool_arr.value(i));
} else {
assert!(!bool_arr.value(i));
}
}
}
#[rstest]
#[tokio::test]
async fn test_column_casting_function(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) {
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let mut data_gen =
BatchGenerator::new().col(Box::new(RandomVector::new().named("vec".to_owned())));
Dataset::write(
data_gen.batch(32),
test_uri,
Some(WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await
.unwrap();
let dataset = Dataset::open(test_uri).await.unwrap();
assert_eq!(dataset.count_rows(None).await.unwrap(), 32);
let mut scanner = dataset.scan();
let scan_res = scanner
.project_with_transform(&[("f16", "_cast_list_f16(vec)")])
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(1, scan_res.num_columns());
assert_eq!(32, scan_res.num_rows());
assert_eq!("f16", scan_res.schema().field(0).name());
let mut scanner = dataset.scan();
let scan_res_original = scanner
.project(&["vec"])
.unwrap()
.try_into_batch()
.await
.unwrap();
let f32_col: &Float32Array = scan_res_original
.column_by_name("vec")
.unwrap()
.as_fixed_size_list()
.values()
.as_primitive();
let f16_col: &Float16Array = scan_res
.column_by_name("f16")
.unwrap()
.as_fixed_size_list()
.values()
.as_primitive();
for (f32_val, f16_val) in f32_col.iter().zip(f16_col.iter()) {
let f32_val = f32_val.unwrap();
let f16_val = f16_val.unwrap();
assert_eq!(f16::from_f32(f32_val), f16_val);
}
}
struct ScalarIndexTestFixture {
_test_dir: TempStrDir,
dataset: Dataset,
sample_query: Arc<dyn Array>,
delete_query: Arc<dyn Array>,
original_version: u64,
compact_version: u64,
append_version: u64,
updated_version: u64,
delete_version: u64,
append_then_delete_version: u64,
}
#[derive(Debug, PartialEq)]
struct ScalarTestParams {
use_index: bool,
use_projection: bool,
use_deleted_data: bool,
use_new_data: bool,
with_row_id: bool,
use_compaction: bool,
use_updated: bool,
}
impl ScalarIndexTestFixture {
async fn new(data_storage_version: LanceFileVersion, use_stable_row_ids: bool) -> Self {
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let data = gen_batch()
.col(
"vector",
array::rand_vec::<Float32Type>(Dimension::from(32)),
)
.col("indexed", array::step::<Int32Type>())
.col("not_indexed", array::step::<Int32Type>())
.into_batch_rows(RowCount::from(1000))
.unwrap();
let mut dataset = Dataset::write(
RecordBatchIterator::new(vec![Ok(data.clone())], data.schema().clone()),
test_uri,
Some(WriteParams {
max_rows_per_file: 500,
data_storage_version: Some(data_storage_version),
enable_stable_row_ids: use_stable_row_ids,
..Default::default()
}),
)
.await
.unwrap();
dataset
.create_index(
&["vector"],
IndexType::Vector,
None,
&VectorIndexParams::ivf_pq(2, 8, 2, MetricType::L2, 2),
false,
)
.await
.unwrap();
dataset
.create_index(
&["indexed"],
IndexType::Scalar,
None,
&ScalarIndexParams::default(),
false,
)
.await
.unwrap();
let original_version = dataset.version().version;
let sample_query = data["vector"].as_fixed_size_list().value(50);
let delete_query = data["vector"].as_fixed_size_list().value(75);
let new_indexed =
arrow_arith::numeric::add(&data["indexed"], &Int32Array::new_scalar(1000)).unwrap();
let new_not_indexed =
arrow_arith::numeric::add(&data["indexed"], &Int32Array::new_scalar(1000)).unwrap();
let append_data = RecordBatch::try_new(
data.schema(),
vec![data["vector"].clone(), new_indexed, new_not_indexed],
)
.unwrap();
dataset
.append(
RecordBatchIterator::new(vec![Ok(append_data)], data.schema()),
Some(WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await
.unwrap();
let append_version = dataset.version().version;
dataset
.optimize_indices(&OptimizeOptions::merge(1))
.await
.unwrap();
let updated_version = dataset.version().version;
dataset.checkout_version(append_version).await.unwrap();
dataset.restore().await.unwrap();
dataset.delete("not_indexed = 75").await.unwrap();
let append_then_delete_version = dataset.version().version;
let mut dataset = dataset.checkout_version(original_version).await.unwrap();
dataset.restore().await.unwrap();
dataset.delete("not_indexed = 75").await.unwrap();
let delete_version = dataset.version().version;
compact_files(&mut dataset, CompactionOptions::default(), None)
.await
.unwrap();
let compact_version = dataset.version().version;
dataset.checkout_version(original_version).await.unwrap();
dataset.restore().await.unwrap();
Self {
_test_dir: test_dir,
dataset,
sample_query,
delete_query,
original_version,
compact_version,
append_version,
updated_version,
delete_version,
append_then_delete_version,
}
}
fn sample_query(&self) -> &PrimitiveArray<Float32Type> {
self.sample_query.as_primitive::<Float32Type>()
}
fn delete_query(&self) -> &PrimitiveArray<Float32Type> {
self.delete_query.as_primitive::<Float32Type>()
}
async fn get_dataset(&self, params: &ScalarTestParams) -> Dataset {
let version = if params.use_compaction {
if params.use_deleted_data || params.use_new_data || params.use_updated {
panic!(
"There is no test data combining new/deleted/updated data with compaction"
);
} else {
self.compact_version
}
} else if params.use_updated {
if params.use_deleted_data || params.use_new_data || params.use_compaction {
panic!(
"There is no test data combining updated data with new/deleted/compaction"
);
} else {
self.updated_version
}
} else {
match (params.use_new_data, params.use_deleted_data) {
(false, false) => self.original_version,
(false, true) => self.delete_version,
(true, false) => self.append_version,
(true, true) => self.append_then_delete_version,
}
};
self.dataset.checkout_version(version).await.unwrap()
}
async fn run_query(
&self,
query: &str,
vector: Option<&PrimitiveArray<Float32Type>>,
params: &ScalarTestParams,
) -> (String, RecordBatch) {
let dataset = self.get_dataset(params).await;
let mut scan = dataset.scan();
if let Some(vector) = vector {
scan.nearest("vector", vector, 10).unwrap();
}
if params.use_projection {
scan.project(&["indexed"]).unwrap();
}
if params.with_row_id {
scan.with_row_id();
}
scan.scan_in_order(true);
scan.use_index(params.use_index);
scan.filter(query).unwrap();
scan.prefilter(true);
let plan = scan.explain_plan(true).await.unwrap();
let batch = scan.try_into_batch().await.unwrap();
if params.use_projection {
let mut expected_columns = 1;
if vector.is_some() {
expected_columns += 1;
}
if params.with_row_id {
expected_columns += 1;
}
assert_eq!(batch.num_columns(), expected_columns);
} else {
let mut expected_columns = 3;
if vector.is_some() {
expected_columns += 1;
}
if params.with_row_id {
expected_columns += 1;
}
assert_eq!(batch.num_columns(), expected_columns);
}
(plan, batch)
}
fn assert_none<F: Fn(i32) -> bool>(
&self,
batch: &RecordBatch,
predicate: F,
message: &str,
) {
let indexed = batch["indexed"].as_primitive::<Int32Type>();
if indexed.iter().map(|val| val.unwrap()).any(predicate) {
panic!("{}", message);
}
}
fn assert_one<F: Fn(i32) -> bool>(&self, batch: &RecordBatch, predicate: F, message: &str) {
let indexed = batch["indexed"].as_primitive::<Int32Type>();
if !indexed.iter().map(|val| val.unwrap()).any(predicate) {
panic!("{}", message);
}
}
async fn check_vector_scalar_indexed_and_refine(&self, params: &ScalarTestParams) {
let (query_plan, batch) = self
.run_query(
"indexed != 50 AND ((not_indexed < 100) OR (not_indexed >= 1000 AND not_indexed < 1100))",
Some(self.sample_query()),
params,
)
.await;
if self.dataset.is_legacy_storage() {
assert!(query_plan.contains("MaterializeIndex"));
}
self.assert_none(
&batch,
|val| val == 50,
"The query contained 50 even though it was filtered",
);
if !params.use_new_data {
self.assert_none(
&batch,
|val| (100..1000).contains(&val) || (val >= 1100),
"The non-indexed refine filter was not applied",
);
}
if params.use_new_data || params.use_updated {
self.assert_one(
&batch,
|val| val == 1050,
"The query did not contain 1050 from the new data",
);
}
}
async fn check_vector_scalar_indexed_only(&self, params: &ScalarTestParams) {
let (query_plan, batch) = self
.run_query("indexed != 50", Some(self.sample_query()), params)
.await;
if self.dataset.is_legacy_storage() {
if params.use_index {
assert!(query_plan.contains("ScalarIndexQuery"));
} else {
assert!(query_plan.contains("MaterializeIndex"));
}
}
self.assert_none(
&batch,
|val| val == 50,
"The query contained 50 even though it was filtered",
);
if params.use_new_data {
self.assert_one(
&batch,
|val| val == 1050,
"The query did not contain 1050 from the new data",
);
if !params.use_new_data {
let (_, batch) = self
.run_query("indexed == 1050", Some(self.sample_query()), params)
.await;
assert_eq!(batch.num_rows(), 1);
}
}
if params.use_deleted_data {
let (_, batch) = self
.run_query("indexed == 75", Some(self.delete_query()), params)
.await;
if !params.use_new_data {
assert_eq!(batch.num_rows(), 0);
}
}
}
async fn check_vector_queries(&self, params: &ScalarTestParams) {
self.check_vector_scalar_indexed_only(params).await;
self.check_vector_scalar_indexed_and_refine(params).await;
}
async fn check_simple_indexed_only(&self, params: &ScalarTestParams) {
let (query_plan, batch) = self.run_query("indexed != 50", None, params).await;
if self.dataset.is_legacy_storage() {
assert!(query_plan.contains("MaterializeIndex"));
} else {
assert!(query_plan.contains("LanceRead"));
}
self.assert_none(
&batch,
|val| val == 50,
"The query contained 50 even though it was filtered",
);
let mut expected_num_rows = if params.use_new_data || params.use_updated {
1999
} else {
999
};
if params.use_deleted_data || params.use_compaction {
expected_num_rows -= 1;
}
assert_eq!(batch.num_rows(), expected_num_rows);
if params.use_new_data || params.use_updated {
let (_, batch) = self.run_query("indexed == 1050", None, params).await;
assert_eq!(batch.num_rows(), 1);
}
if params.use_deleted_data || params.use_compaction {
let (_, batch) = self.run_query("indexed == 75", None, params).await;
assert_eq!(batch.num_rows(), 0);
}
}
async fn check_simple_indexed_and_refine(&self, params: &ScalarTestParams) {
let (query_plan, batch) = self.run_query(
"indexed != 50 AND ((not_indexed < 100) OR (not_indexed >= 1000 AND not_indexed < 1100))",
None,
params
).await;
if self.dataset.is_legacy_storage() {
assert!(query_plan.contains("MaterializeIndex"));
} else {
assert!(query_plan.contains("LanceRead"));
}
self.assert_none(
&batch,
|val| val == 50,
"The query contained 50 even though it was filtered",
);
self.assert_none(
&batch,
|val| (100..1000).contains(&val) || (val >= 1100),
"The non-indexed refine filter was not applied",
);
let mut expected_num_rows = if params.use_new_data || params.use_updated {
199
} else {
99
};
if params.use_deleted_data || params.use_compaction {
expected_num_rows -= 1;
}
assert_eq!(batch.num_rows(), expected_num_rows);
}
async fn check_simple_queries(&self, params: &ScalarTestParams) {
self.check_simple_indexed_only(params).await;
self.check_simple_indexed_and_refine(params).await;
}
}
#[rstest]
#[tokio::test]
async fn test_secondary_index_scans(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
#[values(false, true)] use_stable_row_ids: bool,
) {
let fixture = Box::pin(ScalarIndexTestFixture::new(
data_storage_version,
use_stable_row_ids,
))
.await;
for use_index in [false, true] {
for use_projection in [false, true] {
for use_deleted_data in [false, true] {
for use_new_data in [false, true] {
let compaction_choices =
if use_deleted_data || use_new_data || use_stable_row_ids {
vec![false]
} else {
vec![false, true]
};
for use_compaction in compaction_choices {
let updated_choices =
if use_deleted_data || use_new_data || use_compaction {
vec![false]
} else {
vec![false, true]
};
for use_updated in updated_choices {
for with_row_id in [false, true] {
let params = ScalarTestParams {
use_index,
use_projection,
use_deleted_data,
use_new_data,
with_row_id,
use_compaction,
use_updated,
};
fixture.check_vector_queries(¶ms).await;
fixture.check_simple_queries(¶ms).await;
}
}
}
}
}
}
}
}
#[tokio::test]
async fn can_filter_row_id() {
let dataset = lance_datagen::gen_batch()
.col("x", array::step::<Int32Type>())
.into_ram_dataset(FragmentCount::from(1), FragmentRowCount::from(1000))
.await
.unwrap();
let mut scan = dataset.scan();
scan.with_row_id();
scan.project::<&str>(&[]).unwrap();
scan.filter("_rowid == 50").unwrap();
let batch = scan.try_into_batch().await.unwrap();
assert_eq!(batch.num_rows(), 1);
assert_eq!(batch.column(0).as_primitive::<UInt64Type>().values()[0], 50);
}
#[rstest]
#[tokio::test]
async fn test_index_take_batch_size() {
let fixture = Box::pin(ScalarIndexTestFixture::new(LanceFileVersion::Stable, false)).await;
let stream = fixture
.dataset
.scan()
.filter("indexed > 0")
.unwrap()
.batch_size(16)
.try_into_stream()
.await
.unwrap();
let batches = stream.collect::<Vec<_>>().await;
assert_eq!(batches.len(), 1000_usize.div_ceil(16));
}
async fn assert_plan_equals(
dataset: &Dataset,
plan: impl Fn(&mut Scanner) -> Result<&mut Scanner>,
expected: &str,
) -> Result<()> {
let mut scan = dataset.scan();
plan(&mut scan)?;
let exec_plan = scan.create_plan().await?;
assert_plan_node_equals(exec_plan, expected).await
}
#[tokio::test]
async fn test_inexact_scalar_index_plans() {
let data = gen_batch()
.col("ngram", array::rand_utf8(ByteCount::from(5), false))
.col("exact", array::rand_type(&DataType::UInt32))
.col("no_index", array::rand_type(&DataType::UInt32))
.into_reader_rows(RowCount::from(1000), BatchCount::from(5));
let mut dataset = Dataset::write(data, "memory://test", None).await.unwrap();
dataset
.create_index(
&["ngram"],
IndexType::NGram,
None,
&ScalarIndexParams::default(),
true,
)
.await
.unwrap();
dataset
.create_index(
&["exact"],
IndexType::BTree,
None,
&ScalarIndexParams::default(),
true,
)
.await
.unwrap();
assert_plan_equals(
&dataset,
|scanner| scanner.filter("contains(ngram, 'test string')"),
"LanceRead: uri=..., projection=[ngram, exact, no_index], num_fragments=1, \
range_before=None, range_after=None, row_id=false, row_addr=false, \
full_filter=contains(ngram, Utf8(\"test string\")), refine_filter=--
ScalarIndexQuery: query=[contains(ngram, Utf8(\"test string\"))]@ngram_idx",
)
.await
.unwrap();
assert_plan_equals(
&dataset,
|scanner| scanner.filter("contains(ngram, 'test string') and exact < 50"),
"LanceRead: uri=..., projection=[ngram, exact, no_index], num_fragments=1, \
range_before=None, range_after=None, row_id=false, row_addr=false, \
full_filter=contains(ngram, Utf8(\"test string\")) AND exact < UInt32(50), \
refine_filter=--
ScalarIndexQuery: query=AND([contains(ngram, Utf8(\"test string\"))]@ngram_idx,[exact < 50]@exact_idx)",
)
.await
.unwrap();
assert_plan_equals(
&dataset,
|scanner| {
scanner.filter("contains(ngram, 'test string') and exact < 50 AND no_index > 100")
},
"ProjectionExec: expr=[ngram@0 as ngram, exact@1 as exact, no_index@2 as no_index]
LanceRead: uri=..., projection=[ngram, exact, no_index], num_fragments=1, range_before=None, \
range_after=None, row_id=true, row_addr=false, full_filter=contains(ngram, Utf8(\"test string\")) AND exact < UInt32(50) AND no_index > UInt32(100), \
refine_filter=no_index > UInt32(100)
ScalarIndexQuery: query=AND([contains(ngram, Utf8(\"test string\"))]@ngram_idx,[exact < 50]@exact_idx)",
)
.await
.unwrap();
}
#[tokio::test]
async fn test_like_prefix_with_btree_index() {
let data = gen_batch()
.col(
"name",
array::cycle_utf8_literals(&[
"apple",
"application",
"app",
"banana",
"band",
"testns1",
"testns2",
"test",
"testing",
"zoo",
]),
)
.col("id", array::step::<Int32Type>())
.into_reader_rows(RowCount::from(100), BatchCount::from(1));
let mut dataset = Dataset::write(data, "memory://test_like", None)
.await
.unwrap();
dataset
.create_index(
&["name"],
IndexType::BTree,
None,
&ScalarIndexParams::default(),
true,
)
.await
.unwrap();
assert_plan_equals(
&dataset,
|scanner| scanner.filter("name LIKE 'app%'"),
"LanceRead: uri=..., projection=[name, id], num_fragments=1, \
range_before=None, range_after=None, row_id=false, row_addr=false, \
full_filter=name LIKE Utf8(\"app%\"), refine_filter=--
ScalarIndexQuery: query=[name LIKE 'app%']@name_idx",
)
.await
.unwrap();
let results = dataset
.scan()
.filter("name LIKE 'app%'")
.unwrap()
.try_into_batch()
.await
.unwrap();
let names: Vec<&str> = results
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.map(|s| s.unwrap())
.collect();
assert!(names.iter().all(|n| n.starts_with("app")));
assert!(!names.is_empty());
assert_plan_equals(
&dataset,
|scanner| scanner.filter("starts_with(name, 'ban')"),
"LanceRead: uri=..., projection=[name, id], num_fragments=1, \
range_before=None, range_after=None, row_id=false, row_addr=false, \
full_filter=name LIKE Utf8(\"ban%\"), refine_filter=--
ScalarIndexQuery: query=[name LIKE 'ban%']@name_idx",
)
.await
.unwrap();
let results = dataset
.scan()
.filter("starts_with(name, 'ban')")
.unwrap()
.try_into_batch()
.await
.unwrap();
let names: Vec<&str> = results
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.map(|s| s.unwrap())
.collect();
assert!(names.iter().all(|n| n.starts_with("ban")));
assert!(!names.is_empty());
assert_plan_equals(
&dataset,
|scanner| scanner.filter("name LIKE 'test%2'"),
"ProjectionExec: expr=[name@0 as name, id@1 as id]
LanceRead: uri=..., projection=[name, id], num_fragments=1, \
range_before=None, range_after=None, row_id=true, row_addr=false, \
full_filter=name LIKE Utf8(\"test%2\"), refine_filter=name LIKE Utf8(\"test%2\")
ScalarIndexQuery: query=[name LIKE 'test%']@name_idx",
)
.await
.unwrap();
let results = dataset
.scan()
.filter("name LIKE 'test%2'")
.unwrap()
.try_into_batch()
.await
.unwrap();
let names: Vec<&str> = results
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.map(|s| s.unwrap())
.collect();
assert!(
names
.iter()
.all(|n| n.starts_with("test") && n.ends_with("2"))
);
let mut scanner = dataset.scan();
scanner.filter("name LIKE '%app%'").unwrap();
let plan = scanner.create_plan().await.unwrap();
let plan_str = format!("{:?}", plan);
assert!(
!plan_str.contains("ScalarIndexQuery"),
"LIKE '%app%' should not use scalar index, but got: {}",
plan_str
);
let results = dataset
.scan()
.filter("name LIKE '%app%'")
.unwrap()
.try_into_batch()
.await
.unwrap();
let names: Vec<&str> = results
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.map(|s| s.unwrap())
.collect();
assert!(names.iter().all(|n| n.contains("app")));
let mut scanner = dataset.scan();
scanner.filter("name NOT LIKE 'app%'").unwrap();
let plan = scanner.create_plan().await.unwrap();
let plan_str = format!("{:?}", plan);
assert!(
!plan_str.contains("ScalarIndexQuery"),
"NOT LIKE should not use scalar index, but got: {}",
plan_str
);
}
#[tokio::test]
async fn test_like_prefix_correctness_with_btree_index() {
let names: Vec<&str> = vec![
"alpha", "alphabet", "beta", "gamma", "delta", "epsilon", "eta", "theta", "iota",
"kappa",
];
let data = RecordBatch::try_new(
Arc::new(ArrowSchema::new(vec![
ArrowField::new("name", DataType::Utf8, false),
ArrowField::new("id", DataType::Int32, false),
])),
vec![
Arc::new(StringArray::from(names.clone())),
Arc::new(Int32Array::from_iter_values(0..10)),
],
)
.unwrap();
let reader = RecordBatchIterator::new(
vec![Ok(data)],
Arc::new(ArrowSchema::new(vec![
ArrowField::new("name", DataType::Utf8, false),
ArrowField::new("id", DataType::Int32, false),
])),
);
let mut dataset = Dataset::write(reader, "memory://test_like_correctness", None)
.await
.unwrap();
dataset
.create_index(
&["name"],
IndexType::BTree,
None,
&ScalarIndexParams::default(),
true,
)
.await
.unwrap();
let with_index = dataset
.scan()
.filter("name LIKE 'alpha%'")
.unwrap()
.try_into_batch()
.await
.unwrap();
let without_index = dataset
.scan()
.use_scalar_index(false)
.filter("name LIKE 'alpha%'")
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(with_index.num_rows(), without_index.num_rows());
assert_eq!(with_index.num_rows(), 2);
let with_index_names: BTreeSet<String> = with_index
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.map(|s| s.unwrap().to_string())
.collect();
let without_index_names: BTreeSet<String> = without_index
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.map(|s| s.unwrap().to_string())
.collect();
assert_eq!(with_index_names, without_index_names);
assert_eq!(
with_index_names,
BTreeSet::from(["alpha".to_string(), "alphabet".to_string()])
);
let starts_with_result = dataset
.scan()
.filter("starts_with(name, 'e')")
.unwrap()
.try_into_batch()
.await
.unwrap();
let starts_with_names: BTreeSet<String> = starts_with_result
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.map(|s| s.unwrap().to_string())
.collect();
assert_eq!(
starts_with_names,
BTreeSet::from(["epsilon".to_string(), "eta".to_string()])
);
}
#[tokio::test]
async fn test_like_prefix_with_zone_map() {
use lance_index::scalar::BuiltinIndexType;
let data = gen_batch()
.col(
"name",
array::cycle_utf8_literals(&[
"apple",
"application",
"app",
"banana",
"band",
"testns1",
"testns2",
"test",
"testing",
"zoo",
]),
)
.col("id", array::step::<Int32Type>())
.into_reader_rows(RowCount::from(100), BatchCount::from(1));
let mut dataset = Dataset::write(data, "memory://test_like_zonemap", None)
.await
.unwrap();
let params = ScalarIndexParams::for_builtin(BuiltinIndexType::ZoneMap);
dataset
.create_index(
&["name"],
IndexType::Scalar,
Some("name_zonemap".to_string()),
¶ms,
true,
)
.await
.unwrap();
let mut scanner = dataset.scan();
scanner.filter("name LIKE 'app%'").unwrap();
let plan = scanner.create_plan().await.unwrap();
let plan_str = format!("{:?}", plan);
assert!(
plan_str.contains("ScalarIndexExec") && plan_str.contains("LikePrefix"),
"LIKE 'app%' should use zone map index with LikePrefix, but got: {}",
plan_str
);
let results = dataset
.scan()
.filter("name LIKE 'app%'")
.unwrap()
.try_into_batch()
.await
.unwrap();
let names: Vec<&str> = results
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.map(|s| s.unwrap())
.collect();
assert!(names.iter().all(|n| n.starts_with("app")));
assert!(!names.is_empty());
let mut scanner = dataset.scan();
scanner.filter("starts_with(name, 'ban')").unwrap();
let plan = scanner.create_plan().await.unwrap();
let plan_str = format!("{:?}", plan);
assert!(
plan_str.contains("ScalarIndexExec") && plan_str.contains("LikePrefix"),
"starts_with should use zone map index with LikePrefix, but got: {}",
plan_str
);
let results = dataset
.scan()
.filter("starts_with(name, 'ban')")
.unwrap()
.try_into_batch()
.await
.unwrap();
let names: Vec<&str> = results
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.map(|s| s.unwrap())
.collect();
assert!(names.iter().all(|n| n.starts_with("ban")));
let mut scanner = dataset.scan();
scanner.filter("name LIKE 'test%2'").unwrap();
let plan = scanner.create_plan().await.unwrap();
let plan_str = format!("{:?}", plan);
assert!(
plan_str.contains("ScalarIndexExec") && plan_str.contains("LikePrefix"),
"LIKE 'test%2' should use zone map index for prefix, but got: {}",
plan_str
);
let mut scanner = dataset.scan();
scanner.filter("name LIKE '%app%'").unwrap();
let plan = scanner.create_plan().await.unwrap();
let plan_str = format!("{:?}", plan);
assert!(
!plan_str.contains("LikePrefix"),
"LIKE '%app%' should not use LikePrefix index, but got: {}",
plan_str
);
}
#[tokio::test]
async fn test_like_prefix_correctness_with_zone_map() {
use lance_index::scalar::BuiltinIndexType;
let names: Vec<&str> = vec![
"alpha", "alphabet", "beta", "gamma", "delta", "epsilon", "eta", "theta", "iota",
"kappa",
];
let data = RecordBatch::try_new(
Arc::new(ArrowSchema::new(vec![
ArrowField::new("name", DataType::Utf8, false),
ArrowField::new("id", DataType::Int32, false),
])),
vec![
Arc::new(StringArray::from(names.clone())),
Arc::new(Int32Array::from_iter_values(0..10)),
],
)
.unwrap();
let reader = RecordBatchIterator::new(
vec![Ok(data)],
Arc::new(ArrowSchema::new(vec![
ArrowField::new("name", DataType::Utf8, false),
ArrowField::new("id", DataType::Int32, false),
])),
);
let mut dataset = Dataset::write(reader, "memory://test_like_correctness_zonemap", None)
.await
.unwrap();
let params = ScalarIndexParams::for_builtin(BuiltinIndexType::ZoneMap);
dataset
.create_index(
&["name"],
IndexType::Scalar,
Some("name_zonemap".to_string()),
¶ms,
true,
)
.await
.unwrap();
let with_index = dataset
.scan()
.filter("name LIKE 'alpha%'")
.unwrap()
.try_into_batch()
.await
.unwrap();
let without_index = dataset
.scan()
.use_scalar_index(false)
.filter("name LIKE 'alpha%'")
.unwrap()
.try_into_batch()
.await
.unwrap();
assert_eq!(with_index.num_rows(), without_index.num_rows());
assert_eq!(with_index.num_rows(), 2);
let with_index_names: BTreeSet<String> = with_index
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.map(|s| s.unwrap().to_string())
.collect();
let without_index_names: BTreeSet<String> = without_index
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.map(|s| s.unwrap().to_string())
.collect();
assert_eq!(with_index_names, without_index_names);
assert_eq!(
with_index_names,
BTreeSet::from(["alpha".to_string(), "alphabet".to_string()])
);
let starts_with_result = dataset
.scan()
.filter("starts_with(name, 'e')")
.unwrap()
.try_into_batch()
.await
.unwrap();
let starts_with_names: BTreeSet<String> = starts_with_result
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.map(|s| s.unwrap().to_string())
.collect();
assert_eq!(
starts_with_names,
BTreeSet::from(["epsilon".to_string(), "eta".to_string()])
);
}
#[rstest]
#[tokio::test]
async fn test_late_materialization(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) {
use lance_io::assert_io_lt;
use lance_table::io::commit::RenameCommitHandler;
let data = gen_batch()
.col(
"vector",
array::rand_vec::<Float32Type>(Dimension::from(32)),
)
.col("indexed", array::step::<Int32Type>())
.col("not_indexed", array::step::<Int32Type>())
.into_reader_rows(RowCount::from(1000), BatchCount::from(20));
let mut dataset = Dataset::write(
data,
"memory://test",
Some(WriteParams {
commit_handler: Some(Arc::new(RenameCommitHandler)),
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await
.unwrap();
dataset
.create_index(
&["indexed"],
IndexType::Scalar,
None,
&ScalarIndexParams::default(),
false,
)
.await
.unwrap();
let _ = dataset.object_store().io_stats_incremental(); dataset.scan().try_into_batch().await.unwrap();
let io_stats = dataset.object_store().io_stats_incremental();
let full_scan_bytes = io_stats.read_bytes;
dataset
.scan()
.use_stats(false)
.filter("not_indexed = 50")
.unwrap()
.try_into_batch()
.await
.unwrap();
let io_stats = dataset.object_store().io_stats_incremental();
assert_io_lt!(io_stats, read_bytes, full_scan_bytes);
let filtered_scan_bytes = io_stats.read_bytes;
if data_storage_version == LanceFileVersion::Legacy {
dataset
.scan()
.filter("not_indexed = 50")
.unwrap()
.try_into_batch()
.await
.unwrap();
let io_stats = dataset.object_store().io_stats_incremental();
assert_io_lt!(io_stats, read_bytes, filtered_scan_bytes);
}
dataset
.scan()
.filter("indexed = 50")
.unwrap()
.try_into_batch()
.await
.unwrap();
let io_stats = dataset.object_store().io_stats_incremental();
assert_io_lt!(io_stats, read_bytes, full_scan_bytes);
let index_scan_bytes = io_stats.read_bytes;
dataset
.scan()
.filter("indexed = 50")
.unwrap()
.try_into_batch()
.await
.unwrap();
let io_stats = dataset.object_store().io_stats_incremental();
assert_io_lt!(io_stats, read_bytes, index_scan_bytes);
}
#[rstest]
#[tokio::test]
async fn test_project_nested(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) -> Result<()> {
let struct_i_field = ArrowField::new("i", DataType::Int32, true);
let struct_o_field = ArrowField::new("o", DataType::Utf8, true);
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new(
"struct",
DataType::Struct(vec![struct_i_field.clone(), struct_o_field.clone()].into()),
true,
),
ArrowField::new("s", DataType::Utf8, true),
]));
let input_batches: Vec<RecordBatch> = (0..5)
.map(|i| {
let struct_i_arr: Arc<Int32Array> =
Arc::new(Int32Array::from_iter_values(i * 20..(i + 1) * 20));
let struct_o_arr: Arc<StringArray> = Arc::new(StringArray::from_iter_values(
(i * 20..(i + 1) * 20).map(|v| format!("o-{:02}", v)),
));
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StructArray::from(vec![
(Arc::new(struct_i_field.clone()), struct_i_arr as ArrayRef),
(Arc::new(struct_o_field.clone()), struct_o_arr as ArrayRef),
])),
Arc::new(StringArray::from_iter_values(
(i * 20..(i + 1) * 20).map(|v| format!("s-{}", v)),
)),
],
)
.unwrap()
})
.collect();
let batches =
RecordBatchIterator::new(input_batches.clone().into_iter().map(Ok), schema.clone());
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let write_params = WriteParams {
max_rows_per_file: 40,
max_rows_per_group: 10,
data_storage_version: Some(data_storage_version),
..Default::default()
};
Dataset::write(batches, test_uri, Some(write_params))
.await
.unwrap();
let dataset = Dataset::open(test_uri).await.unwrap();
let batches = dataset
.scan()
.project(&["struct.i"])
.unwrap()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = concat_batches(&batches[0].schema(), &batches).unwrap();
assert!(batch.column_by_name("struct.i").is_some());
Ok(())
}
#[rstest]
#[tokio::test]
async fn test_plans(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
#[values(false, true)] stable_row_id: bool,
) -> Result<()> {
use lance_index::scalar::inverted::query::BoostQuery;
let dim = 256;
let mut dataset =
TestVectorDataset::new_with_dimension(data_storage_version, stable_row_id, dim).await?;
let lance_schema = dataset.dataset.schema();
if data_storage_version == LanceFileVersion::Legacy {
log::info!("Test case: Pushdown scan");
assert_plan_equals(
&dataset.dataset,
|scan| scan.project(&["s"])?.filter("i > 10 and i < 20"),
"LancePushdownScan: uri=..., projection=[s], predicate=i > Int32(10) AND i < Int32(20), row_id=false, row_addr=false, ordered=true"
).await?;
}
log::info!("Test case: Project and filter");
let expected = if data_storage_version == LanceFileVersion::Legacy {
"ProjectionExec: expr=[s@2 as s]
Take: columns=\"i, _rowid, (s)\"
CoalesceBatchesExec: target_batch_size=8192
FilterExec: i@0 > 10 AND i@0 < 20
LanceScan: uri..., projection=[i], row_id=true, row_addr=false, ordered=true, range=None"
} else {
"ProjectionExec: expr=[s@2 as s]
Take: columns=\"i, _rowid, (s)\"
CoalesceBatchesExec: target_batch_size=8192
LanceRead: ..., projection=[i], num_fragments=2, range_before=None, range_after=None, row_id=true, row_addr=false, full_filter=i > Int32(10) AND i < Int32(20), refine_filter=i > Int32(10) AND i < Int32(20)"
};
assert_plan_equals(
&dataset.dataset,
|scan| {
scan.use_stats(false)
.project(&["s"])?
.filter("i > 10 and i < 20")
},
expected,
)
.await?;
log::info!("Test case: Late materialization");
let expected = if data_storage_version == LanceFileVersion::Legacy {
"ProjectionExec: expr=[i@0 as i, s@1 as s, vec@3 as vec]
Take: columns=\"i, s, _rowid, (vec)\"
CoalesceBatchesExec: target_batch_size=8192
FilterExec: s@1 IS NOT NULL
LanceScan: uri..., projection=[i, s], row_id=true, row_addr=false, ordered=true, range=None"
} else {
"ProjectionExec: expr=[i@0 as i, s@1 as s, vec@3 as vec]
Take: columns=\"i, s, _rowid, (vec)\"
CoalesceBatchesExec: target_batch_size=8192
LanceRead: uri=..., projection=[i, s], num_fragments=2, range_before=None, range_after=None, \
row_id=true, row_addr=false, full_filter=s IS NOT NULL, refine_filter=s IS NOT NULL"
};
assert_plan_equals(
&dataset.dataset,
|scan| scan.use_stats(false).filter("s IS NOT NULL"),
expected,
)
.await?;
log::info!("Test case: Custom materialization (all early)");
let expected = if data_storage_version == LanceFileVersion::Legacy {
"ProjectionExec: expr=[i@0 as i, s@1 as s, vec@2 as vec]
FilterExec: s@1 IS NOT NULL
LanceScan: uri..., projection=[i, s, vec], row_id=true, row_addr=false, ordered=true, range=None"
} else {
"ProjectionExec: expr=[i@0 as i, s@1 as s, vec@2 as vec]
LanceRead: uri=..., projection=[i, s, vec], num_fragments=2, range_before=None, \
range_after=None, row_id=true, row_addr=false, full_filter=s IS NOT NULL, refine_filter=s IS NOT NULL"
};
assert_plan_equals(
&dataset.dataset,
|scan| {
scan.use_stats(false)
.materialization_style(MaterializationStyle::AllEarly)
.filter("s IS NOT NULL")
},
expected,
)
.await?;
log::info!("Test case: Custom materialization 2 (all late)");
let expected = if data_storage_version == LanceFileVersion::Legacy {
"ProjectionExec: expr=[i@2 as i, s@0 as s, vec@3 as vec]
Take: columns=\"s, _rowid, (i), (vec)\"
CoalesceBatchesExec: target_batch_size=8192
FilterExec: s@0 IS NOT NULL
LanceScan: uri..., projection=[s], row_id=true, row_addr=false, ordered=true, range=None"
} else {
"ProjectionExec: expr=[i@2 as i, s@0 as s, vec@3 as vec]
Take: columns=\"s, _rowid, (i), (vec)\"
CoalesceBatchesExec: target_batch_size=8192
LanceRead: uri=..., projection=[s], num_fragments=2, range_before=None, \
range_after=None, row_id=true, row_addr=false, full_filter=s IS NOT NULL, refine_filter=s IS NOT NULL"
};
assert_plan_equals(
&dataset.dataset,
|scan| {
scan.use_stats(false)
.materialization_style(MaterializationStyle::AllLate)
.filter("s IS NOT NULL")
},
expected,
)
.await?;
log::info!("Test case: Custom materialization 3 (mixed)");
let expected = if data_storage_version == LanceFileVersion::Legacy {
"ProjectionExec: expr=[i@3 as i, s@0 as s, vec@1 as vec]
Take: columns=\"s, vec, _rowid, (i)\"
CoalesceBatchesExec: target_batch_size=8192
FilterExec: s@0 IS NOT NULL
LanceScan: uri..., projection=[s, vec], row_id=true, row_addr=false, ordered=true, range=None"
} else {
"ProjectionExec: expr=[i@3 as i, s@0 as s, vec@1 as vec]
Take: columns=\"s, vec, _rowid, (i)\"
CoalesceBatchesExec: target_batch_size=8192
LanceRead: uri=..., projection=[s, vec], num_fragments=2, range_before=None, range_after=None, \
row_id=true, row_addr=false, full_filter=s IS NOT NULL, refine_filter=s IS NOT NULL"
};
assert_plan_equals(
&dataset.dataset,
|scan| {
scan.use_stats(false)
.materialization_style(
MaterializationStyle::all_early_except(&["i"], lance_schema).unwrap(),
)
.filter("s IS NOT NULL")
},
expected,
)
.await?;
log::info!("Test case: Scan out of order");
let expected = if data_storage_version == LanceFileVersion::Legacy {
"LanceScan: uri=..., projection=[s], row_id=true, row_addr=false, ordered=false, range=None"
} else {
"LanceRead: uri=..., projection=[s], num_fragments=2, range_before=None, range_after=None, row_id=true, \
row_addr=false, full_filter=--, refine_filter=--"
};
assert_plan_equals(
&dataset.dataset,
|scan| Ok(scan.project(&["s"])?.with_row_id().scan_in_order(false)),
expected,
)
.await?;
let q: Float32Array = (32..32 + dim).map(|v| v as f32).collect();
log::info!("Test case: Basic KNN");
let expected = if data_storage_version == LanceFileVersion::Legacy {
"ProjectionExec: expr=[i@3 as i, s@4 as s, vec@0 as vec, _distance@2 as _distance]
Take: columns=\"vec, _rowid, _distance, (i), (s)\"
CoalesceBatchesExec: target_batch_size=8192
FilterExec: _distance@2 IS NOT NULL
SortExec: TopK(fetch=5), expr=...
KNNVectorDistance: metric=l2
LanceScan: uri=..., projection=[vec], row_id=true, row_addr=false, ordered=false, range=None"
} else {
"ProjectionExec: expr=[i@3 as i, s@4 as s, vec@0 as vec, _distance@2 as _distance]
Take: columns=\"vec, _rowid, _distance, (i), (s)\"
CoalesceBatchesExec: target_batch_size=8192
FilterExec: _distance@2 IS NOT NULL
SortExec: TopK(fetch=5), expr=...
KNNVectorDistance: metric=l2
LanceRead: uri=..., projection=[vec], num_fragments=2, range_before=None, range_after=None, \
row_id=true, row_addr=false, full_filter=--, refine_filter=--"
};
assert_plan_equals(
&dataset.dataset,
|scan| scan.nearest("vec", &q, 5),
expected,
)
.await?;
let q: Float32Array = (32..32 + dim).map(|v| v as f32).collect();
log::info!("Test case: KNN with extraneous limit");
let expected = if data_storage_version == LanceFileVersion::Legacy {
"ProjectionExec: expr=[i@3 as i, s@4 as s, vec@0 as vec, _distance@2 as _distance]
Take: columns=\"vec, _rowid, _distance, (i), (s)\"
CoalesceBatchesExec: target_batch_size=8192
GlobalLimitExec: skip=0, fetch=1
FilterExec: _distance@2 IS NOT NULL
SortExec: TopK(fetch=5), expr=...
KNNVectorDistance: metric=l2
LanceScan: uri=..., projection=[vec], row_id=true, row_addr=false, ordered=false, range=None"
} else {
"ProjectionExec: expr=[i@3 as i, s@4 as s, vec@0 as vec, _distance@2 as _distance]
Take: columns=\"vec, _rowid, _distance, (i), (s)\"
CoalesceBatchesExec: target_batch_size=8192
GlobalLimitExec: skip=0, fetch=1
FilterExec: _distance@2 IS NOT NULL
SortExec: TopK(fetch=5), expr=...
KNNVectorDistance: metric=l2
LanceRead: uri=..., projection=[vec], num_fragments=2, range_before=None, range_after=None, \
row_id=true, row_addr=false, full_filter=--, refine_filter=--"
};
assert_plan_equals(
&dataset.dataset,
|scan| scan.nearest("vec", &q, 5)?.limit(Some(1), None),
expected,
)
.await?;
dataset.make_vector_index().await?;
log::info!("Test case: Basic ANN");
let expected =
"ProjectionExec: expr=[i@2 as i, s@3 as s, vec@4 as vec, _distance@0 as _distance]
Take: columns=\"_distance, _rowid, (i), (s), (vec)\"
CoalesceBatchesExec: target_batch_size=8192
SortExec: TopK(fetch=42), expr=...
ANNSubIndex: name=..., k=42, deltas=1, metric=L2
ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1";
assert_plan_equals(
&dataset.dataset,
|scan| scan.nearest("vec", &q, 42),
expected,
)
.await?;
log::info!("Test case: ANN with refine");
let expected =
"ProjectionExec: expr=[i@3 as i, s@4 as s, vec@1 as vec, _distance@2 as _distance]
Take: columns=\"_rowid, vec, _distance, (i), (s)\"
CoalesceBatchesExec: target_batch_size=8192
FilterExec: _distance@... IS NOT NULL
SortExec: TopK(fetch=10), expr=...
KNNVectorDistance: metric=l2
Take: columns=\"_distance, _rowid, (vec)\"
CoalesceBatchesExec: target_batch_size=8192
SortExec: TopK(fetch=40), expr=...
ANNSubIndex: name=..., k=40, deltas=1, metric=L2
ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1";
assert_plan_equals(
&dataset.dataset,
|scan| Ok(scan.nearest("vec", &q, 10)?.refine(4)),
expected,
)
.await?;
log::info!("Test case: ANN with index disabled");
let expected = if data_storage_version == LanceFileVersion::Legacy {
"ProjectionExec: expr=[i@3 as i, s@4 as s, vec@0 as vec, _distance@2 as _distance]
Take: columns=\"vec, _rowid, _distance, (i), (s)\"
CoalesceBatchesExec: target_batch_size=8192
FilterExec: _distance@... IS NOT NULL
SortExec: TopK(fetch=13), expr=...
KNNVectorDistance: metric=l2
LanceScan: uri=..., projection=[vec], row_id=true, row_addr=false, ordered=false, range=None"
} else {
"ProjectionExec: expr=[i@3 as i, s@4 as s, vec@0 as vec, _distance@2 as _distance]
Take: columns=\"vec, _rowid, _distance, (i), (s)\"
CoalesceBatchesExec: target_batch_size=8192
FilterExec: _distance@... IS NOT NULL
SortExec: TopK(fetch=13), expr=...
KNNVectorDistance: metric=l2
LanceRead: uri=..., projection=[vec], num_fragments=2, range_before=None, range_after=None, \
row_id=true, row_addr=false, full_filter=--, refine_filter=--"
};
assert_plan_equals(
&dataset.dataset,
|scan| Ok(scan.nearest("vec", &q, 13)?.use_index(false)),
expected,
)
.await?;
log::info!("Test case: ANN with postfilter");
let expected = "ProjectionExec: expr=[s@3 as s, vec@4 as vec, _distance@0 as _distance, _rowid@1 as _rowid]
Take: columns=\"_distance, _rowid, i, (s), (vec)\"
CoalesceBatchesExec: target_batch_size=8192
FilterExec: i@2 > 10
Take: columns=\"_distance, _rowid, (i)\"
CoalesceBatchesExec: target_batch_size=8192
SortExec: TopK(fetch=17), expr=...
ANNSubIndex: name=..., k=17, deltas=1, metric=L2
ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1";
assert_plan_equals(
&dataset.dataset,
|scan| {
Ok(scan
.nearest("vec", &q, 17)?
.filter("i > 10")?
.project(&["s", "vec"])?
.with_row_id())
},
expected,
)
.await?;
log::info!("Test case: ANN with prefilter");
let expected = if data_storage_version == LanceFileVersion::Legacy {
"ProjectionExec: expr=[i@2 as i, s@3 as s, vec@4 as vec, _distance@0 as _distance]
Take: columns=\"_distance, _rowid, (i), (s), (vec)\"
CoalesceBatchesExec: target_batch_size=8192
SortExec: TopK(fetch=17), expr=...
ANNSubIndex: name=..., k=17, deltas=1, metric=L2
ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1
FilterExec: i@0 > 10
LanceScan: uri=..., projection=[i], row_id=true, row_addr=false, ordered=false, range=None"
} else {
"ProjectionExec: expr=[i@2 as i, s@3 as s, vec@4 as vec, _distance@0 as _distance]
Take: columns=\"_distance, _rowid, (i), (s), (vec)\"
CoalesceBatchesExec: target_batch_size=8192
SortExec: TopK(fetch=17), expr=...
ANNSubIndex: name=..., k=17, deltas=1, metric=L2
ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1
LanceRead: uri=..., projection=[], num_fragments=2, range_before=None, range_after=None, \
row_id=true, row_addr=false, full_filter=i > Int32(10), refine_filter=i > Int32(10)
"
};
assert_plan_equals(
&dataset.dataset,
|scan| {
Ok(scan
.nearest("vec", &q, 17)?
.filter("i > 10")?
.prefilter(true))
},
expected,
)
.await?;
dataset.append_new_data().await?;
log::info!("Test case: Combined KNN/ANN");
let expected = "ProjectionExec: expr=[i@3 as i, s@4 as s, vec@1 as vec, _distance@2 as _distance]
Take: columns=\"_rowid, vec, _distance, (i), (s)\"
CoalesceBatchesExec: target_batch_size=8192
FilterExec: _distance@... IS NOT NULL
SortExec: TopK(fetch=6), expr=...
KNNVectorDistance: metric=l2
RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2
UnionExec
ProjectionExec: expr=[_distance@2 as _distance, _rowid@1 as _rowid, vec@0 as vec]
FilterExec: _distance@... IS NOT NULL
SortExec: TopK(fetch=6), expr=...
KNNVectorDistance: metric=l2
LanceScan: uri=..., projection=[vec], row_id=true, row_addr=false, ordered=false, range=None
Take: columns=\"_distance, _rowid, (vec)\"
CoalesceBatchesExec: target_batch_size=8192
SortExec: TopK(fetch=6), expr=...
ANNSubIndex: name=..., k=6, deltas=1, metric=L2
ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1";
assert_plan_equals(
&dataset.dataset,
|scan| scan.nearest("vec", &q, 6),
expected,
)
.await?;
log::info!("Test case: Combined KNN/ANN with postfilter");
let expected = "ProjectionExec: expr=[i@3 as i, s@4 as s, vec@1 as vec, _distance@2 as _distance]
Take: columns=\"_rowid, vec, _distance, i, (s)\"
CoalesceBatchesExec: target_batch_size=8192
FilterExec: i@3 > 10
Take: columns=\"_rowid, vec, _distance, (i)\"
CoalesceBatchesExec: target_batch_size=8192
FilterExec: _distance@... IS NOT NULL
SortExec: TopK(fetch=15), expr=...
KNNVectorDistance: metric=l2
RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2
UnionExec
ProjectionExec: expr=[_distance@2 as _distance, _rowid@1 as _rowid, vec@0 as vec]
FilterExec: _distance@... IS NOT NULL
SortExec: TopK(fetch=15), expr=...
KNNVectorDistance: metric=l2
LanceScan: uri=..., projection=[vec], row_id=true, row_addr=false, ordered=false, range=None
Take: columns=\"_distance, _rowid, (vec)\"
CoalesceBatchesExec: target_batch_size=8192
SortExec: TopK(fetch=15), expr=...
ANNSubIndex: name=..., k=15, deltas=1, metric=L2
ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1";
assert_plan_equals(
&dataset.dataset,
|scan| scan.nearest("vec", &q, 15)?.filter("i > 10"),
expected,
)
.await?;
log::info!("Test case: Combined KNN/ANN with prefilter");
let expected = if data_storage_version == LanceFileVersion::Legacy {
"ProjectionExec: expr=[i@3 as i, s@4 as s, vec@1 as vec, _distance@2 as _distance]
Take: columns=\"_rowid, vec, _distance, (i), (s)\"
CoalesceBatchesExec: target_batch_size=8192
FilterExec: _distance@... IS NOT NULL
SortExec: TopK(fetch=5), expr=...
KNNVectorDistance: metric=l2
RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2
UnionExec
ProjectionExec: expr=[_distance@3 as _distance, _rowid@2 as _rowid, vec@0 as vec]
FilterExec: _distance@... IS NOT NULL
SortExec: TopK(fetch=5), expr=...
KNNVectorDistance: metric=l2
FilterExec: i@1 > 10
LanceScan: uri=..., projection=[vec, i], row_id=true, row_addr=false, ordered=false, range=None
Take: columns=\"_distance, _rowid, (vec)\"
CoalesceBatchesExec: target_batch_size=8192
SortExec: TopK(fetch=5), expr=...
ANNSubIndex: name=..., k=5, deltas=1, metric=L2
ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1
FilterExec: i@0 > 10
LanceScan: uri=..., projection=[i], row_id=true, row_addr=false, ordered=false, range=None"
} else {
"ProjectionExec: expr=[i@3 as i, s@4 as s, vec@1 as vec, _distance@2 as _distance]
Take: columns=\"_rowid, vec, _distance, (i), (s)\"
CoalesceBatchesExec: target_batch_size=8192
FilterExec: _distance@... IS NOT NULL
SortExec: TopK(fetch=5), expr=...
KNNVectorDistance: metric=l2
RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2
UnionExec
ProjectionExec: expr=[_distance@3 as _distance, _rowid@2 as _rowid, vec@0 as vec]
FilterExec: _distance@... IS NOT NULL
SortExec: TopK(fetch=5), expr=...
KNNVectorDistance: metric=l2
FilterExec: i@1 > 10
LanceScan: uri=..., projection=[vec, i], row_id=true, row_addr=false, ordered=false, range=None
Take: columns=\"_distance, _rowid, (vec)\"
CoalesceBatchesExec: target_batch_size=8192
SortExec: TopK(fetch=5), expr=...
ANNSubIndex: name=..., k=5, deltas=1, metric=L2
ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1
LanceRead: uri=..., projection=[], num_fragments=2, range_before=None, range_after=None, \
row_id=true, row_addr=false, full_filter=i > Int32(10), refine_filter=i > Int32(10)"
};
assert_plan_equals(
&dataset.dataset,
|scan| {
Ok(scan
.nearest("vec", &q, 5)?
.filter("i > 10")?
.prefilter(true))
},
expected,
)
.await?;
dataset.make_vector_index().await?;
dataset.make_scalar_index().await?;
log::info!("Test case: ANN with scalar index");
let expected =
"ProjectionExec: expr=[i@2 as i, s@3 as s, vec@4 as vec, _distance@0 as _distance]
Take: columns=\"_distance, _rowid, (i), (s), (vec)\"
CoalesceBatchesExec: target_batch_size=8192
SortExec: TopK(fetch=5), expr=...
ANNSubIndex: name=..., k=5, deltas=1, metric=L2
ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1
ScalarIndexQuery: query=[i > 10]@i_idx";
assert_plan_equals(
&dataset.dataset,
|scan| {
Ok(scan
.nearest("vec", &q, 5)?
.filter("i > 10")?
.prefilter(true))
},
expected,
)
.await?;
log::info!("Test case: ANN with scalar index disabled");
let expected = if data_storage_version == LanceFileVersion::Legacy {
"ProjectionExec: expr=[i@2 as i, s@3 as s, vec@4 as vec, _distance@0 as _distance]
Take: columns=\"_distance, _rowid, (i), (s), (vec)\"
CoalesceBatchesExec: target_batch_size=8192
SortExec: TopK(fetch=5), expr=...
ANNSubIndex: name=..., k=5, deltas=1, metric=L2
ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1
FilterExec: i@0 > 10
LanceScan: uri=..., projection=[i], row_id=true, row_addr=false, ordered=false, range=None"
} else {
"ProjectionExec: expr=[i@2 as i, s@3 as s, vec@4 as vec, _distance@0 as _distance]
Take: columns=\"_distance, _rowid, (i), (s), (vec)\"
CoalesceBatchesExec: target_batch_size=8192
SortExec: TopK(fetch=5), expr=...
ANNSubIndex: name=..., k=5, deltas=1, metric=L2
ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1
LanceRead: uri=..., projection=[], num_fragments=3, range_before=None, \
range_after=None, row_id=true, row_addr=false, full_filter=i > Int32(10), refine_filter=i > Int32(10)"
};
assert_plan_equals(
&dataset.dataset,
|scan| {
Ok(scan
.nearest("vec", &q, 5)?
.use_scalar_index(false)
.filter("i > 10")?
.prefilter(true))
},
expected,
)
.await?;
dataset.append_new_data().await?;
log::info!("Test case: Combined KNN/ANN with scalar index");
let expected = "ProjectionExec: expr=[i@3 as i, s@4 as s, vec@1 as vec, _distance@2 as _distance]
Take: columns=\"_rowid, vec, _distance, (i), (s)\"
CoalesceBatchesExec: target_batch_size=8192
FilterExec: _distance@... IS NOT NULL
SortExec: TopK(fetch=8), expr=...
KNNVectorDistance: metric=l2
RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2
UnionExec
ProjectionExec: expr=[_distance@3 as _distance, _rowid@2 as _rowid, vec@0 as vec]
FilterExec: _distance@... IS NOT NULL
SortExec: TopK(fetch=8), expr=...
KNNVectorDistance: metric=l2
FilterExec: i@1 > 10
LanceScan: uri=..., projection=[vec, i], row_id=true, row_addr=false, ordered=false, range=None
Take: columns=\"_distance, _rowid, (vec)\"
CoalesceBatchesExec: target_batch_size=8192
SortExec: TopK(fetch=8), expr=...
ANNSubIndex: name=..., k=8, deltas=1, metric=L2
ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1
ScalarIndexQuery: query=[i > 10]@i_idx";
assert_plan_equals(
&dataset.dataset,
|scan| {
Ok(scan
.nearest("vec", &q, 8)?
.filter("i > 10")?
.prefilter(true))
},
expected,
)
.await?;
log::info!(
"Test case: Combined KNN/ANN with updated scalar index and outdated vector index"
);
let expected = "ProjectionExec: expr=[i@3 as i, s@4 as s, vec@1 as vec, _distance@2 as _distance]
Take: columns=\"_rowid, vec, _distance, (i), (s)\"
CoalesceBatchesExec: target_batch_size=8192
FilterExec: _distance@... IS NOT NULL
SortExec: TopK(fetch=11), expr=...
KNNVectorDistance: metric=l2
RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2
UnionExec
ProjectionExec: expr=[_distance@3 as _distance, _rowid@2 as _rowid, vec@0 as vec]
FilterExec: _distance@... IS NOT NULL
SortExec: TopK(fetch=11), expr=...
KNNVectorDistance: metric=l2
FilterExec: i@1 > 10
LanceScan: uri=..., projection=[vec, i], row_id=true, row_addr=false, ordered=false, range=None
Take: columns=\"_distance, _rowid, (vec)\"
CoalesceBatchesExec: target_batch_size=8192
SortExec: TopK(fetch=11), expr=...
ANNSubIndex: name=..., k=11, deltas=1, metric=L2
ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1
ScalarIndexQuery: query=[i > 10]@i_idx";
dataset.make_scalar_index().await?;
assert_plan_equals(
&dataset.dataset,
|scan| {
Ok(scan
.nearest("vec", &q, 11)?
.filter("i > 10")?
.prefilter(true))
},
expected,
)
.await?;
log::info!("Test case: Filtered read with scalar index");
let expected = if data_storage_version == LanceFileVersion::Legacy {
"ProjectionExec: expr=[s@1 as s]
Take: columns=\"_rowid, (s)\"
CoalesceBatchesExec: target_batch_size=8192
MaterializeIndex: query=[i > 10]@i_idx"
} else {
"LanceRead: uri=..., projection=[s], num_fragments=4, range_before=None, \
range_after=None, row_id=false, row_addr=false, full_filter=i > Int32(10), refine_filter=--
ScalarIndexQuery: query=[i > 10]@i_idx"
};
assert_plan_equals(
&dataset.dataset,
|scan| scan.project(&["s"])?.filter("i > 10"),
expected,
)
.await?;
if data_storage_version != LanceFileVersion::Legacy {
log::info!(
"Test case: Filtered read with scalar index disabled (late materialization)"
);
assert_plan_equals(
&dataset.dataset,
|scan| {
scan.project(&["s"])?
.use_scalar_index(false)
.filter("i > 10")
},
"ProjectionExec: expr=[s@2 as s]
Take: columns=\"i, _rowid, (s)\"
CoalesceBatchesExec: target_batch_size=8192
LanceRead: uri=..., projection=[i], num_fragments=4, range_before=None, \
range_after=None, row_id=true, row_addr=false, full_filter=i > Int32(10), refine_filter=i > Int32(10)",
)
.await?;
}
log::info!("Test case: Empty projection");
let expected = if data_storage_version == LanceFileVersion::Legacy {
"ProjectionExec: expr=[_rowaddr@0 as _rowaddr]
AddRowAddrExec
MaterializeIndex: query=[i > 10]@i_idx"
} else {
"LanceRead: uri=..., projection=[], num_fragments=4, range_before=None, \
range_after=None, row_id=false, row_addr=true, full_filter=i > Int32(10), refine_filter=--
ScalarIndexQuery: query=[i > 10]@i_idx"
};
assert_plan_equals(
&dataset.dataset,
|scan| {
scan.filter("i > 10")
.unwrap()
.with_row_address()
.project::<&str>(&[])
},
expected,
)
.await?;
dataset.append_new_data().await?;
log::info!("Test case: Combined Scalar/non-scalar filtered read");
let expected = if data_storage_version == LanceFileVersion::Legacy {
"ProjectionExec: expr=[s@1 as s]
RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2
UnionExec
Take: columns=\"_rowid, (s)\"
CoalesceBatchesExec: target_batch_size=8192
MaterializeIndex: query=[i > 10]@i_idx
ProjectionExec: expr=[_rowid@2 as _rowid, s@1 as s]
FilterExec: i@0 > 10
LanceScan: uri=..., projection=[i, s], row_id=true, row_addr=false, ordered=false, range=None"
} else {
"LanceRead: uri=..., projection=[s], num_fragments=5, range_before=None, \
range_after=None, row_id=false, row_addr=false, full_filter=i > Int32(10), refine_filter=--
ScalarIndexQuery: query=[i > 10]@i_idx"
};
assert_plan_equals(
&dataset.dataset,
|scan| scan.project(&["s"])?.filter("i > 10"),
expected,
)
.await?;
log::info!("Test case: Combined Scalar/non-scalar filtered read with empty projection");
let expected = if data_storage_version == LanceFileVersion::Legacy {
"ProjectionExec: expr=[_rowaddr@0 as _rowaddr]
RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2
UnionExec
AddRowAddrExec
MaterializeIndex: query=[i > 10]@i_idx
ProjectionExec: expr=[_rowaddr@2 as _rowaddr, _rowid@1 as _rowid]
FilterExec: i@0 > 10
LanceScan: uri=..., projection=[i], row_id=true, row_addr=true, ordered=false, range=None"
} else {
"LanceRead: uri=..., projection=[], num_fragments=5, range_before=None, \
range_after=None, row_id=false, row_addr=true, full_filter=i > Int32(10), refine_filter=--
ScalarIndexQuery: query=[i > 10]@i_idx"
};
assert_plan_equals(
&dataset.dataset,
|scan| {
scan.filter("i > 10")
.unwrap()
.with_row_address()
.project::<&str>(&[])
},
expected,
)
.await?;
log::info!("Test case: Dynamic projection");
let expected = if data_storage_version == LanceFileVersion::Legacy {
"ProjectionExec: expr=[regexp_match(s@1, .*) as matches]
RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2
UnionExec
Take: columns=\"_rowid, (s)\"
CoalesceBatchesExec: target_batch_size=8192
MaterializeIndex: query=[i > 10]@i_idx
ProjectionExec: expr=[_rowid@2 as _rowid, s@1 as s]
FilterExec: i@0 > 10
LanceScan: uri=..., row_id=true, row_addr=false, ordered=false, range=None"
} else {
"ProjectionExec: expr=[regexp_match(s@0, .*) as matches]
LanceRead: uri=..., projection=[s], num_fragments=5, range_before=None, \
range_after=None, row_id=false, row_addr=false, full_filter=i > Int32(10), refine_filter=--
ScalarIndexQuery: query=[i > 10]@i_idx"
};
assert_plan_equals(
&dataset.dataset,
|scan| {
scan.project_with_transform(&[("matches", "regexp_match(s, \".*\")")])?
.filter("i > 10")
},
expected,
)
.await?;
dataset.make_fts_index().await?;
log::info!("Test case: Full text search (match query)");
let expected = r#"ProjectionExec: expr=[s@2 as s, _score@1 as _score, _rowid@0 as _rowid]
Take: columns="_rowid, _score, (s)"
CoalesceBatchesExec: target_batch_size=8192
MatchQuery: column=s, query=hello"#;
assert_plan_equals(
&dataset.dataset,
|scan| {
scan.project(&["s"])?
.with_row_id()
.full_text_search(FullTextSearchQuery::new("hello".to_owned()))
},
expected,
)
.await?;
log::info!("Test case: Full text search (phrase query)");
let expected = r#"ProjectionExec: expr=[s@2 as s, _score@1 as _score, _rowid@0 as _rowid]
Take: columns="_rowid, _score, (s)"
CoalesceBatchesExec: target_batch_size=8192
PhraseQuery: column=s, query=hello world"#;
assert_plan_equals(
&dataset.dataset,
|scan| {
let query = PhraseQuery::new("hello world".to_owned());
scan.project(&["s"])?
.with_row_id()
.full_text_search(FullTextSearchQuery::new_query(query.into()))
},
expected,
)
.await?;
log::info!("Test case: Full text search (boost query)");
let expected = r#"ProjectionExec: expr=[s@2 as s, _score@1 as _score, _rowid@0 as _rowid]
Take: columns="_rowid, _score, (s)"
CoalesceBatchesExec: target_batch_size=8192
BoostQuery: negative_boost=1
MatchQuery: column=s, query=hello
MatchQuery: column=s, query=world"#;
assert_plan_equals(
&dataset.dataset,
|scan| {
let positive =
MatchQuery::new("hello".to_owned()).with_column(Some("s".to_owned()));
let negative =
MatchQuery::new("world".to_owned()).with_column(Some("s".to_owned()));
let query = BoostQuery::new(positive.into(), negative.into(), Some(1.0));
scan.project(&["s"])?
.with_row_id()
.full_text_search(FullTextSearchQuery::new_query(query.into()))
},
expected,
)
.await?;
log::info!("Test case: Full text search with prefilter");
let expected = if data_storage_version == LanceFileVersion::Legacy {
r#"ProjectionExec: expr=[s@2 as s, _score@1 as _score, _rowid@0 as _rowid]
Take: columns="_rowid, _score, (s)"
CoalesceBatchesExec: target_batch_size=8192
MatchQuery: column=s, query=hello
RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2
UnionExec
MaterializeIndex: query=[i > 10]@i_idx
ProjectionExec: expr=[_rowid@1 as _rowid]
FilterExec: i@0 > 10
LanceScan: uri=..., projection=[i], row_id=true, row_addr=false, ordered=false, range=None"#
} else {
r#"ProjectionExec: expr=[s@2 as s, _score@1 as _score, _rowid@0 as _rowid]
Take: columns="_rowid, _score, (s)"
CoalesceBatchesExec: target_batch_size=8192
MatchQuery: column=s, query=hello
LanceRead: uri=..., projection=[], num_fragments=5, range_before=None, range_after=None, row_id=true, row_addr=false, full_filter=i > Int32(10), refine_filter=--
ScalarIndexQuery: query=[i > 10]@i_idx"#
};
assert_plan_equals(
&dataset.dataset,
|scan| {
scan.project(&["s"])?
.with_row_id()
.filter("i > 10")?
.prefilter(true)
.full_text_search(FullTextSearchQuery::new("hello".to_owned()))
},
expected,
)
.await?;
log::info!("Test case: Full text search with unindexed rows");
let expected = r#"ProjectionExec: expr=[s@2 as s, _score@1 as _score, _rowid@0 as _rowid]
Take: columns="_rowid, _score, (s)"
CoalesceBatchesExec: target_batch_size=8192
SortExec: expr=[_score@1 DESC NULLS LAST], preserve_partitioning=[false]
RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2
UnionExec
MatchQuery: column=s, query=hello
FlatMatchQuery: column=s, query=hello
LanceScan: uri=..., projection=[s], row_id=true, row_addr=false, ordered=false, range=None"#;
dataset.append_new_data().await?;
assert_plan_equals(
&dataset.dataset,
|scan| {
scan.project(&["s"])?
.with_row_id()
.full_text_search(FullTextSearchQuery::new("hello".to_owned()))
},
expected,
)
.await?;
log::info!("Test case: Full text search with unindexed rows and fast_search");
let expected = r#"ProjectionExec: expr=[s@2 as s, _score@1 as _score, _rowid@0 as _rowid]
Take: columns="_rowid, _score, (s)"
CoalesceBatchesExec: target_batch_size=8192
MatchQuery: column=s, query=hello"#;
assert_plan_equals(
&dataset.dataset,
|scan| {
let scan = scan
.project(&["s"])?
.with_row_id()
.full_text_search(FullTextSearchQuery::new("hello".to_owned()))?;
scan.fast_search();
Ok(scan)
},
expected,
)
.await?;
log::info!("Test case: Full text search with unindexed rows and prefilter");
let expected = if data_storage_version == LanceFileVersion::Legacy {
r#"ProjectionExec: expr=[s@2 as s, _score@1 as _score, _rowid@0 as _rowid]
Take: columns="_rowid, _score, (s)"
CoalesceBatchesExec: target_batch_size=8192
SortExec: expr=[_score@1 DESC NULLS LAST], preserve_partitioning=[false]
RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2
UnionExec
MatchQuery: column=s, query=hello
RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2
UnionExec
MaterializeIndex: query=[i > 10]@i_idx
ProjectionExec: expr=[_rowid@1 as _rowid]
FilterExec: i@0 > 10
LanceScan: uri=..., projection=[i], row_id=true, row_addr=false, ordered=false, range=None
FlatMatchQuery: column=s, query=hello
FilterExec: i@1 > 10
LanceScan: uri=..., projection=[s, i], row_id=true, row_addr=false, ordered=false, range=None"#
} else {
r#"ProjectionExec: expr=[s@2 as s, _score@1 as _score, _rowid@0 as _rowid]
Take: columns="_rowid, _score, (s)"
CoalesceBatchesExec: target_batch_size=8192
SortExec: expr=[_score@1 DESC NULLS LAST], preserve_partitioning=[false]
RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2
UnionExec
MatchQuery: column=s, query=hello
LanceRead: uri=..., projection=[], num_fragments=5, range_before=None, range_after=None, row_id=true, row_addr=false, full_filter=i > Int32(10), refine_filter=--
ScalarIndexQuery: query=[i > 10]@i_idx
FlatMatchQuery: column=s, query=hello
FilterExec: i@1 > 10
LanceScan: uri=..., projection=[s, i], row_id=true, row_addr=false, ordered=false, range=None"#
};
assert_plan_equals(
&dataset.dataset,
|scan| {
scan.project(&["s"])?
.with_row_id()
.filter("i > 10")?
.prefilter(true)
.full_text_search(FullTextSearchQuery::new("hello".to_owned()))
},
expected,
)
.await?;
Ok(())
}
#[tokio::test]
async fn test_fast_search_plan() {
let mut dataset = TestVectorDataset::new(LanceFileVersion::Stable, true)
.await
.unwrap();
dataset.make_vector_index().await.unwrap();
dataset.append_new_data().await.unwrap();
let q: Float32Array = (32..64).map(|v| v as f32).collect();
assert_plan_equals(
&dataset.dataset,
|scan| {
scan.nearest("vec", &q, 32)?
.fast_search()
.project(&["_distance", "_rowid"])
},
"SortExec: TopK(fetch=32), expr=[_distance@0 ASC NULLS LAST, _rowid@1 ASC NULLS LAST]...
ANNSubIndex: name=idx, k=32, deltas=1, metric=L2
ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1",
)
.await
.unwrap();
assert_plan_equals(
&dataset.dataset,
|scan| {
scan.nearest("vec", &q, 33)?
.fast_search()
.with_row_id()
.project(&["_distance", "_rowid"])
},
"SortExec: TopK(fetch=33), expr=[_distance@0 ASC NULLS LAST, _rowid@1 ASC NULLS LAST]...
ANNSubIndex: name=idx, k=33, deltas=1, metric=L2
ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1",
)
.await
.unwrap();
assert_plan_equals(
&dataset.dataset,
|scan| {
scan.nearest("vec", &q, 34)?
.with_row_id()
.project(&["_distance", "_rowid"])
},
"ProjectionExec: expr=[_distance@2 as _distance, _rowid@0 as _rowid]
FilterExec: _distance@2 IS NOT NULL
SortExec: TopK(fetch=34), expr=[_distance@2 ASC NULLS LAST, _rowid@0 ASC NULLS LAST]...
KNNVectorDistance: metric=l2
RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2
UnionExec
ProjectionExec: expr=[_distance@2 as _distance, _rowid@1 as _rowid, vec@0 as vec]
FilterExec: _distance@2 IS NOT NULL
SortExec: TopK(fetch=34), expr=[_distance@2 ASC NULLS LAST, _rowid@1 ASC NULLS LAST]...
KNNVectorDistance: metric=l2
LanceScan: uri=..., projection=[vec], row_id=true, row_addr=false, ordered=false, range=None
Take: columns=\"_distance, _rowid, (vec)\"
CoalesceBatchesExec: target_batch_size=8192
SortExec: TopK(fetch=34), expr=[_distance@0 ASC NULLS LAST, _rowid@1 ASC NULLS LAST]...
ANNSubIndex: name=idx, k=34, deltas=1, metric=L2
ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1",
)
.await
.unwrap();
}
#[tokio::test]
async fn test_fast_search_without_vector_index_returns_empty() {
let dataset = TestVectorDataset::new(LanceFileVersion::Stable, true)
.await
.unwrap();
let q: Float32Array = (32..64).map(|v| v as f32).collect();
let mut scanner = dataset.dataset.scan();
scanner.nearest("vec", &q, 10).unwrap();
let normal_rows = scanner.try_into_batch().await.unwrap().num_rows();
let mut scanner = dataset.dataset.scan();
scanner.nearest("vec", &q, 10).unwrap().fast_search();
let fast_rows = scanner.try_into_batch().await.unwrap().num_rows();
assert_eq!(normal_rows, 10);
assert_eq!(fast_rows, 0);
}
#[rstest]
#[tokio::test]
pub async fn test_scan_planning_io(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) {
use lance_index::scalar::inverted::tokenizer::InvertedIndexParams;
use lance_io::assert_io_eq;
let data = gen_batch()
.col(
"vector",
array::rand_vec::<Float32Type>(Dimension::from(32)),
)
.col("text", array::rand_utf8(ByteCount::from(4), false))
.col("indexed", array::step::<Int32Type>())
.col("not_indexed", array::step::<Int32Type>())
.into_reader_rows(RowCount::from(100), BatchCount::from(5));
let mut dataset = Dataset::write(
data,
"memory://test",
Some(WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await
.unwrap();
dataset
.create_index(
&["indexed"],
IndexType::Scalar,
None,
&ScalarIndexParams::default(),
false,
)
.await
.unwrap();
dataset
.create_index(
&["text"],
IndexType::Inverted,
None,
&InvertedIndexParams::default(),
false,
)
.await
.unwrap();
dataset
.create_index(
&["vector"],
IndexType::Vector,
None,
&VectorIndexParams {
metric_type: DistanceType::L2,
stages: vec![
StageParams::Ivf(IvfBuildParams {
max_iters: 2,
num_partitions: Some(2),
sample_rate: 2,
..Default::default()
}),
StageParams::PQ(PQBuildParams {
max_iters: 2,
num_sub_vectors: 2,
..Default::default()
}),
],
version: crate::index::vector::IndexFileVersion::Legacy,
skip_transpose: false,
},
false,
)
.await
.unwrap();
dataset
.scan()
.prefilter(true)
.filter("indexed > 10")
.unwrap()
.explain_plan(true)
.await
.unwrap();
let io_stats = dataset.object_store().io_stats_incremental();
assert_io_gt!(io_stats, read_iops, 0);
dataset
.scan()
.prefilter(true)
.filter("indexed > 10")
.unwrap()
.explain_plan(true)
.await
.unwrap();
let io_stats = dataset.object_store().io_stats_incremental();
assert_io_eq!(io_stats, read_iops, 0);
dataset
.scan()
.prefilter(true)
.filter("true")
.unwrap()
.explain_plan(true)
.await
.unwrap();
let io_stats = dataset.object_store().io_stats_incremental();
assert_io_eq!(io_stats, read_iops, 0);
dataset
.scan()
.prefilter(true)
.materialization_style(MaterializationStyle::AllEarly)
.filter("true")
.unwrap()
.explain_plan(true)
.await
.unwrap();
let io_stats = dataset.object_store().io_stats_incremental();
assert_io_eq!(io_stats, read_iops, 0);
dataset
.scan()
.prefilter(true)
.materialization_style(MaterializationStyle::AllLate)
.filter("true")
.unwrap()
.explain_plan(true)
.await
.unwrap();
let io_stats = dataset.object_store().io_stats_incremental();
assert_io_eq!(io_stats, read_iops, 0);
}
#[rstest]
#[tokio::test]
pub async fn test_row_meta_columns(
#[values(
(true, false), (false, true), (true, true) )]
columns: (bool, bool),
) {
let (with_row_id, with_row_address) = columns;
let test_dir = TempStrDir::default();
let uri = &test_dir;
let schema = Arc::new(arrow_schema::Schema::new(vec![
arrow_schema::Field::new("data_item_id", arrow_schema::DataType::Int32, false),
arrow_schema::Field::new("a", arrow_schema::DataType::Int32, false),
]));
let data = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1001, 1002, 1003])),
Arc::new(Int32Array::from(vec![1, 2, 3])),
],
)
.unwrap();
let dataset = Dataset::write(
RecordBatchIterator::new(vec![Ok(data)], schema.clone()),
uri,
None,
)
.await
.unwrap();
let mut scanner = dataset.scan();
let mut projection = vec!["data_item_id".to_string()];
if with_row_id {
scanner.with_row_id();
projection.push(ROW_ID.to_string());
}
if with_row_address {
scanner.with_row_address();
projection.push(ROW_ADDR.to_string());
}
scanner.project(&projection).unwrap();
let stream = scanner.try_into_stream().await.unwrap();
let batch = stream.try_collect::<Vec<_>>().await.unwrap().pop().unwrap();
if with_row_id {
let column = batch.column_by_name(ROW_ID).unwrap();
assert_eq!(column.data_type(), &DataType::UInt64);
}
if with_row_address {
let column = batch.column_by_name(ROW_ADDR).unwrap();
assert_eq!(column.data_type(), &DataType::UInt64);
}
let mut scanner = dataset.scan();
if with_row_id {
scanner.with_row_id();
}
if with_row_address {
scanner.with_row_address();
}
scanner.project(&["data_item_id"]).unwrap();
let stream = scanner.try_into_stream().await.unwrap();
let batch = stream.try_collect::<Vec<_>>().await.unwrap().pop().unwrap();
let meta_column = batch.column_by_name(if with_row_id { ROW_ID } else { ROW_ADDR });
assert!(meta_column.is_some());
let mut scanner = dataset.scan();
if with_row_id {
scanner.project(&[ROW_ID]).unwrap();
} else {
scanner.project(&[ROW_ADDR]).unwrap();
};
let stream = scanner.try_into_stream().await.unwrap();
assert_eq!(stream.schema().fields().len(), 1);
if with_row_id {
assert!(stream.schema().field_with_name(ROW_ID).is_ok());
} else {
assert!(stream.schema().field_with_name(ROW_ADDR).is_ok());
}
}
async fn limit_offset_equivalency_test(scanner: &Scanner) {
async fn test_one(
scanner: &Scanner,
full_result: &RecordBatch,
limit: Option<i64>,
offset: Option<i64>,
) {
let mut new_scanner = scanner.clone();
new_scanner.limit(limit, offset).unwrap();
if let Some(nearest) = new_scanner.nearest_mut() {
nearest.k = offset.unwrap_or(0).saturating_add(limit.unwrap_or(10_000)) as usize;
}
let result = new_scanner.try_into_batch().await.unwrap();
let resolved_offset = offset.unwrap_or(0).min(full_result.num_rows() as i64);
let resolved_length = limit
.unwrap_or(i64::MAX)
.min(full_result.num_rows() as i64 - resolved_offset);
let expected = full_result.slice(resolved_offset as usize, resolved_length as usize);
if expected != result {
let plan = new_scanner.analyze_plan().await.unwrap();
assert_eq!(
&expected, &result,
"Limit: {:?}, Offset: {:?}, Plan: \n{}",
limit, offset, plan
);
}
}
let mut scanner_full = scanner.clone();
if let Some(nearest) = scanner_full.nearest_mut() {
nearest.k = 500;
}
let full_results = scanner_full.try_into_batch().await.unwrap();
test_one(scanner, &full_results, Some(1), None).await;
test_one(scanner, &full_results, Some(1), Some(1)).await;
test_one(scanner, &full_results, Some(1), Some(2)).await;
test_one(scanner, &full_results, Some(1), Some(10)).await;
test_one(scanner, &full_results, Some(3), None).await;
test_one(scanner, &full_results, Some(3), Some(2)).await;
test_one(scanner, &full_results, Some(3), Some(4)).await;
test_one(scanner, &full_results, None, Some(3)).await;
test_one(scanner, &full_results, None, Some(10)).await;
}
#[tokio::test]
async fn test_scan_limit_offset() {
let test_ds = TestVectorDataset::new(LanceFileVersion::Stable, false)
.await
.unwrap();
let scanner = test_ds.dataset.scan();
limit_offset_equivalency_test(&scanner).await;
}
#[tokio::test]
async fn test_knn_limit_offset() {
let test_ds = TestVectorDataset::new(LanceFileVersion::Stable, false)
.await
.unwrap();
let query_vector = Float32Array::from(vec![0.0; 32]);
let mut scanner = test_ds.dataset.scan();
scanner
.nearest("vec", &query_vector, 5)
.unwrap()
.project(&["i"])
.unwrap();
limit_offset_equivalency_test(&scanner).await;
}
#[tokio::test]
async fn test_ivf_pq_limit_offset() {
let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, false)
.await
.unwrap();
test_ds.make_vector_index().await.unwrap();
test_ds.append_new_data().await.unwrap();
let query_vector = Float32Array::from(vec![0.0; 32]);
let mut scanner = test_ds.dataset.scan();
scanner.nearest("vec", &query_vector, 500).unwrap();
limit_offset_equivalency_test(&scanner).await;
}
#[tokio::test]
async fn test_fts_limit_offset() {
let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, false)
.await
.unwrap();
test_ds.make_fts_index().await.unwrap();
test_ds.append_new_data().await.unwrap();
let mut scanner = test_ds.dataset.scan();
scanner
.full_text_search(FullTextSearchQuery::new("4".into()))
.unwrap();
limit_offset_equivalency_test(&scanner).await;
}
#[tokio::test]
async fn test_fts_fast_search_excludes_unindexed_rows() {
let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, false)
.await
.unwrap();
test_ds.make_fts_index().await.unwrap();
test_ds.append_data_with_range(10, 20).await.unwrap();
let mut scanner = test_ds.dataset.scan();
scanner
.full_text_search(FullTextSearchQuery::new_query(
MatchQuery::new("15".to_owned())
.with_column(Some("s".to_owned()))
.into(),
))
.unwrap();
let normal_rows = scanner.try_into_batch().await.unwrap().num_rows();
let mut scanner = test_ds.dataset.scan();
scanner
.full_text_search(FullTextSearchQuery::new_query(
MatchQuery::new("15".to_owned())
.with_column(Some("s".to_owned()))
.into(),
))
.unwrap()
.fast_search();
let fast_rows = scanner.try_into_batch().await.unwrap().num_rows();
assert_eq!(normal_rows, 2);
assert_eq!(fast_rows, 1);
}
async fn test_row_offset_read_helper(
ds: &Dataset,
scan_builder: impl FnOnce(&mut Scanner) -> &mut Scanner,
expected_cols: &[&str],
expected_row_offsets: &[u64],
) {
let mut scanner = ds.scan();
let scanner = scan_builder(&mut scanner);
let stream = scanner.try_into_stream().await.unwrap();
let schema = stream.schema();
let actual_cols = schema
.fields()
.iter()
.map(|f| f.name().as_str())
.collect::<Vec<_>>();
assert_eq!(&actual_cols, expected_cols);
let batches = stream.try_collect::<Vec<_>>().await.unwrap();
let batch = arrow_select::concat::concat_batches(&schema, &batches).unwrap();
let row_offsets = batch
.column_by_name(ROW_OFFSET)
.unwrap()
.as_primitive::<UInt64Type>()
.values();
assert_eq!(row_offsets.as_ref(), expected_row_offsets);
}
#[tokio::test]
async fn test_row_offset_read() {
let mut ds = lance_datagen::gen_batch()
.col("idx", array::step::<Int32Type>())
.into_ram_dataset(FragmentCount::from(3), FragmentRowCount::from(3))
.await
.unwrap();
ds.delete("idx >= 2 AND idx <= 6").await.unwrap();
test_row_offset_read_helper(
&ds,
|scanner| scanner.project(&["idx", ROW_OFFSET]).unwrap(),
&["idx", ROW_OFFSET],
&[0, 1, 2, 3],
)
.await;
test_row_offset_read_helper(
&ds,
|scanner| scanner.project(&[ROW_OFFSET]).unwrap(),
&[ROW_OFFSET],
&[0, 1, 2, 3],
)
.await;
test_row_offset_read_helper(
&ds,
|scanner| {
scanner
.filter("idx > 1")
.unwrap()
.project(&[ROW_OFFSET])
.unwrap()
},
&[ROW_OFFSET],
&[2, 3],
)
.await;
}
#[tokio::test]
async fn test_filter_to_take() {
let mut ds = lance_datagen::gen_batch()
.col("idx", array::step::<Int32Type>())
.into_ram_dataset(FragmentCount::from(3), FragmentRowCount::from(100))
.await
.unwrap();
let row_ids = ds
.scan()
.project(&Vec::<&str>::default())
.unwrap()
.with_row_id()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let schema = row_ids[0].schema();
let row_ids = concat_batches(&schema, row_ids.iter()).unwrap();
let row_ids = row_ids.column(0).as_primitive::<UInt64Type>().clone();
let row_addrs = ds
.scan()
.project(&Vec::<&str>::default())
.unwrap()
.with_row_address()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let schema = row_addrs[0].schema();
let row_addrs = concat_batches(&schema, row_addrs.iter()).unwrap();
let row_addrs = row_addrs.column(0).as_primitive::<UInt64Type>().clone();
ds.delete("idx >= 190 AND idx < 210").await.unwrap();
let ds_copy = ds.clone();
let do_check = async move |filt: &str, expected_idx: &[i32], applies_optimization: bool| {
let mut scanner = ds_copy.scan();
scanner.filter(filt).unwrap();
let plan = scanner.explain_plan(true).await.unwrap();
if applies_optimization {
assert!(
plan.contains("OneShotStream"),
"expected take optimization to be applied. Filter: '{}'. Plan:\n{}",
filt,
plan
);
} else {
assert!(
!plan.contains("OneShotStream"),
"expected take optimization to not be applied. Filter: '{}'. Plan:\n{}",
filt,
plan
);
}
let stream = scanner.try_into_stream().await.unwrap();
let batches = stream.try_collect::<Vec<_>>().await.unwrap();
let idx = batches
.iter()
.map(|b| b.column_by_name("idx").unwrap().as_ref())
.collect::<Vec<_>>();
if idx.is_empty() {
assert!(expected_idx.is_empty());
return;
}
let idx = arrow::compute::concat(&idx).unwrap();
assert_eq!(idx.as_primitive::<Int32Type>().values(), expected_idx);
};
let check =
async |filt: &str, expected_idx: &[i32]| do_check(filt, expected_idx, true).await;
let check_no_opt = async |filt: &str, expected_idx: &[i32]| {
do_check(filt, expected_idx, false).await;
};
check("_rowid = 50", &[50]).await;
check("_rowaddr = 50", &[50]).await;
check("_rowoffset = 50", &[50]).await;
check(
"_rowid = 50 OR _rowid = 51 OR _rowid = 52 OR _rowid = 49",
&[49, 50, 51, 52],
)
.await;
check(
"_rowaddr = 50 OR _rowaddr = 51 OR _rowaddr = 52 OR _rowaddr = 49",
&[49, 50, 51, 52],
)
.await;
check(
"_rowoffset = 50 OR _rowoffset = 51 OR _rowoffset = 52 OR _rowoffset = 49",
&[49, 50, 51, 52],
)
.await;
check("_rowid IN (52, 51, 50, 17)", &[17, 50, 51, 52]).await;
check("_rowaddr IN (52, 51, 50, 17)", &[17, 50, 51, 52]).await;
check("_rowoffset IN (52, 51, 50, 17)", &[17, 50, 51, 52]).await;
check(&format!("_rowid = {}", row_ids.value(190)), &[]).await;
check(&format!("_rowaddr = {}", row_addrs.value(190)), &[]).await;
check("_rowoffset = 190", &[210]).await;
check(&format!("_rowid = {}", row_ids.value(250)), &[250]).await;
check(&format!("_rowaddr = {}", row_addrs.value(250)), &[250]).await;
check("_rowoffset = 250", &[270]).await;
check("_rowoffset = 1000", &[]).await;
check("_rowid IN (5, 10, 15) AND idx > 10", &[15]).await;
check("_rowaddr IN (5, 10, 15) AND idx > 10", &[15]).await;
check("_rowoffset IN (5, 10, 15) AND idx > 10", &[15]).await;
check("idx > 10 AND _rowid IN (5, 10, 15)", &[15]).await;
check("idx > 10 AND _rowaddr IN (5, 10, 15)", &[15]).await;
check("idx > 10 AND _rowoffset IN (5, 10, 15)", &[15]).await;
check("_rowid = 50 AND _rowid = 50", &[50]).await;
check_no_opt("_rowid = 50 AND _rowid = 51", &[]).await;
check_no_opt("(_rowid = 50 AND idx < 100) OR _rowid = 51", &[50, 51]).await;
let mut scanner = ds.scan();
scanner.filter("_rowoffset = 77").unwrap();
scanner
.project_with_transform(&[("foo", "idx * 2")])
.unwrap();
let stream = scanner.try_into_stream().await.unwrap();
let batches = stream.try_collect::<Vec<_>>().await.unwrap();
assert_eq!(batches[0].schema().field(0).name(), "foo");
let val = batches[0].column(0).as_primitive::<Int32Type>().values()[0];
assert_eq!(val, 154);
}
#[tokio::test]
async fn test_nested_field_ordering() {
use arrow_array::StructArray;
let id_array = Int32Array::from(vec![3, 1, 2]);
let nested_values = Int32Array::from(vec![30, 10, 20]);
let nested_struct = StructArray::from(vec![(
Arc::new(ArrowField::new("value", DataType::Int32, false)),
Arc::new(nested_values) as ArrayRef,
)]);
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("id", DataType::Int32, false),
ArrowField::new(
"nested",
DataType::Struct(vec![ArrowField::new("value", DataType::Int32, false)].into()),
false,
),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(id_array), Arc::new(nested_struct)],
)
.unwrap();
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let reader = RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema.clone());
let dataset = Dataset::write(reader, test_uri, None).await.unwrap();
let mut scanner = dataset.scan();
scanner
.order_by(Some(vec![ColumnOrdering {
column_name: "nested.value".to_string(),
ascending: true,
nulls_first: true,
}]))
.unwrap();
let stream = scanner.try_into_stream().await.unwrap();
let batches = stream.try_collect::<Vec<_>>().await.unwrap();
let sorted_ids = batches[0].column(0).as_primitive::<Int32Type>().values();
assert_eq!(sorted_ids[0], 1); assert_eq!(sorted_ids[1], 2); assert_eq!(sorted_ids[2], 3); }
#[tokio::test]
async fn test_limit_with_ordering_not_pushed_down() {
let id_array = Int32Array::from(vec![5, 2, 8, 1, 3, 7, 4, 6]);
let value_array = Int32Array::from(vec![50, 20, 80, 10, 30, 70, 40, 60]);
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("id", DataType::Int32, false),
ArrowField::new("value", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(id_array), Arc::new(value_array)],
)
.unwrap();
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let reader = RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema.clone());
let dataset = Dataset::write(reader, test_uri, None).await.unwrap();
let mut scanner = dataset.scan();
scanner
.order_by(Some(vec![ColumnOrdering {
column_name: "value".to_string(),
ascending: true,
nulls_first: true,
}]))
.unwrap();
scanner.limit(Some(3), None).unwrap();
let stream = scanner.try_into_stream().await.unwrap();
let batches = stream.try_collect::<Vec<_>>().await.unwrap();
let sorted_ids = batches[0].column(0).as_primitive::<Int32Type>().values();
let sorted_values = batches[0].column(1).as_primitive::<Int32Type>().values();
assert_eq!(batches[0].num_rows(), 3);
assert_eq!(sorted_ids[0], 1); assert_eq!(sorted_ids[1], 2); assert_eq!(sorted_ids[2], 3); assert_eq!(sorted_values[0], 10);
assert_eq!(sorted_values[1], 20);
assert_eq!(sorted_values[2], 30);
let mut scanner = dataset.scan();
scanner
.order_by(Some(vec![ColumnOrdering {
column_name: "value".to_string(),
ascending: true,
nulls_first: true,
}]))
.unwrap();
scanner.limit(Some(3), Some(2)).unwrap();
let stream = scanner.try_into_stream().await.unwrap();
let batches = stream.try_collect::<Vec<_>>().await.unwrap();
let sorted_ids = batches[0].column(0).as_primitive::<Int32Type>().values();
let sorted_values = batches[0].column(1).as_primitive::<Int32Type>().values();
assert_eq!(batches[0].num_rows(), 3);
assert_eq!(sorted_ids[0], 3); assert_eq!(sorted_ids[1], 4); assert_eq!(sorted_ids[2], 5); assert_eq!(sorted_values[0], 30);
assert_eq!(sorted_values[1], 40);
assert_eq!(sorted_values[2], 50);
let mut scanner = dataset.scan();
scanner.limit(Some(3), None).unwrap();
let stream = scanner.try_into_stream().await.unwrap();
let batches = stream.try_collect::<Vec<_>>().await.unwrap();
assert_eq!(batches[0].num_rows(), 3);
let unsorted_values = batches[0].column(1).as_primitive::<Int32Type>().values();
assert_eq!(unsorted_values[0], 50);
assert_eq!(unsorted_values[1], 20);
assert_eq!(unsorted_values[2], 80);
}
#[tokio::test]
async fn test_scan_with_version_columns() {
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"id",
DataType::Int32,
false,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let test_dir = lance_core::utils::tempfile::TempStrDir::default();
let test_uri = test_dir.as_str();
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let write_params = WriteParams {
enable_stable_row_ids: true,
..Default::default()
};
Dataset::write(reader, test_uri, Some(write_params))
.await
.unwrap();
let dataset = Dataset::open(test_uri).await.unwrap();
let mut scanner = dataset.scan();
scanner
.project(&[ROW_CREATED_AT_VERSION, ROW_LAST_UPDATED_AT_VERSION])
.unwrap();
let output_schema = scanner.schema().await.unwrap();
assert!(
output_schema
.column_with_name("_row_last_updated_at_version")
.is_some(),
"Schema should include _row_last_updated_at_version"
);
assert!(
output_schema
.column_with_name("_row_created_at_version")
.is_some(),
"Schema should include _row_created_at_version"
);
let batches = scanner
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(batches.len(), 1);
let batch = &batches[0];
let last_updated = batch
.column_by_name("_row_last_updated_at_version")
.expect("Should have _row_last_updated_at_version column");
let created_at = batch
.column_by_name("_row_created_at_version")
.expect("Should have _row_created_at_version column");
let last_updated_array = last_updated
.as_any()
.downcast_ref::<arrow_array::UInt64Array>()
.unwrap();
let created_at_array = created_at
.as_any()
.downcast_ref::<arrow_array::UInt64Array>()
.unwrap();
for i in 0..batch.num_rows() {
assert_eq!(
last_updated_array.value(i),
1,
"All rows last updated at version 1"
);
assert_eq!(
created_at_array.value(i),
1,
"All rows created at version 1"
);
}
}
#[test_log::test(test)]
fn test_scan_finishes_all_tasks() {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_time()
.build()
.unwrap();
runtime.block_on(async move {
let ds = lance_datagen::gen_batch()
.col("id", lance_datagen::array::step::<Int32Type>())
.into_ram_dataset(FragmentCount::from(1000), FragmentRowCount::from(10))
.await
.unwrap();
let mut stream = ds
.scan()
.fragment_readahead(1000)
.batch_size(1)
.io_buffer_size(1)
.batch_readahead(1)
.try_into_stream()
.await
.unwrap();
stream.next().await.unwrap().unwrap();
});
let start = Instant::now();
while start.elapsed() < Duration::from_secs(10) {
if runtime.handle().metrics().num_alive_tasks() == 0 {
break;
}
std::thread::sleep(Duration::from_millis(100));
}
assert!(
runtime.handle().metrics().num_alive_tasks() == 0,
"Tasks should have finished within 10 seconds but there are still {} tasks running",
runtime.handle().metrics().num_alive_tasks()
);
}
fn assert_values_in_range(array: &Int32Array, range: std::ops::Range<i32>, msg: &str) {
assert!(!array.is_empty(), "Expected some results but got none");
assert!(
array
.iter()
.all(|v| v.is_some_and(|val| range.contains(&val))),
"{msg} (expected range {range:?})"
);
}
fn assert_has_all_fragments(array: &Int32Array) {
assert!(
array
.iter()
.any(|v| v.is_some_and(|val| (0..200).contains(&val)))
&& array
.iter()
.any(|v| v.is_some_and(|val| (200..400).contains(&val)))
&& array
.iter()
.any(|v| v.is_some_and(|val| (400..410).contains(&val)))
&& array
.iter()
.any(|v| v.is_some_and(|val| (410..420).contains(&val))),
"Expected results from all fragments"
);
}
async fn test_fragment_list_filtering(
test_ds: &TestVectorDataset,
fragments: &[Fragment],
mut build_scanner: impl FnMut(&Dataset) -> Scanner,
) {
let batch = build_scanner(&test_ds.dataset)
.try_into_batch()
.await
.unwrap();
let i_array = batch
.column_by_name("i")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_has_all_fragments(i_array);
let mut scanner = build_scanner(&test_ds.dataset);
scanner.with_fragments(vec![fragments[2].clone()]);
let batch = scanner.try_into_batch().await.unwrap();
let i_array = batch
.column_by_name("i")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_values_in_range(i_array, 400..410, "Should only get results from fragment 2");
let mut scanner = build_scanner(&test_ds.dataset);
scanner.with_fragments(vec![fragments[0].clone()]);
let batch = scanner.try_into_batch().await.unwrap();
let i_array = batch
.column_by_name("i")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_values_in_range(i_array, 0..200, "Should only get results from fragment 0");
let mut scanner = build_scanner(&test_ds.dataset);
scanner.with_fragments(vec![
fragments[0].clone(),
fragments[1].clone(),
fragments[2].clone(),
]);
let batch = scanner.try_into_batch().await.unwrap();
let i_array = batch
.column_by_name("i")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_values_in_range(
i_array,
0..410,
"Should get results from fragments 0, 1, and 2, excluding fragment 3",
);
let mut scanner = build_scanner(&test_ds.dataset);
scanner.with_fragments(vec![fragments[0].clone(), fragments[2].clone()]);
let batch = scanner.try_into_batch().await.unwrap();
let i_array = batch
.column_by_name("i")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert!(
i_array
.iter()
.all(|v| v.is_some_and(|val| (0..200).contains(&val) || (400..410).contains(&val)))
&& i_array
.iter()
.any(|v| v.is_some_and(|val| (0..200).contains(&val)))
&& i_array
.iter()
.any(|v| v.is_some_and(|val| (400..410).contains(&val))),
"Should only get results from fragment 0 (indexed) and fragment 2 (unindexed)"
);
}
#[tokio::test]
async fn test_vector_search_respects_fragment_list() {
let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, false)
.await
.unwrap();
test_ds.make_vector_index().await.unwrap();
let query: Float32Array = (0..32).map(|v| v as f32).collect();
test_ds.append_data_with_range(400, 410).await.unwrap();
test_ds.append_data_with_range(410, 420).await.unwrap();
let fragments = test_ds.dataset.fragments();
assert_eq!(fragments.len(), 4);
test_fragment_list_filtering(&test_ds, fragments, |dataset| {
let mut scanner = dataset.scan();
scanner.nearest("vec", &query, 420).unwrap();
scanner
})
.await;
}
#[tokio::test]
async fn test_fts_respects_fragment_list() {
let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, false)
.await
.unwrap();
test_ds.make_fts_index().await.unwrap();
test_ds.append_data_with_range(400, 410).await.unwrap();
test_ds.append_data_with_range(410, 420).await.unwrap();
let fragments = test_ds.dataset.fragments();
assert_eq!(fragments.len(), 4);
test_fragment_list_filtering(&test_ds, fragments, |dataset| {
let mut scanner = dataset.scan();
scanner
.full_text_search(FullTextSearchQuery::new("s-5".into()))
.unwrap();
scanner
})
.await;
}
}