1use std::any::Any;
21use std::fmt::Debug;
22use std::hash::{Hash, Hasher};
23use std::sync::Arc;
24
25use crate::physical_expr::physical_exprs_bag_equal;
26use crate::PhysicalExpr;
27
28use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano};
29use arrow::array::*;
30use arrow::buffer::BooleanBuffer;
31use arrow::compute::kernels::boolean::{not, or_kleene};
32use arrow::compute::take;
33use arrow::datatypes::*;
34use arrow::util::bit_iterator::BitIndexIterator;
35use arrow::{downcast_dictionary_array, downcast_primitive_array};
36use datafusion_common::cast::{
37 as_boolean_array, as_generic_binary_array, as_string_array,
38};
39use datafusion_common::hash_utils::HashValue;
40use datafusion_common::{
41 exec_err, internal_err, not_impl_err, DFSchema, Result, ScalarValue,
42};
43use datafusion_expr::ColumnarValue;
44use datafusion_physical_expr_common::datum::compare_with_eq;
45
46use ahash::RandomState;
47use datafusion_common::HashMap;
48use hashbrown::hash_map::RawEntryMut;
49
50pub struct InListExpr {
52 expr: Arc<dyn PhysicalExpr>,
53 list: Vec<Arc<dyn PhysicalExpr>>,
54 negated: bool,
55 static_filter: Option<Arc<dyn Set>>,
56}
57
58impl Debug for InListExpr {
59 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
60 f.debug_struct("InListExpr")
61 .field("expr", &self.expr)
62 .field("list", &self.list)
63 .field("negated", &self.negated)
64 .finish()
65 }
66}
67
68pub trait Set: Send + Sync {
70 fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray>;
71 fn has_nulls(&self) -> bool;
72}
73
74struct ArrayHashSet {
75 state: RandomState,
76 map: HashMap<usize, (), ()>,
81}
82
83struct ArraySet<T> {
84 array: T,
85 hash_set: ArrayHashSet,
86}
87
88impl<T> ArraySet<T>
89where
90 T: Array + From<ArrayData>,
91{
92 fn new(array: &T, hash_set: ArrayHashSet) -> Self {
93 Self {
94 array: downcast_array(array),
95 hash_set,
96 }
97 }
98}
99
100impl<T> Set for ArraySet<T>
101where
102 T: Array + 'static,
103 for<'a> &'a T: ArrayAccessor,
104 for<'a> <&'a T as ArrayAccessor>::Item: IsEqual,
105{
106 fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
107 downcast_dictionary_array! {
108 v => {
109 let values_contains = self.contains(v.values().as_ref(), negated)?;
110 let result = take(&values_contains, v.keys(), None)?;
111 return Ok(downcast_array(result.as_ref()))
112 }
113 _ => {}
114 }
115
116 let v = v.as_any().downcast_ref::<T>().unwrap();
117 let in_array = &self.array;
118 let has_nulls = in_array.null_count() != 0;
119
120 Ok(ArrayIter::new(v)
121 .map(|v| {
122 v.and_then(|v| {
123 let hash = v.hash_one(&self.hash_set.state);
124 let contains = self
125 .hash_set
126 .map
127 .raw_entry()
128 .from_hash(hash, |idx| in_array.value(*idx).is_equal(&v))
129 .is_some();
130
131 match contains {
132 true => Some(!negated),
133 false if has_nulls => None,
134 false => Some(negated),
135 }
136 })
137 })
138 .collect())
139 }
140
141 fn has_nulls(&self) -> bool {
142 self.array.null_count() != 0
143 }
144}
145
146fn make_hash_set<T>(array: T) -> ArrayHashSet
153where
154 T: ArrayAccessor,
155 T::Item: IsEqual,
156{
157 let state = RandomState::new();
158 let mut map: HashMap<usize, (), ()> =
159 HashMap::with_capacity_and_hasher(array.len(), ());
160
161 let insert_value = |idx| {
162 let value = array.value(idx);
163 let hash = value.hash_one(&state);
164 if let RawEntryMut::Vacant(v) = map
165 .raw_entry_mut()
166 .from_hash(hash, |x| array.value(*x).is_equal(&value))
167 {
168 v.insert_with_hasher(hash, idx, (), |x| array.value(*x).hash_one(&state));
169 }
170 };
171
172 match array.nulls() {
173 Some(nulls) => {
174 BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len())
175 .for_each(insert_value)
176 }
177 None => (0..array.len()).for_each(insert_value),
178 }
179
180 ArrayHashSet { state, map }
181}
182
183fn make_set(array: &dyn Array) -> Result<Arc<dyn Set>> {
185 Ok(downcast_primitive_array! {
186 array => Arc::new(ArraySet::new(array, make_hash_set(array))),
187 DataType::Boolean => {
188 let array = as_boolean_array(array)?;
189 Arc::new(ArraySet::new(array, make_hash_set(array)))
190 },
191 DataType::Utf8 => {
192 let array = as_string_array(array)?;
193 Arc::new(ArraySet::new(array, make_hash_set(array)))
194 }
195 DataType::LargeUtf8 => {
196 let array = as_largestring_array(array);
197 Arc::new(ArraySet::new(array, make_hash_set(array)))
198 }
199 DataType::Binary => {
200 let array = as_generic_binary_array::<i32>(array)?;
201 Arc::new(ArraySet::new(array, make_hash_set(array)))
202 }
203 DataType::LargeBinary => {
204 let array = as_generic_binary_array::<i64>(array)?;
205 Arc::new(ArraySet::new(array, make_hash_set(array)))
206 }
207 DataType::Dictionary(_, _) => unreachable!("dictionary should have been flattened"),
208 d => return not_impl_err!("DataType::{d} not supported in InList")
209 })
210}
211
212fn evaluate_list(
214 list: &[Arc<dyn PhysicalExpr>],
215 batch: &RecordBatch,
216) -> Result<ArrayRef> {
217 let scalars = list
218 .iter()
219 .map(|expr| {
220 expr.evaluate(batch).and_then(|r| match r {
221 ColumnarValue::Array(_) => {
222 exec_err!("InList expression must evaluate to a scalar")
223 }
224 ColumnarValue::Scalar(ScalarValue::Dictionary(_, v)) => Ok(*v),
226 ColumnarValue::Scalar(s) => Ok(s),
227 })
228 })
229 .collect::<Result<Vec<_>>>()?;
230
231 ScalarValue::iter_to_array(scalars)
232}
233
234fn try_cast_static_filter_to_set(
235 list: &[Arc<dyn PhysicalExpr>],
236 schema: &Schema,
237) -> Result<Arc<dyn Set>> {
238 let batch = RecordBatch::new_empty(Arc::new(schema.clone()));
239 make_set(evaluate_list(list, &batch)?.as_ref())
240}
241
242trait IsEqual: HashValue {
244 fn is_equal(&self, other: &Self) -> bool;
245}
246
247impl<T: IsEqual + ?Sized> IsEqual for &T {
248 fn is_equal(&self, other: &Self) -> bool {
249 T::is_equal(self, other)
250 }
251}
252
253macro_rules! is_equal {
254 ($($t:ty),+) => {
255 $(impl IsEqual for $t {
256 fn is_equal(&self, other: &Self) -> bool {
257 self == other
258 }
259 })*
260 };
261}
262is_equal!(i8, i16, i32, i64, i128, i256, u8, u16, u32, u64);
263is_equal!(bool, str, [u8]);
264is_equal!(IntervalDayTime, IntervalMonthDayNano);
265
266macro_rules! is_equal_float {
267 ($($t:ty),+) => {
268 $(impl IsEqual for $t {
269 fn is_equal(&self, other: &Self) -> bool {
270 self.to_bits() == other.to_bits()
271 }
272 })*
273 };
274}
275is_equal_float!(half::f16, f32, f64);
276
277impl InListExpr {
278 pub fn new(
280 expr: Arc<dyn PhysicalExpr>,
281 list: Vec<Arc<dyn PhysicalExpr>>,
282 negated: bool,
283 static_filter: Option<Arc<dyn Set>>,
284 ) -> Self {
285 Self {
286 expr,
287 list,
288 negated,
289 static_filter,
290 }
291 }
292
293 pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
295 &self.expr
296 }
297
298 pub fn list(&self) -> &[Arc<dyn PhysicalExpr>] {
300 &self.list
301 }
302
303 pub fn negated(&self) -> bool {
305 self.negated
306 }
307}
308
309#[macro_export]
310macro_rules! expr_vec_fmt {
311 ( $ARRAY:expr ) => {{
312 $ARRAY
313 .iter()
314 .map(|e| format!("{e}"))
315 .collect::<Vec<String>>()
316 .join(", ")
317 }};
318}
319
320impl std::fmt::Display for InListExpr {
321 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
322 let list = expr_vec_fmt!(self.list);
323
324 if self.negated {
325 if self.static_filter.is_some() {
326 write!(f, "{} NOT IN (SET) ([{list}])", self.expr)
327 } else {
328 write!(f, "{} NOT IN ([{list}])", self.expr)
329 }
330 } else if self.static_filter.is_some() {
331 write!(f, "{} IN (SET) ([{list}])", self.expr)
332 } else {
333 write!(f, "{} IN ([{list}])", self.expr)
334 }
335 }
336}
337
338impl PhysicalExpr for InListExpr {
339 fn as_any(&self) -> &dyn Any {
341 self
342 }
343
344 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
345 Ok(DataType::Boolean)
346 }
347
348 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
349 if self.expr.nullable(input_schema)? {
350 return Ok(true);
351 }
352
353 if let Some(static_filter) = &self.static_filter {
354 Ok(static_filter.has_nulls())
355 } else {
356 for expr in &self.list {
357 if expr.nullable(input_schema)? {
358 return Ok(true);
359 }
360 }
361 Ok(false)
362 }
363 }
364
365 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
366 let num_rows = batch.num_rows();
367 let value = self.expr.evaluate(batch)?;
368 let r = match &self.static_filter {
369 Some(f) => f.contains(value.into_array(num_rows)?.as_ref(), self.negated)?,
370 None => {
371 let value = value.into_array(num_rows)?;
372 let is_nested = value.data_type().is_nested();
373 let found = self.list.iter().map(|expr| expr.evaluate(batch)).try_fold(
374 BooleanArray::new(BooleanBuffer::new_unset(num_rows), None),
375 |result, expr| -> Result<BooleanArray> {
376 let rhs = compare_with_eq(
377 &value,
378 &expr?.into_array(num_rows)?,
379 is_nested,
380 )?;
381 Ok(or_kleene(&result, &rhs)?)
382 },
383 )?;
384
385 if self.negated {
386 not(&found)?
387 } else {
388 found
389 }
390 }
391 };
392 Ok(ColumnarValue::Array(Arc::new(r)))
393 }
394
395 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
396 let mut children = vec![];
397 children.push(&self.expr);
398 children.extend(&self.list);
399 children
400 }
401
402 fn with_new_children(
403 self: Arc<Self>,
404 children: Vec<Arc<dyn PhysicalExpr>>,
405 ) -> Result<Arc<dyn PhysicalExpr>> {
406 Ok(Arc::new(InListExpr::new(
408 Arc::clone(&children[0]),
409 children[1..].to_vec(),
410 self.negated,
411 self.static_filter.clone(),
412 )))
413 }
414
415 fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416 self.expr.fmt_sql(f)?;
417 if self.negated {
418 write!(f, " NOT")?;
419 }
420
421 write!(f, " IN (")?;
422 for (i, expr) in self.list.iter().enumerate() {
423 if i > 0 {
424 write!(f, ", ")?;
425 }
426 expr.fmt_sql(f)?;
427 }
428 write!(f, ")")
429 }
430}
431
432impl PartialEq for InListExpr {
433 fn eq(&self, other: &Self) -> bool {
434 self.expr.eq(&other.expr)
435 && physical_exprs_bag_equal(&self.list, &other.list)
436 && self.negated == other.negated
437 }
438}
439
440impl Eq for InListExpr {}
441
442impl Hash for InListExpr {
443 fn hash<H: Hasher>(&self, state: &mut H) {
444 self.expr.hash(state);
445 self.negated.hash(state);
446 self.list.hash(state);
447 }
449}
450
451pub fn in_list(
453 expr: Arc<dyn PhysicalExpr>,
454 list: Vec<Arc<dyn PhysicalExpr>>,
455 negated: &bool,
456 schema: &Schema,
457) -> Result<Arc<dyn PhysicalExpr>> {
458 let expr_data_type = expr.data_type(schema)?;
460 for list_expr in list.iter() {
461 let list_expr_data_type = list_expr.data_type(schema)?;
462 if !DFSchema::datatype_is_logically_equal(&expr_data_type, &list_expr_data_type) {
463 return internal_err!(
464 "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}"
465 );
466 }
467 }
468 let static_filter = try_cast_static_filter_to_set(&list, schema).ok();
469 Ok(Arc::new(InListExpr::new(
470 expr,
471 list,
472 *negated,
473 static_filter,
474 )))
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480 use crate::expressions;
481 use crate::expressions::{col, lit, try_cast};
482 use datafusion_common::plan_err;
483 use datafusion_expr::type_coercion::binary::comparison_coercion;
484 use datafusion_physical_expr_common::physical_expr::fmt_sql;
485 use insta::assert_snapshot;
486 use itertools::Itertools as _;
487
488 type InListCastResult = (Arc<dyn PhysicalExpr>, Vec<Arc<dyn PhysicalExpr>>);
489
490 fn in_list_cast(
493 expr: Arc<dyn PhysicalExpr>,
494 list: Vec<Arc<dyn PhysicalExpr>>,
495 input_schema: &Schema,
496 ) -> Result<InListCastResult> {
497 let expr_type = &expr.data_type(input_schema)?;
498 let list_types: Vec<DataType> = list
499 .iter()
500 .map(|list_expr| list_expr.data_type(input_schema).unwrap())
501 .collect();
502 let result_type = get_coerce_type(expr_type, &list_types);
503 match result_type {
504 None => plan_err!(
505 "Can not find compatible types to compare {expr_type} with [{}]",
506 list_types.iter().join(", ")
507 ),
508 Some(data_type) => {
509 let cast_expr = try_cast(expr, input_schema, data_type.clone())?;
511 let cast_list_expr = list
512 .into_iter()
513 .map(|list_expr| {
514 try_cast(list_expr, input_schema, data_type.clone()).unwrap()
515 })
516 .collect();
517 Ok((cast_expr, cast_list_expr))
518 }
519 }
520 }
521
522 fn get_coerce_type(expr_type: &DataType, list_type: &[DataType]) -> Option<DataType> {
525 list_type
526 .iter()
527 .try_fold(expr_type.clone(), |left_type, right_type| {
528 comparison_coercion(&left_type, right_type)
529 })
530 }
531
532 macro_rules! in_list {
534 ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{
535 let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, $SCHEMA)?;
536 in_list_raw!(
537 $BATCH,
538 cast_list_exprs,
539 $NEGATED,
540 $EXPECTED,
541 cast_expr,
542 $SCHEMA
543 );
544 }};
545 }
546
547 macro_rules! in_list_raw {
549 ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{
550 let expr = in_list($COL, $LIST, $NEGATED, $SCHEMA).unwrap();
551 let result = expr
552 .evaluate(&$BATCH)?
553 .into_array($BATCH.num_rows())
554 .expect("Failed to convert to array");
555 let result =
556 as_boolean_array(&result).expect("failed to downcast to BooleanArray");
557 let expected = &BooleanArray::from($EXPECTED);
558 assert_eq!(expected, result);
559 }};
560 }
561
562 #[test]
563 fn in_list_utf8() -> Result<()> {
564 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
565 let a = StringArray::from(vec![Some("a"), Some("d"), None]);
566 let col_a = col("a", &schema)?;
567 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
568
569 let list = vec![lit("a"), lit("b")];
571 in_list!(
572 batch,
573 list,
574 &false,
575 vec![Some(true), Some(false), None],
576 Arc::clone(&col_a),
577 &schema
578 );
579
580 let list = vec![lit("a"), lit("b")];
582 in_list!(
583 batch,
584 list,
585 &true,
586 vec![Some(false), Some(true), None],
587 Arc::clone(&col_a),
588 &schema
589 );
590
591 let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))];
593 in_list!(
594 batch,
595 list,
596 &false,
597 vec![Some(true), None, None],
598 Arc::clone(&col_a),
599 &schema
600 );
601
602 let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))];
604 in_list!(
605 batch,
606 list,
607 &true,
608 vec![Some(false), None, None],
609 Arc::clone(&col_a),
610 &schema
611 );
612
613 Ok(())
614 }
615
616 #[test]
617 fn in_list_binary() -> Result<()> {
618 let schema = Schema::new(vec![Field::new("a", DataType::Binary, true)]);
619 let a = BinaryArray::from(vec![
620 Some([1, 2, 3].as_slice()),
621 Some([1, 2, 2].as_slice()),
622 None,
623 ]);
624 let col_a = col("a", &schema)?;
625 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
626
627 let list = vec![lit([1, 2, 3].as_slice()), lit([4, 5, 6].as_slice())];
629 in_list!(
630 batch,
631 list.clone(),
632 &false,
633 vec![Some(true), Some(false), None],
634 Arc::clone(&col_a),
635 &schema
636 );
637
638 in_list!(
640 batch,
641 list,
642 &true,
643 vec![Some(false), Some(true), None],
644 Arc::clone(&col_a),
645 &schema
646 );
647
648 let list = vec![
650 lit([1, 2, 3].as_slice()),
651 lit([4, 5, 6].as_slice()),
652 lit(ScalarValue::Binary(None)),
653 ];
654 in_list!(
655 batch,
656 list.clone(),
657 &false,
658 vec![Some(true), None, None],
659 Arc::clone(&col_a),
660 &schema
661 );
662
663 in_list!(
665 batch,
666 list,
667 &true,
668 vec![Some(false), None, None],
669 Arc::clone(&col_a),
670 &schema
671 );
672
673 Ok(())
674 }
675
676 #[test]
677 fn in_list_int64() -> Result<()> {
678 let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
679 let a = Int64Array::from(vec![Some(0), Some(2), None]);
680 let col_a = col("a", &schema)?;
681 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
682
683 let list = vec![lit(0i64), lit(1i64)];
685 in_list!(
686 batch,
687 list,
688 &false,
689 vec![Some(true), Some(false), None],
690 Arc::clone(&col_a),
691 &schema
692 );
693
694 let list = vec![lit(0i64), lit(1i64)];
696 in_list!(
697 batch,
698 list,
699 &true,
700 vec![Some(false), Some(true), None],
701 Arc::clone(&col_a),
702 &schema
703 );
704
705 let list = vec![lit(0i64), lit(1i64), lit(ScalarValue::Null)];
707 in_list!(
708 batch,
709 list,
710 &false,
711 vec![Some(true), None, None],
712 Arc::clone(&col_a),
713 &schema
714 );
715
716 let list = vec![lit(0i64), lit(1i64), lit(ScalarValue::Null)];
718 in_list!(
719 batch,
720 list,
721 &true,
722 vec![Some(false), None, None],
723 Arc::clone(&col_a),
724 &schema
725 );
726
727 Ok(())
728 }
729
730 #[test]
731 fn in_list_float64() -> Result<()> {
732 let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
733 let a = Float64Array::from(vec![
734 Some(0.0),
735 Some(0.2),
736 None,
737 Some(f64::NAN),
738 Some(-f64::NAN),
739 ]);
740 let col_a = col("a", &schema)?;
741 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
742
743 let list = vec![lit(0.0f64), lit(0.1f64)];
745 in_list!(
746 batch,
747 list,
748 &false,
749 vec![Some(true), Some(false), None, Some(false), Some(false)],
750 Arc::clone(&col_a),
751 &schema
752 );
753
754 let list = vec![lit(0.0f64), lit(0.1f64)];
756 in_list!(
757 batch,
758 list,
759 &true,
760 vec![Some(false), Some(true), None, Some(true), Some(true)],
761 Arc::clone(&col_a),
762 &schema
763 );
764
765 let list = vec![lit(0.0f64), lit(0.1f64), lit(ScalarValue::Null)];
767 in_list!(
768 batch,
769 list,
770 &false,
771 vec![Some(true), None, None, None, None],
772 Arc::clone(&col_a),
773 &schema
774 );
775
776 let list = vec![lit(0.0f64), lit(0.1f64), lit(ScalarValue::Null)];
778 in_list!(
779 batch,
780 list,
781 &true,
782 vec![Some(false), None, None, None, None],
783 Arc::clone(&col_a),
784 &schema
785 );
786
787 let list = vec![lit(0.0f64), lit(0.1f64), lit(f64::NAN)];
789 in_list!(
790 batch,
791 list,
792 &false,
793 vec![Some(true), Some(false), None, Some(true), Some(false)],
794 Arc::clone(&col_a),
795 &schema
796 );
797
798 let list = vec![lit(0.0f64), lit(0.1f64), lit(f64::NAN)];
800 in_list!(
801 batch,
802 list,
803 &true,
804 vec![Some(false), Some(true), None, Some(false), Some(true)],
805 Arc::clone(&col_a),
806 &schema
807 );
808
809 let list = vec![lit(0.0f64), lit(0.1f64), lit(-f64::NAN)];
811 in_list!(
812 batch,
813 list,
814 &false,
815 vec![Some(true), Some(false), None, Some(false), Some(true)],
816 Arc::clone(&col_a),
817 &schema
818 );
819
820 let list = vec![lit(0.0f64), lit(0.1f64), lit(-f64::NAN)];
822 in_list!(
823 batch,
824 list,
825 &true,
826 vec![Some(false), Some(true), None, Some(true), Some(false)],
827 Arc::clone(&col_a),
828 &schema
829 );
830
831 Ok(())
832 }
833
834 #[test]
835 fn in_list_bool() -> Result<()> {
836 let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
837 let a = BooleanArray::from(vec![Some(true), None]);
838 let col_a = col("a", &schema)?;
839 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
840
841 let list = vec![lit(true)];
843 in_list!(
844 batch,
845 list,
846 &false,
847 vec![Some(true), None],
848 Arc::clone(&col_a),
849 &schema
850 );
851
852 let list = vec![lit(true)];
854 in_list!(
855 batch,
856 list,
857 &true,
858 vec![Some(false), None],
859 Arc::clone(&col_a),
860 &schema
861 );
862
863 let list = vec![lit(true), lit(ScalarValue::Null)];
865 in_list!(
866 batch,
867 list,
868 &false,
869 vec![Some(true), None],
870 Arc::clone(&col_a),
871 &schema
872 );
873
874 let list = vec![lit(true), lit(ScalarValue::Null)];
876 in_list!(
877 batch,
878 list,
879 &true,
880 vec![Some(false), None],
881 Arc::clone(&col_a),
882 &schema
883 );
884
885 Ok(())
886 }
887
888 #[test]
889 fn in_list_date64() -> Result<()> {
890 let schema = Schema::new(vec![Field::new("a", DataType::Date64, true)]);
891 let a = Date64Array::from(vec![Some(0), Some(2), None]);
892 let col_a = col("a", &schema)?;
893 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
894
895 let list = vec![
897 lit(ScalarValue::Date64(Some(0))),
898 lit(ScalarValue::Date64(Some(1))),
899 ];
900 in_list!(
901 batch,
902 list,
903 &false,
904 vec![Some(true), Some(false), None],
905 Arc::clone(&col_a),
906 &schema
907 );
908
909 let list = vec![
911 lit(ScalarValue::Date64(Some(0))),
912 lit(ScalarValue::Date64(Some(1))),
913 ];
914 in_list!(
915 batch,
916 list,
917 &true,
918 vec![Some(false), Some(true), None],
919 Arc::clone(&col_a),
920 &schema
921 );
922
923 let list = vec![
925 lit(ScalarValue::Date64(Some(0))),
926 lit(ScalarValue::Date64(Some(1))),
927 lit(ScalarValue::Null),
928 ];
929 in_list!(
930 batch,
931 list,
932 &false,
933 vec![Some(true), None, None],
934 Arc::clone(&col_a),
935 &schema
936 );
937
938 let list = vec![
940 lit(ScalarValue::Date64(Some(0))),
941 lit(ScalarValue::Date64(Some(1))),
942 lit(ScalarValue::Null),
943 ];
944 in_list!(
945 batch,
946 list,
947 &true,
948 vec![Some(false), None, None],
949 Arc::clone(&col_a),
950 &schema
951 );
952
953 Ok(())
954 }
955
956 #[test]
957 fn in_list_date32() -> Result<()> {
958 let schema = Schema::new(vec![Field::new("a", DataType::Date32, true)]);
959 let a = Date32Array::from(vec![Some(0), Some(2), None]);
960 let col_a = col("a", &schema)?;
961 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
962
963 let list = vec![
965 lit(ScalarValue::Date32(Some(0))),
966 lit(ScalarValue::Date32(Some(1))),
967 ];
968 in_list!(
969 batch,
970 list,
971 &false,
972 vec![Some(true), Some(false), None],
973 Arc::clone(&col_a),
974 &schema
975 );
976
977 let list = vec![
979 lit(ScalarValue::Date32(Some(0))),
980 lit(ScalarValue::Date32(Some(1))),
981 ];
982 in_list!(
983 batch,
984 list,
985 &true,
986 vec![Some(false), Some(true), None],
987 Arc::clone(&col_a),
988 &schema
989 );
990
991 let list = vec![
993 lit(ScalarValue::Date32(Some(0))),
994 lit(ScalarValue::Date32(Some(1))),
995 lit(ScalarValue::Null),
996 ];
997 in_list!(
998 batch,
999 list,
1000 &false,
1001 vec![Some(true), None, None],
1002 Arc::clone(&col_a),
1003 &schema
1004 );
1005
1006 let list = vec![
1008 lit(ScalarValue::Date32(Some(0))),
1009 lit(ScalarValue::Date32(Some(1))),
1010 lit(ScalarValue::Null),
1011 ];
1012 in_list!(
1013 batch,
1014 list,
1015 &true,
1016 vec![Some(false), None, None],
1017 Arc::clone(&col_a),
1018 &schema
1019 );
1020
1021 Ok(())
1022 }
1023
1024 #[test]
1025 fn in_list_decimal() -> Result<()> {
1026 let schema =
1028 Schema::new(vec![Field::new("a", DataType::Decimal128(13, 4), true)]);
1029 let array = vec![Some(100_0000_i128), None, Some(200_5000_i128)]
1030 .into_iter()
1031 .collect::<Decimal128Array>();
1032 let array = array.with_precision_and_scale(13, 4).unwrap();
1033 let col_a = col("a", &schema)?;
1034 let batch =
1035 RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)])?;
1036
1037 let list = vec![lit(100i32), lit(200i32)];
1039 in_list!(
1040 batch,
1041 list,
1042 &false,
1043 vec![Some(true), None, Some(false)],
1044 Arc::clone(&col_a),
1045 &schema
1046 );
1047 let list = vec![lit(100i32), lit(200i32)];
1049 in_list!(
1050 batch,
1051 list,
1052 &true,
1053 vec![Some(false), None, Some(true)],
1054 Arc::clone(&col_a),
1055 &schema
1056 );
1057
1058 let list = vec![lit(ScalarValue::Int32(Some(100))), lit(ScalarValue::Null)];
1060 in_list!(
1061 batch,
1062 list.clone(),
1063 &false,
1064 vec![Some(true), None, None],
1065 Arc::clone(&col_a),
1066 &schema
1067 );
1068 in_list!(
1070 batch,
1071 list,
1072 &true,
1073 vec![Some(false), None, None],
1074 Arc::clone(&col_a),
1075 &schema
1076 );
1077
1078 let list = vec![lit(200.50f32), lit(100i32)];
1080 in_list!(
1081 batch,
1082 list,
1083 &false,
1084 vec![Some(true), None, Some(true)],
1085 Arc::clone(&col_a),
1086 &schema
1087 );
1088
1089 let list = vec![lit(200.50f32), lit(101i32)];
1091 in_list!(
1092 batch,
1093 list,
1094 &true,
1095 vec![Some(true), None, Some(false)],
1096 Arc::clone(&col_a),
1097 &schema
1098 );
1099
1100 let list = (99i32..300).map(lit).collect::<Vec<_>>();
1103
1104 in_list!(
1105 batch,
1106 list.clone(),
1107 &false,
1108 vec![Some(true), None, Some(false)],
1109 Arc::clone(&col_a),
1110 &schema
1111 );
1112
1113 in_list!(
1114 batch,
1115 list,
1116 &true,
1117 vec![Some(false), None, Some(true)],
1118 Arc::clone(&col_a),
1119 &schema
1120 );
1121
1122 Ok(())
1123 }
1124
1125 #[test]
1126 fn test_cast_static_filter_to_set() -> Result<()> {
1127 let schema =
1129 Schema::new(vec![Field::new("a", DataType::Decimal128(13, 4), true)]);
1130
1131 let mut phy_exprs = vec![
1133 lit(1i64),
1134 expressions::cast(lit(2i32), &schema, DataType::Int64)?,
1135 try_cast(lit(3.13f32), &schema, DataType::Int64)?,
1136 ];
1137 let result = try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();
1138
1139 let array = Int64Array::from(vec![1, 2, 3, 4]);
1140 let r = result.contains(&array, false).unwrap();
1141 assert_eq!(r, BooleanArray::from(vec![true, true, true, false]));
1142
1143 try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();
1144 phy_exprs.push(expressions::cast(
1146 expressions::cast(lit(2i32), &schema, DataType::Int64)?,
1147 &schema,
1148 DataType::Int64,
1149 )?);
1150 try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();
1151
1152 phy_exprs.clear();
1153
1154 phy_exprs.push(expressions::cast(
1156 expressions::cast(lit(2i32), &schema, DataType::Int64)?,
1157 &schema,
1158 DataType::Int32,
1159 )?);
1160 try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();
1161
1162 phy_exprs.push(col("a", &schema)?);
1164 assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_err());
1165
1166 Ok(())
1167 }
1168
1169 #[test]
1170 fn in_list_timestamp() -> Result<()> {
1171 let schema = Schema::new(vec![Field::new(
1172 "a",
1173 DataType::Timestamp(TimeUnit::Microsecond, None),
1174 true,
1175 )]);
1176 let a = TimestampMicrosecondArray::from(vec![
1177 Some(1388588401000000000),
1178 Some(1288588501000000000),
1179 None,
1180 ]);
1181 let col_a = col("a", &schema)?;
1182 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1183
1184 let list = vec![
1185 lit(ScalarValue::TimestampMicrosecond(
1186 Some(1388588401000000000),
1187 None,
1188 )),
1189 lit(ScalarValue::TimestampMicrosecond(
1190 Some(1388588401000000001),
1191 None,
1192 )),
1193 lit(ScalarValue::TimestampMicrosecond(
1194 Some(1388588401000000002),
1195 None,
1196 )),
1197 ];
1198
1199 in_list!(
1200 batch,
1201 list.clone(),
1202 &false,
1203 vec![Some(true), Some(false), None],
1204 Arc::clone(&col_a),
1205 &schema
1206 );
1207
1208 in_list!(
1209 batch,
1210 list.clone(),
1211 &true,
1212 vec![Some(false), Some(true), None],
1213 Arc::clone(&col_a),
1214 &schema
1215 );
1216 Ok(())
1217 }
1218
1219 #[test]
1220 fn in_expr_with_multiple_element_in_list() -> Result<()> {
1221 let schema = Schema::new(vec![
1222 Field::new("a", DataType::Float64, true),
1223 Field::new("b", DataType::Float64, true),
1224 Field::new("c", DataType::Float64, true),
1225 ]);
1226 let a = Float64Array::from(vec![
1227 Some(0.0),
1228 Some(1.0),
1229 Some(2.0),
1230 Some(f64::NAN),
1231 Some(-f64::NAN),
1232 ]);
1233 let b = Float64Array::from(vec![
1234 Some(8.0),
1235 Some(1.0),
1236 Some(5.0),
1237 Some(f64::NAN),
1238 Some(3.0),
1239 ]);
1240 let c = Float64Array::from(vec![
1241 Some(6.0),
1242 Some(7.0),
1243 None,
1244 Some(5.0),
1245 Some(-f64::NAN),
1246 ]);
1247 let col_a = col("a", &schema)?;
1248 let col_b = col("b", &schema)?;
1249 let col_c = col("c", &schema)?;
1250 let batch = RecordBatch::try_new(
1251 Arc::new(schema.clone()),
1252 vec![Arc::new(a), Arc::new(b), Arc::new(c)],
1253 )?;
1254
1255 let list = vec![Arc::clone(&col_b), Arc::clone(&col_c)];
1256 in_list!(
1257 batch,
1258 list.clone(),
1259 &false,
1260 vec![Some(false), Some(true), None, Some(true), Some(true)],
1261 Arc::clone(&col_a),
1262 &schema
1263 );
1264
1265 in_list!(
1266 batch,
1267 list,
1268 &true,
1269 vec![Some(true), Some(false), None, Some(false), Some(false)],
1270 Arc::clone(&col_a),
1271 &schema
1272 );
1273
1274 Ok(())
1275 }
1276
1277 macro_rules! test_nullable {
1278 ($COL:expr, $LIST:expr, $SCHEMA:expr, $EXPECTED:expr) => {{
1279 let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, $SCHEMA)?;
1280 let expr = in_list(cast_expr, cast_list_exprs, &false, $SCHEMA).unwrap();
1281 let result = expr.nullable($SCHEMA)?;
1282 assert_eq!($EXPECTED, result);
1283 }};
1284 }
1285
1286 #[test]
1287 fn in_list_nullable() -> Result<()> {
1288 let schema = Schema::new(vec![
1289 Field::new("c1_nullable", DataType::Int64, true),
1290 Field::new("c2_non_nullable", DataType::Int64, false),
1291 ]);
1292
1293 let c1_nullable = col("c1_nullable", &schema)?;
1294 let c2_non_nullable = col("c2_non_nullable", &schema)?;
1295
1296 let list = vec![lit(1_i64), lit(2_i64)];
1298 test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true);
1299 test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, false);
1300
1301 let list = vec![lit(1_i64), lit(2_i64), lit(ScalarValue::Null)];
1303 test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true);
1304 test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, true);
1305
1306 let list = vec![Arc::clone(&c1_nullable)];
1307 test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, true);
1308
1309 let list = vec![Arc::clone(&c2_non_nullable)];
1310 test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true);
1311
1312 let list = vec![Arc::clone(&c2_non_nullable), Arc::clone(&c2_non_nullable)];
1313 test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, false);
1314
1315 Ok(())
1316 }
1317
1318 #[test]
1319 fn in_list_no_cols() -> Result<()> {
1320 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1322 let a = Int32Array::from(vec![Some(1), Some(2), None]);
1323 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1324
1325 let list = vec![lit(ScalarValue::from(1i32)), lit(ScalarValue::from(6i32))];
1326
1327 let expr = lit(ScalarValue::Int32(Some(1)));
1329 in_list!(
1330 batch,
1331 list.clone(),
1332 &false,
1333 vec![Some(true), Some(true), Some(true)],
1335 expr,
1336 &schema
1337 );
1338
1339 let expr = lit(ScalarValue::Int32(Some(2)));
1341 in_list!(
1342 batch,
1343 list.clone(),
1344 &false,
1345 vec![Some(false), Some(false), Some(false)],
1347 expr,
1348 &schema
1349 );
1350
1351 let expr = lit(ScalarValue::Int32(None));
1353 in_list!(
1354 batch,
1355 list.clone(),
1356 &false,
1357 vec![None, None, None],
1359 expr,
1360 &schema
1361 );
1362
1363 Ok(())
1364 }
1365
1366 #[test]
1367 fn in_list_utf8_with_dict_types() -> Result<()> {
1368 fn dict_lit(key_type: DataType, value: &str) -> Arc<dyn PhysicalExpr> {
1369 lit(ScalarValue::Dictionary(
1370 Box::new(key_type),
1371 Box::new(ScalarValue::new_utf8(value.to_string())),
1372 ))
1373 }
1374
1375 fn null_dict_lit(key_type: DataType) -> Arc<dyn PhysicalExpr> {
1376 lit(ScalarValue::Dictionary(
1377 Box::new(key_type),
1378 Box::new(ScalarValue::Utf8(None)),
1379 ))
1380 }
1381
1382 let schema = Schema::new(vec![Field::new(
1383 "a",
1384 DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
1385 true,
1386 )]);
1387 let a: UInt16DictionaryArray =
1388 vec![Some("a"), Some("d"), None].into_iter().collect();
1389 let col_a = col("a", &schema)?;
1390 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1391
1392 let lists = [
1394 vec![lit("a"), lit("b")],
1395 vec![
1396 dict_lit(DataType::Int8, "a"),
1397 dict_lit(DataType::UInt16, "b"),
1398 ],
1399 ];
1400 for list in lists.iter() {
1401 in_list_raw!(
1402 batch,
1403 list.clone(),
1404 &false,
1405 vec![Some(true), Some(false), None],
1406 Arc::clone(&col_a),
1407 &schema
1408 );
1409 }
1410
1411 for list in lists.iter() {
1413 in_list_raw!(
1414 batch,
1415 list.clone(),
1416 &true,
1417 vec![Some(false), Some(true), None],
1418 Arc::clone(&col_a),
1419 &schema
1420 );
1421 }
1422
1423 let lists = [
1425 vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))],
1426 vec![
1427 dict_lit(DataType::Int8, "a"),
1428 dict_lit(DataType::UInt16, "b"),
1429 null_dict_lit(DataType::UInt16),
1430 ],
1431 ];
1432 for list in lists.iter() {
1433 in_list_raw!(
1434 batch,
1435 list.clone(),
1436 &false,
1437 vec![Some(true), None, None],
1438 Arc::clone(&col_a),
1439 &schema
1440 );
1441 }
1442
1443 for list in lists.iter() {
1445 in_list_raw!(
1446 batch,
1447 list.clone(),
1448 &true,
1449 vec![Some(false), None, None],
1450 Arc::clone(&col_a),
1451 &schema
1452 );
1453 }
1454
1455 Ok(())
1456 }
1457
1458 #[test]
1459 fn test_fmt_sql_1() -> Result<()> {
1460 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
1461 let col_a = col("a", &schema)?;
1462
1463 let list = vec![lit("a"), lit("b")];
1465 let expr = in_list(Arc::clone(&col_a), list, &false, &schema)?;
1466 let sql_string = fmt_sql(expr.as_ref()).to_string();
1467 let display_string = expr.to_string();
1468 assert_snapshot!(sql_string, @"a IN (a, b)");
1469 assert_snapshot!(display_string, @"a@0 IN (SET) ([a, b])");
1470 Ok(())
1471 }
1472
1473 #[test]
1474 fn test_fmt_sql_2() -> Result<()> {
1475 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
1476 let col_a = col("a", &schema)?;
1477
1478 let list = vec![lit("a"), lit("b")];
1480 let expr = in_list(Arc::clone(&col_a), list, &true, &schema)?;
1481 let sql_string = fmt_sql(expr.as_ref()).to_string();
1482 let display_string = expr.to_string();
1483
1484 assert_snapshot!(sql_string, @"a NOT IN (a, b)");
1485 assert_snapshot!(display_string, @"a@0 NOT IN (SET) ([a, b])");
1486 Ok(())
1487 }
1488
1489 #[test]
1490 fn test_fmt_sql_3() -> Result<()> {
1491 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
1492 let col_a = col("a", &schema)?;
1493 let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))];
1495 let expr = in_list(Arc::clone(&col_a), list, &false, &schema)?;
1496 let sql_string = fmt_sql(expr.as_ref()).to_string();
1497 let display_string = expr.to_string();
1498
1499 assert_snapshot!(sql_string, @"a IN (a, b, NULL)");
1500 assert_snapshot!(display_string, @"a@0 IN (SET) ([a, b, NULL])");
1501 Ok(())
1502 }
1503
1504 #[test]
1505 fn test_fmt_sql_4() -> Result<()> {
1506 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
1507 let col_a = col("a", &schema)?;
1508 let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))];
1510 let expr = in_list(Arc::clone(&col_a), list, &true, &schema)?;
1511 let sql_string = fmt_sql(expr.as_ref()).to_string();
1512 let display_string = expr.to_string();
1513 assert_snapshot!(sql_string, @"a NOT IN (a, b, NULL)");
1514 assert_snapshot!(display_string, @"a@0 NOT IN (SET) ([a, b, NULL])");
1515 Ok(())
1516 }
1517}