use std::{
ops::Bound,
sync::{Arc, LazyLock},
};
use arrow::array::BinaryBuilder;
use arrow_array::{Array, RecordBatch, UInt32Array};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use async_recursion::async_recursion;
use async_trait::async_trait;
use datafusion_common::ScalarValue;
use datafusion_expr::{
Between, BinaryExpr, Expr, Operator, ReturnFieldArgs, ScalarUDF,
expr::{InList, Like, ScalarFunction},
};
use tokio::try_join;
use super::{
AnyQuery, BloomFilterQuery, LabelListQuery, MetricsCollector, SargableQuery, ScalarIndex,
SearchResult, TextQuery, TokenQuery,
};
#[cfg(feature = "geo")]
use super::{GeoQuery, RelationQuery};
use lance_core::{
Error, Result,
utils::mask::{NullableRowAddrMask, RowAddrMask},
};
use lance_datafusion::{expr::safe_coerce_scalar, planner::Planner};
use roaring::RoaringBitmap;
use tracing::instrument;
const MAX_DEPTH: usize = 500;
#[derive(Debug, PartialEq)]
pub struct IndexedExpression {
pub scalar_query: Option<ScalarIndexExpr>,
pub refine_expr: Option<Expr>,
}
pub trait ScalarQueryParser: std::fmt::Debug + Send + Sync {
fn visit_between(
&self,
column: &str,
low: &Bound<ScalarValue>,
high: &Bound<ScalarValue>,
) -> Option<IndexedExpression>;
fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option<IndexedExpression>;
fn visit_is_bool(&self, column: &str, value: bool) -> Option<IndexedExpression>;
fn visit_is_null(&self, column: &str) -> Option<IndexedExpression>;
fn visit_comparison(
&self,
column: &str,
value: &ScalarValue,
op: &Operator,
) -> Option<IndexedExpression>;
fn visit_scalar_function(
&self,
column: &str,
data_type: &DataType,
func: &ScalarUDF,
args: &[Expr],
) -> Option<IndexedExpression>;
fn visit_like(
&self,
_column: &str,
_like: &Like,
_pattern: &ScalarValue,
) -> Option<IndexedExpression> {
None
}
fn is_valid_reference(&self, func: &Expr, data_type: &DataType) -> Option<DataType> {
match func {
Expr::Column(_) => Some(data_type.clone()),
_ => None,
}
}
}
#[derive(Debug)]
pub struct MultiQueryParser {
parsers: Vec<Box<dyn ScalarQueryParser>>,
}
impl MultiQueryParser {
pub fn single(parser: Box<dyn ScalarQueryParser>) -> Self {
Self {
parsers: vec![parser],
}
}
pub fn add(&mut self, other: Box<dyn ScalarQueryParser>) {
self.parsers.push(other);
}
}
impl ScalarQueryParser for MultiQueryParser {
fn visit_between(
&self,
column: &str,
low: &Bound<ScalarValue>,
high: &Bound<ScalarValue>,
) -> Option<IndexedExpression> {
self.parsers
.iter()
.find_map(|parser| parser.visit_between(column, low, high))
}
fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option<IndexedExpression> {
self.parsers
.iter()
.find_map(|parser| parser.visit_in_list(column, in_list))
}
fn visit_is_bool(&self, column: &str, value: bool) -> Option<IndexedExpression> {
self.parsers
.iter()
.find_map(|parser| parser.visit_is_bool(column, value))
}
fn visit_is_null(&self, column: &str) -> Option<IndexedExpression> {
self.parsers
.iter()
.find_map(|parser| parser.visit_is_null(column))
}
fn visit_comparison(
&self,
column: &str,
value: &ScalarValue,
op: &Operator,
) -> Option<IndexedExpression> {
self.parsers
.iter()
.find_map(|parser| parser.visit_comparison(column, value, op))
}
fn visit_scalar_function(
&self,
column: &str,
data_type: &DataType,
func: &ScalarUDF,
args: &[Expr],
) -> Option<IndexedExpression> {
self.parsers
.iter()
.find_map(|parser| parser.visit_scalar_function(column, data_type, func, args))
}
fn visit_like(
&self,
column: &str,
like: &Like,
pattern: &ScalarValue,
) -> Option<IndexedExpression> {
self.parsers
.iter()
.find_map(|parser| parser.visit_like(column, like, pattern))
}
fn is_valid_reference(&self, func: &Expr, data_type: &DataType) -> Option<DataType> {
self.parsers
.iter()
.find_map(|parser| parser.is_valid_reference(func, data_type))
}
}
#[derive(Debug)]
pub struct SargableQueryParser {
index_name: String,
needs_recheck: bool,
}
impl SargableQueryParser {
pub fn new(index_name: String, needs_recheck: bool) -> Self {
Self {
index_name,
needs_recheck,
}
}
}
impl ScalarQueryParser for SargableQueryParser {
fn is_valid_reference(&self, func: &Expr, data_type: &DataType) -> Option<DataType> {
match func {
Expr::Column(_) => Some(data_type.clone()),
Expr::ScalarFunction(udf) if udf.name() == "get_field" => Some(data_type.clone()),
_ => None,
}
}
fn visit_between(
&self,
column: &str,
low: &Bound<ScalarValue>,
high: &Bound<ScalarValue>,
) -> Option<IndexedExpression> {
if let Bound::Included(val) | Bound::Excluded(val) = low
&& val.is_null()
{
return None;
}
if let Bound::Included(val) | Bound::Excluded(val) = high
&& val.is_null()
{
return None;
}
let query = SargableQuery::Range(low.clone(), high.clone());
Some(IndexedExpression::index_query_with_recheck(
column.to_string(),
self.index_name.clone(),
Arc::new(query),
self.needs_recheck,
))
}
fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option<IndexedExpression> {
if in_list.iter().any(|val| val.is_null()) {
return None;
}
let query = SargableQuery::IsIn(in_list.to_vec());
Some(IndexedExpression::index_query_with_recheck(
column.to_string(),
self.index_name.clone(),
Arc::new(query),
self.needs_recheck,
))
}
fn visit_is_bool(&self, column: &str, value: bool) -> Option<IndexedExpression> {
Some(IndexedExpression::index_query_with_recheck(
column.to_string(),
self.index_name.clone(),
Arc::new(SargableQuery::Equals(ScalarValue::Boolean(Some(value)))),
self.needs_recheck,
))
}
fn visit_is_null(&self, column: &str) -> Option<IndexedExpression> {
Some(IndexedExpression::index_query_with_recheck(
column.to_string(),
self.index_name.clone(),
Arc::new(SargableQuery::IsNull()),
self.needs_recheck,
))
}
fn visit_comparison(
&self,
column: &str,
value: &ScalarValue,
op: &Operator,
) -> Option<IndexedExpression> {
if value.is_null() {
return None;
}
let query = match op {
Operator::Lt => SargableQuery::Range(Bound::Unbounded, Bound::Excluded(value.clone())),
Operator::LtEq => {
SargableQuery::Range(Bound::Unbounded, Bound::Included(value.clone()))
}
Operator::Gt => SargableQuery::Range(Bound::Excluded(value.clone()), Bound::Unbounded),
Operator::GtEq => {
SargableQuery::Range(Bound::Included(value.clone()), Bound::Unbounded)
}
Operator::Eq => SargableQuery::Equals(value.clone()),
Operator::NotEq => SargableQuery::Equals(value.clone()),
_ => unreachable!(),
};
Some(IndexedExpression::index_query_with_recheck(
column.to_string(),
self.index_name.clone(),
Arc::new(query),
self.needs_recheck,
))
}
fn visit_scalar_function(
&self,
column: &str,
_data_type: &DataType,
func: &ScalarUDF,
args: &[Expr],
) -> Option<IndexedExpression> {
if func.name() == "starts_with" && args.len() == 2 {
let prefix = match &args[1] {
Expr::Literal(ScalarValue::Utf8(Some(s)), _) => ScalarValue::Utf8(Some(s.clone())),
Expr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => {
ScalarValue::LargeUtf8(Some(s.clone()))
}
_ => return None,
};
let query = SargableQuery::LikePrefix(prefix);
return Some(IndexedExpression::index_query_with_recheck(
column.to_string(),
self.index_name.clone(),
Arc::new(query),
self.needs_recheck,
));
}
None
}
fn visit_like(
&self,
column: &str,
like: &Like,
pattern: &ScalarValue,
) -> Option<IndexedExpression> {
if like.case_insensitive {
return None;
}
let pattern_str = match pattern {
ScalarValue::Utf8(Some(s)) => s.as_str(),
ScalarValue::LargeUtf8(Some(s)) => s.as_str(),
_ => return None,
};
let (prefix, needs_refine) = extract_like_leading_prefix(pattern_str, like.escape_char)?;
let prefix_value = match pattern {
ScalarValue::Utf8(_) => ScalarValue::Utf8(Some(prefix)),
ScalarValue::LargeUtf8(_) => ScalarValue::LargeUtf8(Some(prefix)),
_ => return None,
};
let query = SargableQuery::LikePrefix(prefix_value);
let scalar_query = Some(ScalarIndexExpr::Query(ScalarIndexSearch {
column: column.to_string(),
index_name: self.index_name.clone(),
query: Arc::new(query),
needs_recheck: self.needs_recheck,
}));
let refine_expr = if needs_refine {
Some(Expr::Like(like.clone()))
} else {
None
};
Some(IndexedExpression {
scalar_query,
refine_expr,
})
}
}
fn extract_like_leading_prefix(pattern: &str, escape_char: Option<char>) -> Option<(String, bool)> {
let chars: Vec<char> = pattern.chars().collect();
let len = chars.len();
if len == 0 {
return None;
}
let effective_escape_char = escape_char.or(Some('\\'));
let is_escaped = |i: usize| -> bool {
if let Some(esc) = effective_escape_char {
if i > 0 && chars[i - 1] == esc {
if i >= 2 && chars[i - 2] == esc {
false } else {
true }
} else {
false
}
} else {
false
}
};
let has_wildcard = chars.iter().enumerate().any(|(i, &c)| {
if c != '%' && c != '_' {
return false;
}
!is_escaped(i)
});
if !has_wildcard {
return None; }
if chars[0] == '%' || chars[0] == '_' {
return None; }
let mut prefix = String::new();
let mut i = 0;
let mut found_wildcard = false;
while i < len {
let c = chars[i];
if let Some(esc) = effective_escape_char
&& c == esc
&& i + 1 < len
{
let next = chars[i + 1];
if next == '%' || next == '_' || next == esc {
prefix.push(next);
i += 2;
continue;
}
}
if c == '%' || c == '_' {
found_wildcard = true;
break;
}
prefix.push(c);
i += 1;
}
if prefix.is_empty() {
return None;
}
let needs_refine = if found_wildcard && i < len {
if chars[i] == '%' && i + 1 == len {
false
} else {
true
}
} else {
false
};
Some((prefix, needs_refine))
}
#[derive(Debug)]
pub struct BloomFilterQueryParser {
index_name: String,
needs_recheck: bool,
}
impl BloomFilterQueryParser {
pub fn new(index_name: String, needs_recheck: bool) -> Self {
Self {
index_name,
needs_recheck,
}
}
}
impl ScalarQueryParser for BloomFilterQueryParser {
fn visit_between(
&self,
_: &str,
_: &Bound<ScalarValue>,
_: &Bound<ScalarValue>,
) -> Option<IndexedExpression> {
None
}
fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option<IndexedExpression> {
let query = BloomFilterQuery::IsIn(in_list.to_vec());
Some(IndexedExpression::index_query_with_recheck(
column.to_string(),
self.index_name.clone(),
Arc::new(query),
self.needs_recheck,
))
}
fn visit_is_bool(&self, column: &str, value: bool) -> Option<IndexedExpression> {
Some(IndexedExpression::index_query_with_recheck(
column.to_string(),
self.index_name.clone(),
Arc::new(BloomFilterQuery::Equals(ScalarValue::Boolean(Some(value)))),
self.needs_recheck,
))
}
fn visit_is_null(&self, column: &str) -> Option<IndexedExpression> {
Some(IndexedExpression::index_query_with_recheck(
column.to_string(),
self.index_name.clone(),
Arc::new(BloomFilterQuery::IsNull()),
self.needs_recheck,
))
}
fn visit_comparison(
&self,
column: &str,
value: &ScalarValue,
op: &Operator,
) -> Option<IndexedExpression> {
let query = match op {
Operator::Eq => BloomFilterQuery::Equals(value.clone()),
Operator::NotEq => BloomFilterQuery::Equals(value.clone()),
_ => return None,
};
Some(IndexedExpression::index_query_with_recheck(
column.to_string(),
self.index_name.clone(),
Arc::new(query),
self.needs_recheck,
))
}
fn visit_scalar_function(
&self,
_: &str,
_: &DataType,
_: &ScalarUDF,
_: &[Expr],
) -> Option<IndexedExpression> {
None
}
}
#[derive(Debug)]
pub struct LabelListQueryParser {
index_name: String,
}
impl LabelListQueryParser {
pub fn new(index_name: String) -> Self {
Self { index_name }
}
}
impl ScalarQueryParser for LabelListQueryParser {
fn visit_between(
&self,
_: &str,
_: &Bound<ScalarValue>,
_: &Bound<ScalarValue>,
) -> Option<IndexedExpression> {
None
}
fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option<IndexedExpression> {
None
}
fn visit_is_bool(&self, _: &str, _: bool) -> Option<IndexedExpression> {
None
}
fn visit_is_null(&self, _: &str) -> Option<IndexedExpression> {
None
}
fn visit_comparison(
&self,
_: &str,
_: &ScalarValue,
_: &Operator,
) -> Option<IndexedExpression> {
None
}
fn visit_scalar_function(
&self,
column: &str,
data_type: &DataType,
func: &ScalarUDF,
args: &[Expr],
) -> Option<IndexedExpression> {
if args.len() != 2 {
return None;
}
if func.name() == "array_has" {
let inner_type = match data_type {
DataType::List(field) | DataType::LargeList(field) => field.data_type(),
_ => return None,
};
let scalar = maybe_scalar(&args[1], inner_type)?;
if scalar.is_null() {
return None;
}
let query = LabelListQuery::HasAnyLabel(vec![scalar]);
return Some(IndexedExpression::index_query(
column.to_string(),
self.index_name.clone(),
Arc::new(query),
));
}
let label_list = maybe_scalar(&args[1], data_type)?;
if let ScalarValue::List(list_arr) = label_list {
let list_values = list_arr.values();
if list_values.is_empty() {
return None;
}
let mut scalars = Vec::with_capacity(list_values.len());
for idx in 0..list_values.len() {
scalars.push(ScalarValue::try_from_array(list_values.as_ref(), idx).ok()?);
}
if func.name() == "array_has_all" {
let query = LabelListQuery::HasAllLabels(scalars);
Some(IndexedExpression::index_query(
column.to_string(),
self.index_name.clone(),
Arc::new(query),
))
} else if func.name() == "array_has_any" {
let query = LabelListQuery::HasAnyLabel(scalars);
Some(IndexedExpression::index_query(
column.to_string(),
self.index_name.clone(),
Arc::new(query),
))
} else {
None
}
} else {
None
}
}
}
#[derive(Debug, Clone)]
pub struct TextQueryParser {
index_name: String,
needs_recheck: bool,
}
impl TextQueryParser {
pub fn new(index_name: String, needs_recheck: bool) -> Self {
Self {
index_name,
needs_recheck,
}
}
}
impl ScalarQueryParser for TextQueryParser {
fn visit_between(
&self,
_: &str,
_: &Bound<ScalarValue>,
_: &Bound<ScalarValue>,
) -> Option<IndexedExpression> {
None
}
fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option<IndexedExpression> {
None
}
fn visit_is_bool(&self, _: &str, _: bool) -> Option<IndexedExpression> {
None
}
fn visit_is_null(&self, _: &str) -> Option<IndexedExpression> {
None
}
fn visit_comparison(
&self,
_: &str,
_: &ScalarValue,
_: &Operator,
) -> Option<IndexedExpression> {
None
}
fn visit_scalar_function(
&self,
column: &str,
data_type: &DataType,
func: &ScalarUDF,
args: &[Expr],
) -> Option<IndexedExpression> {
if args.len() != 2 {
return None;
}
let scalar = maybe_scalar(&args[1], data_type)?;
match scalar {
ScalarValue::Utf8(Some(scalar_str)) | ScalarValue::LargeUtf8(Some(scalar_str)) => {
if func.name() == "contains" {
let query = TextQuery::StringContains(scalar_str);
Some(IndexedExpression::index_query_with_recheck(
column.to_string(),
self.index_name.clone(),
Arc::new(query),
self.needs_recheck,
))
} else {
None
}
}
_ => {
None
}
}
}
}
#[derive(Debug, Clone)]
pub struct FtsQueryParser {
index_name: String,
}
impl FtsQueryParser {
pub fn new(name: String) -> Self {
Self { index_name: name }
}
}
impl ScalarQueryParser for FtsQueryParser {
fn visit_between(
&self,
_: &str,
_: &Bound<ScalarValue>,
_: &Bound<ScalarValue>,
) -> Option<IndexedExpression> {
None
}
fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option<IndexedExpression> {
None
}
fn visit_is_bool(&self, _: &str, _: bool) -> Option<IndexedExpression> {
None
}
fn visit_is_null(&self, _: &str) -> Option<IndexedExpression> {
None
}
fn visit_comparison(
&self,
_: &str,
_: &ScalarValue,
_: &Operator,
) -> Option<IndexedExpression> {
None
}
fn visit_scalar_function(
&self,
column: &str,
data_type: &DataType,
func: &ScalarUDF,
args: &[Expr],
) -> Option<IndexedExpression> {
if args.len() != 2 {
return None;
}
let scalar = maybe_scalar(&args[1], data_type)?;
if let ScalarValue::Utf8(Some(scalar_str)) = scalar
&& func.name() == "contains_tokens"
{
let query = TokenQuery::TokensContains(scalar_str);
return Some(IndexedExpression::index_query(
column.to_string(),
self.index_name.clone(),
Arc::new(query),
));
}
None
}
}
#[cfg(feature = "geo")]
#[derive(Debug, Clone)]
pub struct GeoQueryParser {
index_name: String,
}
#[cfg(feature = "geo")]
impl GeoQueryParser {
pub fn new(index_name: String) -> Self {
Self { index_name }
}
}
#[cfg(feature = "geo")]
impl ScalarQueryParser for GeoQueryParser {
fn visit_between(
&self,
_: &str,
_: &Bound<ScalarValue>,
_: &Bound<ScalarValue>,
) -> Option<IndexedExpression> {
None
}
fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option<IndexedExpression> {
None
}
fn visit_is_bool(&self, _: &str, _: bool) -> Option<IndexedExpression> {
None
}
fn visit_is_null(&self, column: &str) -> Option<IndexedExpression> {
Some(IndexedExpression::index_query_with_recheck(
column.to_string(),
self.index_name.clone(),
Arc::new(GeoQuery::IsNull),
true,
))
}
fn visit_comparison(
&self,
_: &str,
_: &ScalarValue,
_: &Operator,
) -> Option<IndexedExpression> {
None
}
fn visit_scalar_function(
&self,
column: &str,
_data_type: &DataType,
func: &ScalarUDF,
args: &[Expr],
) -> Option<IndexedExpression> {
if (func.name() == "st_intersects"
|| func.name() == "st_contains"
|| func.name() == "st_within"
|| func.name() == "st_touches"
|| func.name() == "st_crosses"
|| func.name() == "st_overlaps"
|| func.name() == "st_covers"
|| func.name() == "st_coveredby")
&& args.len() == 2
{
let left_arg = &args[0];
let right_arg = &args[1];
return match (left_arg, right_arg) {
(Expr::Literal(left_value, metadata), Expr::Column(_)) => {
let mut field = Field::new("_geo", left_value.data_type(), false);
if let Some(metadata) = metadata {
field = field.with_metadata(metadata.to_hashmap());
}
let query = GeoQuery::IntersectQuery(RelationQuery {
value: left_value.clone(),
field,
});
Some(IndexedExpression::index_query_with_recheck(
column.to_string(),
self.index_name.clone(),
Arc::new(query),
true,
))
}
(Expr::Column(_), Expr::Literal(right_value, metadata)) => {
let mut field = Field::new("_geo", right_value.data_type(), false);
if let Some(metadata) = metadata {
field = field.with_metadata(metadata.to_hashmap());
}
let query = GeoQuery::IntersectQuery(RelationQuery {
value: right_value.clone(),
field,
});
Some(IndexedExpression::index_query_with_recheck(
column.to_string(),
self.index_name.clone(),
Arc::new(query),
true,
))
}
_ => None,
};
}
None
}
}
impl IndexedExpression {
fn refine_only(refine_expr: Expr) -> Self {
Self {
scalar_query: None,
refine_expr: Some(refine_expr),
}
}
fn index_query(column: String, index_name: String, query: Arc<dyn AnyQuery>) -> Self {
Self {
scalar_query: Some(ScalarIndexExpr::Query(ScalarIndexSearch {
column,
index_name,
query,
needs_recheck: false, })),
refine_expr: None,
}
}
fn index_query_with_recheck(
column: String,
index_name: String,
query: Arc<dyn AnyQuery>,
needs_recheck: bool,
) -> Self {
Self {
scalar_query: Some(ScalarIndexExpr::Query(ScalarIndexSearch {
column,
index_name,
query,
needs_recheck,
})),
refine_expr: None,
}
}
fn maybe_not(self) -> Option<Self> {
match (self.scalar_query, self.refine_expr) {
(Some(_), Some(_)) => None,
(Some(scalar_query), None) => {
if scalar_query.needs_recheck() {
return None;
}
Some(Self {
scalar_query: Some(ScalarIndexExpr::Not(Box::new(scalar_query))),
refine_expr: None,
})
}
(None, Some(refine_expr)) => Some(Self {
scalar_query: None,
refine_expr: Some(Expr::Not(Box::new(refine_expr))),
}),
(None, None) => panic!("Empty node should not occur"),
}
}
fn and(self, other: Self) -> Self {
let scalar_query = match (self.scalar_query, other.scalar_query) {
(Some(scalar_query), Some(other_scalar_query)) => Some(ScalarIndexExpr::And(
Box::new(scalar_query),
Box::new(other_scalar_query),
)),
(Some(scalar_query), None) => Some(scalar_query),
(None, Some(scalar_query)) => Some(scalar_query),
(None, None) => None,
};
let refine_expr = match (self.refine_expr, other.refine_expr) {
(Some(refine_expr), Some(other_refine_expr)) => {
Some(refine_expr.and(other_refine_expr))
}
(Some(refine_expr), None) => Some(refine_expr),
(None, Some(refine_expr)) => Some(refine_expr),
(None, None) => None,
};
Self {
scalar_query,
refine_expr,
}
}
fn maybe_or(self, other: Self) -> Option<Self> {
let scalar_query = self.scalar_query?;
let other_scalar_query = other.scalar_query?;
let scalar_query = Some(ScalarIndexExpr::Or(
Box::new(scalar_query),
Box::new(other_scalar_query),
));
let refine_expr = match (self.refine_expr, other.refine_expr) {
(Some(_), Some(_)) => {
return None;
}
(Some(_), None) => {
return None;
}
(None, Some(_)) => {
return None;
}
(None, None) => None,
};
Some(Self {
scalar_query,
refine_expr,
})
}
fn refine(self, expr: Expr) -> Self {
match self.refine_expr {
Some(refine_expr) => Self {
scalar_query: self.scalar_query,
refine_expr: Some(refine_expr.and(expr)),
},
None => Self {
scalar_query: self.scalar_query,
refine_expr: Some(expr),
},
}
}
}
#[async_trait]
pub trait ScalarIndexLoader: Send + Sync {
async fn load_index(
&self,
column: &str,
index_name: &str,
metrics: &dyn MetricsCollector,
) -> Result<Arc<dyn ScalarIndex>>;
}
#[derive(Debug, Clone)]
pub struct ScalarIndexSearch {
pub column: String,
pub index_name: String,
pub query: Arc<dyn AnyQuery>,
pub needs_recheck: bool,
}
impl PartialEq for ScalarIndexSearch {
fn eq(&self, other: &Self) -> bool {
self.column == other.column
&& self.index_name == other.index_name
&& self.query.as_ref().eq(other.query.as_ref())
}
}
#[derive(Debug, Clone)]
pub enum ScalarIndexExpr {
Not(Box<Self>),
And(Box<Self>, Box<Self>),
Or(Box<Self>, Box<Self>),
Query(ScalarIndexSearch),
}
impl PartialEq for ScalarIndexExpr {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Not(l0), Self::Not(r0)) => l0 == r0,
(Self::And(l0, l1), Self::And(r0, r1)) => l0 == r0 && l1 == r1,
(Self::Or(l0, l1), Self::Or(r0, r1)) => l0 == r0 && l1 == r1,
(Self::Query(l_search), Self::Query(r_search)) => l_search == r_search,
_ => false,
}
}
}
impl std::fmt::Display for ScalarIndexExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Not(inner) => write!(f, "NOT({})", inner),
Self::And(lhs, rhs) => write!(f, "AND({},{})", lhs, rhs),
Self::Or(lhs, rhs) => write!(f, "OR({},{})", lhs, rhs),
Self::Query(search) => write!(
f,
"[{}]@{}",
search.query.format(&search.column),
search.index_name
),
}
}
}
pub static INDEX_EXPR_RESULT_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
Arc::new(Schema::new(vec![
Field::new("result".to_string(), DataType::Binary, true),
Field::new("discriminant".to_string(), DataType::UInt32, true),
Field::new("fragments_covered".to_string(), DataType::Binary, true),
]))
});
#[derive(Debug)]
enum NullableIndexExprResult {
Exact(NullableRowAddrMask),
AtMost(NullableRowAddrMask),
AtLeast(NullableRowAddrMask),
}
impl From<SearchResult> for NullableIndexExprResult {
fn from(result: SearchResult) -> Self {
match result {
SearchResult::Exact(mask) => Self::Exact(NullableRowAddrMask::AllowList(mask)),
SearchResult::AtMost(mask) => Self::AtMost(NullableRowAddrMask::AllowList(mask)),
SearchResult::AtLeast(mask) => Self::AtLeast(NullableRowAddrMask::AllowList(mask)),
}
}
}
impl std::ops::BitAnd<Self> for NullableIndexExprResult {
type Output = Self;
fn bitand(self, rhs: Self) -> Self {
match (self, rhs) {
(Self::Exact(lhs), Self::Exact(rhs)) => Self::Exact(lhs & rhs),
(Self::Exact(lhs), Self::AtMost(rhs)) | (Self::AtMost(lhs), Self::Exact(rhs)) => {
Self::AtMost(lhs & rhs)
}
(Self::Exact(exact), Self::AtLeast(_)) | (Self::AtLeast(_), Self::Exact(exact)) => {
Self::AtMost(exact)
}
(Self::AtMost(lhs), Self::AtMost(rhs)) => Self::AtMost(lhs & rhs),
(Self::AtLeast(lhs), Self::AtLeast(rhs)) => Self::AtLeast(lhs & rhs),
(Self::AtMost(most), Self::AtLeast(_)) | (Self::AtLeast(_), Self::AtMost(most)) => {
Self::AtMost(most)
}
}
}
}
impl std::ops::BitOr<Self> for NullableIndexExprResult {
type Output = Self;
fn bitor(self, rhs: Self) -> Self {
match (self, rhs) {
(Self::Exact(lhs), Self::Exact(rhs)) => Self::Exact(lhs | rhs),
(Self::Exact(lhs), Self::AtMost(rhs)) | (Self::AtMost(rhs), Self::Exact(lhs)) => {
Self::AtMost(lhs | rhs)
}
(Self::Exact(lhs), Self::AtLeast(rhs)) | (Self::AtLeast(rhs), Self::Exact(lhs)) => {
Self::AtLeast(lhs | rhs)
}
(Self::AtMost(lhs), Self::AtMost(rhs)) => Self::AtMost(lhs | rhs),
(Self::AtLeast(lhs), Self::AtLeast(rhs)) => Self::AtLeast(lhs | rhs),
(Self::AtMost(_), Self::AtLeast(least)) | (Self::AtLeast(least), Self::AtMost(_)) => {
Self::AtLeast(least)
}
}
}
}
impl NullableIndexExprResult {
pub fn drop_nulls(self) -> IndexExprResult {
match self {
Self::Exact(mask) => IndexExprResult::Exact(mask.drop_nulls()),
Self::AtMost(mask) => IndexExprResult::AtMost(mask.drop_nulls()),
Self::AtLeast(mask) => IndexExprResult::AtLeast(mask.drop_nulls()),
}
}
}
#[derive(Debug)]
pub enum IndexExprResult {
Exact(RowAddrMask),
AtMost(RowAddrMask),
AtLeast(RowAddrMask),
}
impl IndexExprResult {
pub fn row_addr_mask(&self) -> &RowAddrMask {
match self {
Self::Exact(mask) => mask,
Self::AtMost(mask) => mask,
Self::AtLeast(mask) => mask,
}
}
pub fn discriminant(&self) -> u32 {
match self {
Self::Exact(_) => 0,
Self::AtMost(_) => 1,
Self::AtLeast(_) => 2,
}
}
pub fn from_parts(mask: RowAddrMask, discriminant: u32) -> Result<Self> {
match discriminant {
0 => Ok(Self::Exact(mask)),
1 => Ok(Self::AtMost(mask)),
2 => Ok(Self::AtLeast(mask)),
_ => Err(Error::invalid_input_source(
format!("Invalid IndexExprResult discriminant: {}", discriminant).into(),
)),
}
}
#[instrument(skip_all)]
pub fn serialize_to_arrow(
&self,
fragments_covered_by_result: &RoaringBitmap,
) -> Result<RecordBatch> {
let row_addr_mask = self.row_addr_mask();
let row_addr_mask_arr = row_addr_mask.into_arrow()?;
let discriminant = self.discriminant();
let discriminant_arr =
Arc::new(UInt32Array::from(vec![discriminant, discriminant])) as Arc<dyn Array>;
let mut fragments_covered_builder = BinaryBuilder::new();
let fragments_covered_bytes_len = fragments_covered_by_result.serialized_size();
let mut fragments_covered_bytes = Vec::with_capacity(fragments_covered_bytes_len);
fragments_covered_by_result.serialize_into(&mut fragments_covered_bytes)?;
fragments_covered_builder.append_value(fragments_covered_bytes);
fragments_covered_builder.append_null();
let fragments_covered_arr = Arc::new(fragments_covered_builder.finish()) as Arc<dyn Array>;
Ok(RecordBatch::try_new(
INDEX_EXPR_RESULT_SCHEMA.clone(),
vec![
Arc::new(row_addr_mask_arr),
Arc::new(discriminant_arr),
Arc::new(fragments_covered_arr),
],
)?)
}
}
impl ScalarIndexExpr {
#[async_recursion]
async fn evaluate_impl(
&self,
index_loader: &dyn ScalarIndexLoader,
metrics: &dyn MetricsCollector,
) -> Result<NullableIndexExprResult> {
match self {
Self::Not(inner) => {
let result = inner.evaluate_impl(index_loader, metrics).await?;
Ok(match result {
NullableIndexExprResult::Exact(mask) => NullableIndexExprResult::Exact(!mask),
NullableIndexExprResult::AtMost(mask) => {
NullableIndexExprResult::AtLeast(!mask)
}
NullableIndexExprResult::AtLeast(mask) => {
NullableIndexExprResult::AtMost(!mask)
}
})
}
Self::And(lhs, rhs) => {
let lhs_result = lhs.evaluate_impl(index_loader, metrics);
let rhs_result = rhs.evaluate_impl(index_loader, metrics);
let (lhs_result, rhs_result) = try_join!(lhs_result, rhs_result)?;
Ok(lhs_result & rhs_result)
}
Self::Or(lhs, rhs) => {
let lhs_result = lhs.evaluate_impl(index_loader, metrics);
let rhs_result = rhs.evaluate_impl(index_loader, metrics);
let (lhs_result, rhs_result) = try_join!(lhs_result, rhs_result)?;
Ok(lhs_result | rhs_result)
}
Self::Query(search) => {
let index = index_loader
.load_index(&search.column, &search.index_name, metrics)
.await?;
let search_result = index.search(search.query.as_ref(), metrics).await?;
Ok(search_result.into())
}
}
}
#[instrument(level = "debug", skip_all)]
pub async fn evaluate(
&self,
index_loader: &dyn ScalarIndexLoader,
metrics: &dyn MetricsCollector,
) -> Result<IndexExprResult> {
Ok(self
.evaluate_impl(index_loader, metrics)
.await?
.drop_nulls())
}
pub fn to_expr(&self) -> Expr {
match self {
Self::Not(inner) => Expr::Not(inner.to_expr().into()),
Self::And(lhs, rhs) => {
let lhs = lhs.to_expr();
let rhs = rhs.to_expr();
lhs.and(rhs)
}
Self::Or(lhs, rhs) => {
let lhs = lhs.to_expr();
let rhs = rhs.to_expr();
lhs.or(rhs)
}
Self::Query(search) => search.query.to_expr(search.column.clone()),
}
}
pub fn needs_recheck(&self) -> bool {
match self {
Self::Not(inner) => inner.needs_recheck(),
Self::And(lhs, rhs) | Self::Or(lhs, rhs) => lhs.needs_recheck() || rhs.needs_recheck(),
Self::Query(search) => search.needs_recheck,
}
}
}
fn maybe_column(expr: &Expr) -> Option<&str> {
match expr {
Expr::Column(col) => Some(&col.name),
_ => None,
}
}
fn extract_nested_column_path(expr: &Expr) -> Option<String> {
let mut current_expr = expr;
let mut parts = Vec::new();
loop {
match current_expr {
Expr::ScalarFunction(udf) if udf.name() == "get_field" => {
if udf.args.len() != 2 {
return None;
}
if let Expr::Literal(ScalarValue::Utf8(Some(field_name)), _) = &udf.args[1] {
parts.push(field_name.clone());
} else {
return None;
}
current_expr = &udf.args[0];
}
Expr::Column(col) => {
parts.push(col.name.clone());
break;
}
_ => {
return None;
}
}
}
parts.reverse();
let field_refs: Vec<&str> = parts.iter().map(|s| s.as_str()).collect();
Some(lance_core::datatypes::format_field_path(&field_refs))
}
fn maybe_indexed_column<'b>(
expr: &Expr,
index_info: &'b dyn IndexInformationProvider,
) -> Option<(String, DataType, &'b dyn ScalarQueryParser)> {
if let Some(nested_path) = extract_nested_column_path(expr)
&& let Some((data_type, parser)) = index_info.get_index(&nested_path)
&& let Some(data_type) = parser.is_valid_reference(expr, data_type)
{
return Some((nested_path, data_type, parser));
}
match expr {
Expr::Column(col) => {
let col = col.name.as_str();
let (data_type, parser) = index_info.get_index(col)?;
if let Some(data_type) = parser.is_valid_reference(expr, data_type) {
Some((col.to_string(), data_type, parser))
} else {
None
}
}
Expr::ScalarFunction(udf) => {
if udf.args.is_empty() {
return None;
}
let col = maybe_column(&udf.args[0])?;
let (data_type, parser) = index_info.get_index(col)?;
if let Some(data_type) = parser.is_valid_reference(expr, data_type) {
Some((col.to_string(), data_type, parser))
} else {
None
}
}
_ => None,
}
}
fn maybe_scalar(expr: &Expr, expected_type: &DataType) -> Option<ScalarValue> {
match expr {
Expr::Literal(value, _) => safe_coerce_scalar(value, expected_type),
Expr::Cast(cast) => match cast.expr.as_ref() {
Expr::Literal(value, _) => {
let casted = value.cast_to(&cast.data_type).ok()?;
safe_coerce_scalar(&casted, expected_type)
}
_ => None,
},
Expr::ScalarFunction(scalar_function) => {
if scalar_function.name() == "arrow_cast" {
if scalar_function.args.len() != 2 {
return None;
}
match (&scalar_function.args[0], &scalar_function.args[1]) {
(Expr::Literal(value, _), Expr::Literal(cast_type, _)) => {
let target_type = scalar_function
.func
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[
Arc::new(Field::new("expression", value.data_type(), false)),
Arc::new(Field::new("datatype", cast_type.data_type(), false)),
],
scalar_arguments: &[Some(value), Some(cast_type)],
})
.ok()?;
let casted = value.cast_to(target_type.data_type()).ok()?;
safe_coerce_scalar(&casted, expected_type)
}
_ => None,
}
} else {
None
}
}
_ => None,
}
}
fn maybe_scalar_list(exprs: &Vec<Expr>, expected_type: &DataType) -> Option<Vec<ScalarValue>> {
let mut scalar_values = Vec::with_capacity(exprs.len());
for expr in exprs {
match maybe_scalar(expr, expected_type) {
Some(scalar_val) => {
scalar_values.push(scalar_val);
}
None => {
return None;
}
}
}
Some(scalar_values)
}
fn visit_between(
between: &Between,
index_info: &dyn IndexInformationProvider,
) -> Option<IndexedExpression> {
let (column, col_type, query_parser) = maybe_indexed_column(&between.expr, index_info)?;
let low = maybe_scalar(&between.low, &col_type)?;
let high = maybe_scalar(&between.high, &col_type)?;
let indexed_expr =
query_parser.visit_between(&column, &Bound::Included(low), &Bound::Included(high))?;
if between.negated {
indexed_expr.maybe_not()
} else {
Some(indexed_expr)
}
}
fn visit_in_list(
in_list: &InList,
index_info: &dyn IndexInformationProvider,
) -> Option<IndexedExpression> {
let (column, col_type, query_parser) = maybe_indexed_column(&in_list.expr, index_info)?;
let values = maybe_scalar_list(&in_list.list, &col_type)?;
let indexed_expr = query_parser.visit_in_list(&column, &values)?;
if in_list.negated {
indexed_expr.maybe_not()
} else {
Some(indexed_expr)
}
}
fn visit_is_bool(
expr: &Expr,
index_info: &dyn IndexInformationProvider,
value: bool,
) -> Option<IndexedExpression> {
let (column, col_type, query_parser) = maybe_indexed_column(expr, index_info)?;
if col_type != DataType::Boolean {
None
} else {
query_parser.visit_is_bool(&column, value)
}
}
fn visit_column(
col: &Expr,
index_info: &dyn IndexInformationProvider,
) -> Option<IndexedExpression> {
let (column, col_type, query_parser) = maybe_indexed_column(col, index_info)?;
if col_type != DataType::Boolean {
None
} else {
query_parser.visit_is_bool(&column, true)
}
}
fn visit_is_null(
expr: &Expr,
index_info: &dyn IndexInformationProvider,
negated: bool,
) -> Option<IndexedExpression> {
let (column, _, query_parser) = maybe_indexed_column(expr, index_info)?;
let indexed_expr = query_parser.visit_is_null(&column)?;
if negated {
indexed_expr.maybe_not()
} else {
Some(indexed_expr)
}
}
fn visit_not(
expr: &Expr,
index_info: &dyn IndexInformationProvider,
depth: usize,
) -> Result<Option<IndexedExpression>> {
let node = visit_node(expr, index_info, depth + 1)?;
Ok(node.and_then(|node| node.maybe_not()))
}
fn visit_comparison(
expr: &BinaryExpr,
index_info: &dyn IndexInformationProvider,
) -> Option<IndexedExpression> {
let left_col = maybe_indexed_column(&expr.left, index_info);
if let Some((column, col_type, query_parser)) = left_col {
let scalar = maybe_scalar(&expr.right, &col_type)?;
query_parser.visit_comparison(&column, &scalar, &expr.op)
} else {
None
}
}
fn maybe_range(
expr: &BinaryExpr,
index_info: &dyn IndexInformationProvider,
) -> Option<IndexedExpression> {
let left_expr = match expr.left.as_ref() {
Expr::BinaryExpr(binary_expr) => Some(binary_expr),
_ => None,
}?;
let right_expr = match expr.right.as_ref() {
Expr::BinaryExpr(binary_expr) => Some(binary_expr),
_ => None,
}?;
let (left_col, dt, parser) = maybe_indexed_column(&left_expr.left, index_info)?;
let right_col = maybe_column(&right_expr.left)?;
if left_col != right_col {
return None;
}
let left_value = maybe_scalar(&left_expr.right, &dt)?;
let right_value = maybe_scalar(&right_expr.right, &dt)?;
let (low, high) = match (left_expr.op, right_expr.op) {
(Operator::GtEq, Operator::LtEq) => {
(Bound::Included(left_value), Bound::Included(right_value))
}
(Operator::GtEq, Operator::Lt) => {
(Bound::Included(left_value), Bound::Excluded(right_value))
}
(Operator::Gt, Operator::LtEq) => {
(Bound::Excluded(left_value), Bound::Included(right_value))
}
(Operator::Gt, Operator::Lt) => (Bound::Excluded(left_value), Bound::Excluded(right_value)),
(Operator::LtEq, Operator::GtEq) => {
(Bound::Included(right_value), Bound::Included(left_value))
}
(Operator::LtEq, Operator::Gt) => {
(Bound::Included(right_value), Bound::Excluded(left_value))
}
(Operator::Lt, Operator::GtEq) => {
(Bound::Excluded(right_value), Bound::Included(left_value))
}
(Operator::Lt, Operator::Gt) => (Bound::Excluded(right_value), Bound::Excluded(left_value)),
_ => return None,
};
parser.visit_between(&left_col, &low, &high)
}
fn visit_and(
expr: &BinaryExpr,
index_info: &dyn IndexInformationProvider,
depth: usize,
) -> Result<Option<IndexedExpression>> {
if let Some(range_expr) = maybe_range(expr, index_info) {
return Ok(Some(range_expr));
}
let left = visit_node(&expr.left, index_info, depth + 1)?;
let right = visit_node(&expr.right, index_info, depth + 1)?;
Ok(match (left, right) {
(Some(left), Some(right)) => Some(left.and(right)),
(Some(left), None) => Some(left.refine((*expr.right).clone())),
(None, Some(right)) => Some(right.refine((*expr.left).clone())),
(None, None) => None,
})
}
fn visit_or(
expr: &BinaryExpr,
index_info: &dyn IndexInformationProvider,
depth: usize,
) -> Result<Option<IndexedExpression>> {
let left = visit_node(&expr.left, index_info, depth + 1)?;
let right = visit_node(&expr.right, index_info, depth + 1)?;
Ok(match (left, right) {
(Some(left), Some(right)) => left.maybe_or(right),
(Some(_), None) => None,
(None, Some(_)) => None,
(None, None) => None,
})
}
fn visit_binary_expr(
expr: &BinaryExpr,
index_info: &dyn IndexInformationProvider,
depth: usize,
) -> Result<Option<IndexedExpression>> {
match &expr.op {
Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq | Operator::Eq => {
Ok(visit_comparison(expr, index_info))
}
Operator::NotEq => Ok(visit_comparison(expr, index_info).and_then(|node| node.maybe_not())),
Operator::And => visit_and(expr, index_info, depth),
Operator::Or => visit_or(expr, index_info, depth),
_ => Ok(None),
}
}
fn visit_scalar_fn(
scalar_fn: &ScalarFunction,
index_info: &dyn IndexInformationProvider,
) -> Option<IndexedExpression> {
if scalar_fn.args.is_empty() {
return None;
}
let (col, data_type, query_parser) = maybe_indexed_column(&scalar_fn.args[0], index_info)?;
query_parser.visit_scalar_function(&col, &data_type, &scalar_fn.func, &scalar_fn.args)
}
fn visit_like_expr(
like: &Like,
index_info: &dyn IndexInformationProvider,
) -> Option<IndexedExpression> {
let (column, _, query_parser) = maybe_indexed_column(&like.expr, index_info)?;
let pattern = match like.pattern.as_ref() {
Expr::Literal(scalar, _) => scalar.clone(),
_ => return None,
};
query_parser.visit_like(&column, like, &pattern)
}
fn visit_node(
expr: &Expr,
index_info: &dyn IndexInformationProvider,
depth: usize,
) -> Result<Option<IndexedExpression>> {
if depth >= MAX_DEPTH {
return Err(Error::invalid_input(format!(
"the filter expression is too long, lance limit the max number of conditions to {}",
MAX_DEPTH
)));
}
match expr {
Expr::Between(between) => Ok(visit_between(between, index_info)),
Expr::Alias(alias) => visit_node(alias.expr.as_ref(), index_info, depth),
Expr::Column(_) => Ok(visit_column(expr, index_info)),
Expr::InList(in_list) => Ok(visit_in_list(in_list, index_info)),
Expr::IsFalse(expr) => Ok(visit_is_bool(expr.as_ref(), index_info, false)),
Expr::IsTrue(expr) => Ok(visit_is_bool(expr.as_ref(), index_info, true)),
Expr::IsNull(expr) => Ok(visit_is_null(expr.as_ref(), index_info, false)),
Expr::IsNotNull(expr) => Ok(visit_is_null(expr.as_ref(), index_info, true)),
Expr::Not(expr) => visit_not(expr.as_ref(), index_info, depth),
Expr::BinaryExpr(binary_expr) => visit_binary_expr(binary_expr, index_info, depth),
Expr::ScalarFunction(scalar_fn) => Ok(visit_scalar_fn(scalar_fn, index_info)),
Expr::Like(like) => {
if like.negated {
Ok(None)
} else {
Ok(visit_like_expr(like, index_info))
}
}
_ => Ok(None),
}
}
pub trait IndexInformationProvider {
fn get_index(&self, col: &str) -> Option<(&DataType, &dyn ScalarQueryParser)>;
}
pub fn apply_scalar_indices(
expr: Expr,
index_info: &dyn IndexInformationProvider,
) -> Result<IndexedExpression> {
Ok(visit_node(&expr, index_info, 0)?.unwrap_or(IndexedExpression::refine_only(expr)))
}
#[derive(Clone, Default, Debug)]
pub struct FilterPlan {
pub index_query: Option<ScalarIndexExpr>,
pub skip_recheck: bool,
pub refine_expr: Option<Expr>,
pub full_expr: Option<Expr>,
}
impl FilterPlan {
pub fn empty() -> Self {
Self {
index_query: None,
skip_recheck: true,
refine_expr: None,
full_expr: None,
}
}
pub fn new_refine_only(expr: Expr) -> Self {
Self {
index_query: None,
skip_recheck: true,
refine_expr: Some(expr.clone()),
full_expr: Some(expr),
}
}
pub fn is_empty(&self) -> bool {
self.refine_expr.is_none() && self.index_query.is_none()
}
pub fn all_columns(&self) -> Vec<String> {
self.full_expr
.as_ref()
.map(Planner::column_names_in_expr)
.unwrap_or_default()
}
pub fn refine_columns(&self) -> Vec<String> {
self.refine_expr
.as_ref()
.map(Planner::column_names_in_expr)
.unwrap_or_default()
}
pub fn has_refine(&self) -> bool {
self.refine_expr.is_some()
}
pub fn has_index_query(&self) -> bool {
self.index_query.is_some()
}
pub fn has_any_filter(&self) -> bool {
self.refine_expr.is_some() || self.index_query.is_some()
}
pub fn make_refine_only(&mut self) {
self.index_query = None;
self.refine_expr = self.full_expr.clone();
}
pub fn is_exact_index_search(&self) -> bool {
self.index_query.is_some() && self.refine_expr.is_none() && self.skip_recheck
}
}
pub trait PlannerIndexExt {
fn create_filter_plan(
&self,
filter: Expr,
index_info: &dyn IndexInformationProvider,
use_scalar_index: bool,
) -> Result<FilterPlan>;
}
impl PlannerIndexExt for Planner {
fn create_filter_plan(
&self,
filter: Expr,
index_info: &dyn IndexInformationProvider,
use_scalar_index: bool,
) -> Result<FilterPlan> {
let logical_expr = self.optimize_expr(filter)?;
if use_scalar_index {
let indexed_expr = apply_scalar_indices(logical_expr.clone(), index_info)?;
let mut skip_recheck = false;
if let Some(scalar_query) = indexed_expr.scalar_query.as_ref() {
skip_recheck = !scalar_query.needs_recheck();
}
Ok(FilterPlan {
index_query: indexed_expr.scalar_query,
refine_expr: indexed_expr.refine_expr,
full_expr: Some(logical_expr),
skip_recheck,
})
} else {
Ok(FilterPlan {
index_query: None,
skip_recheck: true,
refine_expr: Some(logical_expr.clone()),
full_expr: Some(logical_expr),
})
}
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use arrow_schema::{Field, Schema};
use chrono::Utc;
use datafusion_common::{Column, DFSchema};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::simplify::SimplifyContext;
use lance_datafusion::exec::{LanceExecutionOptions, get_session_context};
use crate::scalar::json::{JsonQuery, JsonQueryParser};
use super::*;
struct ColInfo {
data_type: DataType,
parser: Box<dyn ScalarQueryParser>,
}
impl ColInfo {
fn new(data_type: DataType, parser: Box<dyn ScalarQueryParser>) -> Self {
Self { data_type, parser }
}
}
struct MockIndexInfoProvider {
indexed_columns: HashMap<String, ColInfo>,
}
impl MockIndexInfoProvider {
fn new(indexed_columns: Vec<(&str, ColInfo)>) -> Self {
Self {
indexed_columns: HashMap::from_iter(
indexed_columns
.into_iter()
.map(|(s, ty)| (s.to_string(), ty)),
),
}
}
}
impl IndexInformationProvider for MockIndexInfoProvider {
fn get_index(&self, col: &str) -> Option<(&DataType, &dyn ScalarQueryParser)> {
self.indexed_columns
.get(col)
.map(|col_info| (&col_info.data_type, col_info.parser.as_ref()))
}
}
fn check(
index_info: &dyn IndexInformationProvider,
expr: &str,
expected: Option<IndexedExpression>,
optimize: bool,
) {
let schema = Schema::new(vec![
Field::new("color", DataType::Utf8, false),
Field::new("size", DataType::Float32, false),
Field::new("aisle", DataType::UInt32, false),
Field::new("on_sale", DataType::Boolean, false),
Field::new("price", DataType::Float32, false),
Field::new("json", DataType::LargeBinary, false),
]);
let df_schema: DFSchema = schema.try_into().unwrap();
let ctx = get_session_context(&LanceExecutionOptions::default());
let state = ctx.state();
let mut expr = state.create_logical_expr(expr, &df_schema).unwrap();
if optimize {
let props = ExecutionProps::new().with_query_execution_start_time(Utc::now());
let simplify_context = SimplifyContext::new(&props).with_schema(Arc::new(df_schema));
let simplifier =
datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
expr = simplifier.simplify(expr).unwrap();
}
let actual = apply_scalar_indices(expr.clone(), index_info).unwrap();
if let Some(expected) = expected {
assert_eq!(actual, expected);
} else {
assert!(actual.scalar_query.is_none());
assert_eq!(actual.refine_expr.unwrap(), expr);
}
}
fn check_no_index(index_info: &dyn IndexInformationProvider, expr: &str) {
check(index_info, expr, None, false)
}
fn check_simple(
index_info: &dyn IndexInformationProvider,
expr: &str,
col: &str,
query: impl AnyQuery,
) {
check(
index_info,
expr,
Some(IndexedExpression::index_query(
col.to_string(),
format!("{}_idx", col),
Arc::new(query),
)),
false,
)
}
fn check_range(
index_info: &dyn IndexInformationProvider,
expr: &str,
col: &str,
query: SargableQuery,
) {
check(
index_info,
expr,
Some(IndexedExpression::index_query(
col.to_string(),
format!("{}_idx", col),
Arc::new(query),
)),
true,
)
}
fn check_simple_negated(
index_info: &dyn IndexInformationProvider,
expr: &str,
col: &str,
query: SargableQuery,
) {
check(
index_info,
expr,
Some(
IndexedExpression::index_query(
col.to_string(),
format!("{}_idx", col),
Arc::new(query),
)
.maybe_not()
.unwrap(),
),
false,
)
}
#[test]
fn test_expressions() {
let index_info = MockIndexInfoProvider::new(vec![
(
"color",
ColInfo::new(
DataType::Utf8,
Box::new(SargableQueryParser::new("color_idx".to_string(), false)),
),
),
(
"aisle",
ColInfo::new(
DataType::UInt32,
Box::new(SargableQueryParser::new("aisle_idx".to_string(), false)),
),
),
(
"on_sale",
ColInfo::new(
DataType::Boolean,
Box::new(SargableQueryParser::new("on_sale_idx".to_string(), false)),
),
),
(
"price",
ColInfo::new(
DataType::Float32,
Box::new(SargableQueryParser::new("price_idx".to_string(), false)),
),
),
(
"json",
ColInfo::new(
DataType::LargeBinary,
Box::new(JsonQueryParser::new(
"$.name".to_string(),
Box::new(SargableQueryParser::new("json_idx".to_string(), false)),
)),
),
),
]);
check_simple(
&index_info,
"json_extract(json, '$.name') = 'foo'",
"json",
JsonQuery::new(
Arc::new(SargableQuery::Equals(ScalarValue::Utf8(Some(
"foo".to_string(),
)))),
"$.name".to_string(),
),
);
check_no_index(&index_info, "size BETWEEN 5 AND 10");
check_simple(
&index_info,
"aisle = arrow_cast(5, 'Int16')",
"aisle",
SargableQuery::Equals(ScalarValue::UInt32(Some(5))),
);
check_range(
&index_info,
"aisle BETWEEN 5 AND 10",
"aisle",
SargableQuery::Range(
Bound::Included(ScalarValue::UInt32(Some(5))),
Bound::Included(ScalarValue::UInt32(Some(10))),
),
);
check_range(
&index_info,
"aisle >= 5 AND aisle <= 10",
"aisle",
SargableQuery::Range(
Bound::Included(ScalarValue::UInt32(Some(5))),
Bound::Included(ScalarValue::UInt32(Some(10))),
),
);
check_range(
&index_info,
"aisle <= 10 AND aisle >= 5",
"aisle",
SargableQuery::Range(
Bound::Included(ScalarValue::UInt32(Some(5))),
Bound::Included(ScalarValue::UInt32(Some(10))),
),
);
check_range(
&index_info,
"5 <= aisle AND 10 >= aisle",
"aisle",
SargableQuery::Range(
Bound::Included(ScalarValue::UInt32(Some(5))),
Bound::Included(ScalarValue::UInt32(Some(10))),
),
);
check_range(
&index_info,
"10 >= aisle AND 5 <= aisle",
"aisle",
SargableQuery::Range(
Bound::Included(ScalarValue::UInt32(Some(5))),
Bound::Included(ScalarValue::UInt32(Some(10))),
),
);
check_simple(
&index_info,
"on_sale IS TRUE",
"on_sale",
SargableQuery::Equals(ScalarValue::Boolean(Some(true))),
);
check_simple(
&index_info,
"on_sale",
"on_sale",
SargableQuery::Equals(ScalarValue::Boolean(Some(true))),
);
check_simple_negated(
&index_info,
"NOT on_sale",
"on_sale",
SargableQuery::Equals(ScalarValue::Boolean(Some(true))),
);
check_simple(
&index_info,
"on_sale IS FALSE",
"on_sale",
SargableQuery::Equals(ScalarValue::Boolean(Some(false))),
);
check_simple_negated(
&index_info,
"aisle NOT BETWEEN 5 AND 10",
"aisle",
SargableQuery::Range(
Bound::Included(ScalarValue::UInt32(Some(5))),
Bound::Included(ScalarValue::UInt32(Some(10))),
),
);
check_simple(
&index_info,
"aisle IN (5, 6, 7)",
"aisle",
SargableQuery::IsIn(vec![
ScalarValue::UInt32(Some(5)),
ScalarValue::UInt32(Some(6)),
ScalarValue::UInt32(Some(7)),
]),
);
check_simple_negated(
&index_info,
"NOT aisle IN (5, 6, 7)",
"aisle",
SargableQuery::IsIn(vec![
ScalarValue::UInt32(Some(5)),
ScalarValue::UInt32(Some(6)),
ScalarValue::UInt32(Some(7)),
]),
);
check_simple_negated(
&index_info,
"aisle NOT IN (5, 6, 7)",
"aisle",
SargableQuery::IsIn(vec![
ScalarValue::UInt32(Some(5)),
ScalarValue::UInt32(Some(6)),
ScalarValue::UInt32(Some(7)),
]),
);
check_simple(
&index_info,
"aisle IN (5, 6, 7, 8, 9)",
"aisle",
SargableQuery::IsIn(vec![
ScalarValue::UInt32(Some(5)),
ScalarValue::UInt32(Some(6)),
ScalarValue::UInt32(Some(7)),
ScalarValue::UInt32(Some(8)),
ScalarValue::UInt32(Some(9)),
]),
);
check_simple_negated(
&index_info,
"NOT aisle IN (5, 6, 7, 8, 9)",
"aisle",
SargableQuery::IsIn(vec![
ScalarValue::UInt32(Some(5)),
ScalarValue::UInt32(Some(6)),
ScalarValue::UInt32(Some(7)),
ScalarValue::UInt32(Some(8)),
ScalarValue::UInt32(Some(9)),
]),
);
check_simple_negated(
&index_info,
"aisle NOT IN (5, 6, 7, 8, 9)",
"aisle",
SargableQuery::IsIn(vec![
ScalarValue::UInt32(Some(5)),
ScalarValue::UInt32(Some(6)),
ScalarValue::UInt32(Some(7)),
ScalarValue::UInt32(Some(8)),
ScalarValue::UInt32(Some(9)),
]),
);
check_simple(
&index_info,
"on_sale is false",
"on_sale",
SargableQuery::Equals(ScalarValue::Boolean(Some(false))),
);
check_simple(
&index_info,
"on_sale is true",
"on_sale",
SargableQuery::Equals(ScalarValue::Boolean(Some(true))),
);
check_simple(
&index_info,
"aisle < 10",
"aisle",
SargableQuery::Range(
Bound::Unbounded,
Bound::Excluded(ScalarValue::UInt32(Some(10))),
),
);
check_simple(
&index_info,
"aisle <= 10",
"aisle",
SargableQuery::Range(
Bound::Unbounded,
Bound::Included(ScalarValue::UInt32(Some(10))),
),
);
check_simple(
&index_info,
"aisle > 10",
"aisle",
SargableQuery::Range(
Bound::Excluded(ScalarValue::UInt32(Some(10))),
Bound::Unbounded,
),
);
check_no_index(&index_info, "10 > aisle");
check_simple(
&index_info,
"aisle >= 10",
"aisle",
SargableQuery::Range(
Bound::Included(ScalarValue::UInt32(Some(10))),
Bound::Unbounded,
),
);
check_simple(
&index_info,
"aisle = 10",
"aisle",
SargableQuery::Equals(ScalarValue::UInt32(Some(10))),
);
check_simple_negated(
&index_info,
"aisle <> 10",
"aisle",
SargableQuery::Equals(ScalarValue::UInt32(Some(10))),
);
let left = Box::new(ScalarIndexExpr::Query(ScalarIndexSearch {
column: "aisle".to_string(),
index_name: "aisle_idx".to_string(),
query: Arc::new(SargableQuery::Equals(ScalarValue::UInt32(Some(10)))),
needs_recheck: false,
}));
let right = Box::new(ScalarIndexExpr::Query(ScalarIndexSearch {
column: "color".to_string(),
index_name: "color_idx".to_string(),
query: Arc::new(SargableQuery::Equals(ScalarValue::Utf8(Some(
"blue".to_string(),
)))),
needs_recheck: false,
}));
check(
&index_info,
"aisle = 10 AND color = 'blue'",
Some(IndexedExpression {
scalar_query: Some(ScalarIndexExpr::And(left.clone(), right.clone())),
refine_expr: None,
}),
false,
);
let refine = Expr::Column(Column::new_unqualified("size")).gt(datafusion_expr::lit(30_i64));
check(
&index_info,
"aisle = 10 AND color = 'blue' AND size > 30",
Some(IndexedExpression {
scalar_query: Some(ScalarIndexExpr::And(left.clone(), right.clone())),
refine_expr: Some(refine.clone()),
}),
false,
);
check(
&index_info,
"aisle = 10 OR color = 'blue'",
Some(IndexedExpression {
scalar_query: Some(ScalarIndexExpr::Or(left.clone(), right.clone())),
refine_expr: None,
}),
false,
);
check_no_index(&index_info, "aisle = 10 OR color = 'blue' OR size > 30");
check(
&index_info,
"(aisle = 10 OR color = 'blue') AND size > 30",
Some(IndexedExpression {
scalar_query: Some(ScalarIndexExpr::Or(left, right)),
refine_expr: Some(refine),
}),
false,
);
check_no_index(
&index_info,
"(aisle = 10 AND size > 30) OR (color = 'blue' AND size > 20)",
);
check_no_index(&index_info, "aisle + 3 < 10");
check_no_index(&index_info, "aisle IN (5, 6, NULL)");
check_no_index(&index_info, "aisle = 5 OR aisle = 6 OR NULL");
check_no_index(&index_info, "aisle IN (5, 6, 7, 8, NULL)");
check_no_index(&index_info, "aisle = NULL");
check_no_index(&index_info, "aisle BETWEEN 5 AND NULL");
check_no_index(&index_info, "aisle BETWEEN NULL AND 10");
}
#[tokio::test]
async fn test_not_flips_certainty() {
use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap};
fn apply_not(result: NullableIndexExprResult) -> NullableIndexExprResult {
match result {
NullableIndexExprResult::Exact(mask) => NullableIndexExprResult::Exact(!mask),
NullableIndexExprResult::AtMost(mask) => NullableIndexExprResult::AtLeast(!mask),
NullableIndexExprResult::AtLeast(mask) => NullableIndexExprResult::AtMost(!mask),
}
}
let at_most = NullableIndexExprResult::AtMost(NullableRowAddrMask::AllowList(
NullableRowAddrSet::new(RowAddrTreeMap::from_iter(&[1, 2]), RowAddrTreeMap::new()),
));
assert!(matches!(
apply_not(at_most),
NullableIndexExprResult::AtLeast(_)
));
let at_least = NullableIndexExprResult::AtLeast(NullableRowAddrMask::AllowList(
NullableRowAddrSet::new(RowAddrTreeMap::from_iter(&[1, 2]), RowAddrTreeMap::new()),
));
assert!(matches!(
apply_not(at_least),
NullableIndexExprResult::AtMost(_)
));
let exact = NullableIndexExprResult::Exact(NullableRowAddrMask::AllowList(
NullableRowAddrSet::new(RowAddrTreeMap::from_iter(&[1, 2]), RowAddrTreeMap::new()),
));
assert!(matches!(
apply_not(exact),
NullableIndexExprResult::Exact(_)
));
}
#[tokio::test]
async fn test_and_or_preserve_certainty() {
use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap};
let make_at_most = || {
NullableIndexExprResult::AtMost(NullableRowAddrMask::AllowList(
NullableRowAddrSet::new(
RowAddrTreeMap::from_iter(&[1, 2, 3]),
RowAddrTreeMap::new(),
),
))
};
let make_at_least = || {
NullableIndexExprResult::AtLeast(NullableRowAddrMask::AllowList(
NullableRowAddrSet::new(
RowAddrTreeMap::from_iter(&[2, 3, 4]),
RowAddrTreeMap::new(),
),
))
};
let make_exact = || {
NullableIndexExprResult::Exact(NullableRowAddrMask::AllowList(NullableRowAddrSet::new(
RowAddrTreeMap::from_iter(&[1, 2]),
RowAddrTreeMap::new(),
)))
};
assert!(matches!(
make_at_most() & make_at_most(),
NullableIndexExprResult::AtMost(_)
));
assert!(matches!(
make_at_least() & make_at_least(),
NullableIndexExprResult::AtLeast(_)
));
assert!(matches!(
make_at_most() & make_at_least(),
NullableIndexExprResult::AtMost(_)
));
assert!(matches!(
make_at_most() | make_at_most(),
NullableIndexExprResult::AtMost(_)
));
assert!(matches!(
make_at_least() | make_at_least(),
NullableIndexExprResult::AtLeast(_)
));
assert!(matches!(
make_at_most() | make_at_least(),
NullableIndexExprResult::AtLeast(_)
));
assert!(matches!(
make_exact() & make_at_most(),
NullableIndexExprResult::AtMost(_)
));
assert!(matches!(
make_exact() | make_at_least(),
NullableIndexExprResult::AtLeast(_)
));
}
#[test]
fn test_extract_like_leading_prefix() {
assert_eq!(
extract_like_leading_prefix("foo%", None),
Some(("foo".to_string(), false))
);
assert_eq!(
extract_like_leading_prefix("abc%", None),
Some(("abc".to_string(), false))
);
assert_eq!(
extract_like_leading_prefix("foo%bar%", None),
Some(("foo".to_string(), true))
);
assert_eq!(
extract_like_leading_prefix("foo_bar%", None),
Some(("foo".to_string(), true))
);
assert_eq!(
extract_like_leading_prefix("foo%bar", None),
Some(("foo".to_string(), true))
);
assert_eq!(
extract_like_leading_prefix("foo_", None),
Some(("foo".to_string(), true))
);
assert_eq!(extract_like_leading_prefix("%foo", None), None);
assert_eq!(extract_like_leading_prefix("_foo%", None), None);
assert_eq!(extract_like_leading_prefix("%", None), None);
assert_eq!(extract_like_leading_prefix("foo", None), None);
assert_eq!(
extract_like_leading_prefix(r"foo\%bar%", Some('\\')),
Some(("foo%bar".to_string(), false))
);
assert_eq!(
extract_like_leading_prefix(r"foo\_bar%", Some('\\')),
Some(("foo_bar".to_string(), false))
);
assert_eq!(
extract_like_leading_prefix(r"foo\\bar%", Some('\\')),
Some(("foo\\bar".to_string(), false))
);
assert_eq!(extract_like_leading_prefix(r"foo\%", Some('\\')), None);
assert_eq!(extract_like_leading_prefix(r"foo\%", None), None);
assert_eq!(
extract_like_leading_prefix(r"foo\bar%", None),
Some(("foo\\bar".to_string(), false))
);
assert_eq!(extract_like_leading_prefix("", None), None);
assert_eq!(
extract_like_leading_prefix(r"foo\%bar%baz%", Some('\\')),
Some(("foo%bar".to_string(), true))
);
}
#[test]
fn test_like_expression_parsing() {
let index_info = MockIndexInfoProvider::new(vec![(
"color",
ColInfo::new(
DataType::Utf8,
Box::new(SargableQueryParser::new("color_idx".to_string(), false)),
),
)]);
let schema = Schema::new(vec![Field::new("color", DataType::Utf8, false)]);
let df_schema: DFSchema = schema.try_into().unwrap();
let ctx = get_session_context(&LanceExecutionOptions::default());
let state = ctx.state();
let expr = state
.create_logical_expr("color LIKE 'foo%'", &df_schema)
.unwrap();
let result = apply_scalar_indices(expr, &index_info).unwrap();
assert!(result.scalar_query.is_some(), "Should have scalar_query");
assert!(
result.refine_expr.is_none(),
"Simple prefix should not need refine_expr"
);
if let Some(ScalarIndexExpr::Query(search)) = &result.scalar_query {
let query = search.query.as_any().downcast_ref::<SargableQuery>();
assert!(query.is_some(), "Query should be SargableQuery");
match query.unwrap() {
SargableQuery::LikePrefix(prefix) => {
assert_eq!(prefix, &ScalarValue::Utf8(Some("foo".to_string())));
}
_ => panic!("Expected LikePrefix query"),
}
} else {
panic!("Expected Query variant");
}
let expr = state
.create_logical_expr("color LIKE 'foo%bar%'", &df_schema)
.unwrap();
let result = apply_scalar_indices(expr, &index_info).unwrap();
assert!(result.scalar_query.is_some(), "Should have scalar_query");
assert!(
result.refine_expr.is_some(),
"Complex pattern should have refine_expr"
);
if let Some(ScalarIndexExpr::Query(search)) = &result.scalar_query {
let query = search.query.as_any().downcast_ref::<SargableQuery>();
assert!(query.is_some(), "Query should be SargableQuery");
match query.unwrap() {
SargableQuery::LikePrefix(prefix) => {
assert_eq!(prefix, &ScalarValue::Utf8(Some("foo".to_string())));
}
_ => panic!("Expected LikePrefix query"),
}
}
let refine = result.refine_expr.unwrap();
match refine {
Expr::Like(like) => {
assert!(!like.negated);
assert!(!like.case_insensitive);
if let Expr::Literal(ScalarValue::Utf8(Some(pattern)), _) = like.pattern.as_ref() {
assert_eq!(pattern, "foo%bar%");
} else {
panic!("Expected Utf8 literal pattern");
}
}
_ => panic!("Expected Like expression in refine_expr"),
}
let expr = state
.create_logical_expr("color LIKE '%foo'", &df_schema)
.unwrap();
let result = apply_scalar_indices(expr, &index_info).unwrap();
assert!(
result.scalar_query.is_none(),
"Pattern starting with wildcard should not use index"
);
assert!(result.refine_expr.is_some(), "Should fall back to refine");
}
#[test]
fn test_starts_with_with_underscore_after_optimization() {
let index_info = MockIndexInfoProvider::new(vec![(
"object_id",
ColInfo::new(
DataType::Utf8,
Box::new(SargableQueryParser::new("object_id_idx".to_string(), false)),
),
)]);
let schema = Schema::new(vec![Field::new("object_id", DataType::Utf8, false)]);
let df_schema: DFSchema = schema.try_into().unwrap();
let ctx = get_session_context(&LanceExecutionOptions::default());
let state = ctx.state();
let expr = state
.create_logical_expr("starts_with(object_id, 'test_ns$')", &df_schema)
.unwrap();
let props = ExecutionProps::new().with_query_execution_start_time(Utc::now());
let simplify_context = SimplifyContext::new(&props).with_schema(Arc::new(df_schema));
let simplifier =
datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
let simplified_expr = simplifier.simplify(expr).unwrap();
let result = apply_scalar_indices(simplified_expr, &index_info).unwrap();
if let Some(ScalarIndexExpr::Query(search)) = &result.scalar_query {
let query = search
.query
.as_any()
.downcast_ref::<SargableQuery>()
.unwrap();
match query {
SargableQuery::LikePrefix(prefix) => {
let prefix_str = match prefix {
ScalarValue::Utf8(Some(s)) => s.clone(),
_ => panic!("Expected Utf8 prefix"),
};
assert_eq!(
prefix_str, "test_ns$",
"Prefix should be 'test_ns$', not 'test' (underscore should not be a wildcard)"
);
}
_ => panic!("Expected LikePrefix query"),
}
} else {
panic!("Expected scalar_query to be present");
}
}
#[test]
fn test_starts_with_to_like_conversion() {
let index_info = MockIndexInfoProvider::new(vec![(
"color",
ColInfo::new(
DataType::Utf8,
Box::new(SargableQueryParser::new("color_idx".to_string(), false)),
),
)]);
let schema = Schema::new(vec![Field::new("color", DataType::Utf8, false)]);
let df_schema: DFSchema = schema.try_into().unwrap();
let ctx = get_session_context(&LanceExecutionOptions::default());
let state = ctx.state();
let expr = state
.create_logical_expr("starts_with(color, 'foo')", &df_schema)
.unwrap();
let result = apply_scalar_indices(expr, &index_info).unwrap();
assert!(
result.scalar_query.is_some(),
"starts_with should use index"
);
assert!(
result.refine_expr.is_none(),
"Pure prefix starts_with should not need refine_expr"
);
if let Some(ScalarIndexExpr::Query(search)) = &result.scalar_query {
let query = search.query.as_any().downcast_ref::<SargableQuery>();
assert!(query.is_some(), "Query should be SargableQuery");
match query.unwrap() {
SargableQuery::LikePrefix(prefix) => {
assert_eq!(prefix, &ScalarValue::Utf8(Some("foo".to_string())));
}
_ => panic!("Expected LikePrefix query"),
}
} else {
panic!("Expected Query variant");
}
let like_expr = state
.create_logical_expr("color LIKE 'foo%'", &df_schema)
.unwrap();
let like_result = apply_scalar_indices(like_expr, &index_info).unwrap();
if let (
Some(ScalarIndexExpr::Query(starts_with_search)),
Some(ScalarIndexExpr::Query(like_search)),
) = (&result.scalar_query, &like_result.scalar_query)
{
let sw_query = starts_with_search
.query
.as_any()
.downcast_ref::<SargableQuery>()
.unwrap();
let like_query = like_search
.query
.as_any()
.downcast_ref::<SargableQuery>()
.unwrap();
assert_eq!(
sw_query, like_query,
"starts_with and LIKE 'prefix%' should produce identical queries"
);
}
}
}