1use std::{
5 ops::Bound,
6 sync::{Arc, LazyLock},
7};
8
9use arrow::array::BinaryBuilder;
10use arrow_array::{Array, RecordBatch, UInt32Array};
11use arrow_schema::{DataType, Field, Schema, SchemaRef};
12use async_recursion::async_recursion;
13use async_trait::async_trait;
14use datafusion_common::ScalarValue;
15use datafusion_expr::{
16 Between, BinaryExpr, Expr, Operator, ReturnFieldArgs, ScalarUDF,
17 expr::{InList, Like, ScalarFunction},
18};
19use tokio::try_join;
20
21use super::{
22 AnyQuery, BloomFilterQuery, LabelListQuery, MetricsCollector, SargableQuery, ScalarIndex,
23 SearchResult, TextQuery, TokenQuery,
24};
25#[cfg(feature = "geo")]
26use super::{GeoQuery, RelationQuery};
27use lance_core::{
28 Error, Result,
29 utils::mask::{NullableRowAddrMask, RowAddrMask},
30};
31use lance_datafusion::{expr::safe_coerce_scalar, planner::Planner};
32use roaring::RoaringBitmap;
33use tracing::instrument;
34
35const MAX_DEPTH: usize = 500;
36
37#[derive(Debug, PartialEq)]
65pub struct IndexedExpression {
66 pub scalar_query: Option<ScalarIndexExpr>,
68 pub refine_expr: Option<Expr>,
70}
71
72pub trait ScalarQueryParser: std::fmt::Debug + Send + Sync {
73 fn visit_between(
77 &self,
78 column: &str,
79 low: &Bound<ScalarValue>,
80 high: &Bound<ScalarValue>,
81 ) -> Option<IndexedExpression>;
82 fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option<IndexedExpression>;
86 fn visit_is_bool(&self, column: &str, value: bool) -> Option<IndexedExpression>;
90 fn visit_is_null(&self, column: &str) -> Option<IndexedExpression>;
94 fn visit_comparison(
98 &self,
99 column: &str,
100 value: &ScalarValue,
101 op: &Operator,
102 ) -> Option<IndexedExpression>;
103 fn visit_scalar_function(
108 &self,
109 column: &str,
110 data_type: &DataType,
111 func: &ScalarUDF,
112 args: &[Expr],
113 ) -> Option<IndexedExpression>;
114
115 fn visit_like(
130 &self,
131 _column: &str,
132 _like: &Like,
133 _pattern: &ScalarValue,
134 ) -> Option<IndexedExpression> {
135 None
136 }
137
138 fn is_valid_reference(&self, func: &Expr, data_type: &DataType) -> Option<DataType> {
162 match func {
163 Expr::Column(_) => Some(data_type.clone()),
164 _ => None,
165 }
166 }
167}
168
169#[derive(Debug)]
173pub struct MultiQueryParser {
174 parsers: Vec<Box<dyn ScalarQueryParser>>,
175}
176
177impl MultiQueryParser {
178 pub fn single(parser: Box<dyn ScalarQueryParser>) -> Self {
180 Self {
181 parsers: vec![parser],
182 }
183 }
184
185 pub fn add(&mut self, other: Box<dyn ScalarQueryParser>) {
187 self.parsers.push(other);
188 }
189}
190
191impl ScalarQueryParser for MultiQueryParser {
192 fn visit_between(
193 &self,
194 column: &str,
195 low: &Bound<ScalarValue>,
196 high: &Bound<ScalarValue>,
197 ) -> Option<IndexedExpression> {
198 self.parsers
199 .iter()
200 .find_map(|parser| parser.visit_between(column, low, high))
201 }
202 fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option<IndexedExpression> {
203 self.parsers
204 .iter()
205 .find_map(|parser| parser.visit_in_list(column, in_list))
206 }
207 fn visit_is_bool(&self, column: &str, value: bool) -> Option<IndexedExpression> {
208 self.parsers
209 .iter()
210 .find_map(|parser| parser.visit_is_bool(column, value))
211 }
212 fn visit_is_null(&self, column: &str) -> Option<IndexedExpression> {
213 self.parsers
214 .iter()
215 .find_map(|parser| parser.visit_is_null(column))
216 }
217 fn visit_comparison(
218 &self,
219 column: &str,
220 value: &ScalarValue,
221 op: &Operator,
222 ) -> Option<IndexedExpression> {
223 self.parsers
224 .iter()
225 .find_map(|parser| parser.visit_comparison(column, value, op))
226 }
227 fn visit_scalar_function(
228 &self,
229 column: &str,
230 data_type: &DataType,
231 func: &ScalarUDF,
232 args: &[Expr],
233 ) -> Option<IndexedExpression> {
234 self.parsers
235 .iter()
236 .find_map(|parser| parser.visit_scalar_function(column, data_type, func, args))
237 }
238 fn visit_like(
239 &self,
240 column: &str,
241 like: &Like,
242 pattern: &ScalarValue,
243 ) -> Option<IndexedExpression> {
244 self.parsers
245 .iter()
246 .find_map(|parser| parser.visit_like(column, like, pattern))
247 }
248 fn is_valid_reference(&self, func: &Expr, data_type: &DataType) -> Option<DataType> {
255 self.parsers
256 .iter()
257 .find_map(|parser| parser.is_valid_reference(func, data_type))
258 }
259}
260
261#[derive(Debug)]
263pub struct SargableQueryParser {
264 index_name: String,
265 index_type: String,
266 needs_recheck: bool,
267}
268
269impl SargableQueryParser {
270 pub fn new(index_name: String, index_type: String, needs_recheck: bool) -> Self {
271 Self {
272 index_name,
273 index_type,
274 needs_recheck,
275 }
276 }
277}
278
279impl ScalarQueryParser for SargableQueryParser {
280 fn is_valid_reference(&self, func: &Expr, data_type: &DataType) -> Option<DataType> {
281 match func {
282 Expr::Column(_) => Some(data_type.clone()),
283 Expr::ScalarFunction(udf) if udf.name() == "get_field" => Some(data_type.clone()),
285 _ => None,
286 }
287 }
288
289 fn visit_between(
290 &self,
291 column: &str,
292 low: &Bound<ScalarValue>,
293 high: &Bound<ScalarValue>,
294 ) -> Option<IndexedExpression> {
295 if let Bound::Included(val) | Bound::Excluded(val) = low
296 && val.is_null()
297 {
298 return None;
299 }
300 if let Bound::Included(val) | Bound::Excluded(val) = high
301 && val.is_null()
302 {
303 return None;
304 }
305 let query = SargableQuery::Range(low.clone(), high.clone());
306 Some(IndexedExpression::index_query_with_recheck(
307 column.to_string(),
308 self.index_name.clone(),
309 self.index_type.clone(),
310 Arc::new(query),
311 self.needs_recheck,
312 ))
313 }
314
315 fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option<IndexedExpression> {
316 if in_list.iter().any(|val| val.is_null()) {
317 return None;
318 }
319 let query = SargableQuery::IsIn(in_list.to_vec());
320 Some(IndexedExpression::index_query_with_recheck(
321 column.to_string(),
322 self.index_name.clone(),
323 self.index_type.clone(),
324 Arc::new(query),
325 self.needs_recheck,
326 ))
327 }
328
329 fn visit_is_bool(&self, column: &str, value: bool) -> Option<IndexedExpression> {
330 Some(IndexedExpression::index_query_with_recheck(
331 column.to_string(),
332 self.index_name.clone(),
333 self.index_type.clone(),
334 Arc::new(SargableQuery::Equals(ScalarValue::Boolean(Some(value)))),
335 self.needs_recheck,
336 ))
337 }
338
339 fn visit_is_null(&self, column: &str) -> Option<IndexedExpression> {
340 Some(IndexedExpression::index_query_with_recheck(
341 column.to_string(),
342 self.index_name.clone(),
343 self.index_type.clone(),
344 Arc::new(SargableQuery::IsNull()),
345 self.needs_recheck,
346 ))
347 }
348
349 fn visit_comparison(
350 &self,
351 column: &str,
352 value: &ScalarValue,
353 op: &Operator,
354 ) -> Option<IndexedExpression> {
355 if value.is_null() {
356 return None;
357 }
358 let query = match op {
359 Operator::Lt => SargableQuery::Range(Bound::Unbounded, Bound::Excluded(value.clone())),
360 Operator::LtEq => {
361 SargableQuery::Range(Bound::Unbounded, Bound::Included(value.clone()))
362 }
363 Operator::Gt => SargableQuery::Range(Bound::Excluded(value.clone()), Bound::Unbounded),
364 Operator::GtEq => {
365 SargableQuery::Range(Bound::Included(value.clone()), Bound::Unbounded)
366 }
367 Operator::Eq => SargableQuery::Equals(value.clone()),
368 Operator::NotEq => SargableQuery::Equals(value.clone()),
370 _ => unreachable!(),
371 };
372 Some(IndexedExpression::index_query_with_recheck(
373 column.to_string(),
374 self.index_name.clone(),
375 self.index_type.clone(),
376 Arc::new(query),
377 self.needs_recheck,
378 ))
379 }
380
381 fn visit_scalar_function(
382 &self,
383 column: &str,
384 _data_type: &DataType,
385 func: &ScalarUDF,
386 args: &[Expr],
387 ) -> Option<IndexedExpression> {
388 if func.name() == "starts_with" && args.len() == 2 {
390 let prefix = match &args[1] {
392 Expr::Literal(ScalarValue::Utf8(Some(s)), _) => ScalarValue::Utf8(Some(s.clone())),
393 Expr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => {
394 ScalarValue::LargeUtf8(Some(s.clone()))
395 }
396 _ => return None,
397 };
398
399 let query = SargableQuery::LikePrefix(prefix);
400 return Some(IndexedExpression::index_query_with_recheck(
401 column.to_string(),
402 self.index_name.clone(),
403 self.index_type.clone(),
404 Arc::new(query),
405 self.needs_recheck,
406 ));
407 }
408
409 None
410 }
411
412 fn visit_like(
413 &self,
414 column: &str,
415 like: &Like,
416 pattern: &ScalarValue,
417 ) -> Option<IndexedExpression> {
418 if like.case_insensitive {
420 return None;
421 }
422
423 let pattern_str = match pattern {
425 ScalarValue::Utf8(Some(s)) => s.as_str(),
426 ScalarValue::LargeUtf8(Some(s)) => s.as_str(),
427 _ => return None,
428 };
429
430 let (prefix, needs_refine) = extract_like_leading_prefix(pattern_str, like.escape_char)?;
432
433 let prefix_value = match pattern {
435 ScalarValue::Utf8(_) => ScalarValue::Utf8(Some(prefix)),
436 ScalarValue::LargeUtf8(_) => ScalarValue::LargeUtf8(Some(prefix)),
437 _ => return None,
438 };
439
440 let query = SargableQuery::LikePrefix(prefix_value);
441 let scalar_query = Some(ScalarIndexExpr::Query(ScalarIndexSearch {
442 column: column.to_string(),
443 index_name: self.index_name.clone(),
444 index_type: self.index_type.clone(),
445 query: Arc::new(query),
446 needs_recheck: self.needs_recheck,
447 }));
448
449 let refine_expr = if needs_refine {
451 Some(Expr::Like(like.clone()))
452 } else {
453 None
454 };
455
456 Some(IndexedExpression {
457 scalar_query,
458 refine_expr,
459 })
460 }
461}
462
463fn extract_like_leading_prefix(pattern: &str, escape_char: Option<char>) -> Option<(String, bool)> {
479 let chars: Vec<char> = pattern.chars().collect();
480 let len = chars.len();
481
482 if len == 0 {
483 return None;
484 }
485
486 let effective_escape_char = escape_char.or(Some('\\'));
491
492 let is_escaped = |i: usize| -> bool {
494 if let Some(esc) = effective_escape_char {
495 if i > 0 && chars[i - 1] == esc {
496 if i >= 2 && chars[i - 2] == esc {
498 false } else {
500 true }
502 } else {
503 false
504 }
505 } else {
506 false
508 }
509 };
510
511 let has_wildcard = chars.iter().enumerate().any(|(i, &c)| {
513 if c != '%' && c != '_' {
514 return false;
515 }
516 !is_escaped(i)
517 });
518
519 if !has_wildcard {
520 return None; }
522
523 if chars[0] == '%' || chars[0] == '_' {
525 return None; }
527
528 let mut prefix = String::new();
530 let mut i = 0;
531 let mut found_wildcard = false;
532
533 while i < len {
534 let c = chars[i];
535
536 if let Some(esc) = effective_escape_char
538 && c == esc
539 && i + 1 < len
540 {
541 let next = chars[i + 1];
542 if next == '%' || next == '_' || next == esc {
543 prefix.push(next);
545 i += 2;
546 continue;
547 }
548 }
549
550 if c == '%' || c == '_' {
552 found_wildcard = true;
553 break;
554 }
555
556 prefix.push(c);
557 i += 1;
558 }
559
560 if prefix.is_empty() {
561 return None;
562 }
563
564 let needs_refine = if found_wildcard && i < len {
566 if chars[i] == '%' && i + 1 == len {
568 false
570 } else {
571 true
573 }
574 } else {
575 false
577 };
578
579 Some((prefix, needs_refine))
580}
581
582#[derive(Debug)]
584pub struct BloomFilterQueryParser {
585 index_name: String,
586 index_type: String,
587 needs_recheck: bool,
588}
589
590impl BloomFilterQueryParser {
591 pub fn new(index_name: String, index_type: String, needs_recheck: bool) -> Self {
592 Self {
593 index_name,
594 index_type,
595 needs_recheck,
596 }
597 }
598}
599
600impl ScalarQueryParser for BloomFilterQueryParser {
601 fn visit_between(
602 &self,
603 _: &str,
604 _: &Bound<ScalarValue>,
605 _: &Bound<ScalarValue>,
606 ) -> Option<IndexedExpression> {
607 None
609 }
610
611 fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option<IndexedExpression> {
612 let query = BloomFilterQuery::IsIn(in_list.to_vec());
613 Some(IndexedExpression::index_query_with_recheck(
614 column.to_string(),
615 self.index_name.clone(),
616 self.index_type.clone(),
617 Arc::new(query),
618 self.needs_recheck,
619 ))
620 }
621
622 fn visit_is_bool(&self, column: &str, value: bool) -> Option<IndexedExpression> {
623 Some(IndexedExpression::index_query_with_recheck(
624 column.to_string(),
625 self.index_name.clone(),
626 self.index_type.clone(),
627 Arc::new(BloomFilterQuery::Equals(ScalarValue::Boolean(Some(value)))),
628 self.needs_recheck,
629 ))
630 }
631
632 fn visit_is_null(&self, column: &str) -> Option<IndexedExpression> {
633 Some(IndexedExpression::index_query_with_recheck(
634 column.to_string(),
635 self.index_name.clone(),
636 self.index_type.clone(),
637 Arc::new(BloomFilterQuery::IsNull()),
638 self.needs_recheck,
639 ))
640 }
641
642 fn visit_comparison(
643 &self,
644 column: &str,
645 value: &ScalarValue,
646 op: &Operator,
647 ) -> Option<IndexedExpression> {
648 let query = match op {
649 Operator::Eq => BloomFilterQuery::Equals(value.clone()),
651 Operator::NotEq => BloomFilterQuery::Equals(value.clone()),
653 _ => return None,
655 };
656 Some(IndexedExpression::index_query_with_recheck(
657 column.to_string(),
658 self.index_name.clone(),
659 self.index_type.clone(),
660 Arc::new(query),
661 self.needs_recheck,
662 ))
663 }
664
665 fn visit_scalar_function(
666 &self,
667 _: &str,
668 _: &DataType,
669 _: &ScalarUDF,
670 _: &[Expr],
671 ) -> Option<IndexedExpression> {
672 None
674 }
675}
676
677#[derive(Debug)]
679pub struct LabelListQueryParser {
680 index_name: String,
681 index_type: String,
682}
683
684impl LabelListQueryParser {
685 pub fn new(index_name: String, index_type: String) -> Self {
686 Self {
687 index_name,
688 index_type,
689 }
690 }
691}
692
693impl ScalarQueryParser for LabelListQueryParser {
694 fn visit_between(
695 &self,
696 _: &str,
697 _: &Bound<ScalarValue>,
698 _: &Bound<ScalarValue>,
699 ) -> Option<IndexedExpression> {
700 None
701 }
702
703 fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option<IndexedExpression> {
704 None
705 }
706
707 fn visit_is_bool(&self, _: &str, _: bool) -> Option<IndexedExpression> {
708 None
709 }
710
711 fn visit_is_null(&self, _: &str) -> Option<IndexedExpression> {
712 None
713 }
714
715 fn visit_comparison(
716 &self,
717 _: &str,
718 _: &ScalarValue,
719 _: &Operator,
720 ) -> Option<IndexedExpression> {
721 None
722 }
723
724 fn visit_scalar_function(
725 &self,
726 column: &str,
727 data_type: &DataType,
728 func: &ScalarUDF,
729 args: &[Expr],
730 ) -> Option<IndexedExpression> {
731 if args.len() != 2 {
732 return None;
733 }
734 if func.name() == "array_has" {
736 let inner_type = match data_type {
737 DataType::List(field) | DataType::LargeList(field) => field.data_type(),
738 _ => return None,
739 };
740 let scalar = maybe_scalar(&args[1], inner_type)?;
741 if scalar.is_null() {
744 return None;
745 }
746 let query = LabelListQuery::HasAnyLabel(vec![scalar]);
747 return Some(IndexedExpression::index_query(
748 column.to_string(),
749 self.index_name.clone(),
750 self.index_type.clone(),
751 Arc::new(query),
752 ));
753 }
754
755 let label_list = maybe_scalar(&args[1], data_type)?;
756 if let ScalarValue::List(list_arr) = label_list {
757 let list_values = list_arr.values();
758 if list_values.is_empty() {
759 return None;
760 }
761 let mut scalars = Vec::with_capacity(list_values.len());
762 for idx in 0..list_values.len() {
763 scalars.push(ScalarValue::try_from_array(list_values.as_ref(), idx).ok()?);
764 }
765 if func.name() == "array_has_all" {
766 let query = LabelListQuery::HasAllLabels(scalars);
767 Some(IndexedExpression::index_query(
768 column.to_string(),
769 self.index_name.clone(),
770 self.index_type.clone(),
771 Arc::new(query),
772 ))
773 } else if func.name() == "array_has_any" {
774 let query = LabelListQuery::HasAnyLabel(scalars);
775 Some(IndexedExpression::index_query(
776 column.to_string(),
777 self.index_name.clone(),
778 self.index_type.clone(),
779 Arc::new(query),
780 ))
781 } else {
782 None
783 }
784 } else {
785 None
786 }
787 }
788}
789
790#[derive(Debug, Clone)]
792pub struct TextQueryParser {
793 index_name: String,
794 index_type: String,
795 needs_recheck: bool,
796}
797
798impl TextQueryParser {
799 pub fn new(index_name: String, index_type: String, needs_recheck: bool) -> Self {
800 Self {
801 index_name,
802 index_type,
803 needs_recheck,
804 }
805 }
806}
807
808impl ScalarQueryParser for TextQueryParser {
809 fn visit_between(
810 &self,
811 _: &str,
812 _: &Bound<ScalarValue>,
813 _: &Bound<ScalarValue>,
814 ) -> Option<IndexedExpression> {
815 None
816 }
817
818 fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option<IndexedExpression> {
819 None
820 }
821
822 fn visit_is_bool(&self, _: &str, _: bool) -> Option<IndexedExpression> {
823 None
824 }
825
826 fn visit_is_null(&self, _: &str) -> Option<IndexedExpression> {
827 None
828 }
829
830 fn visit_comparison(
831 &self,
832 _: &str,
833 _: &ScalarValue,
834 _: &Operator,
835 ) -> Option<IndexedExpression> {
836 None
837 }
838
839 fn visit_scalar_function(
840 &self,
841 column: &str,
842 data_type: &DataType,
843 func: &ScalarUDF,
844 args: &[Expr],
845 ) -> Option<IndexedExpression> {
846 if args.len() != 2 {
847 return None;
848 }
849 let scalar = maybe_scalar(&args[1], data_type)?;
850 match scalar {
851 ScalarValue::Utf8(Some(scalar_str)) | ScalarValue::LargeUtf8(Some(scalar_str)) => {
852 if func.name() == "contains" {
853 let query = TextQuery::StringContains(scalar_str);
854 Some(IndexedExpression::index_query_with_recheck(
855 column.to_string(),
856 self.index_name.clone(),
857 self.index_type.clone(),
858 Arc::new(query),
859 self.needs_recheck,
860 ))
861 } else {
862 None
863 }
864 }
865 _ => {
866 None
868 }
869 }
870 }
871}
872
873#[derive(Debug, Clone)]
875pub struct FtsQueryParser {
876 index_name: String,
877 index_type: String,
878}
879
880impl FtsQueryParser {
881 pub fn new(name: String, index_type: String) -> Self {
882 Self {
883 index_name: name,
884 index_type,
885 }
886 }
887}
888
889impl ScalarQueryParser for FtsQueryParser {
890 fn visit_between(
891 &self,
892 _: &str,
893 _: &Bound<ScalarValue>,
894 _: &Bound<ScalarValue>,
895 ) -> Option<IndexedExpression> {
896 None
897 }
898
899 fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option<IndexedExpression> {
900 None
901 }
902
903 fn visit_is_bool(&self, _: &str, _: bool) -> Option<IndexedExpression> {
904 None
905 }
906
907 fn visit_is_null(&self, _: &str) -> Option<IndexedExpression> {
908 None
909 }
910
911 fn visit_comparison(
912 &self,
913 _: &str,
914 _: &ScalarValue,
915 _: &Operator,
916 ) -> Option<IndexedExpression> {
917 None
918 }
919
920 fn visit_scalar_function(
921 &self,
922 column: &str,
923 data_type: &DataType,
924 func: &ScalarUDF,
925 args: &[Expr],
926 ) -> Option<IndexedExpression> {
927 if args.len() != 2 {
928 return None;
929 }
930 let scalar = maybe_scalar(&args[1], data_type)?;
931 if let ScalarValue::Utf8(Some(scalar_str)) = scalar
932 && func.name() == "contains_tokens"
933 {
934 let query = TokenQuery::TokensContains(scalar_str);
935 return Some(IndexedExpression::index_query(
936 column.to_string(),
937 self.index_name.clone(),
938 self.index_type.clone(),
939 Arc::new(query),
940 ));
941 }
942 None
943 }
944}
945
946#[cfg(feature = "geo")]
948#[derive(Debug, Clone)]
949pub struct GeoQueryParser {
950 index_name: String,
951 index_type: String,
952}
953
954#[cfg(feature = "geo")]
955impl GeoQueryParser {
956 pub fn new(index_name: String, index_type: String) -> Self {
957 Self {
958 index_name,
959 index_type,
960 }
961 }
962}
963
964#[cfg(feature = "geo")]
965impl ScalarQueryParser for GeoQueryParser {
966 fn visit_between(
967 &self,
968 _: &str,
969 _: &Bound<ScalarValue>,
970 _: &Bound<ScalarValue>,
971 ) -> Option<IndexedExpression> {
972 None
973 }
974
975 fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option<IndexedExpression> {
976 None
977 }
978
979 fn visit_is_bool(&self, _: &str, _: bool) -> Option<IndexedExpression> {
980 None
981 }
982
983 fn visit_is_null(&self, column: &str) -> Option<IndexedExpression> {
984 Some(IndexedExpression::index_query_with_recheck(
985 column.to_string(),
986 self.index_name.clone(),
987 self.index_type.clone(),
988 Arc::new(GeoQuery::IsNull),
989 true,
990 ))
991 }
992
993 fn visit_comparison(
994 &self,
995 _: &str,
996 _: &ScalarValue,
997 _: &Operator,
998 ) -> Option<IndexedExpression> {
999 None
1000 }
1001
1002 fn visit_scalar_function(
1003 &self,
1004 column: &str,
1005 _data_type: &DataType,
1006 func: &ScalarUDF,
1007 args: &[Expr],
1008 ) -> Option<IndexedExpression> {
1009 if (func.name() == "st_intersects"
1010 || func.name() == "st_contains"
1011 || func.name() == "st_within"
1012 || func.name() == "st_touches"
1013 || func.name() == "st_crosses"
1014 || func.name() == "st_overlaps"
1015 || func.name() == "st_covers"
1016 || func.name() == "st_coveredby")
1017 && args.len() == 2
1018 {
1019 let left_arg = &args[0];
1020 let right_arg = &args[1];
1021 return match (left_arg, right_arg) {
1022 (Expr::Literal(left_value, metadata), Expr::Column(_)) => {
1023 let mut field = Field::new("_geo", left_value.data_type(), false);
1024 if let Some(metadata) = metadata {
1025 field = field.with_metadata(metadata.to_hashmap());
1026 }
1027 let query = GeoQuery::IntersectQuery(RelationQuery {
1028 value: left_value.clone(),
1029 field,
1030 });
1031 Some(IndexedExpression::index_query_with_recheck(
1032 column.to_string(),
1033 self.index_name.clone(),
1034 self.index_type.clone(),
1035 Arc::new(query),
1036 true,
1037 ))
1038 }
1039 (Expr::Column(_), Expr::Literal(right_value, metadata)) => {
1040 let mut field = Field::new("_geo", right_value.data_type(), false);
1041 if let Some(metadata) = metadata {
1042 field = field.with_metadata(metadata.to_hashmap());
1043 }
1044 let query = GeoQuery::IntersectQuery(RelationQuery {
1045 value: right_value.clone(),
1046 field,
1047 });
1048 Some(IndexedExpression::index_query_with_recheck(
1049 column.to_string(),
1050 self.index_name.clone(),
1051 self.index_type.clone(),
1052 Arc::new(query),
1053 true,
1054 ))
1055 }
1056 _ => None,
1057 };
1058 }
1059 None
1060 }
1061}
1062
1063impl IndexedExpression {
1064 fn refine_only(refine_expr: Expr) -> Self {
1066 Self {
1067 scalar_query: None,
1068 refine_expr: Some(refine_expr),
1069 }
1070 }
1071
1072 fn index_query(
1074 column: String,
1075 index_name: String,
1076 index_type: String,
1077 query: Arc<dyn AnyQuery>,
1078 ) -> Self {
1079 Self {
1080 scalar_query: Some(ScalarIndexExpr::Query(ScalarIndexSearch {
1081 column,
1082 index_name,
1083 index_type,
1084 query,
1085 needs_recheck: false, })),
1087 refine_expr: None,
1088 }
1089 }
1090
1091 fn index_query_with_recheck(
1093 column: String,
1094 index_name: String,
1095 index_type: String,
1096 query: Arc<dyn AnyQuery>,
1097 needs_recheck: bool,
1098 ) -> Self {
1099 Self {
1100 scalar_query: Some(ScalarIndexExpr::Query(ScalarIndexSearch {
1101 column,
1102 index_name,
1103 index_type,
1104 query,
1105 needs_recheck,
1106 })),
1107 refine_expr: None,
1108 }
1109 }
1110
1111 fn maybe_not(self) -> Option<Self> {
1116 match (self.scalar_query, self.refine_expr) {
1117 (Some(_), Some(_)) => None,
1118 (Some(scalar_query), None) => {
1119 if scalar_query.needs_recheck() {
1120 return None;
1121 }
1122 Some(Self {
1123 scalar_query: Some(ScalarIndexExpr::Not(Box::new(scalar_query))),
1124 refine_expr: None,
1125 })
1126 }
1127 (None, Some(refine_expr)) => Some(Self {
1128 scalar_query: None,
1129 refine_expr: Some(Expr::Not(Box::new(refine_expr))),
1130 }),
1131 (None, None) => panic!("Empty node should not occur"),
1132 }
1133 }
1134
1135 fn and(self, other: Self) -> Self {
1140 let scalar_query = match (self.scalar_query, other.scalar_query) {
1141 (Some(scalar_query), Some(other_scalar_query)) => Some(ScalarIndexExpr::And(
1142 Box::new(scalar_query),
1143 Box::new(other_scalar_query),
1144 )),
1145 (Some(scalar_query), None) => Some(scalar_query),
1146 (None, Some(scalar_query)) => Some(scalar_query),
1147 (None, None) => None,
1148 };
1149 let refine_expr = match (self.refine_expr, other.refine_expr) {
1150 (Some(refine_expr), Some(other_refine_expr)) => {
1151 Some(refine_expr.and(other_refine_expr))
1152 }
1153 (Some(refine_expr), None) => Some(refine_expr),
1154 (None, Some(refine_expr)) => Some(refine_expr),
1155 (None, None) => None,
1156 };
1157 Self {
1158 scalar_query,
1159 refine_expr,
1160 }
1161 }
1162
1163 fn maybe_or(self, other: Self) -> Option<Self> {
1170 let scalar_query = self.scalar_query?;
1173 let other_scalar_query = other.scalar_query?;
1174 let scalar_query = Some(ScalarIndexExpr::Or(
1175 Box::new(scalar_query),
1176 Box::new(other_scalar_query),
1177 ));
1178
1179 let refine_expr = match (self.refine_expr, other.refine_expr) {
1180 (Some(_), Some(_)) => {
1190 return None;
1191 }
1192 (Some(_), None) => {
1193 return None;
1194 }
1195 (None, Some(_)) => {
1196 return None;
1197 }
1198 (None, None) => None,
1199 };
1200 Some(Self {
1201 scalar_query,
1202 refine_expr,
1203 })
1204 }
1205
1206 fn refine(self, expr: Expr) -> Self {
1207 match self.refine_expr {
1208 Some(refine_expr) => Self {
1209 scalar_query: self.scalar_query,
1210 refine_expr: Some(refine_expr.and(expr)),
1211 },
1212 None => Self {
1213 scalar_query: self.scalar_query,
1214 refine_expr: Some(expr),
1215 },
1216 }
1217 }
1218}
1219
1220#[async_trait]
1224pub trait ScalarIndexLoader: Send + Sync {
1225 async fn load_index(
1227 &self,
1228 column: &str,
1229 index_name: &str,
1230 metrics: &dyn MetricsCollector,
1231 ) -> Result<Arc<dyn ScalarIndex>>;
1232}
1233
1234#[derive(Debug, Clone)]
1236pub struct ScalarIndexSearch {
1237 pub column: String,
1239 pub index_name: String,
1241 pub index_type: String,
1243 pub query: Arc<dyn AnyQuery>,
1245 pub needs_recheck: bool,
1247}
1248
1249impl PartialEq for ScalarIndexSearch {
1250 fn eq(&self, other: &Self) -> bool {
1251 self.column == other.column
1252 && self.index_name == other.index_name
1253 && self.query.as_ref().eq(other.query.as_ref())
1254 }
1255}
1256
1257#[derive(Debug, Clone)]
1262pub enum ScalarIndexExpr {
1263 Not(Box<Self>),
1264 And(Box<Self>, Box<Self>),
1265 Or(Box<Self>, Box<Self>),
1266 Query(ScalarIndexSearch),
1267}
1268
1269impl PartialEq for ScalarIndexExpr {
1270 fn eq(&self, other: &Self) -> bool {
1271 match (self, other) {
1272 (Self::Not(l0), Self::Not(r0)) => l0 == r0,
1273 (Self::And(l0, l1), Self::And(r0, r1)) => l0 == r0 && l1 == r1,
1274 (Self::Or(l0, l1), Self::Or(r0, r1)) => l0 == r0 && l1 == r1,
1275 (Self::Query(l_search), Self::Query(r_search)) => l_search == r_search,
1276 _ => false,
1277 }
1278 }
1279}
1280
1281impl std::fmt::Display for ScalarIndexExpr {
1282 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1283 match self {
1284 Self::Not(inner) => write!(f, "NOT({})", inner),
1285 Self::And(lhs, rhs) => write!(f, "AND({},{})", lhs, rhs),
1286 Self::Or(lhs, rhs) => write!(f, "OR({},{})", lhs, rhs),
1287 Self::Query(search) => write!(
1288 f,
1289 "[{}]@{}({})",
1290 search.query.format(&search.column),
1291 search.index_name,
1292 search.index_type
1293 ),
1294 }
1295 }
1296}
1297
1298pub static INDEX_EXPR_RESULT_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
1304 Arc::new(Schema::new(vec![
1305 Field::new("result".to_string(), DataType::Binary, true),
1306 Field::new("discriminant".to_string(), DataType::UInt32, true),
1307 Field::new("fragments_covered".to_string(), DataType::Binary, true),
1308 ]))
1309});
1310
1311#[derive(Debug)]
1312enum NullableIndexExprResult {
1313 Exact(NullableRowAddrMask),
1314 AtMost(NullableRowAddrMask),
1315 AtLeast(NullableRowAddrMask),
1316}
1317
1318impl From<SearchResult> for NullableIndexExprResult {
1319 fn from(result: SearchResult) -> Self {
1320 match result {
1321 SearchResult::Exact(mask) => Self::Exact(NullableRowAddrMask::AllowList(mask)),
1322 SearchResult::AtMost(mask) => Self::AtMost(NullableRowAddrMask::AllowList(mask)),
1323 SearchResult::AtLeast(mask) => Self::AtLeast(NullableRowAddrMask::AllowList(mask)),
1324 }
1325 }
1326}
1327
1328impl std::ops::BitAnd<Self> for NullableIndexExprResult {
1329 type Output = Self;
1330
1331 fn bitand(self, rhs: Self) -> Self {
1332 match (self, rhs) {
1333 (Self::Exact(lhs), Self::Exact(rhs)) => Self::Exact(lhs & rhs),
1334 (Self::Exact(lhs), Self::AtMost(rhs)) | (Self::AtMost(lhs), Self::Exact(rhs)) => {
1335 Self::AtMost(lhs & rhs)
1336 }
1337 (Self::Exact(exact), Self::AtLeast(_)) | (Self::AtLeast(_), Self::Exact(exact)) => {
1338 Self::AtMost(exact)
1342 }
1343 (Self::AtMost(lhs), Self::AtMost(rhs)) => Self::AtMost(lhs & rhs),
1344 (Self::AtLeast(lhs), Self::AtLeast(rhs)) => Self::AtLeast(lhs & rhs),
1345 (Self::AtMost(most), Self::AtLeast(_)) | (Self::AtLeast(_), Self::AtMost(most)) => {
1346 Self::AtMost(most)
1347 }
1348 }
1349 }
1350}
1351
1352impl std::ops::BitOr<Self> for NullableIndexExprResult {
1353 type Output = Self;
1354
1355 fn bitor(self, rhs: Self) -> Self {
1356 match (self, rhs) {
1357 (Self::Exact(lhs), Self::Exact(rhs)) => Self::Exact(lhs | rhs),
1358 (Self::Exact(lhs), Self::AtMost(rhs)) | (Self::AtMost(rhs), Self::Exact(lhs)) => {
1359 Self::AtMost(lhs | rhs)
1363 }
1364 (Self::Exact(lhs), Self::AtLeast(rhs)) | (Self::AtLeast(rhs), Self::Exact(lhs)) => {
1365 Self::AtLeast(lhs | rhs)
1366 }
1367 (Self::AtMost(lhs), Self::AtMost(rhs)) => Self::AtMost(lhs | rhs),
1368 (Self::AtLeast(lhs), Self::AtLeast(rhs)) => Self::AtLeast(lhs | rhs),
1369 (Self::AtMost(_), Self::AtLeast(least)) | (Self::AtLeast(least), Self::AtMost(_)) => {
1370 Self::AtLeast(least)
1371 }
1372 }
1373 }
1374}
1375
1376impl NullableIndexExprResult {
1377 pub fn drop_nulls(self) -> IndexExprResult {
1378 match self {
1379 Self::Exact(mask) => IndexExprResult::Exact(mask.drop_nulls()),
1380 Self::AtMost(mask) => IndexExprResult::AtMost(mask.drop_nulls()),
1381 Self::AtLeast(mask) => IndexExprResult::AtLeast(mask.drop_nulls()),
1382 }
1383 }
1384}
1385
1386#[derive(Debug)]
1387pub enum IndexExprResult {
1388 Exact(RowAddrMask),
1390 AtMost(RowAddrMask),
1394 AtLeast(RowAddrMask),
1398}
1399
1400impl IndexExprResult {
1401 pub fn row_addr_mask(&self) -> &RowAddrMask {
1402 match self {
1403 Self::Exact(mask) => mask,
1404 Self::AtMost(mask) => mask,
1405 Self::AtLeast(mask) => mask,
1406 }
1407 }
1408
1409 pub fn discriminant(&self) -> u32 {
1410 match self {
1411 Self::Exact(_) => 0,
1412 Self::AtMost(_) => 1,
1413 Self::AtLeast(_) => 2,
1414 }
1415 }
1416
1417 pub fn from_parts(mask: RowAddrMask, discriminant: u32) -> Result<Self> {
1418 match discriminant {
1419 0 => Ok(Self::Exact(mask)),
1420 1 => Ok(Self::AtMost(mask)),
1421 2 => Ok(Self::AtLeast(mask)),
1422 _ => Err(Error::invalid_input_source(
1423 format!("Invalid IndexExprResult discriminant: {}", discriminant).into(),
1424 )),
1425 }
1426 }
1427
1428 #[instrument(skip_all)]
1429 pub fn serialize_to_arrow(
1430 &self,
1431 fragments_covered_by_result: &RoaringBitmap,
1432 ) -> Result<RecordBatch> {
1433 let row_addr_mask = self.row_addr_mask();
1434 let row_addr_mask_arr = row_addr_mask.into_arrow()?;
1435 let discriminant = self.discriminant();
1436 let discriminant_arr =
1437 Arc::new(UInt32Array::from(vec![discriminant, discriminant])) as Arc<dyn Array>;
1438 let mut fragments_covered_builder = BinaryBuilder::new();
1439 let fragments_covered_bytes_len = fragments_covered_by_result.serialized_size();
1440 let mut fragments_covered_bytes = Vec::with_capacity(fragments_covered_bytes_len);
1441 fragments_covered_by_result.serialize_into(&mut fragments_covered_bytes)?;
1442 fragments_covered_builder.append_value(fragments_covered_bytes);
1443 fragments_covered_builder.append_null();
1444 let fragments_covered_arr = Arc::new(fragments_covered_builder.finish()) as Arc<dyn Array>;
1445 Ok(RecordBatch::try_new(
1446 INDEX_EXPR_RESULT_SCHEMA.clone(),
1447 vec![
1448 Arc::new(row_addr_mask_arr),
1449 Arc::new(discriminant_arr),
1450 Arc::new(fragments_covered_arr),
1451 ],
1452 )?)
1453 }
1454}
1455
1456impl ScalarIndexExpr {
1457 #[async_recursion]
1464 async fn evaluate_impl(
1465 &self,
1466 index_loader: &dyn ScalarIndexLoader,
1467 metrics: &dyn MetricsCollector,
1468 ) -> Result<NullableIndexExprResult> {
1469 match self {
1470 Self::Not(inner) => {
1471 let result = inner.evaluate_impl(index_loader, metrics).await?;
1472 Ok(match result {
1474 NullableIndexExprResult::Exact(mask) => NullableIndexExprResult::Exact(!mask),
1475 NullableIndexExprResult::AtMost(mask) => {
1476 NullableIndexExprResult::AtLeast(!mask)
1477 }
1478 NullableIndexExprResult::AtLeast(mask) => {
1479 NullableIndexExprResult::AtMost(!mask)
1480 }
1481 })
1482 }
1483 Self::And(lhs, rhs) => {
1484 let lhs_result = lhs.evaluate_impl(index_loader, metrics);
1485 let rhs_result = rhs.evaluate_impl(index_loader, metrics);
1486 let (lhs_result, rhs_result) = try_join!(lhs_result, rhs_result)?;
1487 Ok(lhs_result & rhs_result)
1488 }
1489 Self::Or(lhs, rhs) => {
1490 let lhs_result = lhs.evaluate_impl(index_loader, metrics);
1491 let rhs_result = rhs.evaluate_impl(index_loader, metrics);
1492 let (lhs_result, rhs_result) = try_join!(lhs_result, rhs_result)?;
1493 Ok(lhs_result | rhs_result)
1494 }
1495 Self::Query(search) => {
1496 let index = index_loader
1497 .load_index(&search.column, &search.index_name, metrics)
1498 .await?;
1499 let search_result = index.search(search.query.as_ref(), metrics).await?;
1500 Ok(search_result.into())
1501 }
1502 }
1503 }
1504
1505 #[instrument(level = "debug", skip_all)]
1506 pub async fn evaluate(
1507 &self,
1508 index_loader: &dyn ScalarIndexLoader,
1509 metrics: &dyn MetricsCollector,
1510 ) -> Result<IndexExprResult> {
1511 Ok(self
1512 .evaluate_impl(index_loader, metrics)
1513 .await?
1514 .drop_nulls())
1515 }
1516
1517 pub fn to_expr(&self) -> Expr {
1518 match self {
1519 Self::Not(inner) => Expr::Not(inner.to_expr().into()),
1520 Self::And(lhs, rhs) => {
1521 let lhs = lhs.to_expr();
1522 let rhs = rhs.to_expr();
1523 lhs.and(rhs)
1524 }
1525 Self::Or(lhs, rhs) => {
1526 let lhs = lhs.to_expr();
1527 let rhs = rhs.to_expr();
1528 lhs.or(rhs)
1529 }
1530 Self::Query(search) => search.query.to_expr(search.column.clone()),
1531 }
1532 }
1533
1534 pub fn needs_recheck(&self) -> bool {
1535 match self {
1536 Self::Not(inner) => inner.needs_recheck(),
1537 Self::And(lhs, rhs) | Self::Or(lhs, rhs) => lhs.needs_recheck() || rhs.needs_recheck(),
1538 Self::Query(search) => search.needs_recheck,
1539 }
1540 }
1541}
1542
1543fn maybe_column(expr: &Expr) -> Option<&str> {
1545 match expr {
1546 Expr::Column(col) => Some(&col.name),
1547 _ => None,
1548 }
1549}
1550
1551fn extract_nested_column_path(expr: &Expr) -> Option<String> {
1554 let mut current_expr = expr;
1555 let mut parts = Vec::new();
1556
1557 loop {
1559 match current_expr {
1560 Expr::ScalarFunction(udf) if udf.name() == "get_field" => {
1561 if udf.args.len() != 2 {
1562 return None;
1563 }
1564 if let Expr::Literal(ScalarValue::Utf8(Some(field_name)), _) = &udf.args[1] {
1567 parts.push(field_name.clone());
1568 } else {
1569 return None;
1570 }
1571 current_expr = &udf.args[0];
1573 }
1574 Expr::Column(col) => {
1575 parts.push(col.name.clone());
1577 break;
1578 }
1579 _ => {
1580 return None;
1581 }
1582 }
1583 }
1584
1585 parts.reverse();
1587
1588 let field_refs: Vec<&str> = parts.iter().map(|s| s.as_str()).collect();
1590 Some(lance_core::datatypes::format_field_path(&field_refs))
1591}
1592
1593fn maybe_indexed_column<'b>(
1600 expr: &Expr,
1601 index_info: &'b dyn IndexInformationProvider,
1602) -> Option<(String, DataType, &'b dyn ScalarQueryParser)> {
1603 if let Some(nested_path) = extract_nested_column_path(expr)
1605 && let Some((data_type, parser)) = index_info.get_index(&nested_path)
1606 && let Some(data_type) = parser.is_valid_reference(expr, data_type)
1607 {
1608 return Some((nested_path, data_type, parser));
1609 }
1610
1611 match expr {
1612 Expr::Column(col) => {
1613 let col = col.name.as_str();
1614 let (data_type, parser) = index_info.get_index(col)?;
1615 if let Some(data_type) = parser.is_valid_reference(expr, data_type) {
1616 Some((col.to_string(), data_type, parser))
1617 } else {
1618 None
1619 }
1620 }
1621 Expr::ScalarFunction(udf) => {
1622 if udf.args.is_empty() {
1623 return None;
1624 }
1625 let col = maybe_column(&udf.args[0])?;
1627 let (data_type, parser) = index_info.get_index(col)?;
1628 if let Some(data_type) = parser.is_valid_reference(expr, data_type) {
1629 Some((col.to_string(), data_type, parser))
1630 } else {
1631 None
1632 }
1633 }
1634 _ => None,
1635 }
1636}
1637
1638fn maybe_scalar(expr: &Expr, expected_type: &DataType) -> Option<ScalarValue> {
1640 match expr {
1641 Expr::Literal(value, _) => safe_coerce_scalar(value, expected_type),
1642 Expr::Cast(cast) => match cast.expr.as_ref() {
1650 Expr::Literal(value, _) => {
1651 let casted = value.cast_to(&cast.data_type).ok()?;
1652 safe_coerce_scalar(&casted, expected_type)
1653 }
1654 _ => None,
1655 },
1656 Expr::ScalarFunction(scalar_function) => {
1657 if scalar_function.name() == "arrow_cast" {
1658 if scalar_function.args.len() != 2 {
1659 return None;
1660 }
1661 match (&scalar_function.args[0], &scalar_function.args[1]) {
1662 (Expr::Literal(value, _), Expr::Literal(cast_type, _)) => {
1663 let target_type = scalar_function
1664 .func
1665 .return_field_from_args(ReturnFieldArgs {
1666 arg_fields: &[
1667 Arc::new(Field::new("expression", value.data_type(), false)),
1668 Arc::new(Field::new("datatype", cast_type.data_type(), false)),
1669 ],
1670 scalar_arguments: &[Some(value), Some(cast_type)],
1671 })
1672 .ok()?;
1673 let casted = value.cast_to(target_type.data_type()).ok()?;
1674 safe_coerce_scalar(&casted, expected_type)
1675 }
1676 _ => None,
1677 }
1678 } else {
1679 None
1680 }
1681 }
1682 _ => None,
1683 }
1684}
1685
1686fn maybe_scalar_list(exprs: &Vec<Expr>, expected_type: &DataType) -> Option<Vec<ScalarValue>> {
1688 let mut scalar_values = Vec::with_capacity(exprs.len());
1689 for expr in exprs {
1690 match maybe_scalar(expr, expected_type) {
1691 Some(scalar_val) => {
1692 scalar_values.push(scalar_val);
1693 }
1694 None => {
1695 return None;
1696 }
1697 }
1698 }
1699 Some(scalar_values)
1700}
1701
1702fn visit_between(
1703 between: &Between,
1704 index_info: &dyn IndexInformationProvider,
1705) -> Option<IndexedExpression> {
1706 let (column, col_type, query_parser) = maybe_indexed_column(&between.expr, index_info)?;
1707 let low = maybe_scalar(&between.low, &col_type)?;
1708 let high = maybe_scalar(&between.high, &col_type)?;
1709
1710 let indexed_expr =
1711 query_parser.visit_between(&column, &Bound::Included(low), &Bound::Included(high))?;
1712
1713 if between.negated {
1714 indexed_expr.maybe_not()
1715 } else {
1716 Some(indexed_expr)
1717 }
1718}
1719
1720fn visit_in_list(
1721 in_list: &InList,
1722 index_info: &dyn IndexInformationProvider,
1723) -> Option<IndexedExpression> {
1724 let (column, col_type, query_parser) = maybe_indexed_column(&in_list.expr, index_info)?;
1725 let values = maybe_scalar_list(&in_list.list, &col_type)?;
1726
1727 let indexed_expr = query_parser.visit_in_list(&column, &values)?;
1728
1729 if in_list.negated {
1730 indexed_expr.maybe_not()
1731 } else {
1732 Some(indexed_expr)
1733 }
1734}
1735
1736fn visit_is_bool(
1737 expr: &Expr,
1738 index_info: &dyn IndexInformationProvider,
1739 value: bool,
1740) -> Option<IndexedExpression> {
1741 let (column, col_type, query_parser) = maybe_indexed_column(expr, index_info)?;
1742 if col_type != DataType::Boolean {
1743 None
1744 } else {
1745 query_parser.visit_is_bool(&column, value)
1746 }
1747}
1748
1749fn visit_column(
1751 col: &Expr,
1752 index_info: &dyn IndexInformationProvider,
1753) -> Option<IndexedExpression> {
1754 let (column, col_type, query_parser) = maybe_indexed_column(col, index_info)?;
1755 if col_type != DataType::Boolean {
1756 None
1757 } else {
1758 query_parser.visit_is_bool(&column, true)
1759 }
1760}
1761
1762fn visit_is_null(
1763 expr: &Expr,
1764 index_info: &dyn IndexInformationProvider,
1765 negated: bool,
1766) -> Option<IndexedExpression> {
1767 let (column, _, query_parser) = maybe_indexed_column(expr, index_info)?;
1768 let indexed_expr = query_parser.visit_is_null(&column)?;
1769 if negated {
1770 indexed_expr.maybe_not()
1771 } else {
1772 Some(indexed_expr)
1773 }
1774}
1775
1776fn visit_not(
1777 expr: &Expr,
1778 index_info: &dyn IndexInformationProvider,
1779 depth: usize,
1780) -> Result<Option<IndexedExpression>> {
1781 let node = visit_node(expr, index_info, depth + 1)?;
1782 Ok(node.and_then(|node| node.maybe_not()))
1783}
1784
1785fn visit_comparison(
1786 expr: &BinaryExpr,
1787 index_info: &dyn IndexInformationProvider,
1788) -> Option<IndexedExpression> {
1789 let left_col = maybe_indexed_column(&expr.left, index_info);
1790 if let Some((column, col_type, query_parser)) = left_col {
1791 let scalar = maybe_scalar(&expr.right, &col_type)?;
1792 query_parser.visit_comparison(&column, &scalar, &expr.op)
1793 } else {
1794 None
1797 }
1798}
1799
1800fn maybe_range(
1801 expr: &BinaryExpr,
1802 index_info: &dyn IndexInformationProvider,
1803) -> Option<IndexedExpression> {
1804 let left_expr = match expr.left.as_ref() {
1805 Expr::BinaryExpr(binary_expr) => Some(binary_expr),
1806 _ => None,
1807 }?;
1808 let right_expr = match expr.right.as_ref() {
1809 Expr::BinaryExpr(binary_expr) => Some(binary_expr),
1810 _ => None,
1811 }?;
1812
1813 let (left_col, dt, parser) = maybe_indexed_column(&left_expr.left, index_info)?;
1814 let right_col = maybe_column(&right_expr.left)?;
1815
1816 if left_col != right_col {
1817 return None;
1818 }
1819
1820 let left_value = maybe_scalar(&left_expr.right, &dt)?;
1821 let right_value = maybe_scalar(&right_expr.right, &dt)?;
1822
1823 let (low, high) = match (left_expr.op, right_expr.op) {
1824 (Operator::GtEq, Operator::LtEq) => {
1826 (Bound::Included(left_value), Bound::Included(right_value))
1827 }
1828 (Operator::GtEq, Operator::Lt) => {
1830 (Bound::Included(left_value), Bound::Excluded(right_value))
1831 }
1832 (Operator::Gt, Operator::LtEq) => {
1834 (Bound::Excluded(left_value), Bound::Included(right_value))
1835 }
1836 (Operator::Gt, Operator::Lt) => (Bound::Excluded(left_value), Bound::Excluded(right_value)),
1838 (Operator::LtEq, Operator::GtEq) => {
1840 (Bound::Included(right_value), Bound::Included(left_value))
1841 }
1842 (Operator::LtEq, Operator::Gt) => {
1844 (Bound::Excluded(right_value), Bound::Included(left_value))
1845 }
1846 (Operator::Lt, Operator::GtEq) => {
1848 (Bound::Included(right_value), Bound::Excluded(left_value))
1849 }
1850 (Operator::Lt, Operator::Gt) => (Bound::Excluded(right_value), Bound::Excluded(left_value)),
1852 _ => return None,
1853 };
1854
1855 parser.visit_between(&left_col, &low, &high)
1856}
1857
1858fn visit_and(
1859 expr: &BinaryExpr,
1860 index_info: &dyn IndexInformationProvider,
1861 depth: usize,
1862) -> Result<Option<IndexedExpression>> {
1863 if let Some(range_expr) = maybe_range(expr, index_info) {
1871 return Ok(Some(range_expr));
1872 }
1873
1874 let left = visit_node(&expr.left, index_info, depth + 1)?;
1875 let right = visit_node(&expr.right, index_info, depth + 1)?;
1876 Ok(match (left, right) {
1877 (Some(left), Some(right)) => Some(left.and(right)),
1878 (Some(left), None) => Some(left.refine((*expr.right).clone())),
1879 (None, Some(right)) => Some(right.refine((*expr.left).clone())),
1880 (None, None) => None,
1881 })
1882}
1883
1884fn visit_or(
1885 expr: &BinaryExpr,
1886 index_info: &dyn IndexInformationProvider,
1887 depth: usize,
1888) -> Result<Option<IndexedExpression>> {
1889 let left = visit_node(&expr.left, index_info, depth + 1)?;
1890 let right = visit_node(&expr.right, index_info, depth + 1)?;
1891 Ok(match (left, right) {
1892 (Some(left), Some(right)) => left.maybe_or(right),
1893 (Some(_), None) => None,
1899 (None, Some(_)) => None,
1900 (None, None) => None,
1901 })
1902}
1903
1904fn visit_binary_expr(
1905 expr: &BinaryExpr,
1906 index_info: &dyn IndexInformationProvider,
1907 depth: usize,
1908) -> Result<Option<IndexedExpression>> {
1909 match &expr.op {
1910 Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq | Operator::Eq => {
1911 Ok(visit_comparison(expr, index_info))
1912 }
1913 Operator::NotEq => Ok(visit_comparison(expr, index_info).and_then(|node| node.maybe_not())),
1915 Operator::And => visit_and(expr, index_info, depth),
1916 Operator::Or => visit_or(expr, index_info, depth),
1917 _ => Ok(None),
1918 }
1919}
1920
1921fn visit_scalar_fn(
1922 scalar_fn: &ScalarFunction,
1923 index_info: &dyn IndexInformationProvider,
1924) -> Option<IndexedExpression> {
1925 if scalar_fn.args.is_empty() {
1926 return None;
1927 }
1928 let (col, data_type, query_parser) = maybe_indexed_column(&scalar_fn.args[0], index_info)?;
1929 query_parser.visit_scalar_function(&col, &data_type, &scalar_fn.func, &scalar_fn.args)
1930}
1931
1932fn visit_like_expr(
1933 like: &Like,
1934 index_info: &dyn IndexInformationProvider,
1935) -> Option<IndexedExpression> {
1936 let (column, _, query_parser) = maybe_indexed_column(&like.expr, index_info)?;
1937
1938 let pattern = match like.pattern.as_ref() {
1940 Expr::Literal(scalar, _) => scalar.clone(),
1941 _ => return None,
1942 };
1943
1944 query_parser.visit_like(&column, like, &pattern)
1945}
1946
1947fn visit_node(
1948 expr: &Expr,
1949 index_info: &dyn IndexInformationProvider,
1950 depth: usize,
1951) -> Result<Option<IndexedExpression>> {
1952 if depth >= MAX_DEPTH {
1953 return Err(Error::invalid_input(format!(
1954 "the filter expression is too long, lance limit the max number of conditions to {}",
1955 MAX_DEPTH
1956 )));
1957 }
1958 match expr {
1959 Expr::Between(between) => Ok(visit_between(between, index_info)),
1960 Expr::Alias(alias) => visit_node(alias.expr.as_ref(), index_info, depth),
1961 Expr::Column(_) => Ok(visit_column(expr, index_info)),
1962 Expr::InList(in_list) => Ok(visit_in_list(in_list, index_info)),
1963 Expr::IsFalse(expr) => Ok(visit_is_bool(expr.as_ref(), index_info, false)),
1964 Expr::IsTrue(expr) => Ok(visit_is_bool(expr.as_ref(), index_info, true)),
1965 Expr::IsNull(expr) => Ok(visit_is_null(expr.as_ref(), index_info, false)),
1966 Expr::IsNotNull(expr) => Ok(visit_is_null(expr.as_ref(), index_info, true)),
1967 Expr::Not(expr) => visit_not(expr.as_ref(), index_info, depth),
1968 Expr::BinaryExpr(binary_expr) => visit_binary_expr(binary_expr, index_info, depth),
1969 Expr::ScalarFunction(scalar_fn) => Ok(visit_scalar_fn(scalar_fn, index_info)),
1970 Expr::Like(like) => {
1971 if like.negated {
1972 Ok(None)
1974 } else {
1975 Ok(visit_like_expr(like, index_info))
1976 }
1977 }
1978 _ => Ok(None),
1979 }
1980}
1981
1982pub trait IndexInformationProvider {
1984 fn get_index(&self, col: &str) -> Option<(&DataType, &dyn ScalarQueryParser)>;
1987}
1988
1989pub fn apply_scalar_indices(
1992 expr: Expr,
1993 index_info: &dyn IndexInformationProvider,
1994) -> Result<IndexedExpression> {
1995 Ok(visit_node(&expr, index_info, 0)?.unwrap_or(IndexedExpression::refine_only(expr)))
1996}
1997
1998#[derive(Clone, Default, Debug)]
1999pub struct FilterPlan {
2000 pub index_query: Option<ScalarIndexExpr>,
2001 pub skip_recheck: bool,
2003 pub refine_expr: Option<Expr>,
2004 pub full_expr: Option<Expr>,
2005}
2006
2007impl FilterPlan {
2008 pub fn empty() -> Self {
2009 Self {
2010 index_query: None,
2011 skip_recheck: true,
2012 refine_expr: None,
2013 full_expr: None,
2014 }
2015 }
2016
2017 pub fn new_refine_only(expr: Expr) -> Self {
2018 Self {
2019 index_query: None,
2020 skip_recheck: true,
2021 refine_expr: Some(expr.clone()),
2022 full_expr: Some(expr),
2023 }
2024 }
2025
2026 pub fn is_empty(&self) -> bool {
2027 self.refine_expr.is_none() && self.index_query.is_none()
2028 }
2029
2030 pub fn all_columns(&self) -> Vec<String> {
2031 self.full_expr
2032 .as_ref()
2033 .map(Planner::column_names_in_expr)
2034 .unwrap_or_default()
2035 }
2036
2037 pub fn refine_columns(&self) -> Vec<String> {
2038 self.refine_expr
2039 .as_ref()
2040 .map(Planner::column_names_in_expr)
2041 .unwrap_or_default()
2042 }
2043
2044 pub fn has_refine(&self) -> bool {
2046 self.refine_expr.is_some()
2047 }
2048
2049 pub fn has_index_query(&self) -> bool {
2051 self.index_query.is_some()
2052 }
2053
2054 pub fn has_any_filter(&self) -> bool {
2055 self.refine_expr.is_some() || self.index_query.is_some()
2056 }
2057
2058 pub fn make_refine_only(&mut self) {
2059 self.index_query = None;
2060 self.refine_expr = self.full_expr.clone();
2061 }
2062
2063 pub fn is_exact_index_search(&self) -> bool {
2065 self.index_query.is_some() && self.refine_expr.is_none() && self.skip_recheck
2066 }
2067}
2068
2069pub trait PlannerIndexExt {
2070 fn create_filter_plan(
2077 &self,
2078 filter: Expr,
2079 index_info: &dyn IndexInformationProvider,
2080 use_scalar_index: bool,
2081 ) -> Result<FilterPlan>;
2082}
2083
2084impl PlannerIndexExt for Planner {
2085 fn create_filter_plan(
2086 &self,
2087 filter: Expr,
2088 index_info: &dyn IndexInformationProvider,
2089 use_scalar_index: bool,
2090 ) -> Result<FilterPlan> {
2091 let logical_expr = self.optimize_expr(filter)?;
2092 if use_scalar_index {
2093 let indexed_expr = apply_scalar_indices(logical_expr.clone(), index_info)?;
2094 let mut skip_recheck = false;
2095 if let Some(scalar_query) = indexed_expr.scalar_query.as_ref() {
2096 skip_recheck = !scalar_query.needs_recheck();
2097 }
2098 Ok(FilterPlan {
2099 index_query: indexed_expr.scalar_query,
2100 refine_expr: indexed_expr.refine_expr,
2101 full_expr: Some(logical_expr),
2102 skip_recheck,
2103 })
2104 } else {
2105 Ok(FilterPlan {
2106 index_query: None,
2107 skip_recheck: true,
2108 refine_expr: Some(logical_expr.clone()),
2109 full_expr: Some(logical_expr),
2110 })
2111 }
2112 }
2113}
2114
2115#[cfg(test)]
2116mod tests {
2117 use std::collections::HashMap;
2118
2119 use arrow_schema::{Field, Schema};
2120 use chrono::Utc;
2121 use datafusion_common::{Column, DFSchema};
2122 use datafusion_expr::simplify::SimplifyContext;
2123 use lance_datafusion::exec::{LanceExecutionOptions, get_session_context};
2124
2125 use crate::scalar::json::{JsonQuery, JsonQueryParser};
2126
2127 use super::*;
2128
2129 struct ColInfo {
2130 data_type: DataType,
2131 parser: Box<dyn ScalarQueryParser>,
2132 }
2133
2134 impl ColInfo {
2135 fn new(data_type: DataType, parser: Box<dyn ScalarQueryParser>) -> Self {
2136 Self { data_type, parser }
2137 }
2138 }
2139
2140 struct MockIndexInfoProvider {
2141 indexed_columns: HashMap<String, ColInfo>,
2142 }
2143
2144 impl MockIndexInfoProvider {
2145 fn new(indexed_columns: Vec<(&str, ColInfo)>) -> Self {
2146 Self {
2147 indexed_columns: HashMap::from_iter(
2148 indexed_columns
2149 .into_iter()
2150 .map(|(s, ty)| (s.to_string(), ty)),
2151 ),
2152 }
2153 }
2154 }
2155
2156 impl IndexInformationProvider for MockIndexInfoProvider {
2157 fn get_index(&self, col: &str) -> Option<(&DataType, &dyn ScalarQueryParser)> {
2158 self.indexed_columns
2159 .get(col)
2160 .map(|col_info| (&col_info.data_type, col_info.parser.as_ref()))
2161 }
2162 }
2163
2164 fn check(
2165 index_info: &dyn IndexInformationProvider,
2166 expr: &str,
2167 expected: Option<IndexedExpression>,
2168 optimize: bool,
2169 ) {
2170 let schema = Schema::new(vec![
2171 Field::new("color", DataType::Utf8, false),
2172 Field::new("size", DataType::Float32, false),
2173 Field::new("aisle", DataType::UInt32, false),
2174 Field::new("on_sale", DataType::Boolean, false),
2175 Field::new("price", DataType::Float32, false),
2176 Field::new("json", DataType::LargeBinary, false),
2177 ]);
2178 let df_schema: DFSchema = schema.try_into().unwrap();
2179
2180 let ctx = get_session_context(&LanceExecutionOptions::default());
2181 let state = ctx.state();
2182 let mut expr = state.create_logical_expr(expr, &df_schema).unwrap();
2183 if optimize {
2184 let simplify_context = SimplifyContext::default()
2185 .with_schema(Arc::new(df_schema))
2186 .with_query_execution_start_time(Some(Utc::now()));
2187 let simplifier =
2188 datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
2189 expr = simplifier.simplify(expr).unwrap();
2190 }
2191
2192 let actual = apply_scalar_indices(expr.clone(), index_info).unwrap();
2193 if let Some(expected) = expected {
2194 assert_eq!(actual, expected);
2195 } else {
2196 assert!(actual.scalar_query.is_none());
2197 assert_eq!(actual.refine_expr.unwrap(), expr);
2198 }
2199 }
2200
2201 fn check_no_index(index_info: &dyn IndexInformationProvider, expr: &str) {
2202 check(index_info, expr, None, false)
2203 }
2204
2205 fn check_simple(
2206 index_info: &dyn IndexInformationProvider,
2207 expr: &str,
2208 col: &str,
2209 query: impl AnyQuery,
2210 ) {
2211 check(
2212 index_info,
2213 expr,
2214 Some(IndexedExpression::index_query(
2215 col.to_string(),
2216 format!("{}_idx", col),
2217 "BTree".to_string(),
2218 Arc::new(query),
2219 )),
2220 false,
2221 )
2222 }
2223
2224 fn check_range(
2225 index_info: &dyn IndexInformationProvider,
2226 expr: &str,
2227 col: &str,
2228 query: SargableQuery,
2229 ) {
2230 check(
2231 index_info,
2232 expr,
2233 Some(IndexedExpression::index_query(
2234 col.to_string(),
2235 format!("{}_idx", col),
2236 "BTree".to_string(),
2237 Arc::new(query),
2238 )),
2239 true,
2240 )
2241 }
2242
2243 fn check_simple_negated(
2244 index_info: &dyn IndexInformationProvider,
2245 expr: &str,
2246 col: &str,
2247 query: SargableQuery,
2248 ) {
2249 check(
2250 index_info,
2251 expr,
2252 Some(
2253 IndexedExpression::index_query(
2254 col.to_string(),
2255 format!("{}_idx", col),
2256 "BTree".to_string(),
2257 Arc::new(query),
2258 )
2259 .maybe_not()
2260 .unwrap(),
2261 ),
2262 false,
2263 )
2264 }
2265
2266 #[test]
2267 fn test_expressions() {
2268 let index_info = MockIndexInfoProvider::new(vec![
2269 (
2270 "color",
2271 ColInfo::new(
2272 DataType::Utf8,
2273 Box::new(SargableQueryParser::new(
2274 "color_idx".to_string(),
2275 "BTree".to_string(),
2276 false,
2277 )),
2278 ),
2279 ),
2280 (
2281 "aisle",
2282 ColInfo::new(
2283 DataType::UInt32,
2284 Box::new(SargableQueryParser::new(
2285 "aisle_idx".to_string(),
2286 "BTree".to_string(),
2287 false,
2288 )),
2289 ),
2290 ),
2291 (
2292 "on_sale",
2293 ColInfo::new(
2294 DataType::Boolean,
2295 Box::new(SargableQueryParser::new(
2296 "on_sale_idx".to_string(),
2297 "BTree".to_string(),
2298 false,
2299 )),
2300 ),
2301 ),
2302 (
2303 "price",
2304 ColInfo::new(
2305 DataType::Float32,
2306 Box::new(SargableQueryParser::new(
2307 "price_idx".to_string(),
2308 "BTree".to_string(),
2309 false,
2310 )),
2311 ),
2312 ),
2313 (
2314 "json",
2315 ColInfo::new(
2316 DataType::LargeBinary,
2317 Box::new(JsonQueryParser::new(
2318 "$.name".to_string(),
2319 Box::new(SargableQueryParser::new(
2320 "json_idx".to_string(),
2321 "BTree".to_string(),
2322 false,
2323 )),
2324 )),
2325 ),
2326 ),
2327 ]);
2328
2329 check_simple(
2330 &index_info,
2331 "json_extract(json, '$.name') = 'foo'",
2332 "json",
2333 JsonQuery::new(
2334 Arc::new(SargableQuery::Equals(ScalarValue::Utf8(Some(
2335 "foo".to_string(),
2336 )))),
2337 "$.name".to_string(),
2338 ),
2339 );
2340
2341 check_no_index(&index_info, "size BETWEEN 5 AND 10");
2342 check_simple(
2344 &index_info,
2345 "aisle = arrow_cast(5, 'Int16')",
2346 "aisle",
2347 SargableQuery::Equals(ScalarValue::UInt32(Some(5))),
2348 );
2349 check_range(
2351 &index_info,
2352 "aisle BETWEEN 5 AND 10",
2353 "aisle",
2354 SargableQuery::Range(
2355 Bound::Included(ScalarValue::UInt32(Some(5))),
2356 Bound::Included(ScalarValue::UInt32(Some(10))),
2357 ),
2358 );
2359 check_range(
2360 &index_info,
2361 "aisle >= 5 AND aisle <= 10",
2362 "aisle",
2363 SargableQuery::Range(
2364 Bound::Included(ScalarValue::UInt32(Some(5))),
2365 Bound::Included(ScalarValue::UInt32(Some(10))),
2366 ),
2367 );
2368
2369 check_range(
2370 &index_info,
2371 "aisle <= 10 AND aisle >= 5",
2372 "aisle",
2373 SargableQuery::Range(
2374 Bound::Included(ScalarValue::UInt32(Some(5))),
2375 Bound::Included(ScalarValue::UInt32(Some(10))),
2376 ),
2377 );
2378
2379 check_range(
2380 &index_info,
2381 "5 <= aisle AND 10 >= aisle",
2382 "aisle",
2383 SargableQuery::Range(
2384 Bound::Included(ScalarValue::UInt32(Some(5))),
2385 Bound::Included(ScalarValue::UInt32(Some(10))),
2386 ),
2387 );
2388
2389 check_range(
2390 &index_info,
2391 "10 >= aisle AND 5 <= aisle",
2392 "aisle",
2393 SargableQuery::Range(
2394 Bound::Included(ScalarValue::UInt32(Some(5))),
2395 Bound::Included(ScalarValue::UInt32(Some(10))),
2396 ),
2397 );
2398 check_range(
2399 &index_info,
2400 "aisle <= 10 AND aisle > 5",
2401 "aisle",
2402 SargableQuery::Range(
2403 Bound::Excluded(ScalarValue::UInt32(Some(5))),
2404 Bound::Included(ScalarValue::UInt32(Some(10))),
2405 ),
2406 );
2407 check_range(
2408 &index_info,
2409 "aisle < 10 AND aisle >= 5",
2410 "aisle",
2411 SargableQuery::Range(
2412 Bound::Included(ScalarValue::UInt32(Some(5))),
2413 Bound::Excluded(ScalarValue::UInt32(Some(10))),
2414 ),
2415 );
2416 check_simple(
2417 &index_info,
2418 "on_sale IS TRUE",
2419 "on_sale",
2420 SargableQuery::Equals(ScalarValue::Boolean(Some(true))),
2421 );
2422 check_simple(
2423 &index_info,
2424 "on_sale",
2425 "on_sale",
2426 SargableQuery::Equals(ScalarValue::Boolean(Some(true))),
2427 );
2428 check_simple_negated(
2429 &index_info,
2430 "NOT on_sale",
2431 "on_sale",
2432 SargableQuery::Equals(ScalarValue::Boolean(Some(true))),
2433 );
2434 check_simple(
2435 &index_info,
2436 "on_sale IS FALSE",
2437 "on_sale",
2438 SargableQuery::Equals(ScalarValue::Boolean(Some(false))),
2439 );
2440 check_simple_negated(
2441 &index_info,
2442 "aisle NOT BETWEEN 5 AND 10",
2443 "aisle",
2444 SargableQuery::Range(
2445 Bound::Included(ScalarValue::UInt32(Some(5))),
2446 Bound::Included(ScalarValue::UInt32(Some(10))),
2447 ),
2448 );
2449 check_simple(
2451 &index_info,
2452 "aisle IN (5, 6, 7)",
2453 "aisle",
2454 SargableQuery::IsIn(vec![
2455 ScalarValue::UInt32(Some(5)),
2456 ScalarValue::UInt32(Some(6)),
2457 ScalarValue::UInt32(Some(7)),
2458 ]),
2459 );
2460 check_simple_negated(
2461 &index_info,
2462 "NOT aisle IN (5, 6, 7)",
2463 "aisle",
2464 SargableQuery::IsIn(vec![
2465 ScalarValue::UInt32(Some(5)),
2466 ScalarValue::UInt32(Some(6)),
2467 ScalarValue::UInt32(Some(7)),
2468 ]),
2469 );
2470 check_simple_negated(
2471 &index_info,
2472 "aisle NOT IN (5, 6, 7)",
2473 "aisle",
2474 SargableQuery::IsIn(vec![
2475 ScalarValue::UInt32(Some(5)),
2476 ScalarValue::UInt32(Some(6)),
2477 ScalarValue::UInt32(Some(7)),
2478 ]),
2479 );
2480 check_simple(
2481 &index_info,
2482 "aisle IN (5, 6, 7, 8, 9)",
2483 "aisle",
2484 SargableQuery::IsIn(vec![
2485 ScalarValue::UInt32(Some(5)),
2486 ScalarValue::UInt32(Some(6)),
2487 ScalarValue::UInt32(Some(7)),
2488 ScalarValue::UInt32(Some(8)),
2489 ScalarValue::UInt32(Some(9)),
2490 ]),
2491 );
2492 check_simple_negated(
2493 &index_info,
2494 "NOT aisle IN (5, 6, 7, 8, 9)",
2495 "aisle",
2496 SargableQuery::IsIn(vec![
2497 ScalarValue::UInt32(Some(5)),
2498 ScalarValue::UInt32(Some(6)),
2499 ScalarValue::UInt32(Some(7)),
2500 ScalarValue::UInt32(Some(8)),
2501 ScalarValue::UInt32(Some(9)),
2502 ]),
2503 );
2504 check_simple_negated(
2505 &index_info,
2506 "aisle NOT IN (5, 6, 7, 8, 9)",
2507 "aisle",
2508 SargableQuery::IsIn(vec![
2509 ScalarValue::UInt32(Some(5)),
2510 ScalarValue::UInt32(Some(6)),
2511 ScalarValue::UInt32(Some(7)),
2512 ScalarValue::UInt32(Some(8)),
2513 ScalarValue::UInt32(Some(9)),
2514 ]),
2515 );
2516 check_simple(
2517 &index_info,
2518 "on_sale is false",
2519 "on_sale",
2520 SargableQuery::Equals(ScalarValue::Boolean(Some(false))),
2521 );
2522 check_simple(
2523 &index_info,
2524 "on_sale is true",
2525 "on_sale",
2526 SargableQuery::Equals(ScalarValue::Boolean(Some(true))),
2527 );
2528 check_simple(
2529 &index_info,
2530 "aisle < 10",
2531 "aisle",
2532 SargableQuery::Range(
2533 Bound::Unbounded,
2534 Bound::Excluded(ScalarValue::UInt32(Some(10))),
2535 ),
2536 );
2537 check_simple(
2538 &index_info,
2539 "aisle <= 10",
2540 "aisle",
2541 SargableQuery::Range(
2542 Bound::Unbounded,
2543 Bound::Included(ScalarValue::UInt32(Some(10))),
2544 ),
2545 );
2546 check_simple(
2547 &index_info,
2548 "aisle > 10",
2549 "aisle",
2550 SargableQuery::Range(
2551 Bound::Excluded(ScalarValue::UInt32(Some(10))),
2552 Bound::Unbounded,
2553 ),
2554 );
2555 check_no_index(&index_info, "10 > aisle");
2559 check_simple(
2560 &index_info,
2561 "aisle >= 10",
2562 "aisle",
2563 SargableQuery::Range(
2564 Bound::Included(ScalarValue::UInt32(Some(10))),
2565 Bound::Unbounded,
2566 ),
2567 );
2568 check_simple(
2569 &index_info,
2570 "aisle = 10",
2571 "aisle",
2572 SargableQuery::Equals(ScalarValue::UInt32(Some(10))),
2573 );
2574 check_simple_negated(
2575 &index_info,
2576 "aisle <> 10",
2577 "aisle",
2578 SargableQuery::Equals(ScalarValue::UInt32(Some(10))),
2579 );
2580 let left = Box::new(ScalarIndexExpr::Query(ScalarIndexSearch {
2582 column: "aisle".to_string(),
2583 index_name: "aisle_idx".to_string(),
2584 index_type: "BTree".to_string(),
2585 query: Arc::new(SargableQuery::Equals(ScalarValue::UInt32(Some(10)))),
2586 needs_recheck: false,
2587 }));
2588 let right = Box::new(ScalarIndexExpr::Query(ScalarIndexSearch {
2589 column: "color".to_string(),
2590 index_name: "color_idx".to_string(),
2591 index_type: "BTree".to_string(),
2592 query: Arc::new(SargableQuery::Equals(ScalarValue::Utf8(Some(
2593 "blue".to_string(),
2594 )))),
2595 needs_recheck: false,
2596 }));
2597 check(
2598 &index_info,
2599 "aisle = 10 AND color = 'blue'",
2600 Some(IndexedExpression {
2601 scalar_query: Some(ScalarIndexExpr::And(left.clone(), right.clone())),
2602 refine_expr: None,
2603 }),
2604 false,
2605 );
2606 let refine = Expr::Column(Column::new_unqualified("size")).gt(datafusion_expr::lit(30_i64));
2608 check(
2609 &index_info,
2610 "aisle = 10 AND color = 'blue' AND size > 30",
2611 Some(IndexedExpression {
2612 scalar_query: Some(ScalarIndexExpr::And(left.clone(), right.clone())),
2613 refine_expr: Some(refine.clone()),
2614 }),
2615 false,
2616 );
2617 check(
2619 &index_info,
2620 "aisle = 10 OR color = 'blue'",
2621 Some(IndexedExpression {
2622 scalar_query: Some(ScalarIndexExpr::Or(left.clone(), right.clone())),
2623 refine_expr: None,
2624 }),
2625 false,
2626 );
2627 check_no_index(&index_info, "aisle = 10 OR color = 'blue' OR size > 30");
2629 check(
2631 &index_info,
2632 "(aisle = 10 OR color = 'blue') AND size > 30",
2633 Some(IndexedExpression {
2634 scalar_query: Some(ScalarIndexExpr::Or(left, right)),
2635 refine_expr: Some(refine),
2636 }),
2637 false,
2638 );
2639 check_no_index(
2643 &index_info,
2644 "(aisle = 10 AND size > 30) OR (color = 'blue' AND size > 20)",
2645 );
2646
2647 check_no_index(&index_info, "aisle + 3 < 10");
2649
2650 check_no_index(&index_info, "aisle IN (5, 6, NULL)");
2655 check_no_index(&index_info, "aisle = 5 OR aisle = 6 OR NULL");
2658 check_no_index(&index_info, "aisle IN (5, 6, 7, 8, NULL)");
2659 check_no_index(&index_info, "aisle = NULL");
2660 check_no_index(&index_info, "aisle BETWEEN 5 AND NULL");
2661 check_no_index(&index_info, "aisle BETWEEN NULL AND 10");
2662 }
2663
2664 #[tokio::test]
2665 async fn test_not_flips_certainty() {
2666 use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap};
2667
2668 fn apply_not(result: NullableIndexExprResult) -> NullableIndexExprResult {
2673 match result {
2674 NullableIndexExprResult::Exact(mask) => NullableIndexExprResult::Exact(!mask),
2675 NullableIndexExprResult::AtMost(mask) => NullableIndexExprResult::AtLeast(!mask),
2676 NullableIndexExprResult::AtLeast(mask) => NullableIndexExprResult::AtMost(!mask),
2677 }
2678 }
2679
2680 let at_most = NullableIndexExprResult::AtMost(NullableRowAddrMask::AllowList(
2682 NullableRowAddrSet::new(RowAddrTreeMap::from_iter(&[1, 2]), RowAddrTreeMap::new()),
2683 ));
2684 assert!(matches!(
2686 apply_not(at_most),
2687 NullableIndexExprResult::AtLeast(_)
2688 ));
2689
2690 let at_least = NullableIndexExprResult::AtLeast(NullableRowAddrMask::AllowList(
2692 NullableRowAddrSet::new(RowAddrTreeMap::from_iter(&[1, 2]), RowAddrTreeMap::new()),
2693 ));
2694 assert!(matches!(
2696 apply_not(at_least),
2697 NullableIndexExprResult::AtMost(_)
2698 ));
2699
2700 let exact = NullableIndexExprResult::Exact(NullableRowAddrMask::AllowList(
2702 NullableRowAddrSet::new(RowAddrTreeMap::from_iter(&[1, 2]), RowAddrTreeMap::new()),
2703 ));
2704 assert!(matches!(
2705 apply_not(exact),
2706 NullableIndexExprResult::Exact(_)
2707 ));
2708 }
2709
2710 #[tokio::test]
2711 async fn test_and_or_preserve_certainty() {
2712 use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap};
2713
2714 let make_at_most = || {
2716 NullableIndexExprResult::AtMost(NullableRowAddrMask::AllowList(
2717 NullableRowAddrSet::new(
2718 RowAddrTreeMap::from_iter(&[1, 2, 3]),
2719 RowAddrTreeMap::new(),
2720 ),
2721 ))
2722 };
2723
2724 let make_at_least = || {
2725 NullableIndexExprResult::AtLeast(NullableRowAddrMask::AllowList(
2726 NullableRowAddrSet::new(
2727 RowAddrTreeMap::from_iter(&[2, 3, 4]),
2728 RowAddrTreeMap::new(),
2729 ),
2730 ))
2731 };
2732
2733 let make_exact = || {
2734 NullableIndexExprResult::Exact(NullableRowAddrMask::AllowList(NullableRowAddrSet::new(
2735 RowAddrTreeMap::from_iter(&[1, 2]),
2736 RowAddrTreeMap::new(),
2737 )))
2738 };
2739
2740 assert!(matches!(
2742 make_at_most() & make_at_most(),
2743 NullableIndexExprResult::AtMost(_)
2744 ));
2745
2746 assert!(matches!(
2748 make_at_least() & make_at_least(),
2749 NullableIndexExprResult::AtLeast(_)
2750 ));
2751
2752 assert!(matches!(
2754 make_at_most() & make_at_least(),
2755 NullableIndexExprResult::AtMost(_)
2756 ));
2757
2758 assert!(matches!(
2760 make_at_most() | make_at_most(),
2761 NullableIndexExprResult::AtMost(_)
2762 ));
2763
2764 assert!(matches!(
2766 make_at_least() | make_at_least(),
2767 NullableIndexExprResult::AtLeast(_)
2768 ));
2769
2770 assert!(matches!(
2772 make_at_most() | make_at_least(),
2773 NullableIndexExprResult::AtLeast(_)
2774 ));
2775
2776 assert!(matches!(
2778 make_exact() & make_at_most(),
2779 NullableIndexExprResult::AtMost(_)
2780 ));
2781
2782 assert!(matches!(
2784 make_exact() | make_at_least(),
2785 NullableIndexExprResult::AtLeast(_)
2786 ));
2787 }
2788
2789 #[test]
2790 fn test_extract_like_leading_prefix() {
2791 assert_eq!(
2793 extract_like_leading_prefix("foo%", None),
2794 Some(("foo".to_string(), false))
2795 );
2796 assert_eq!(
2797 extract_like_leading_prefix("abc%", None),
2798 Some(("abc".to_string(), false))
2799 );
2800
2801 assert_eq!(
2803 extract_like_leading_prefix("foo%bar%", None),
2804 Some(("foo".to_string(), true))
2805 );
2806 assert_eq!(
2807 extract_like_leading_prefix("foo_bar%", None),
2808 Some(("foo".to_string(), true))
2809 );
2810 assert_eq!(
2811 extract_like_leading_prefix("foo%bar", None),
2812 Some(("foo".to_string(), true))
2813 );
2814 assert_eq!(
2815 extract_like_leading_prefix("foo_", None),
2816 Some(("foo".to_string(), true))
2817 );
2818
2819 assert_eq!(extract_like_leading_prefix("%foo", None), None);
2821 assert_eq!(extract_like_leading_prefix("_foo%", None), None);
2822 assert_eq!(extract_like_leading_prefix("%", None), None);
2823
2824 assert_eq!(extract_like_leading_prefix("foo", None), None);
2826
2827 assert_eq!(
2829 extract_like_leading_prefix(r"foo\%bar%", Some('\\')),
2830 Some(("foo%bar".to_string(), false))
2831 );
2832 assert_eq!(
2833 extract_like_leading_prefix(r"foo\_bar%", Some('\\')),
2834 Some(("foo_bar".to_string(), false))
2835 );
2836 assert_eq!(
2837 extract_like_leading_prefix(r"foo\\bar%", Some('\\')),
2838 Some(("foo\\bar".to_string(), false))
2839 );
2840
2841 assert_eq!(extract_like_leading_prefix(r"foo\%", Some('\\')), None);
2843
2844 assert_eq!(extract_like_leading_prefix(r"foo\%", None), None);
2847 assert_eq!(
2849 extract_like_leading_prefix(r"foo\bar%", None),
2850 Some(("foo\\bar".to_string(), false))
2851 );
2852
2853 assert_eq!(extract_like_leading_prefix("", None), None);
2855
2856 assert_eq!(
2858 extract_like_leading_prefix(r"foo\%bar%baz%", Some('\\')),
2859 Some(("foo%bar".to_string(), true))
2860 );
2861 }
2862
2863 #[test]
2864 fn test_like_expression_parsing() {
2865 let index_info = MockIndexInfoProvider::new(vec![(
2868 "color",
2869 ColInfo::new(
2870 DataType::Utf8,
2871 Box::new(SargableQueryParser::new(
2872 "color_idx".to_string(),
2873 "BTree".to_string(),
2874 false,
2875 )),
2876 ),
2877 )]);
2878
2879 let schema = Schema::new(vec![Field::new("color", DataType::Utf8, false)]);
2881 let df_schema: DFSchema = schema.try_into().unwrap();
2882 let ctx = get_session_context(&LanceExecutionOptions::default());
2883 let state = ctx.state();
2884
2885 let expr = state
2886 .create_logical_expr("color LIKE 'foo%'", &df_schema)
2887 .unwrap();
2888 let result = apply_scalar_indices(expr, &index_info).unwrap();
2889
2890 assert!(result.scalar_query.is_some(), "Should have scalar_query");
2891 assert!(
2892 result.refine_expr.is_none(),
2893 "Simple prefix should not need refine_expr"
2894 );
2895
2896 if let Some(ScalarIndexExpr::Query(search)) = &result.scalar_query {
2898 let query = search.query.as_any().downcast_ref::<SargableQuery>();
2899 assert!(query.is_some(), "Query should be SargableQuery");
2900 match query.unwrap() {
2901 SargableQuery::LikePrefix(prefix) => {
2902 assert_eq!(prefix, &ScalarValue::Utf8(Some("foo".to_string())));
2903 }
2904 _ => panic!("Expected LikePrefix query"),
2905 }
2906 } else {
2907 panic!("Expected Query variant");
2908 }
2909
2910 let expr = state
2912 .create_logical_expr("color LIKE 'foo%bar%'", &df_schema)
2913 .unwrap();
2914 let result = apply_scalar_indices(expr, &index_info).unwrap();
2915
2916 assert!(result.scalar_query.is_some(), "Should have scalar_query");
2917 assert!(
2918 result.refine_expr.is_some(),
2919 "Complex pattern should have refine_expr"
2920 );
2921
2922 if let Some(ScalarIndexExpr::Query(search)) = &result.scalar_query {
2924 let query = search.query.as_any().downcast_ref::<SargableQuery>();
2925 assert!(query.is_some(), "Query should be SargableQuery");
2926 match query.unwrap() {
2927 SargableQuery::LikePrefix(prefix) => {
2928 assert_eq!(prefix, &ScalarValue::Utf8(Some("foo".to_string())));
2929 }
2930 _ => panic!("Expected LikePrefix query"),
2931 }
2932 }
2933
2934 let refine = result.refine_expr.unwrap();
2936 match refine {
2937 Expr::Like(like) => {
2938 assert!(!like.negated);
2939 assert!(!like.case_insensitive);
2940 if let Expr::Literal(ScalarValue::Utf8(Some(pattern)), _) = like.pattern.as_ref() {
2941 assert_eq!(pattern, "foo%bar%");
2942 } else {
2943 panic!("Expected Utf8 literal pattern");
2944 }
2945 }
2946 _ => panic!("Expected Like expression in refine_expr"),
2947 }
2948
2949 let expr = state
2951 .create_logical_expr("color LIKE '%foo'", &df_schema)
2952 .unwrap();
2953 let result = apply_scalar_indices(expr, &index_info).unwrap();
2954
2955 assert!(
2956 result.scalar_query.is_none(),
2957 "Pattern starting with wildcard should not use index"
2958 );
2959 assert!(result.refine_expr.is_some(), "Should fall back to refine");
2960 }
2961
2962 #[test]
2963 fn test_starts_with_with_underscore_after_optimization() {
2964 let index_info = MockIndexInfoProvider::new(vec![(
2968 "object_id",
2969 ColInfo::new(
2970 DataType::Utf8,
2971 Box::new(SargableQueryParser::new(
2972 "object_id_idx".to_string(),
2973 "BTree".to_string(),
2974 false,
2975 )),
2976 ),
2977 )]);
2978
2979 let schema = Schema::new(vec![Field::new("object_id", DataType::Utf8, false)]);
2980 let df_schema: DFSchema = schema.try_into().unwrap();
2981 let ctx = get_session_context(&LanceExecutionOptions::default());
2982 let state = ctx.state();
2983
2984 let expr = state
2986 .create_logical_expr("starts_with(object_id, 'test_ns$')", &df_schema)
2987 .unwrap();
2988
2989 let simplify_context = SimplifyContext::default()
2991 .with_schema(Arc::new(df_schema))
2992 .with_query_execution_start_time(Some(Utc::now()));
2993 let simplifier =
2994 datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
2995 let simplified_expr = simplifier.simplify(expr).unwrap();
2996
2997 let result = apply_scalar_indices(simplified_expr, &index_info).unwrap();
2999
3000 if let Some(ScalarIndexExpr::Query(search)) = &result.scalar_query {
3003 let query = search
3004 .query
3005 .as_any()
3006 .downcast_ref::<SargableQuery>()
3007 .unwrap();
3008 match query {
3009 SargableQuery::LikePrefix(prefix) => {
3010 let prefix_str = match prefix {
3011 ScalarValue::Utf8(Some(s)) => s.clone(),
3012 _ => panic!("Expected Utf8 prefix"),
3013 };
3014 assert_eq!(
3016 prefix_str, "test_ns$",
3017 "Prefix should be 'test_ns$', not 'test' (underscore should not be a wildcard)"
3018 );
3019 }
3020 _ => panic!("Expected LikePrefix query"),
3021 }
3022 } else {
3023 panic!("Expected scalar_query to be present");
3025 }
3026 }
3027
3028 #[test]
3029 fn test_starts_with_to_like_conversion() {
3030 let index_info = MockIndexInfoProvider::new(vec![(
3032 "color",
3033 ColInfo::new(
3034 DataType::Utf8,
3035 Box::new(SargableQueryParser::new(
3036 "color_idx".to_string(),
3037 "BTree".to_string(),
3038 false,
3039 )),
3040 ),
3041 )]);
3042
3043 let schema = Schema::new(vec![Field::new("color", DataType::Utf8, false)]);
3044 let df_schema: DFSchema = schema.try_into().unwrap();
3045 let ctx = get_session_context(&LanceExecutionOptions::default());
3046 let state = ctx.state();
3047
3048 let expr = state
3050 .create_logical_expr("starts_with(color, 'foo')", &df_schema)
3051 .unwrap();
3052 let result = apply_scalar_indices(expr, &index_info).unwrap();
3053
3054 assert!(
3055 result.scalar_query.is_some(),
3056 "starts_with should use index"
3057 );
3058 assert!(
3059 result.refine_expr.is_none(),
3060 "Pure prefix starts_with should not need refine_expr"
3061 );
3062
3063 if let Some(ScalarIndexExpr::Query(search)) = &result.scalar_query {
3065 let query = search.query.as_any().downcast_ref::<SargableQuery>();
3066 assert!(query.is_some(), "Query should be SargableQuery");
3067 match query.unwrap() {
3068 SargableQuery::LikePrefix(prefix) => {
3069 assert_eq!(prefix, &ScalarValue::Utf8(Some("foo".to_string())));
3070 }
3071 _ => panic!("Expected LikePrefix query"),
3072 }
3073 } else {
3074 panic!("Expected Query variant");
3075 }
3076
3077 let like_expr = state
3079 .create_logical_expr("color LIKE 'foo%'", &df_schema)
3080 .unwrap();
3081 let like_result = apply_scalar_indices(like_expr, &index_info).unwrap();
3082
3083 if let (
3085 Some(ScalarIndexExpr::Query(starts_with_search)),
3086 Some(ScalarIndexExpr::Query(like_search)),
3087 ) = (&result.scalar_query, &like_result.scalar_query)
3088 {
3089 let sw_query = starts_with_search
3090 .query
3091 .as_any()
3092 .downcast_ref::<SargableQuery>()
3093 .unwrap();
3094 let like_query = like_search
3095 .query
3096 .as_any()
3097 .downcast_ref::<SargableQuery>()
3098 .unwrap();
3099 assert_eq!(
3100 sw_query, like_query,
3101 "starts_with and LIKE 'prefix%' should produce identical queries"
3102 );
3103 }
3104 }
3105}