1use std::any::Any;
21use std::fmt::Debug;
22use std::hash::{Hash, Hasher};
23use std::sync::Arc;
24
25use crate::physical_expr::physical_exprs_bag_equal;
26use crate::PhysicalExpr;
27
28use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano};
29use arrow::array::*;
30use arrow::buffer::BooleanBuffer;
31use arrow::compute::kernels::boolean::{not, or_kleene};
32use arrow::compute::take;
33use arrow::datatypes::*;
34use arrow::util::bit_iterator::BitIndexIterator;
35use arrow::{downcast_dictionary_array, downcast_primitive_array};
36use datafusion_common::cast::{
37 as_boolean_array, as_generic_binary_array, as_string_array,
38};
39use datafusion_common::hash_utils::HashValue;
40use datafusion_common::{
41 exec_err, internal_err, not_impl_err, DFSchema, Result, ScalarValue,
42};
43use datafusion_expr::ColumnarValue;
44use datafusion_physical_expr_common::datum::compare_with_eq;
45
46use ahash::RandomState;
47use datafusion_common::HashMap;
48use hashbrown::hash_map::RawEntryMut;
49
50pub struct InListExpr {
52 expr: Arc<dyn PhysicalExpr>,
53 list: Vec<Arc<dyn PhysicalExpr>>,
54 negated: bool,
55 static_filter: Option<Arc<dyn Set>>,
56}
57
58impl Debug for InListExpr {
59 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
60 f.debug_struct("InListExpr")
61 .field("expr", &self.expr)
62 .field("list", &self.list)
63 .field("negated", &self.negated)
64 .finish()
65 }
66}
67
68pub trait Set: Send + Sync {
70 fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray>;
71 fn has_nulls(&self) -> bool;
72}
73
74struct ArrayHashSet {
75 state: RandomState,
76 map: HashMap<usize, (), ()>,
81}
82
83struct ArraySet<T> {
84 array: T,
85 hash_set: ArrayHashSet,
86}
87
88impl<T> ArraySet<T>
89where
90 T: Array + From<ArrayData>,
91{
92 fn new(array: &T, hash_set: ArrayHashSet) -> Self {
93 Self {
94 array: downcast_array(array),
95 hash_set,
96 }
97 }
98}
99
100impl<T> Set for ArraySet<T>
101where
102 T: Array + 'static,
103 for<'a> &'a T: ArrayAccessor,
104 for<'a> <&'a T as ArrayAccessor>::Item: IsEqual,
105{
106 fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
107 downcast_dictionary_array! {
108 v => {
109 let values_contains = self.contains(v.values().as_ref(), negated)?;
110 let result = take(&values_contains, v.keys(), None)?;
111 return Ok(downcast_array(result.as_ref()))
112 }
113 _ => {}
114 }
115
116 let v = v.as_any().downcast_ref::<T>().unwrap();
117 let in_array = &self.array;
118 let has_nulls = in_array.null_count() != 0;
119
120 Ok(ArrayIter::new(v)
121 .map(|v| {
122 v.and_then(|v| {
123 let hash = v.hash_one(&self.hash_set.state);
124 let contains = self
125 .hash_set
126 .map
127 .raw_entry()
128 .from_hash(hash, |idx| in_array.value(*idx).is_equal(&v))
129 .is_some();
130
131 match contains {
132 true => Some(!negated),
133 false if has_nulls => None,
134 false => Some(negated),
135 }
136 })
137 })
138 .collect())
139 }
140
141 fn has_nulls(&self) -> bool {
142 self.array.null_count() != 0
143 }
144}
145
146fn make_hash_set<T>(array: T) -> ArrayHashSet
153where
154 T: ArrayAccessor,
155 T::Item: IsEqual,
156{
157 let state = RandomState::new();
158 let mut map: HashMap<usize, (), ()> =
159 HashMap::with_capacity_and_hasher(array.len(), ());
160
161 let insert_value = |idx| {
162 let value = array.value(idx);
163 let hash = value.hash_one(&state);
164 if let RawEntryMut::Vacant(v) = map
165 .raw_entry_mut()
166 .from_hash(hash, |x| array.value(*x).is_equal(&value))
167 {
168 v.insert_with_hasher(hash, idx, (), |x| array.value(*x).hash_one(&state));
169 }
170 };
171
172 match array.nulls() {
173 Some(nulls) => {
174 BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len())
175 .for_each(insert_value)
176 }
177 None => (0..array.len()).for_each(insert_value),
178 }
179
180 ArrayHashSet { state, map }
181}
182
183fn make_set(array: &dyn Array) -> Result<Arc<dyn Set>> {
185 Ok(downcast_primitive_array! {
186 array => Arc::new(ArraySet::new(array, make_hash_set(array))),
187 DataType::Boolean => {
188 let array = as_boolean_array(array)?;
189 Arc::new(ArraySet::new(array, make_hash_set(array)))
190 },
191 DataType::Utf8 => {
192 let array = as_string_array(array)?;
193 Arc::new(ArraySet::new(array, make_hash_set(array)))
194 }
195 DataType::LargeUtf8 => {
196 let array = as_largestring_array(array);
197 Arc::new(ArraySet::new(array, make_hash_set(array)))
198 }
199 DataType::Binary => {
200 let array = as_generic_binary_array::<i32>(array)?;
201 Arc::new(ArraySet::new(array, make_hash_set(array)))
202 }
203 DataType::LargeBinary => {
204 let array = as_generic_binary_array::<i64>(array)?;
205 Arc::new(ArraySet::new(array, make_hash_set(array)))
206 }
207 DataType::Dictionary(_, _) => unreachable!("dictionary should have been flattened"),
208 d => return not_impl_err!("DataType::{d} not supported in InList")
209 })
210}
211
212fn evaluate_list(
214 list: &[Arc<dyn PhysicalExpr>],
215 batch: &RecordBatch,
216) -> Result<ArrayRef> {
217 let scalars = list
218 .iter()
219 .map(|expr| {
220 expr.evaluate(batch).and_then(|r| match r {
221 ColumnarValue::Array(_) => {
222 exec_err!("InList expression must evaluate to a scalar")
223 }
224 ColumnarValue::Scalar(ScalarValue::Dictionary(_, v)) => Ok(*v),
226 ColumnarValue::Scalar(s) => Ok(s),
227 })
228 })
229 .collect::<Result<Vec<_>>>()?;
230
231 ScalarValue::iter_to_array(scalars)
232}
233
234fn try_cast_static_filter_to_set(
235 list: &[Arc<dyn PhysicalExpr>],
236 schema: &Schema,
237) -> Result<Arc<dyn Set>> {
238 let batch = RecordBatch::new_empty(Arc::new(schema.clone()));
239 make_set(evaluate_list(list, &batch)?.as_ref())
240}
241
242trait IsEqual: HashValue {
244 fn is_equal(&self, other: &Self) -> bool;
245}
246
247impl<T: IsEqual + ?Sized> IsEqual for &T {
248 fn is_equal(&self, other: &Self) -> bool {
249 T::is_equal(self, other)
250 }
251}
252
253macro_rules! is_equal {
254 ($($t:ty),+) => {
255 $(impl IsEqual for $t {
256 fn is_equal(&self, other: &Self) -> bool {
257 self == other
258 }
259 })*
260 };
261}
262is_equal!(i8, i16, i32, i64, i128, i256, u8, u16, u32, u64);
263is_equal!(bool, str, [u8]);
264is_equal!(IntervalDayTime, IntervalMonthDayNano);
265
266macro_rules! is_equal_float {
267 ($($t:ty),+) => {
268 $(impl IsEqual for $t {
269 fn is_equal(&self, other: &Self) -> bool {
270 self.to_bits() == other.to_bits()
271 }
272 })*
273 };
274}
275is_equal_float!(half::f16, f32, f64);
276
277impl InListExpr {
278 pub fn new(
280 expr: Arc<dyn PhysicalExpr>,
281 list: Vec<Arc<dyn PhysicalExpr>>,
282 negated: bool,
283 static_filter: Option<Arc<dyn Set>>,
284 ) -> Self {
285 Self {
286 expr,
287 list,
288 negated,
289 static_filter,
290 }
291 }
292
293 pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
295 &self.expr
296 }
297
298 pub fn list(&self) -> &[Arc<dyn PhysicalExpr>] {
300 &self.list
301 }
302
303 pub fn negated(&self) -> bool {
305 self.negated
306 }
307}
308
309impl std::fmt::Display for InListExpr {
310 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
311 if self.negated {
312 if self.static_filter.is_some() {
313 write!(f, "{} NOT IN (SET) ({:?})", self.expr, self.list)
314 } else {
315 write!(f, "{} NOT IN ({:?})", self.expr, self.list)
316 }
317 } else if self.static_filter.is_some() {
318 write!(f, "Use {} IN (SET) ({:?})", self.expr, self.list)
319 } else {
320 write!(f, "{} IN ({:?})", self.expr, self.list)
321 }
322 }
323}
324
325impl PhysicalExpr for InListExpr {
326 fn as_any(&self) -> &dyn Any {
328 self
329 }
330
331 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
332 Ok(DataType::Boolean)
333 }
334
335 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
336 if self.expr.nullable(input_schema)? {
337 return Ok(true);
338 }
339
340 if let Some(static_filter) = &self.static_filter {
341 Ok(static_filter.has_nulls())
342 } else {
343 for expr in &self.list {
344 if expr.nullable(input_schema)? {
345 return Ok(true);
346 }
347 }
348 Ok(false)
349 }
350 }
351
352 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
353 let num_rows = batch.num_rows();
354 let value = self.expr.evaluate(batch)?;
355 let r = match &self.static_filter {
356 Some(f) => f.contains(value.into_array(num_rows)?.as_ref(), self.negated)?,
357 None => {
358 let value = value.into_array(num_rows)?;
359 let is_nested = value.data_type().is_nested();
360 let found = self.list.iter().map(|expr| expr.evaluate(batch)).try_fold(
361 BooleanArray::new(BooleanBuffer::new_unset(num_rows), None),
362 |result, expr| -> Result<BooleanArray> {
363 let rhs = compare_with_eq(
364 &value,
365 &expr?.into_array(num_rows)?,
366 is_nested,
367 )?;
368 Ok(or_kleene(&result, &rhs)?)
369 },
370 )?;
371
372 if self.negated {
373 not(&found)?
374 } else {
375 found
376 }
377 }
378 };
379 Ok(ColumnarValue::Array(Arc::new(r)))
380 }
381
382 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
383 let mut children = vec![];
384 children.push(&self.expr);
385 children.extend(&self.list);
386 children
387 }
388
389 fn with_new_children(
390 self: Arc<Self>,
391 children: Vec<Arc<dyn PhysicalExpr>>,
392 ) -> Result<Arc<dyn PhysicalExpr>> {
393 Ok(Arc::new(InListExpr::new(
395 Arc::clone(&children[0]),
396 children[1..].to_vec(),
397 self.negated,
398 self.static_filter.clone(),
399 )))
400 }
401
402 fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403 self.expr.fmt_sql(f)?;
404 if self.negated {
405 write!(f, " NOT")?;
406 }
407
408 write!(f, " IN (")?;
409 for (i, expr) in self.list.iter().enumerate() {
410 if i > 0 {
411 write!(f, ", ")?;
412 }
413 expr.fmt_sql(f)?;
414 }
415 write!(f, ")")
416 }
417}
418
419impl PartialEq for InListExpr {
420 fn eq(&self, other: &Self) -> bool {
421 self.expr.eq(&other.expr)
422 && physical_exprs_bag_equal(&self.list, &other.list)
423 && self.negated == other.negated
424 }
425}
426
427impl Eq for InListExpr {}
428
429impl Hash for InListExpr {
430 fn hash<H: Hasher>(&self, state: &mut H) {
431 self.expr.hash(state);
432 self.negated.hash(state);
433 self.list.hash(state);
434 }
436}
437
438pub fn in_list(
440 expr: Arc<dyn PhysicalExpr>,
441 list: Vec<Arc<dyn PhysicalExpr>>,
442 negated: &bool,
443 schema: &Schema,
444) -> Result<Arc<dyn PhysicalExpr>> {
445 let expr_data_type = expr.data_type(schema)?;
447 for list_expr in list.iter() {
448 let list_expr_data_type = list_expr.data_type(schema)?;
449 if !DFSchema::datatype_is_logically_equal(&expr_data_type, &list_expr_data_type) {
450 return internal_err!(
451 "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}"
452 );
453 }
454 }
455 let static_filter = try_cast_static_filter_to_set(&list, schema).ok();
456 Ok(Arc::new(InListExpr::new(
457 expr,
458 list,
459 *negated,
460 static_filter,
461 )))
462}
463
464#[cfg(test)]
465mod tests {
466
467 use super::*;
468 use crate::expressions;
469 use crate::expressions::{col, lit, try_cast};
470 use datafusion_common::plan_err;
471 use datafusion_expr::type_coercion::binary::comparison_coercion;
472 use datafusion_physical_expr_common::physical_expr::fmt_sql;
473
474 type InListCastResult = (Arc<dyn PhysicalExpr>, Vec<Arc<dyn PhysicalExpr>>);
475
476 fn in_list_cast(
479 expr: Arc<dyn PhysicalExpr>,
480 list: Vec<Arc<dyn PhysicalExpr>>,
481 input_schema: &Schema,
482 ) -> Result<InListCastResult> {
483 let expr_type = &expr.data_type(input_schema)?;
484 let list_types: Vec<DataType> = list
485 .iter()
486 .map(|list_expr| list_expr.data_type(input_schema).unwrap())
487 .collect();
488 let result_type = get_coerce_type(expr_type, &list_types);
489 match result_type {
490 None => plan_err!(
491 "Can not find compatible types to compare {expr_type:?} with {list_types:?}"
492 ),
493 Some(data_type) => {
494 let cast_expr = try_cast(expr, input_schema, data_type.clone())?;
496 let cast_list_expr = list
497 .into_iter()
498 .map(|list_expr| {
499 try_cast(list_expr, input_schema, data_type.clone()).unwrap()
500 })
501 .collect();
502 Ok((cast_expr, cast_list_expr))
503 }
504 }
505 }
506
507 fn get_coerce_type(expr_type: &DataType, list_type: &[DataType]) -> Option<DataType> {
510 list_type
511 .iter()
512 .try_fold(expr_type.clone(), |left_type, right_type| {
513 comparison_coercion(&left_type, right_type)
514 })
515 }
516
517 macro_rules! in_list {
519 ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{
520 let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, $SCHEMA)?;
521 in_list_raw!(
522 $BATCH,
523 cast_list_exprs,
524 $NEGATED,
525 $EXPECTED,
526 cast_expr,
527 $SCHEMA
528 );
529 }};
530 }
531
532 macro_rules! in_list_raw {
534 ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{
535 let expr = in_list($COL, $LIST, $NEGATED, $SCHEMA).unwrap();
536 let result = expr
537 .evaluate(&$BATCH)?
538 .into_array($BATCH.num_rows())
539 .expect("Failed to convert to array");
540 let result =
541 as_boolean_array(&result).expect("failed to downcast to BooleanArray");
542 let expected = &BooleanArray::from($EXPECTED);
543 assert_eq!(expected, result);
544 }};
545 }
546
547 #[test]
548 fn in_list_utf8() -> Result<()> {
549 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
550 let a = StringArray::from(vec![Some("a"), Some("d"), None]);
551 let col_a = col("a", &schema)?;
552 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
553
554 let list = vec![lit("a"), lit("b")];
556 in_list!(
557 batch,
558 list,
559 &false,
560 vec![Some(true), Some(false), None],
561 Arc::clone(&col_a),
562 &schema
563 );
564
565 let list = vec![lit("a"), lit("b")];
567 in_list!(
568 batch,
569 list,
570 &true,
571 vec![Some(false), Some(true), None],
572 Arc::clone(&col_a),
573 &schema
574 );
575
576 let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))];
578 in_list!(
579 batch,
580 list,
581 &false,
582 vec![Some(true), None, None],
583 Arc::clone(&col_a),
584 &schema
585 );
586
587 let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))];
589 in_list!(
590 batch,
591 list,
592 &true,
593 vec![Some(false), None, None],
594 Arc::clone(&col_a),
595 &schema
596 );
597
598 Ok(())
599 }
600
601 #[test]
602 fn in_list_binary() -> Result<()> {
603 let schema = Schema::new(vec![Field::new("a", DataType::Binary, true)]);
604 let a = BinaryArray::from(vec![
605 Some([1, 2, 3].as_slice()),
606 Some([1, 2, 2].as_slice()),
607 None,
608 ]);
609 let col_a = col("a", &schema)?;
610 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
611
612 let list = vec![lit([1, 2, 3].as_slice()), lit([4, 5, 6].as_slice())];
614 in_list!(
615 batch,
616 list.clone(),
617 &false,
618 vec![Some(true), Some(false), None],
619 Arc::clone(&col_a),
620 &schema
621 );
622
623 in_list!(
625 batch,
626 list,
627 &true,
628 vec![Some(false), Some(true), None],
629 Arc::clone(&col_a),
630 &schema
631 );
632
633 let list = vec![
635 lit([1, 2, 3].as_slice()),
636 lit([4, 5, 6].as_slice()),
637 lit(ScalarValue::Binary(None)),
638 ];
639 in_list!(
640 batch,
641 list.clone(),
642 &false,
643 vec![Some(true), None, None],
644 Arc::clone(&col_a),
645 &schema
646 );
647
648 in_list!(
650 batch,
651 list,
652 &true,
653 vec![Some(false), None, None],
654 Arc::clone(&col_a),
655 &schema
656 );
657
658 Ok(())
659 }
660
661 #[test]
662 fn in_list_int64() -> Result<()> {
663 let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
664 let a = Int64Array::from(vec![Some(0), Some(2), None]);
665 let col_a = col("a", &schema)?;
666 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
667
668 let list = vec![lit(0i64), lit(1i64)];
670 in_list!(
671 batch,
672 list,
673 &false,
674 vec![Some(true), Some(false), None],
675 Arc::clone(&col_a),
676 &schema
677 );
678
679 let list = vec![lit(0i64), lit(1i64)];
681 in_list!(
682 batch,
683 list,
684 &true,
685 vec![Some(false), Some(true), None],
686 Arc::clone(&col_a),
687 &schema
688 );
689
690 let list = vec![lit(0i64), lit(1i64), lit(ScalarValue::Null)];
692 in_list!(
693 batch,
694 list,
695 &false,
696 vec![Some(true), None, None],
697 Arc::clone(&col_a),
698 &schema
699 );
700
701 let list = vec![lit(0i64), lit(1i64), lit(ScalarValue::Null)];
703 in_list!(
704 batch,
705 list,
706 &true,
707 vec![Some(false), None, None],
708 Arc::clone(&col_a),
709 &schema
710 );
711
712 Ok(())
713 }
714
715 #[test]
716 fn in_list_float64() -> Result<()> {
717 let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
718 let a = Float64Array::from(vec![
719 Some(0.0),
720 Some(0.2),
721 None,
722 Some(f64::NAN),
723 Some(-f64::NAN),
724 ]);
725 let col_a = col("a", &schema)?;
726 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
727
728 let list = vec![lit(0.0f64), lit(0.1f64)];
730 in_list!(
731 batch,
732 list,
733 &false,
734 vec![Some(true), Some(false), None, Some(false), Some(false)],
735 Arc::clone(&col_a),
736 &schema
737 );
738
739 let list = vec![lit(0.0f64), lit(0.1f64)];
741 in_list!(
742 batch,
743 list,
744 &true,
745 vec![Some(false), Some(true), None, Some(true), Some(true)],
746 Arc::clone(&col_a),
747 &schema
748 );
749
750 let list = vec![lit(0.0f64), lit(0.1f64), lit(ScalarValue::Null)];
752 in_list!(
753 batch,
754 list,
755 &false,
756 vec![Some(true), None, None, None, None],
757 Arc::clone(&col_a),
758 &schema
759 );
760
761 let list = vec![lit(0.0f64), lit(0.1f64), lit(ScalarValue::Null)];
763 in_list!(
764 batch,
765 list,
766 &true,
767 vec![Some(false), None, None, None, None],
768 Arc::clone(&col_a),
769 &schema
770 );
771
772 let list = vec![lit(0.0f64), lit(0.1f64), lit(f64::NAN)];
774 in_list!(
775 batch,
776 list,
777 &false,
778 vec![Some(true), Some(false), None, Some(true), Some(false)],
779 Arc::clone(&col_a),
780 &schema
781 );
782
783 let list = vec![lit(0.0f64), lit(0.1f64), lit(f64::NAN)];
785 in_list!(
786 batch,
787 list,
788 &true,
789 vec![Some(false), Some(true), None, Some(false), Some(true)],
790 Arc::clone(&col_a),
791 &schema
792 );
793
794 let list = vec![lit(0.0f64), lit(0.1f64), lit(-f64::NAN)];
796 in_list!(
797 batch,
798 list,
799 &false,
800 vec![Some(true), Some(false), None, Some(false), Some(true)],
801 Arc::clone(&col_a),
802 &schema
803 );
804
805 let list = vec![lit(0.0f64), lit(0.1f64), lit(-f64::NAN)];
807 in_list!(
808 batch,
809 list,
810 &true,
811 vec![Some(false), Some(true), None, Some(true), Some(false)],
812 Arc::clone(&col_a),
813 &schema
814 );
815
816 Ok(())
817 }
818
819 #[test]
820 fn in_list_bool() -> Result<()> {
821 let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
822 let a = BooleanArray::from(vec![Some(true), None]);
823 let col_a = col("a", &schema)?;
824 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
825
826 let list = vec![lit(true)];
828 in_list!(
829 batch,
830 list,
831 &false,
832 vec![Some(true), None],
833 Arc::clone(&col_a),
834 &schema
835 );
836
837 let list = vec![lit(true)];
839 in_list!(
840 batch,
841 list,
842 &true,
843 vec![Some(false), None],
844 Arc::clone(&col_a),
845 &schema
846 );
847
848 let list = vec![lit(true), lit(ScalarValue::Null)];
850 in_list!(
851 batch,
852 list,
853 &false,
854 vec![Some(true), None],
855 Arc::clone(&col_a),
856 &schema
857 );
858
859 let list = vec![lit(true), lit(ScalarValue::Null)];
861 in_list!(
862 batch,
863 list,
864 &true,
865 vec![Some(false), None],
866 Arc::clone(&col_a),
867 &schema
868 );
869
870 Ok(())
871 }
872
873 #[test]
874 fn in_list_date64() -> Result<()> {
875 let schema = Schema::new(vec![Field::new("a", DataType::Date64, true)]);
876 let a = Date64Array::from(vec![Some(0), Some(2), None]);
877 let col_a = col("a", &schema)?;
878 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
879
880 let list = vec![
882 lit(ScalarValue::Date64(Some(0))),
883 lit(ScalarValue::Date64(Some(1))),
884 ];
885 in_list!(
886 batch,
887 list,
888 &false,
889 vec![Some(true), Some(false), None],
890 Arc::clone(&col_a),
891 &schema
892 );
893
894 let list = vec![
896 lit(ScalarValue::Date64(Some(0))),
897 lit(ScalarValue::Date64(Some(1))),
898 ];
899 in_list!(
900 batch,
901 list,
902 &true,
903 vec![Some(false), Some(true), None],
904 Arc::clone(&col_a),
905 &schema
906 );
907
908 let list = vec![
910 lit(ScalarValue::Date64(Some(0))),
911 lit(ScalarValue::Date64(Some(1))),
912 lit(ScalarValue::Null),
913 ];
914 in_list!(
915 batch,
916 list,
917 &false,
918 vec![Some(true), None, None],
919 Arc::clone(&col_a),
920 &schema
921 );
922
923 let list = vec![
925 lit(ScalarValue::Date64(Some(0))),
926 lit(ScalarValue::Date64(Some(1))),
927 lit(ScalarValue::Null),
928 ];
929 in_list!(
930 batch,
931 list,
932 &true,
933 vec![Some(false), None, None],
934 Arc::clone(&col_a),
935 &schema
936 );
937
938 Ok(())
939 }
940
941 #[test]
942 fn in_list_date32() -> Result<()> {
943 let schema = Schema::new(vec![Field::new("a", DataType::Date32, true)]);
944 let a = Date32Array::from(vec![Some(0), Some(2), None]);
945 let col_a = col("a", &schema)?;
946 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
947
948 let list = vec![
950 lit(ScalarValue::Date32(Some(0))),
951 lit(ScalarValue::Date32(Some(1))),
952 ];
953 in_list!(
954 batch,
955 list,
956 &false,
957 vec![Some(true), Some(false), None],
958 Arc::clone(&col_a),
959 &schema
960 );
961
962 let list = vec![
964 lit(ScalarValue::Date32(Some(0))),
965 lit(ScalarValue::Date32(Some(1))),
966 ];
967 in_list!(
968 batch,
969 list,
970 &true,
971 vec![Some(false), Some(true), None],
972 Arc::clone(&col_a),
973 &schema
974 );
975
976 let list = vec![
978 lit(ScalarValue::Date32(Some(0))),
979 lit(ScalarValue::Date32(Some(1))),
980 lit(ScalarValue::Null),
981 ];
982 in_list!(
983 batch,
984 list,
985 &false,
986 vec![Some(true), None, None],
987 Arc::clone(&col_a),
988 &schema
989 );
990
991 let list = vec![
993 lit(ScalarValue::Date32(Some(0))),
994 lit(ScalarValue::Date32(Some(1))),
995 lit(ScalarValue::Null),
996 ];
997 in_list!(
998 batch,
999 list,
1000 &true,
1001 vec![Some(false), None, None],
1002 Arc::clone(&col_a),
1003 &schema
1004 );
1005
1006 Ok(())
1007 }
1008
1009 #[test]
1010 fn in_list_decimal() -> Result<()> {
1011 let schema =
1013 Schema::new(vec![Field::new("a", DataType::Decimal128(13, 4), true)]);
1014 let array = vec![Some(100_0000_i128), None, Some(200_5000_i128)]
1015 .into_iter()
1016 .collect::<Decimal128Array>();
1017 let array = array.with_precision_and_scale(13, 4).unwrap();
1018 let col_a = col("a", &schema)?;
1019 let batch =
1020 RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)])?;
1021
1022 let list = vec![lit(100i32), lit(200i32)];
1024 in_list!(
1025 batch,
1026 list,
1027 &false,
1028 vec![Some(true), None, Some(false)],
1029 Arc::clone(&col_a),
1030 &schema
1031 );
1032 let list = vec![lit(100i32), lit(200i32)];
1034 in_list!(
1035 batch,
1036 list,
1037 &true,
1038 vec![Some(false), None, Some(true)],
1039 Arc::clone(&col_a),
1040 &schema
1041 );
1042
1043 let list = vec![lit(ScalarValue::Int32(Some(100))), lit(ScalarValue::Null)];
1045 in_list!(
1046 batch,
1047 list.clone(),
1048 &false,
1049 vec![Some(true), None, None],
1050 Arc::clone(&col_a),
1051 &schema
1052 );
1053 in_list!(
1055 batch,
1056 list,
1057 &true,
1058 vec![Some(false), None, None],
1059 Arc::clone(&col_a),
1060 &schema
1061 );
1062
1063 let list = vec![lit(200.50f32), lit(100i32)];
1065 in_list!(
1066 batch,
1067 list,
1068 &false,
1069 vec![Some(true), None, Some(true)],
1070 Arc::clone(&col_a),
1071 &schema
1072 );
1073
1074 let list = vec![lit(200.50f32), lit(101i32)];
1076 in_list!(
1077 batch,
1078 list,
1079 &true,
1080 vec![Some(true), None, Some(false)],
1081 Arc::clone(&col_a),
1082 &schema
1083 );
1084
1085 let list = (99i32..300).map(lit).collect::<Vec<_>>();
1088
1089 in_list!(
1090 batch,
1091 list.clone(),
1092 &false,
1093 vec![Some(true), None, Some(false)],
1094 Arc::clone(&col_a),
1095 &schema
1096 );
1097
1098 in_list!(
1099 batch,
1100 list,
1101 &true,
1102 vec![Some(false), None, Some(true)],
1103 Arc::clone(&col_a),
1104 &schema
1105 );
1106
1107 Ok(())
1108 }
1109
1110 #[test]
1111 fn test_cast_static_filter_to_set() -> Result<()> {
1112 let schema =
1114 Schema::new(vec![Field::new("a", DataType::Decimal128(13, 4), true)]);
1115
1116 let mut phy_exprs = vec![
1118 lit(1i64),
1119 expressions::cast(lit(2i32), &schema, DataType::Int64)?,
1120 try_cast(lit(3.13f32), &schema, DataType::Int64)?,
1121 ];
1122 let result = try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();
1123
1124 let array = Int64Array::from(vec![1, 2, 3, 4]);
1125 let r = result.contains(&array, false).unwrap();
1126 assert_eq!(r, BooleanArray::from(vec![true, true, true, false]));
1127
1128 try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();
1129 phy_exprs.push(expressions::cast(
1131 expressions::cast(lit(2i32), &schema, DataType::Int64)?,
1132 &schema,
1133 DataType::Int64,
1134 )?);
1135 try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();
1136
1137 phy_exprs.clear();
1138
1139 phy_exprs.push(expressions::cast(
1141 expressions::cast(lit(2i32), &schema, DataType::Int64)?,
1142 &schema,
1143 DataType::Int32,
1144 )?);
1145 try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();
1146
1147 phy_exprs.push(col("a", &schema)?);
1149 assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_err());
1150
1151 Ok(())
1152 }
1153
1154 #[test]
1155 fn in_list_timestamp() -> Result<()> {
1156 let schema = Schema::new(vec![Field::new(
1157 "a",
1158 DataType::Timestamp(TimeUnit::Microsecond, None),
1159 true,
1160 )]);
1161 let a = TimestampMicrosecondArray::from(vec![
1162 Some(1388588401000000000),
1163 Some(1288588501000000000),
1164 None,
1165 ]);
1166 let col_a = col("a", &schema)?;
1167 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1168
1169 let list = vec![
1170 lit(ScalarValue::TimestampMicrosecond(
1171 Some(1388588401000000000),
1172 None,
1173 )),
1174 lit(ScalarValue::TimestampMicrosecond(
1175 Some(1388588401000000001),
1176 None,
1177 )),
1178 lit(ScalarValue::TimestampMicrosecond(
1179 Some(1388588401000000002),
1180 None,
1181 )),
1182 ];
1183
1184 in_list!(
1185 batch,
1186 list.clone(),
1187 &false,
1188 vec![Some(true), Some(false), None],
1189 Arc::clone(&col_a),
1190 &schema
1191 );
1192
1193 in_list!(
1194 batch,
1195 list.clone(),
1196 &true,
1197 vec![Some(false), Some(true), None],
1198 Arc::clone(&col_a),
1199 &schema
1200 );
1201 Ok(())
1202 }
1203
1204 #[test]
1205 fn in_expr_with_multiple_element_in_list() -> Result<()> {
1206 let schema = Schema::new(vec![
1207 Field::new("a", DataType::Float64, true),
1208 Field::new("b", DataType::Float64, true),
1209 Field::new("c", DataType::Float64, true),
1210 ]);
1211 let a = Float64Array::from(vec![
1212 Some(0.0),
1213 Some(1.0),
1214 Some(2.0),
1215 Some(f64::NAN),
1216 Some(-f64::NAN),
1217 ]);
1218 let b = Float64Array::from(vec![
1219 Some(8.0),
1220 Some(1.0),
1221 Some(5.0),
1222 Some(f64::NAN),
1223 Some(3.0),
1224 ]);
1225 let c = Float64Array::from(vec![
1226 Some(6.0),
1227 Some(7.0),
1228 None,
1229 Some(5.0),
1230 Some(-f64::NAN),
1231 ]);
1232 let col_a = col("a", &schema)?;
1233 let col_b = col("b", &schema)?;
1234 let col_c = col("c", &schema)?;
1235 let batch = RecordBatch::try_new(
1236 Arc::new(schema.clone()),
1237 vec![Arc::new(a), Arc::new(b), Arc::new(c)],
1238 )?;
1239
1240 let list = vec![Arc::clone(&col_b), Arc::clone(&col_c)];
1241 in_list!(
1242 batch,
1243 list.clone(),
1244 &false,
1245 vec![Some(false), Some(true), None, Some(true), Some(true)],
1246 Arc::clone(&col_a),
1247 &schema
1248 );
1249
1250 in_list!(
1251 batch,
1252 list,
1253 &true,
1254 vec![Some(true), Some(false), None, Some(false), Some(false)],
1255 Arc::clone(&col_a),
1256 &schema
1257 );
1258
1259 Ok(())
1260 }
1261
1262 macro_rules! test_nullable {
1263 ($COL:expr, $LIST:expr, $SCHEMA:expr, $EXPECTED:expr) => {{
1264 let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, $SCHEMA)?;
1265 let expr = in_list(cast_expr, cast_list_exprs, &false, $SCHEMA).unwrap();
1266 let result = expr.nullable($SCHEMA)?;
1267 assert_eq!($EXPECTED, result);
1268 }};
1269 }
1270
1271 #[test]
1272 fn in_list_nullable() -> Result<()> {
1273 let schema = Schema::new(vec![
1274 Field::new("c1_nullable", DataType::Int64, true),
1275 Field::new("c2_non_nullable", DataType::Int64, false),
1276 ]);
1277
1278 let c1_nullable = col("c1_nullable", &schema)?;
1279 let c2_non_nullable = col("c2_non_nullable", &schema)?;
1280
1281 let list = vec![lit(1_i64), lit(2_i64)];
1283 test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true);
1284 test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, false);
1285
1286 let list = vec![lit(1_i64), lit(2_i64), lit(ScalarValue::Null)];
1288 test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true);
1289 test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, true);
1290
1291 let list = vec![Arc::clone(&c1_nullable)];
1292 test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, true);
1293
1294 let list = vec![Arc::clone(&c2_non_nullable)];
1295 test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true);
1296
1297 let list = vec![Arc::clone(&c2_non_nullable), Arc::clone(&c2_non_nullable)];
1298 test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, false);
1299
1300 Ok(())
1301 }
1302
1303 #[test]
1304 fn in_list_no_cols() -> Result<()> {
1305 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1307 let a = Int32Array::from(vec![Some(1), Some(2), None]);
1308 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1309
1310 let list = vec![lit(ScalarValue::from(1i32)), lit(ScalarValue::from(6i32))];
1311
1312 let expr = lit(ScalarValue::Int32(Some(1)));
1314 in_list!(
1315 batch,
1316 list.clone(),
1317 &false,
1318 vec![Some(true), Some(true), Some(true)],
1320 expr,
1321 &schema
1322 );
1323
1324 let expr = lit(ScalarValue::Int32(Some(2)));
1326 in_list!(
1327 batch,
1328 list.clone(),
1329 &false,
1330 vec![Some(false), Some(false), Some(false)],
1332 expr,
1333 &schema
1334 );
1335
1336 let expr = lit(ScalarValue::Int32(None));
1338 in_list!(
1339 batch,
1340 list.clone(),
1341 &false,
1342 vec![None, None, None],
1344 expr,
1345 &schema
1346 );
1347
1348 Ok(())
1349 }
1350
1351 #[test]
1352 fn in_list_utf8_with_dict_types() -> Result<()> {
1353 fn dict_lit(key_type: DataType, value: &str) -> Arc<dyn PhysicalExpr> {
1354 lit(ScalarValue::Dictionary(
1355 Box::new(key_type),
1356 Box::new(ScalarValue::new_utf8(value.to_string())),
1357 ))
1358 }
1359
1360 fn null_dict_lit(key_type: DataType) -> Arc<dyn PhysicalExpr> {
1361 lit(ScalarValue::Dictionary(
1362 Box::new(key_type),
1363 Box::new(ScalarValue::Utf8(None)),
1364 ))
1365 }
1366
1367 let schema = Schema::new(vec![Field::new(
1368 "a",
1369 DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
1370 true,
1371 )]);
1372 let a: UInt16DictionaryArray =
1373 vec![Some("a"), Some("d"), None].into_iter().collect();
1374 let col_a = col("a", &schema)?;
1375 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1376
1377 let lists = [
1379 vec![lit("a"), lit("b")],
1380 vec![
1381 dict_lit(DataType::Int8, "a"),
1382 dict_lit(DataType::UInt16, "b"),
1383 ],
1384 ];
1385 for list in lists.iter() {
1386 in_list_raw!(
1387 batch,
1388 list.clone(),
1389 &false,
1390 vec![Some(true), Some(false), None],
1391 Arc::clone(&col_a),
1392 &schema
1393 );
1394 }
1395
1396 for list in lists.iter() {
1398 in_list_raw!(
1399 batch,
1400 list.clone(),
1401 &true,
1402 vec![Some(false), Some(true), None],
1403 Arc::clone(&col_a),
1404 &schema
1405 );
1406 }
1407
1408 let lists = [
1410 vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))],
1411 vec![
1412 dict_lit(DataType::Int8, "a"),
1413 dict_lit(DataType::UInt16, "b"),
1414 null_dict_lit(DataType::UInt16),
1415 ],
1416 ];
1417 for list in lists.iter() {
1418 in_list_raw!(
1419 batch,
1420 list.clone(),
1421 &false,
1422 vec![Some(true), None, None],
1423 Arc::clone(&col_a),
1424 &schema
1425 );
1426 }
1427
1428 for list in lists.iter() {
1430 in_list_raw!(
1431 batch,
1432 list.clone(),
1433 &true,
1434 vec![Some(false), None, None],
1435 Arc::clone(&col_a),
1436 &schema
1437 );
1438 }
1439
1440 Ok(())
1441 }
1442
1443 #[test]
1444 fn test_fmt_sql() -> Result<()> {
1445 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
1446 let col_a = col("a", &schema)?;
1447
1448 let list = vec![lit("a"), lit("b")];
1450 let expr = in_list(Arc::clone(&col_a), list, &false, &schema)?;
1451 let sql_string = fmt_sql(expr.as_ref()).to_string();
1452 let display_string = expr.to_string();
1453 assert_eq!(sql_string, "a IN (a, b)");
1454 assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(\"b\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }])");
1455
1456 let list = vec![lit("a"), lit("b")];
1458 let expr = in_list(Arc::clone(&col_a), list, &true, &schema)?;
1459 let sql_string = fmt_sql(expr.as_ref()).to_string();
1460 let display_string = expr.to_string();
1461 assert_eq!(sql_string, "a NOT IN (a, b)");
1462 assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(\"b\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }])");
1463
1464 let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))];
1466 let expr = in_list(Arc::clone(&col_a), list, &false, &schema)?;
1467 let sql_string = fmt_sql(expr.as_ref()).to_string();
1468 let display_string = expr.to_string();
1469 assert_eq!(sql_string, "a IN (a, b, NULL)");
1470 assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(\"b\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(NULL), field: Field { name: \"lit\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }])");
1471
1472 let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))];
1474 let expr = in_list(Arc::clone(&col_a), list, &true, &schema)?;
1475 let sql_string = fmt_sql(expr.as_ref()).to_string();
1476 let display_string = expr.to_string();
1477 assert_eq!(sql_string, "a NOT IN (a, b, NULL)");
1478 assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(\"b\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(NULL), field: Field { name: \"lit\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }])");
1479
1480 Ok(())
1481 }
1482}