1use std::{
5 ops::Bound,
6 sync::{Arc, LazyLock},
7};
8
9use arrow::array::BinaryBuilder;
10use arrow_array::{Array, RecordBatch, UInt32Array};
11use arrow_schema::{DataType, Field, Schema, SchemaRef};
12use async_recursion::async_recursion;
13use async_trait::async_trait;
14use datafusion_common::ScalarValue;
15use datafusion_expr::{
16 expr::{InList, ScalarFunction},
17 Between, BinaryExpr, Expr, Operator, ReturnFieldArgs, ScalarUDF,
18};
19
20use super::{
21 AnyQuery, BloomFilterQuery, LabelListQuery, MetricsCollector, SargableQuery, ScalarIndex,
22 SearchResult, TextQuery, TokenQuery,
23};
24use futures::join;
25use lance_core::{utils::mask::RowIdMask, Error, Result};
26use lance_datafusion::{expr::safe_coerce_scalar, planner::Planner};
27use roaring::RoaringBitmap;
28use snafu::location;
29use tracing::instrument;
30
31const MAX_DEPTH: usize = 500;
32
33#[derive(Debug, PartialEq)]
61pub struct IndexedExpression {
62 pub scalar_query: Option<ScalarIndexExpr>,
64 pub refine_expr: Option<Expr>,
66}
67
68pub trait ScalarQueryParser: std::fmt::Debug + Send + Sync {
69 fn visit_between(
73 &self,
74 column: &str,
75 low: &Bound<ScalarValue>,
76 high: &Bound<ScalarValue>,
77 ) -> Option<IndexedExpression>;
78 fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option<IndexedExpression>;
82 fn visit_is_bool(&self, column: &str, value: bool) -> Option<IndexedExpression>;
86 fn visit_is_null(&self, column: &str) -> Option<IndexedExpression>;
90 fn visit_comparison(
94 &self,
95 column: &str,
96 value: &ScalarValue,
97 op: &Operator,
98 ) -> Option<IndexedExpression>;
99 fn visit_scalar_function(
104 &self,
105 column: &str,
106 data_type: &DataType,
107 func: &ScalarUDF,
108 args: &[Expr],
109 ) -> Option<IndexedExpression>;
110
111 fn is_valid_reference(&self, func: &Expr, data_type: &DataType) -> Option<DataType> {
135 match func {
136 Expr::Column(_) => Some(data_type.clone()),
137 _ => None,
138 }
139 }
140}
141
142#[derive(Debug)]
146pub struct MultiQueryParser {
147 parsers: Vec<Box<dyn ScalarQueryParser>>,
148}
149
150impl MultiQueryParser {
151 pub fn single(parser: Box<dyn ScalarQueryParser>) -> Self {
153 Self {
154 parsers: vec![parser],
155 }
156 }
157
158 pub fn add(&mut self, other: Box<dyn ScalarQueryParser>) {
160 self.parsers.push(other);
161 }
162}
163
164impl ScalarQueryParser for MultiQueryParser {
165 fn visit_between(
166 &self,
167 column: &str,
168 low: &Bound<ScalarValue>,
169 high: &Bound<ScalarValue>,
170 ) -> Option<IndexedExpression> {
171 self.parsers
172 .iter()
173 .find_map(|parser| parser.visit_between(column, low, high))
174 }
175 fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option<IndexedExpression> {
176 self.parsers
177 .iter()
178 .find_map(|parser| parser.visit_in_list(column, in_list))
179 }
180 fn visit_is_bool(&self, column: &str, value: bool) -> Option<IndexedExpression> {
181 self.parsers
182 .iter()
183 .find_map(|parser| parser.visit_is_bool(column, value))
184 }
185 fn visit_is_null(&self, column: &str) -> Option<IndexedExpression> {
186 self.parsers
187 .iter()
188 .find_map(|parser| parser.visit_is_null(column))
189 }
190 fn visit_comparison(
191 &self,
192 column: &str,
193 value: &ScalarValue,
194 op: &Operator,
195 ) -> Option<IndexedExpression> {
196 self.parsers
197 .iter()
198 .find_map(|parser| parser.visit_comparison(column, value, op))
199 }
200 fn visit_scalar_function(
201 &self,
202 column: &str,
203 data_type: &DataType,
204 func: &ScalarUDF,
205 args: &[Expr],
206 ) -> Option<IndexedExpression> {
207 self.parsers
208 .iter()
209 .find_map(|parser| parser.visit_scalar_function(column, data_type, func, args))
210 }
211 fn is_valid_reference(&self, func: &Expr, data_type: &DataType) -> Option<DataType> {
218 self.parsers
219 .iter()
220 .find_map(|parser| parser.is_valid_reference(func, data_type))
221 }
222}
223
224#[derive(Debug)]
226pub struct SargableQueryParser {
227 index_name: String,
228 needs_recheck: bool,
229}
230
231impl SargableQueryParser {
232 pub fn new(index_name: String, needs_recheck: bool) -> Self {
233 Self {
234 index_name,
235 needs_recheck,
236 }
237 }
238}
239
240impl ScalarQueryParser for SargableQueryParser {
241 fn is_valid_reference(&self, func: &Expr, data_type: &DataType) -> Option<DataType> {
242 match func {
243 Expr::Column(_) => Some(data_type.clone()),
244 Expr::ScalarFunction(udf) if udf.name() == "get_field" => Some(data_type.clone()),
246 _ => None,
247 }
248 }
249
250 fn visit_between(
251 &self,
252 column: &str,
253 low: &Bound<ScalarValue>,
254 high: &Bound<ScalarValue>,
255 ) -> Option<IndexedExpression> {
256 if let Bound::Included(val) | Bound::Excluded(val) = low {
257 if val.is_null() {
258 return None;
259 }
260 }
261 if let Bound::Included(val) | Bound::Excluded(val) = high {
262 if val.is_null() {
263 return None;
264 }
265 }
266 let query = SargableQuery::Range(low.clone(), high.clone());
267 Some(IndexedExpression::index_query_with_recheck(
268 column.to_string(),
269 self.index_name.clone(),
270 Arc::new(query),
271 self.needs_recheck,
272 ))
273 }
274
275 fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option<IndexedExpression> {
276 if in_list.iter().any(|val| val.is_null()) {
277 return None;
278 }
279 let query = SargableQuery::IsIn(in_list.to_vec());
280 Some(IndexedExpression::index_query_with_recheck(
281 column.to_string(),
282 self.index_name.clone(),
283 Arc::new(query),
284 self.needs_recheck,
285 ))
286 }
287
288 fn visit_is_bool(&self, column: &str, value: bool) -> Option<IndexedExpression> {
289 Some(IndexedExpression::index_query_with_recheck(
290 column.to_string(),
291 self.index_name.clone(),
292 Arc::new(SargableQuery::Equals(ScalarValue::Boolean(Some(value)))),
293 self.needs_recheck,
294 ))
295 }
296
297 fn visit_is_null(&self, column: &str) -> Option<IndexedExpression> {
298 Some(IndexedExpression::index_query_with_recheck(
299 column.to_string(),
300 self.index_name.clone(),
301 Arc::new(SargableQuery::IsNull()),
302 self.needs_recheck,
303 ))
304 }
305
306 fn visit_comparison(
307 &self,
308 column: &str,
309 value: &ScalarValue,
310 op: &Operator,
311 ) -> Option<IndexedExpression> {
312 if value.is_null() {
313 return None;
314 }
315 let query = match op {
316 Operator::Lt => SargableQuery::Range(Bound::Unbounded, Bound::Excluded(value.clone())),
317 Operator::LtEq => {
318 SargableQuery::Range(Bound::Unbounded, Bound::Included(value.clone()))
319 }
320 Operator::Gt => SargableQuery::Range(Bound::Excluded(value.clone()), Bound::Unbounded),
321 Operator::GtEq => {
322 SargableQuery::Range(Bound::Included(value.clone()), Bound::Unbounded)
323 }
324 Operator::Eq => SargableQuery::Equals(value.clone()),
325 Operator::NotEq => SargableQuery::Equals(value.clone()),
327 _ => unreachable!(),
328 };
329 Some(IndexedExpression::index_query_with_recheck(
330 column.to_string(),
331 self.index_name.clone(),
332 Arc::new(query),
333 self.needs_recheck,
334 ))
335 }
336
337 fn visit_scalar_function(
338 &self,
339 _: &str,
340 _: &DataType,
341 _: &ScalarUDF,
342 _: &[Expr],
343 ) -> Option<IndexedExpression> {
344 None
345 }
346}
347
348#[derive(Debug)]
350pub struct BloomFilterQueryParser {
351 index_name: String,
352 needs_recheck: bool,
353}
354
355impl BloomFilterQueryParser {
356 pub fn new(index_name: String, needs_recheck: bool) -> Self {
357 Self {
358 index_name,
359 needs_recheck,
360 }
361 }
362}
363
364impl ScalarQueryParser for BloomFilterQueryParser {
365 fn visit_between(
366 &self,
367 _: &str,
368 _: &Bound<ScalarValue>,
369 _: &Bound<ScalarValue>,
370 ) -> Option<IndexedExpression> {
371 None
373 }
374
375 fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option<IndexedExpression> {
376 let query = BloomFilterQuery::IsIn(in_list.to_vec());
377 Some(IndexedExpression::index_query_with_recheck(
378 column.to_string(),
379 self.index_name.clone(),
380 Arc::new(query),
381 self.needs_recheck,
382 ))
383 }
384
385 fn visit_is_bool(&self, column: &str, value: bool) -> Option<IndexedExpression> {
386 Some(IndexedExpression::index_query_with_recheck(
387 column.to_string(),
388 self.index_name.clone(),
389 Arc::new(BloomFilterQuery::Equals(ScalarValue::Boolean(Some(value)))),
390 self.needs_recheck,
391 ))
392 }
393
394 fn visit_is_null(&self, column: &str) -> Option<IndexedExpression> {
395 Some(IndexedExpression::index_query_with_recheck(
396 column.to_string(),
397 self.index_name.clone(),
398 Arc::new(BloomFilterQuery::IsNull()),
399 self.needs_recheck,
400 ))
401 }
402
403 fn visit_comparison(
404 &self,
405 column: &str,
406 value: &ScalarValue,
407 op: &Operator,
408 ) -> Option<IndexedExpression> {
409 let query = match op {
410 Operator::Eq => BloomFilterQuery::Equals(value.clone()),
412 Operator::NotEq => BloomFilterQuery::Equals(value.clone()),
414 _ => return None,
416 };
417 Some(IndexedExpression::index_query_with_recheck(
418 column.to_string(),
419 self.index_name.clone(),
420 Arc::new(query),
421 self.needs_recheck,
422 ))
423 }
424
425 fn visit_scalar_function(
426 &self,
427 _: &str,
428 _: &DataType,
429 _: &ScalarUDF,
430 _: &[Expr],
431 ) -> Option<IndexedExpression> {
432 None
434 }
435}
436
437#[derive(Debug)]
439pub struct LabelListQueryParser {
440 index_name: String,
441}
442
443impl LabelListQueryParser {
444 pub fn new(index_name: String) -> Self {
445 Self { index_name }
446 }
447}
448
449impl ScalarQueryParser for LabelListQueryParser {
450 fn visit_between(
451 &self,
452 _: &str,
453 _: &Bound<ScalarValue>,
454 _: &Bound<ScalarValue>,
455 ) -> Option<IndexedExpression> {
456 None
457 }
458
459 fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option<IndexedExpression> {
460 None
461 }
462
463 fn visit_is_bool(&self, _: &str, _: bool) -> Option<IndexedExpression> {
464 None
465 }
466
467 fn visit_is_null(&self, _: &str) -> Option<IndexedExpression> {
468 None
469 }
470
471 fn visit_comparison(
472 &self,
473 _: &str,
474 _: &ScalarValue,
475 _: &Operator,
476 ) -> Option<IndexedExpression> {
477 None
478 }
479
480 fn visit_scalar_function(
481 &self,
482 column: &str,
483 data_type: &DataType,
484 func: &ScalarUDF,
485 args: &[Expr],
486 ) -> Option<IndexedExpression> {
487 if args.len() != 2 {
488 return None;
489 }
490 let label_list = maybe_scalar(&args[1], data_type)?;
491 if let ScalarValue::List(list_arr) = label_list {
492 let list_values = list_arr.values();
493 let mut scalars = Vec::with_capacity(list_values.len());
494 for idx in 0..list_values.len() {
495 scalars.push(ScalarValue::try_from_array(list_values.as_ref(), idx).ok()?);
496 }
497 if func.name() == "array_has_all" {
498 let query = LabelListQuery::HasAllLabels(scalars);
499 Some(IndexedExpression::index_query(
500 column.to_string(),
501 self.index_name.clone(),
502 Arc::new(query),
503 ))
504 } else if func.name() == "array_has_any" {
505 let query = LabelListQuery::HasAnyLabel(scalars);
506 Some(IndexedExpression::index_query(
507 column.to_string(),
508 self.index_name.clone(),
509 Arc::new(query),
510 ))
511 } else {
512 None
513 }
514 } else {
515 None
516 }
517 }
518}
519
520#[derive(Debug, Clone)]
522pub struct TextQueryParser {
523 index_name: String,
524 needs_recheck: bool,
525}
526
527impl TextQueryParser {
528 pub fn new(index_name: String, needs_recheck: bool) -> Self {
529 Self {
530 index_name,
531 needs_recheck,
532 }
533 }
534}
535
536impl ScalarQueryParser for TextQueryParser {
537 fn visit_between(
538 &self,
539 _: &str,
540 _: &Bound<ScalarValue>,
541 _: &Bound<ScalarValue>,
542 ) -> Option<IndexedExpression> {
543 None
544 }
545
546 fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option<IndexedExpression> {
547 None
548 }
549
550 fn visit_is_bool(&self, _: &str, _: bool) -> Option<IndexedExpression> {
551 None
552 }
553
554 fn visit_is_null(&self, _: &str) -> Option<IndexedExpression> {
555 None
556 }
557
558 fn visit_comparison(
559 &self,
560 _: &str,
561 _: &ScalarValue,
562 _: &Operator,
563 ) -> Option<IndexedExpression> {
564 None
565 }
566
567 fn visit_scalar_function(
568 &self,
569 column: &str,
570 data_type: &DataType,
571 func: &ScalarUDF,
572 args: &[Expr],
573 ) -> Option<IndexedExpression> {
574 if args.len() != 2 {
575 return None;
576 }
577 let scalar = maybe_scalar(&args[1], data_type)?;
578 match scalar {
579 ScalarValue::Utf8(Some(scalar_str)) | ScalarValue::LargeUtf8(Some(scalar_str)) => {
580 if func.name() == "contains" {
581 let query = TextQuery::StringContains(scalar_str);
582 Some(IndexedExpression::index_query_with_recheck(
583 column.to_string(),
584 self.index_name.clone(),
585 Arc::new(query),
586 self.needs_recheck,
587 ))
588 } else {
589 None
590 }
591 }
592 _ => {
593 None
595 }
596 }
597 }
598}
599
600#[derive(Debug, Clone)]
602pub struct FtsQueryParser {
603 index_name: String,
604}
605
606impl FtsQueryParser {
607 pub fn new(name: String) -> Self {
608 Self { index_name: name }
609 }
610}
611
612impl ScalarQueryParser for FtsQueryParser {
613 fn visit_between(
614 &self,
615 _: &str,
616 _: &Bound<ScalarValue>,
617 _: &Bound<ScalarValue>,
618 ) -> Option<IndexedExpression> {
619 None
620 }
621
622 fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option<IndexedExpression> {
623 None
624 }
625
626 fn visit_is_bool(&self, _: &str, _: bool) -> Option<IndexedExpression> {
627 None
628 }
629
630 fn visit_is_null(&self, _: &str) -> Option<IndexedExpression> {
631 None
632 }
633
634 fn visit_comparison(
635 &self,
636 _: &str,
637 _: &ScalarValue,
638 _: &Operator,
639 ) -> Option<IndexedExpression> {
640 None
641 }
642
643 fn visit_scalar_function(
644 &self,
645 column: &str,
646 data_type: &DataType,
647 func: &ScalarUDF,
648 args: &[Expr],
649 ) -> Option<IndexedExpression> {
650 if args.len() != 2 {
651 return None;
652 }
653 let scalar = maybe_scalar(&args[1], data_type)?;
654 if let ScalarValue::Utf8(Some(scalar_str)) = scalar {
655 if func.name() == "contains_tokens" {
656 let query = TokenQuery::TokensContains(scalar_str);
657 return Some(IndexedExpression::index_query(
658 column.to_string(),
659 self.index_name.clone(),
660 Arc::new(query),
661 ));
662 }
663 }
664 None
665 }
666}
667
668impl IndexedExpression {
669 fn refine_only(refine_expr: Expr) -> Self {
671 Self {
672 scalar_query: None,
673 refine_expr: Some(refine_expr),
674 }
675 }
676
677 fn index_query(column: String, index_name: String, query: Arc<dyn AnyQuery>) -> Self {
679 Self {
680 scalar_query: Some(ScalarIndexExpr::Query(ScalarIndexSearch {
681 column,
682 index_name,
683 query,
684 needs_recheck: false, })),
686 refine_expr: None,
687 }
688 }
689
690 fn index_query_with_recheck(
692 column: String,
693 index_name: String,
694 query: Arc<dyn AnyQuery>,
695 needs_recheck: bool,
696 ) -> Self {
697 Self {
698 scalar_query: Some(ScalarIndexExpr::Query(ScalarIndexSearch {
699 column,
700 index_name,
701 query,
702 needs_recheck,
703 })),
704 refine_expr: None,
705 }
706 }
707
708 fn maybe_not(self) -> Option<Self> {
713 match (self.scalar_query, self.refine_expr) {
714 (Some(_), Some(_)) => None,
715 (Some(scalar_query), None) => {
716 if scalar_query.needs_recheck() {
717 return None;
718 }
719 Some(Self {
720 scalar_query: Some(ScalarIndexExpr::Not(Box::new(scalar_query))),
721 refine_expr: None,
722 })
723 }
724 (None, Some(refine_expr)) => Some(Self {
725 scalar_query: None,
726 refine_expr: Some(Expr::Not(Box::new(refine_expr))),
727 }),
728 (None, None) => panic!("Empty node should not occur"),
729 }
730 }
731
732 fn and(self, other: Self) -> Self {
737 let scalar_query = match (self.scalar_query, other.scalar_query) {
738 (Some(scalar_query), Some(other_scalar_query)) => Some(ScalarIndexExpr::And(
739 Box::new(scalar_query),
740 Box::new(other_scalar_query),
741 )),
742 (Some(scalar_query), None) => Some(scalar_query),
743 (None, Some(scalar_query)) => Some(scalar_query),
744 (None, None) => None,
745 };
746 let refine_expr = match (self.refine_expr, other.refine_expr) {
747 (Some(refine_expr), Some(other_refine_expr)) => {
748 Some(refine_expr.and(other_refine_expr))
749 }
750 (Some(refine_expr), None) => Some(refine_expr),
751 (None, Some(refine_expr)) => Some(refine_expr),
752 (None, None) => None,
753 };
754 Self {
755 scalar_query,
756 refine_expr,
757 }
758 }
759
760 fn maybe_or(self, other: Self) -> Option<Self> {
767 let scalar_query = self.scalar_query?;
770 let other_scalar_query = other.scalar_query?;
771 let scalar_query = Some(ScalarIndexExpr::Or(
772 Box::new(scalar_query),
773 Box::new(other_scalar_query),
774 ));
775
776 let refine_expr = match (self.refine_expr, other.refine_expr) {
777 (Some(_), Some(_)) => {
787 return None;
788 }
789 (Some(_), None) => {
790 return None;
791 }
792 (None, Some(_)) => {
793 return None;
794 }
795 (None, None) => None,
796 };
797 Some(Self {
798 scalar_query,
799 refine_expr,
800 })
801 }
802
803 fn refine(self, expr: Expr) -> Self {
804 match self.refine_expr {
805 Some(refine_expr) => Self {
806 scalar_query: self.scalar_query,
807 refine_expr: Some(refine_expr.and(expr)),
808 },
809 None => Self {
810 scalar_query: self.scalar_query,
811 refine_expr: Some(expr),
812 },
813 }
814 }
815}
816
817#[async_trait]
821pub trait ScalarIndexLoader: Send + Sync {
822 async fn load_index(
824 &self,
825 column: &str,
826 index_name: &str,
827 metrics: &dyn MetricsCollector,
828 ) -> Result<Arc<dyn ScalarIndex>>;
829}
830
831#[derive(Debug, Clone)]
833pub struct ScalarIndexSearch {
834 pub column: String,
836 pub index_name: String,
838 pub query: Arc<dyn AnyQuery>,
840 pub needs_recheck: bool,
842}
843
844impl PartialEq for ScalarIndexSearch {
845 fn eq(&self, other: &Self) -> bool {
846 self.column == other.column
847 && self.index_name == other.index_name
848 && self.query.as_ref().eq(other.query.as_ref())
849 }
850}
851
852#[derive(Debug, Clone)]
857pub enum ScalarIndexExpr {
858 Not(Box<ScalarIndexExpr>),
859 And(Box<ScalarIndexExpr>, Box<ScalarIndexExpr>),
860 Or(Box<ScalarIndexExpr>, Box<ScalarIndexExpr>),
861 Query(ScalarIndexSearch),
862}
863
864impl PartialEq for ScalarIndexExpr {
865 fn eq(&self, other: &Self) -> bool {
866 match (self, other) {
867 (Self::Not(l0), Self::Not(r0)) => l0 == r0,
868 (Self::And(l0, l1), Self::And(r0, r1)) => l0 == r0 && l1 == r1,
869 (Self::Or(l0, l1), Self::Or(r0, r1)) => l0 == r0 && l1 == r1,
870 (Self::Query(l_search), Self::Query(r_search)) => l_search == r_search,
871 _ => false,
872 }
873 }
874}
875
876impl std::fmt::Display for ScalarIndexExpr {
877 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
878 match self {
879 Self::Not(inner) => write!(f, "NOT({})", inner),
880 Self::And(lhs, rhs) => write!(f, "AND({},{})", lhs, rhs),
881 Self::Or(lhs, rhs) => write!(f, "OR({},{})", lhs, rhs),
882 Self::Query(search) => write!(
883 f,
884 "[{}]@{}",
885 search.query.format(&search.column),
886 search.index_name
887 ),
888 }
889 }
890}
891
892pub static INDEX_EXPR_RESULT_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
898 Arc::new(Schema::new(vec![
899 Field::new("result".to_string(), DataType::Binary, true),
900 Field::new("discriminant".to_string(), DataType::UInt32, true),
901 Field::new("fragments_covered".to_string(), DataType::Binary, true),
902 ]))
903});
904
905#[derive(Debug)]
906pub enum IndexExprResult {
907 Exact(RowIdMask),
909 AtMost(RowIdMask),
913 AtLeast(RowIdMask),
917}
918
919impl IndexExprResult {
920 pub fn row_id_mask(&self) -> &RowIdMask {
921 match self {
922 Self::Exact(mask) => mask,
923 Self::AtMost(mask) => mask,
924 Self::AtLeast(mask) => mask,
925 }
926 }
927
928 pub fn discriminant(&self) -> u32 {
929 match self {
930 Self::Exact(_) => 0,
931 Self::AtMost(_) => 1,
932 Self::AtLeast(_) => 2,
933 }
934 }
935
936 pub fn from_parts(mask: RowIdMask, discriminant: u32) -> Result<Self> {
937 match discriminant {
938 0 => Ok(Self::Exact(mask)),
939 1 => Ok(Self::AtMost(mask)),
940 2 => Ok(Self::AtLeast(mask)),
941 _ => Err(Error::InvalidInput {
942 source: format!("Invalid IndexExprResult discriminant: {}", discriminant).into(),
943 location: location!(),
944 }),
945 }
946 }
947
948 #[instrument(skip_all)]
949 pub fn serialize_to_arrow(
950 &self,
951 fragments_covered_by_result: &RoaringBitmap,
952 ) -> Result<RecordBatch> {
953 let row_id_mask = self.row_id_mask();
954 let row_id_mask_arr = row_id_mask.into_arrow()?;
955 let discriminant = self.discriminant();
956 let discriminant_arr =
957 Arc::new(UInt32Array::from(vec![discriminant, discriminant])) as Arc<dyn Array>;
958 let mut fragments_covered_builder = BinaryBuilder::new();
959 let fragments_covered_bytes_len = fragments_covered_by_result.serialized_size();
960 let mut fragments_covered_bytes = Vec::with_capacity(fragments_covered_bytes_len);
961 fragments_covered_by_result.serialize_into(&mut fragments_covered_bytes)?;
962 fragments_covered_builder.append_value(fragments_covered_bytes);
963 fragments_covered_builder.append_null();
964 let fragments_covered_arr = Arc::new(fragments_covered_builder.finish()) as Arc<dyn Array>;
965 Ok(RecordBatch::try_new(
966 INDEX_EXPR_RESULT_SCHEMA.clone(),
967 vec![
968 Arc::new(row_id_mask_arr),
969 Arc::new(discriminant_arr),
970 Arc::new(fragments_covered_arr),
971 ],
972 )?)
973 }
974}
975
976impl ScalarIndexExpr {
977 #[async_recursion]
984 #[instrument(level = "debug", skip_all)]
985 pub async fn evaluate(
986 &self,
987 index_loader: &dyn ScalarIndexLoader,
988 metrics: &dyn MetricsCollector,
989 ) -> Result<IndexExprResult> {
990 match self {
991 Self::Not(inner) => {
992 let result = inner.evaluate(index_loader, metrics).await?;
993 match result {
994 IndexExprResult::Exact(mask) => Ok(IndexExprResult::Exact(!mask)),
995 IndexExprResult::AtMost(mask) => Ok(IndexExprResult::AtLeast(!mask)),
996 IndexExprResult::AtLeast(mask) => Ok(IndexExprResult::AtMost(!mask)),
997 }
998 }
999 Self::And(lhs, rhs) => {
1000 let lhs_result = lhs.evaluate(index_loader, metrics);
1001 let rhs_result = rhs.evaluate(index_loader, metrics);
1002 let (lhs_result, rhs_result) = join!(lhs_result, rhs_result);
1003 match (lhs_result?, rhs_result?) {
1004 (IndexExprResult::Exact(lhs), IndexExprResult::Exact(rhs)) => {
1005 Ok(IndexExprResult::Exact(lhs & rhs))
1006 }
1007 (IndexExprResult::Exact(lhs), IndexExprResult::AtMost(rhs))
1008 | (IndexExprResult::AtMost(lhs), IndexExprResult::Exact(rhs)) => {
1009 Ok(IndexExprResult::AtMost(lhs & rhs))
1010 }
1011 (IndexExprResult::Exact(lhs), IndexExprResult::AtLeast(_)) => {
1012 Ok(IndexExprResult::AtMost(lhs))
1016 }
1017 (IndexExprResult::AtLeast(_), IndexExprResult::Exact(rhs)) => {
1018 Ok(IndexExprResult::AtMost(rhs))
1020 }
1021 (IndexExprResult::AtMost(lhs), IndexExprResult::AtMost(rhs)) => {
1022 Ok(IndexExprResult::AtMost(lhs & rhs))
1023 }
1024 (IndexExprResult::AtLeast(lhs), IndexExprResult::AtLeast(rhs)) => {
1025 Ok(IndexExprResult::AtLeast(lhs & rhs))
1026 }
1027 (IndexExprResult::AtLeast(_), IndexExprResult::AtMost(rhs)) => {
1028 Ok(IndexExprResult::AtMost(rhs))
1029 }
1030 (IndexExprResult::AtMost(lhs), IndexExprResult::AtLeast(_)) => {
1031 Ok(IndexExprResult::AtMost(lhs))
1032 }
1033 }
1034 }
1035 Self::Or(lhs, rhs) => {
1036 let lhs_result = lhs.evaluate(index_loader, metrics);
1037 let rhs_result = rhs.evaluate(index_loader, metrics);
1038 let (lhs_result, rhs_result) = join!(lhs_result, rhs_result);
1039 match (lhs_result?, rhs_result?) {
1040 (IndexExprResult::Exact(lhs), IndexExprResult::Exact(rhs)) => {
1041 Ok(IndexExprResult::Exact(lhs | rhs))
1042 }
1043 (IndexExprResult::Exact(lhs), IndexExprResult::AtMost(rhs))
1044 | (IndexExprResult::AtMost(lhs), IndexExprResult::Exact(rhs)) => {
1045 Ok(IndexExprResult::AtMost(lhs | rhs))
1049 }
1050 (IndexExprResult::Exact(lhs), IndexExprResult::AtLeast(rhs)) => {
1051 Ok(IndexExprResult::AtLeast(lhs | rhs))
1052 }
1053 (IndexExprResult::AtLeast(lhs), IndexExprResult::Exact(rhs)) => {
1054 Ok(IndexExprResult::AtLeast(lhs | rhs))
1055 }
1056 (IndexExprResult::AtMost(lhs), IndexExprResult::AtMost(rhs)) => {
1057 Ok(IndexExprResult::AtMost(lhs | rhs))
1058 }
1059 (IndexExprResult::AtLeast(lhs), IndexExprResult::AtLeast(rhs)) => {
1060 Ok(IndexExprResult::AtLeast(lhs | rhs))
1061 }
1062 (IndexExprResult::AtLeast(lhs), IndexExprResult::AtMost(_)) => {
1063 Ok(IndexExprResult::AtLeast(lhs))
1064 }
1065 (IndexExprResult::AtMost(_), IndexExprResult::AtLeast(rhs)) => {
1066 Ok(IndexExprResult::AtLeast(rhs))
1067 }
1068 }
1069 }
1070 Self::Query(search) => {
1071 let index = index_loader
1072 .load_index(&search.column, &search.index_name, metrics)
1073 .await?;
1074 let search_result = index.search(search.query.as_ref(), metrics).await?;
1075 match search_result {
1076 SearchResult::Exact(matching_row_ids) => {
1077 Ok(IndexExprResult::Exact(RowIdMask {
1078 block_list: None,
1079 allow_list: Some(matching_row_ids),
1080 }))
1081 }
1082 SearchResult::AtMost(row_ids) => Ok(IndexExprResult::AtMost(RowIdMask {
1083 block_list: None,
1084 allow_list: Some(row_ids),
1085 })),
1086 SearchResult::AtLeast(row_ids) => Ok(IndexExprResult::AtLeast(RowIdMask {
1087 block_list: None,
1088 allow_list: Some(row_ids),
1089 })),
1090 }
1091 }
1092 }
1093 }
1094
1095 pub fn to_expr(&self) -> Expr {
1096 match self {
1097 Self::Not(inner) => Expr::Not(inner.to_expr().into()),
1098 Self::And(lhs, rhs) => {
1099 let lhs = lhs.to_expr();
1100 let rhs = rhs.to_expr();
1101 lhs.and(rhs)
1102 }
1103 Self::Or(lhs, rhs) => {
1104 let lhs = lhs.to_expr();
1105 let rhs = rhs.to_expr();
1106 lhs.or(rhs)
1107 }
1108 Self::Query(search) => search.query.to_expr(search.column.clone()),
1109 }
1110 }
1111
1112 pub fn needs_recheck(&self) -> bool {
1113 match self {
1114 Self::Not(inner) => inner.needs_recheck(),
1115 Self::And(lhs, rhs) | Self::Or(lhs, rhs) => lhs.needs_recheck() || rhs.needs_recheck(),
1116 Self::Query(search) => search.needs_recheck,
1117 }
1118 }
1119}
1120
1121fn maybe_column(expr: &Expr) -> Option<&str> {
1123 match expr {
1124 Expr::Column(col) => Some(&col.name),
1125 _ => None,
1126 }
1127}
1128
1129fn extract_nested_column_path(expr: &Expr) -> Option<String> {
1132 let mut current_expr = expr;
1133 let mut parts = Vec::new();
1134
1135 loop {
1137 match current_expr {
1138 Expr::ScalarFunction(udf) if udf.name() == "get_field" => {
1139 if udf.args.len() != 2 {
1140 return None;
1141 }
1142 if let Expr::Literal(ScalarValue::Utf8(Some(field_name)), _) = &udf.args[1] {
1145 parts.push(field_name.clone());
1146 } else {
1147 return None;
1148 }
1149 current_expr = &udf.args[0];
1151 }
1152 Expr::Column(col) => {
1153 parts.push(col.name.clone());
1155 break;
1156 }
1157 _ => {
1158 return None;
1159 }
1160 }
1161 }
1162
1163 parts.reverse();
1165
1166 let field_refs: Vec<&str> = parts.iter().map(|s| s.as_str()).collect();
1168 Some(lance_core::datatypes::format_field_path(&field_refs))
1169}
1170
1171fn maybe_indexed_column<'b>(
1178 expr: &Expr,
1179 index_info: &'b dyn IndexInformationProvider,
1180) -> Option<(String, DataType, &'b dyn ScalarQueryParser)> {
1181 if let Some(nested_path) = extract_nested_column_path(expr) {
1183 if let Some((data_type, parser)) = index_info.get_index(&nested_path) {
1184 if let Some(data_type) = parser.is_valid_reference(expr, data_type) {
1185 return Some((nested_path, data_type, parser));
1186 }
1187 }
1188 }
1189
1190 match expr {
1191 Expr::Column(col) => {
1192 let col = col.name.as_str();
1193 let (data_type, parser) = index_info.get_index(col)?;
1194 if let Some(data_type) = parser.is_valid_reference(expr, data_type) {
1195 Some((col.to_string(), data_type, parser))
1196 } else {
1197 None
1198 }
1199 }
1200 Expr::ScalarFunction(udf) => {
1201 if udf.args.is_empty() {
1202 return None;
1203 }
1204 let col = maybe_column(&udf.args[0])?;
1206 let (data_type, parser) = index_info.get_index(col)?;
1207 if let Some(data_type) = parser.is_valid_reference(expr, data_type) {
1208 Some((col.to_string(), data_type, parser))
1209 } else {
1210 None
1211 }
1212 }
1213 _ => None,
1214 }
1215}
1216
1217fn maybe_scalar(expr: &Expr, expected_type: &DataType) -> Option<ScalarValue> {
1219 match expr {
1220 Expr::Literal(value, _) => safe_coerce_scalar(value, expected_type),
1221 Expr::Cast(cast) => match cast.expr.as_ref() {
1229 Expr::Literal(value, _) => {
1230 let casted = value.cast_to(&cast.data_type).ok()?;
1231 safe_coerce_scalar(&casted, expected_type)
1232 }
1233 _ => None,
1234 },
1235 Expr::ScalarFunction(scalar_function) => {
1236 if scalar_function.name() == "arrow_cast" {
1237 if scalar_function.args.len() != 2 {
1238 return None;
1239 }
1240 match (&scalar_function.args[0], &scalar_function.args[1]) {
1241 (Expr::Literal(value, _), Expr::Literal(cast_type, _)) => {
1242 let target_type = scalar_function
1243 .func
1244 .return_field_from_args(ReturnFieldArgs {
1245 arg_fields: &[
1246 Arc::new(Field::new("expression", value.data_type(), false)),
1247 Arc::new(Field::new("datatype", cast_type.data_type(), false)),
1248 ],
1249 scalar_arguments: &[Some(value), Some(cast_type)],
1250 })
1251 .ok()?;
1252 let casted = value.cast_to(target_type.data_type()).ok()?;
1253 safe_coerce_scalar(&casted, expected_type)
1254 }
1255 _ => None,
1256 }
1257 } else {
1258 None
1259 }
1260 }
1261 _ => None,
1262 }
1263}
1264
1265fn maybe_scalar_list(exprs: &Vec<Expr>, expected_type: &DataType) -> Option<Vec<ScalarValue>> {
1267 let mut scalar_values = Vec::with_capacity(exprs.len());
1268 for expr in exprs {
1269 match maybe_scalar(expr, expected_type) {
1270 Some(scalar_val) => {
1271 scalar_values.push(scalar_val);
1272 }
1273 None => {
1274 return None;
1275 }
1276 }
1277 }
1278 Some(scalar_values)
1279}
1280
1281fn visit_between(
1282 between: &Between,
1283 index_info: &dyn IndexInformationProvider,
1284) -> Option<IndexedExpression> {
1285 let (column, col_type, query_parser) = maybe_indexed_column(&between.expr, index_info)?;
1286 let low = maybe_scalar(&between.low, &col_type)?;
1287 let high = maybe_scalar(&between.high, &col_type)?;
1288
1289 let indexed_expr =
1290 query_parser.visit_between(&column, &Bound::Included(low), &Bound::Included(high))?;
1291
1292 if between.negated {
1293 indexed_expr.maybe_not()
1294 } else {
1295 Some(indexed_expr)
1296 }
1297}
1298
1299fn visit_in_list(
1300 in_list: &InList,
1301 index_info: &dyn IndexInformationProvider,
1302) -> Option<IndexedExpression> {
1303 let (column, col_type, query_parser) = maybe_indexed_column(&in_list.expr, index_info)?;
1304 let values = maybe_scalar_list(&in_list.list, &col_type)?;
1305
1306 let indexed_expr = query_parser.visit_in_list(&column, &values)?;
1307
1308 if in_list.negated {
1309 indexed_expr.maybe_not()
1310 } else {
1311 Some(indexed_expr)
1312 }
1313}
1314
1315fn visit_is_bool(
1316 expr: &Expr,
1317 index_info: &dyn IndexInformationProvider,
1318 value: bool,
1319) -> Option<IndexedExpression> {
1320 let (column, col_type, query_parser) = maybe_indexed_column(expr, index_info)?;
1321 if col_type != DataType::Boolean {
1322 None
1323 } else {
1324 query_parser.visit_is_bool(&column, value)
1325 }
1326}
1327
1328fn visit_column(
1330 col: &Expr,
1331 index_info: &dyn IndexInformationProvider,
1332) -> Option<IndexedExpression> {
1333 let (column, col_type, query_parser) = maybe_indexed_column(col, index_info)?;
1334 if col_type != DataType::Boolean {
1335 None
1336 } else {
1337 query_parser.visit_is_bool(&column, true)
1338 }
1339}
1340
1341fn visit_is_null(
1342 expr: &Expr,
1343 index_info: &dyn IndexInformationProvider,
1344 negated: bool,
1345) -> Option<IndexedExpression> {
1346 let (column, _, query_parser) = maybe_indexed_column(expr, index_info)?;
1347 let indexed_expr = query_parser.visit_is_null(&column)?;
1348 if negated {
1349 indexed_expr.maybe_not()
1350 } else {
1351 Some(indexed_expr)
1352 }
1353}
1354
1355fn visit_not(
1356 expr: &Expr,
1357 index_info: &dyn IndexInformationProvider,
1358 depth: usize,
1359) -> Result<Option<IndexedExpression>> {
1360 let node = visit_node(expr, index_info, depth + 1)?;
1361 Ok(node.and_then(|node| node.maybe_not()))
1362}
1363
1364fn visit_comparison(
1365 expr: &BinaryExpr,
1366 index_info: &dyn IndexInformationProvider,
1367) -> Option<IndexedExpression> {
1368 let left_col = maybe_indexed_column(&expr.left, index_info);
1369 if let Some((column, col_type, query_parser)) = left_col {
1370 let scalar = maybe_scalar(&expr.right, &col_type)?;
1371 query_parser.visit_comparison(&column, &scalar, &expr.op)
1372 } else {
1373 None
1376 }
1377}
1378
1379fn maybe_range(
1380 expr: &BinaryExpr,
1381 index_info: &dyn IndexInformationProvider,
1382) -> Option<IndexedExpression> {
1383 let left_expr = match expr.left.as_ref() {
1384 Expr::BinaryExpr(binary_expr) => Some(binary_expr),
1385 _ => None,
1386 }?;
1387 let right_expr = match expr.right.as_ref() {
1388 Expr::BinaryExpr(binary_expr) => Some(binary_expr),
1389 _ => None,
1390 }?;
1391
1392 let (left_col, dt, parser) = maybe_indexed_column(&left_expr.left, index_info)?;
1393 let right_col = maybe_column(&right_expr.left)?;
1394
1395 if left_col != right_col {
1396 return None;
1397 }
1398
1399 let left_value = maybe_scalar(&left_expr.right, &dt)?;
1400 let right_value = maybe_scalar(&right_expr.right, &dt)?;
1401
1402 let (low, high) = match (left_expr.op, right_expr.op) {
1403 (Operator::GtEq, Operator::LtEq) => {
1405 (Bound::Included(left_value), Bound::Included(right_value))
1406 }
1407 (Operator::GtEq, Operator::Lt) => {
1409 (Bound::Included(left_value), Bound::Excluded(right_value))
1410 }
1411 (Operator::Gt, Operator::LtEq) => {
1413 (Bound::Excluded(left_value), Bound::Included(right_value))
1414 }
1415 (Operator::Gt, Operator::Lt) => (Bound::Excluded(left_value), Bound::Excluded(right_value)),
1417 (Operator::LtEq, Operator::GtEq) => {
1419 (Bound::Included(right_value), Bound::Included(left_value))
1420 }
1421 (Operator::LtEq, Operator::Gt) => {
1423 (Bound::Included(right_value), Bound::Excluded(left_value))
1424 }
1425 (Operator::Lt, Operator::GtEq) => {
1427 (Bound::Excluded(right_value), Bound::Included(left_value))
1428 }
1429 (Operator::Lt, Operator::Gt) => (Bound::Excluded(right_value), Bound::Excluded(left_value)),
1431 _ => return None,
1432 };
1433
1434 parser.visit_between(&left_col, &low, &high)
1435}
1436
1437fn visit_and(
1438 expr: &BinaryExpr,
1439 index_info: &dyn IndexInformationProvider,
1440 depth: usize,
1441) -> Result<Option<IndexedExpression>> {
1442 if let Some(range_expr) = maybe_range(expr, index_info) {
1450 return Ok(Some(range_expr));
1451 }
1452
1453 let left = visit_node(&expr.left, index_info, depth + 1)?;
1454 let right = visit_node(&expr.right, index_info, depth + 1)?;
1455 Ok(match (left, right) {
1456 (Some(left), Some(right)) => Some(left.and(right)),
1457 (Some(left), None) => Some(left.refine((*expr.right).clone())),
1458 (None, Some(right)) => Some(right.refine((*expr.left).clone())),
1459 (None, None) => None,
1460 })
1461}
1462
1463fn visit_or(
1464 expr: &BinaryExpr,
1465 index_info: &dyn IndexInformationProvider,
1466 depth: usize,
1467) -> Result<Option<IndexedExpression>> {
1468 let left = visit_node(&expr.left, index_info, depth + 1)?;
1469 let right = visit_node(&expr.right, index_info, depth + 1)?;
1470 Ok(match (left, right) {
1471 (Some(left), Some(right)) => left.maybe_or(right),
1472 (Some(_), None) => None,
1478 (None, Some(_)) => None,
1479 (None, None) => None,
1480 })
1481}
1482
1483fn visit_binary_expr(
1484 expr: &BinaryExpr,
1485 index_info: &dyn IndexInformationProvider,
1486 depth: usize,
1487) -> Result<Option<IndexedExpression>> {
1488 match &expr.op {
1489 Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq | Operator::Eq => {
1490 Ok(visit_comparison(expr, index_info))
1491 }
1492 Operator::NotEq => Ok(visit_comparison(expr, index_info).and_then(|node| node.maybe_not())),
1494 Operator::And => visit_and(expr, index_info, depth),
1495 Operator::Or => visit_or(expr, index_info, depth),
1496 _ => Ok(None),
1497 }
1498}
1499
1500fn visit_scalar_fn(
1501 scalar_fn: &ScalarFunction,
1502 index_info: &dyn IndexInformationProvider,
1503) -> Option<IndexedExpression> {
1504 if scalar_fn.args.is_empty() {
1505 return None;
1506 }
1507 let (col, data_type, query_parser) = maybe_indexed_column(&scalar_fn.args[0], index_info)?;
1508 query_parser.visit_scalar_function(&col, &data_type, &scalar_fn.func, &scalar_fn.args)
1509}
1510
1511fn visit_node(
1512 expr: &Expr,
1513 index_info: &dyn IndexInformationProvider,
1514 depth: usize,
1515) -> Result<Option<IndexedExpression>> {
1516 if depth >= MAX_DEPTH {
1517 return Err(Error::invalid_input(
1518 format!(
1519 "the filter expression is too long, lance limit the max number of conditions to {}",
1520 MAX_DEPTH
1521 ),
1522 location!(),
1523 ));
1524 }
1525 match expr {
1526 Expr::Between(between) => Ok(visit_between(between, index_info)),
1527 Expr::Column(_) => Ok(visit_column(expr, index_info)),
1528 Expr::InList(in_list) => Ok(visit_in_list(in_list, index_info)),
1529 Expr::IsFalse(expr) => Ok(visit_is_bool(expr.as_ref(), index_info, false)),
1530 Expr::IsTrue(expr) => Ok(visit_is_bool(expr.as_ref(), index_info, true)),
1531 Expr::IsNull(expr) => Ok(visit_is_null(expr.as_ref(), index_info, false)),
1532 Expr::IsNotNull(expr) => Ok(visit_is_null(expr.as_ref(), index_info, true)),
1533 Expr::Not(expr) => visit_not(expr.as_ref(), index_info, depth),
1534 Expr::BinaryExpr(binary_expr) => visit_binary_expr(binary_expr, index_info, depth),
1535 Expr::ScalarFunction(scalar_fn) => Ok(visit_scalar_fn(scalar_fn, index_info)),
1536 _ => Ok(None),
1537 }
1538}
1539
1540pub trait IndexInformationProvider {
1542 fn get_index(&self, col: &str) -> Option<(&DataType, &dyn ScalarQueryParser)>;
1545}
1546
1547pub fn apply_scalar_indices(
1550 expr: Expr,
1551 index_info: &dyn IndexInformationProvider,
1552) -> Result<IndexedExpression> {
1553 Ok(visit_node(&expr, index_info, 0)?.unwrap_or(IndexedExpression::refine_only(expr)))
1554}
1555
1556#[derive(Clone, Default, Debug)]
1557pub struct FilterPlan {
1558 pub index_query: Option<ScalarIndexExpr>,
1559 pub skip_recheck: bool,
1561 pub refine_expr: Option<Expr>,
1562 pub full_expr: Option<Expr>,
1563}
1564
1565impl FilterPlan {
1566 pub fn empty() -> Self {
1567 Self {
1568 index_query: None,
1569 skip_recheck: true,
1570 refine_expr: None,
1571 full_expr: None,
1572 }
1573 }
1574
1575 pub fn new_refine_only(expr: Expr) -> Self {
1576 Self {
1577 index_query: None,
1578 skip_recheck: true,
1579 refine_expr: Some(expr.clone()),
1580 full_expr: Some(expr),
1581 }
1582 }
1583
1584 pub fn is_empty(&self) -> bool {
1585 self.refine_expr.is_none() && self.index_query.is_none()
1586 }
1587
1588 pub fn all_columns(&self) -> Vec<String> {
1589 self.full_expr
1590 .as_ref()
1591 .map(Planner::column_names_in_expr)
1592 .unwrap_or_default()
1593 }
1594
1595 pub fn refine_columns(&self) -> Vec<String> {
1596 self.refine_expr
1597 .as_ref()
1598 .map(Planner::column_names_in_expr)
1599 .unwrap_or_default()
1600 }
1601
1602 pub fn has_refine(&self) -> bool {
1604 self.refine_expr.is_some()
1605 }
1606
1607 pub fn has_index_query(&self) -> bool {
1609 self.index_query.is_some()
1610 }
1611
1612 pub fn has_any_filter(&self) -> bool {
1613 self.refine_expr.is_some() || self.index_query.is_some()
1614 }
1615
1616 pub fn make_refine_only(&mut self) {
1617 self.index_query = None;
1618 self.refine_expr = self.full_expr.clone();
1619 }
1620
1621 pub fn is_exact_index_search(&self) -> bool {
1623 self.index_query.is_some() && self.refine_expr.is_none() && self.skip_recheck
1624 }
1625}
1626
1627pub trait PlannerIndexExt {
1628 fn create_filter_plan(
1635 &self,
1636 filter: Expr,
1637 index_info: &dyn IndexInformationProvider,
1638 use_scalar_index: bool,
1639 ) -> Result<FilterPlan>;
1640}
1641
1642impl PlannerIndexExt for Planner {
1643 fn create_filter_plan(
1644 &self,
1645 filter: Expr,
1646 index_info: &dyn IndexInformationProvider,
1647 use_scalar_index: bool,
1648 ) -> Result<FilterPlan> {
1649 let logical_expr = self.optimize_expr(filter)?;
1650 if use_scalar_index {
1651 let indexed_expr = apply_scalar_indices(logical_expr.clone(), index_info)?;
1652 let mut skip_recheck = false;
1653 if let Some(scalar_query) = indexed_expr.scalar_query.as_ref() {
1654 skip_recheck = !scalar_query.needs_recheck();
1655 }
1656 Ok(FilterPlan {
1657 index_query: indexed_expr.scalar_query,
1658 refine_expr: indexed_expr.refine_expr,
1659 full_expr: Some(logical_expr),
1660 skip_recheck,
1661 })
1662 } else {
1663 Ok(FilterPlan {
1664 index_query: None,
1665 skip_recheck: true,
1666 refine_expr: Some(logical_expr.clone()),
1667 full_expr: Some(logical_expr),
1668 })
1669 }
1670 }
1671}
1672
1673#[cfg(test)]
1674mod tests {
1675 use std::collections::HashMap;
1676
1677 use arrow_schema::{Field, Schema};
1678 use chrono::Utc;
1679 use datafusion_common::{Column, DFSchema};
1680 use datafusion_expr::execution_props::ExecutionProps;
1681 use datafusion_expr::simplify::SimplifyContext;
1682 use lance_datafusion::exec::{get_session_context, LanceExecutionOptions};
1683
1684 use crate::scalar::json::{JsonQuery, JsonQueryParser};
1685
1686 use super::*;
1687
1688 struct ColInfo {
1689 data_type: DataType,
1690 parser: Box<dyn ScalarQueryParser>,
1691 }
1692
1693 impl ColInfo {
1694 fn new(data_type: DataType, parser: Box<dyn ScalarQueryParser>) -> Self {
1695 Self { data_type, parser }
1696 }
1697 }
1698
1699 struct MockIndexInfoProvider {
1700 indexed_columns: HashMap<String, ColInfo>,
1701 }
1702
1703 impl MockIndexInfoProvider {
1704 fn new(indexed_columns: Vec<(&str, ColInfo)>) -> Self {
1705 Self {
1706 indexed_columns: HashMap::from_iter(
1707 indexed_columns
1708 .into_iter()
1709 .map(|(s, ty)| (s.to_string(), ty)),
1710 ),
1711 }
1712 }
1713 }
1714
1715 impl IndexInformationProvider for MockIndexInfoProvider {
1716 fn get_index(&self, col: &str) -> Option<(&DataType, &dyn ScalarQueryParser)> {
1717 self.indexed_columns
1718 .get(col)
1719 .map(|col_info| (&col_info.data_type, col_info.parser.as_ref()))
1720 }
1721 }
1722
1723 fn check(
1724 index_info: &dyn IndexInformationProvider,
1725 expr: &str,
1726 expected: Option<IndexedExpression>,
1727 optimize: bool,
1728 ) {
1729 let schema = Schema::new(vec![
1730 Field::new("color", DataType::Utf8, false),
1731 Field::new("size", DataType::Float32, false),
1732 Field::new("aisle", DataType::UInt32, false),
1733 Field::new("on_sale", DataType::Boolean, false),
1734 Field::new("price", DataType::Float32, false),
1735 Field::new("json", DataType::LargeBinary, false),
1736 ]);
1737 let df_schema: DFSchema = schema.try_into().unwrap();
1738
1739 let ctx = get_session_context(&LanceExecutionOptions::default());
1740 let state = ctx.state();
1741 let mut expr = state.create_logical_expr(expr, &df_schema).unwrap();
1742 if optimize {
1743 let props = ExecutionProps::new().with_query_execution_start_time(Utc::now());
1744 let simplify_context = SimplifyContext::new(&props).with_schema(Arc::new(df_schema));
1745 let simplifier =
1746 datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
1747 expr = simplifier.simplify(expr).unwrap();
1748 }
1749
1750 let actual = apply_scalar_indices(expr.clone(), index_info).unwrap();
1751 if let Some(expected) = expected {
1752 assert_eq!(actual, expected);
1753 } else {
1754 assert!(actual.scalar_query.is_none());
1755 assert_eq!(actual.refine_expr.unwrap(), expr);
1756 }
1757 }
1758
1759 fn check_no_index(index_info: &dyn IndexInformationProvider, expr: &str) {
1760 check(index_info, expr, None, false)
1761 }
1762
1763 fn check_simple(
1764 index_info: &dyn IndexInformationProvider,
1765 expr: &str,
1766 col: &str,
1767 query: impl AnyQuery,
1768 ) {
1769 check(
1770 index_info,
1771 expr,
1772 Some(IndexedExpression::index_query(
1773 col.to_string(),
1774 format!("{}_idx", col),
1775 Arc::new(query),
1776 )),
1777 false,
1778 )
1779 }
1780
1781 fn check_range(
1782 index_info: &dyn IndexInformationProvider,
1783 expr: &str,
1784 col: &str,
1785 query: SargableQuery,
1786 ) {
1787 check(
1788 index_info,
1789 expr,
1790 Some(IndexedExpression::index_query(
1791 col.to_string(),
1792 format!("{}_idx", col),
1793 Arc::new(query),
1794 )),
1795 true,
1796 )
1797 }
1798
1799 fn check_simple_negated(
1800 index_info: &dyn IndexInformationProvider,
1801 expr: &str,
1802 col: &str,
1803 query: SargableQuery,
1804 ) {
1805 check(
1806 index_info,
1807 expr,
1808 Some(
1809 IndexedExpression::index_query(
1810 col.to_string(),
1811 format!("{}_idx", col),
1812 Arc::new(query),
1813 )
1814 .maybe_not()
1815 .unwrap(),
1816 ),
1817 false,
1818 )
1819 }
1820
1821 #[test]
1822 fn test_expressions() {
1823 let index_info = MockIndexInfoProvider::new(vec![
1824 (
1825 "color",
1826 ColInfo::new(
1827 DataType::Utf8,
1828 Box::new(SargableQueryParser::new("color_idx".to_string(), false)),
1829 ),
1830 ),
1831 (
1832 "aisle",
1833 ColInfo::new(
1834 DataType::UInt32,
1835 Box::new(SargableQueryParser::new("aisle_idx".to_string(), false)),
1836 ),
1837 ),
1838 (
1839 "on_sale",
1840 ColInfo::new(
1841 DataType::Boolean,
1842 Box::new(SargableQueryParser::new("on_sale_idx".to_string(), false)),
1843 ),
1844 ),
1845 (
1846 "price",
1847 ColInfo::new(
1848 DataType::Float32,
1849 Box::new(SargableQueryParser::new("price_idx".to_string(), false)),
1850 ),
1851 ),
1852 (
1853 "json",
1854 ColInfo::new(
1855 DataType::LargeBinary,
1856 Box::new(JsonQueryParser::new(
1857 "$.name".to_string(),
1858 Box::new(SargableQueryParser::new("json_idx".to_string(), false)),
1859 )),
1860 ),
1861 ),
1862 ]);
1863
1864 check_simple(
1865 &index_info,
1866 "json_extract(json, '$.name') = 'foo'",
1867 "json",
1868 JsonQuery::new(
1869 Arc::new(SargableQuery::Equals(ScalarValue::Utf8(Some(
1870 "foo".to_string(),
1871 )))),
1872 "$.name".to_string(),
1873 ),
1874 );
1875
1876 check_no_index(&index_info, "size BETWEEN 5 AND 10");
1877 check_simple(
1879 &index_info,
1880 "aisle = arrow_cast(5, 'Int16')",
1881 "aisle",
1882 SargableQuery::Equals(ScalarValue::UInt32(Some(5))),
1883 );
1884 check_range(
1886 &index_info,
1887 "aisle BETWEEN 5 AND 10",
1888 "aisle",
1889 SargableQuery::Range(
1890 Bound::Included(ScalarValue::UInt32(Some(5))),
1891 Bound::Included(ScalarValue::UInt32(Some(10))),
1892 ),
1893 );
1894 check_range(
1895 &index_info,
1896 "aisle >= 5 AND aisle <= 10",
1897 "aisle",
1898 SargableQuery::Range(
1899 Bound::Included(ScalarValue::UInt32(Some(5))),
1900 Bound::Included(ScalarValue::UInt32(Some(10))),
1901 ),
1902 );
1903
1904 check_range(
1905 &index_info,
1906 "aisle <= 10 AND aisle >= 5",
1907 "aisle",
1908 SargableQuery::Range(
1909 Bound::Included(ScalarValue::UInt32(Some(5))),
1910 Bound::Included(ScalarValue::UInt32(Some(10))),
1911 ),
1912 );
1913
1914 check_range(
1915 &index_info,
1916 "5 <= aisle AND 10 >= aisle",
1917 "aisle",
1918 SargableQuery::Range(
1919 Bound::Included(ScalarValue::UInt32(Some(5))),
1920 Bound::Included(ScalarValue::UInt32(Some(10))),
1921 ),
1922 );
1923
1924 check_range(
1925 &index_info,
1926 "10 >= aisle AND 5 <= aisle",
1927 "aisle",
1928 SargableQuery::Range(
1929 Bound::Included(ScalarValue::UInt32(Some(5))),
1930 Bound::Included(ScalarValue::UInt32(Some(10))),
1931 ),
1932 );
1933 check_simple(
1934 &index_info,
1935 "on_sale IS TRUE",
1936 "on_sale",
1937 SargableQuery::Equals(ScalarValue::Boolean(Some(true))),
1938 );
1939 check_simple(
1940 &index_info,
1941 "on_sale",
1942 "on_sale",
1943 SargableQuery::Equals(ScalarValue::Boolean(Some(true))),
1944 );
1945 check_simple_negated(
1946 &index_info,
1947 "NOT on_sale",
1948 "on_sale",
1949 SargableQuery::Equals(ScalarValue::Boolean(Some(true))),
1950 );
1951 check_simple(
1952 &index_info,
1953 "on_sale IS FALSE",
1954 "on_sale",
1955 SargableQuery::Equals(ScalarValue::Boolean(Some(false))),
1956 );
1957 check_simple_negated(
1958 &index_info,
1959 "aisle NOT BETWEEN 5 AND 10",
1960 "aisle",
1961 SargableQuery::Range(
1962 Bound::Included(ScalarValue::UInt32(Some(5))),
1963 Bound::Included(ScalarValue::UInt32(Some(10))),
1964 ),
1965 );
1966 check_simple(
1968 &index_info,
1969 "aisle IN (5, 6, 7)",
1970 "aisle",
1971 SargableQuery::IsIn(vec![
1972 ScalarValue::UInt32(Some(5)),
1973 ScalarValue::UInt32(Some(6)),
1974 ScalarValue::UInt32(Some(7)),
1975 ]),
1976 );
1977 check_simple_negated(
1978 &index_info,
1979 "NOT aisle IN (5, 6, 7)",
1980 "aisle",
1981 SargableQuery::IsIn(vec![
1982 ScalarValue::UInt32(Some(5)),
1983 ScalarValue::UInt32(Some(6)),
1984 ScalarValue::UInt32(Some(7)),
1985 ]),
1986 );
1987 check_simple_negated(
1988 &index_info,
1989 "aisle NOT IN (5, 6, 7)",
1990 "aisle",
1991 SargableQuery::IsIn(vec![
1992 ScalarValue::UInt32(Some(5)),
1993 ScalarValue::UInt32(Some(6)),
1994 ScalarValue::UInt32(Some(7)),
1995 ]),
1996 );
1997 check_simple(
1998 &index_info,
1999 "aisle IN (5, 6, 7, 8, 9)",
2000 "aisle",
2001 SargableQuery::IsIn(vec![
2002 ScalarValue::UInt32(Some(5)),
2003 ScalarValue::UInt32(Some(6)),
2004 ScalarValue::UInt32(Some(7)),
2005 ScalarValue::UInt32(Some(8)),
2006 ScalarValue::UInt32(Some(9)),
2007 ]),
2008 );
2009 check_simple_negated(
2010 &index_info,
2011 "NOT aisle IN (5, 6, 7, 8, 9)",
2012 "aisle",
2013 SargableQuery::IsIn(vec![
2014 ScalarValue::UInt32(Some(5)),
2015 ScalarValue::UInt32(Some(6)),
2016 ScalarValue::UInt32(Some(7)),
2017 ScalarValue::UInt32(Some(8)),
2018 ScalarValue::UInt32(Some(9)),
2019 ]),
2020 );
2021 check_simple_negated(
2022 &index_info,
2023 "aisle NOT IN (5, 6, 7, 8, 9)",
2024 "aisle",
2025 SargableQuery::IsIn(vec![
2026 ScalarValue::UInt32(Some(5)),
2027 ScalarValue::UInt32(Some(6)),
2028 ScalarValue::UInt32(Some(7)),
2029 ScalarValue::UInt32(Some(8)),
2030 ScalarValue::UInt32(Some(9)),
2031 ]),
2032 );
2033 check_simple(
2034 &index_info,
2035 "on_sale is false",
2036 "on_sale",
2037 SargableQuery::Equals(ScalarValue::Boolean(Some(false))),
2038 );
2039 check_simple(
2040 &index_info,
2041 "on_sale is true",
2042 "on_sale",
2043 SargableQuery::Equals(ScalarValue::Boolean(Some(true))),
2044 );
2045 check_simple(
2046 &index_info,
2047 "aisle < 10",
2048 "aisle",
2049 SargableQuery::Range(
2050 Bound::Unbounded,
2051 Bound::Excluded(ScalarValue::UInt32(Some(10))),
2052 ),
2053 );
2054 check_simple(
2055 &index_info,
2056 "aisle <= 10",
2057 "aisle",
2058 SargableQuery::Range(
2059 Bound::Unbounded,
2060 Bound::Included(ScalarValue::UInt32(Some(10))),
2061 ),
2062 );
2063 check_simple(
2064 &index_info,
2065 "aisle > 10",
2066 "aisle",
2067 SargableQuery::Range(
2068 Bound::Excluded(ScalarValue::UInt32(Some(10))),
2069 Bound::Unbounded,
2070 ),
2071 );
2072 check_no_index(&index_info, "10 > aisle");
2076 check_simple(
2077 &index_info,
2078 "aisle >= 10",
2079 "aisle",
2080 SargableQuery::Range(
2081 Bound::Included(ScalarValue::UInt32(Some(10))),
2082 Bound::Unbounded,
2083 ),
2084 );
2085 check_simple(
2086 &index_info,
2087 "aisle = 10",
2088 "aisle",
2089 SargableQuery::Equals(ScalarValue::UInt32(Some(10))),
2090 );
2091 check_simple_negated(
2092 &index_info,
2093 "aisle <> 10",
2094 "aisle",
2095 SargableQuery::Equals(ScalarValue::UInt32(Some(10))),
2096 );
2097 let left = Box::new(ScalarIndexExpr::Query(ScalarIndexSearch {
2099 column: "aisle".to_string(),
2100 index_name: "aisle_idx".to_string(),
2101 query: Arc::new(SargableQuery::Equals(ScalarValue::UInt32(Some(10)))),
2102 needs_recheck: false,
2103 }));
2104 let right = Box::new(ScalarIndexExpr::Query(ScalarIndexSearch {
2105 column: "color".to_string(),
2106 index_name: "color_idx".to_string(),
2107 query: Arc::new(SargableQuery::Equals(ScalarValue::Utf8(Some(
2108 "blue".to_string(),
2109 )))),
2110 needs_recheck: false,
2111 }));
2112 check(
2113 &index_info,
2114 "aisle = 10 AND color = 'blue'",
2115 Some(IndexedExpression {
2116 scalar_query: Some(ScalarIndexExpr::And(left.clone(), right.clone())),
2117 refine_expr: None,
2118 }),
2119 false,
2120 );
2121 let refine = Expr::Column(Column::new_unqualified("size")).gt(datafusion_expr::lit(30_i64));
2123 check(
2124 &index_info,
2125 "aisle = 10 AND color = 'blue' AND size > 30",
2126 Some(IndexedExpression {
2127 scalar_query: Some(ScalarIndexExpr::And(left.clone(), right.clone())),
2128 refine_expr: Some(refine.clone()),
2129 }),
2130 false,
2131 );
2132 check(
2134 &index_info,
2135 "aisle = 10 OR color = 'blue'",
2136 Some(IndexedExpression {
2137 scalar_query: Some(ScalarIndexExpr::Or(left.clone(), right.clone())),
2138 refine_expr: None,
2139 }),
2140 false,
2141 );
2142 check_no_index(&index_info, "aisle = 10 OR color = 'blue' OR size > 30");
2144 check(
2146 &index_info,
2147 "(aisle = 10 OR color = 'blue') AND size > 30",
2148 Some(IndexedExpression {
2149 scalar_query: Some(ScalarIndexExpr::Or(left, right)),
2150 refine_expr: Some(refine),
2151 }),
2152 false,
2153 );
2154 check_no_index(
2158 &index_info,
2159 "(aisle = 10 AND size > 30) OR (color = 'blue' AND size > 20)",
2160 );
2161
2162 check_no_index(&index_info, "aisle + 3 < 10");
2164
2165 check_no_index(&index_info, "aisle IN (5, 6, NULL)");
2170 check_no_index(&index_info, "aisle = 5 OR aisle = 6 OR NULL");
2173 check_no_index(&index_info, "aisle IN (5, 6, 7, 8, NULL)");
2174 check_no_index(&index_info, "aisle = NULL");
2175 check_no_index(&index_info, "aisle BETWEEN 5 AND NULL");
2176 check_no_index(&index_info, "aisle BETWEEN NULL AND 10");
2177 }
2178}