datafusion_physical_expr/expressions/
binary.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18mod kernels;
19
20use crate::PhysicalExpr;
21use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
22use std::hash::Hash;
23use std::{any::Any, sync::Arc};
24
25use arrow::array::*;
26use arrow::compute::kernels::boolean::{and_kleene, or_kleene};
27use arrow::compute::kernels::concat_elements::concat_elements_utf8;
28use arrow::compute::{SlicesIterator, cast, filter_record_batch};
29use arrow::datatypes::*;
30use arrow::error::ArrowError;
31use datafusion_common::cast::as_boolean_array;
32use datafusion_common::{Result, ScalarValue, internal_err, not_impl_err};
33use datafusion_expr::binary::BinaryTypeCoercer;
34use datafusion_expr::interval_arithmetic::{Interval, apply_operator};
35use datafusion_expr::sort_properties::ExprProperties;
36use datafusion_expr::statistics::Distribution::{Bernoulli, Gaussian};
37use datafusion_expr::statistics::{
38    Distribution, combine_bernoullis, combine_gaussians,
39    create_bernoulli_from_comparison, new_generic_from_binary_op,
40};
41use datafusion_expr::{ColumnarValue, Operator};
42use datafusion_physical_expr_common::datum::{apply, apply_cmp};
43
44use kernels::{
45    bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar,
46    bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, bitwise_shift_right_dyn,
47    bitwise_shift_right_dyn_scalar, bitwise_xor_dyn, bitwise_xor_dyn_scalar,
48    concat_elements_utf8view, regex_match_dyn, regex_match_dyn_scalar,
49};
50
51/// Binary expression
52#[derive(Debug, Clone, Eq)]
53pub struct BinaryExpr {
54    left: Arc<dyn PhysicalExpr>,
55    op: Operator,
56    right: Arc<dyn PhysicalExpr>,
57    /// Specifies whether an error is returned on overflow or not
58    fail_on_overflow: bool,
59}
60
61// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808
62impl PartialEq for BinaryExpr {
63    fn eq(&self, other: &Self) -> bool {
64        self.left.eq(&other.left)
65            && self.op.eq(&other.op)
66            && self.right.eq(&other.right)
67            && self.fail_on_overflow.eq(&other.fail_on_overflow)
68    }
69}
70impl Hash for BinaryExpr {
71    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
72        self.left.hash(state);
73        self.op.hash(state);
74        self.right.hash(state);
75        self.fail_on_overflow.hash(state);
76    }
77}
78
79impl BinaryExpr {
80    /// Create new binary expression
81    pub fn new(
82        left: Arc<dyn PhysicalExpr>,
83        op: Operator,
84        right: Arc<dyn PhysicalExpr>,
85    ) -> Self {
86        Self {
87            left,
88            op,
89            right,
90            fail_on_overflow: false,
91        }
92    }
93
94    /// Create new binary expression with explicit fail_on_overflow value
95    pub fn with_fail_on_overflow(self, fail_on_overflow: bool) -> Self {
96        Self {
97            left: self.left,
98            op: self.op,
99            right: self.right,
100            fail_on_overflow,
101        }
102    }
103
104    /// Get the left side of the binary expression
105    pub fn left(&self) -> &Arc<dyn PhysicalExpr> {
106        &self.left
107    }
108
109    /// Get the right side of the binary expression
110    pub fn right(&self) -> &Arc<dyn PhysicalExpr> {
111        &self.right
112    }
113
114    /// Get the operator for this binary expression
115    pub fn op(&self) -> &Operator {
116        &self.op
117    }
118}
119
120impl std::fmt::Display for BinaryExpr {
121    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
122        // Put parentheses around child binary expressions so that we can see the difference
123        // between `(a OR b) AND c` and `a OR (b AND c)`. We only insert parentheses when needed,
124        // based on operator precedence. For example, `(a AND b) OR c` and `a AND b OR c` are
125        // equivalent and the parentheses are not necessary.
126
127        fn write_child(
128            f: &mut std::fmt::Formatter,
129            expr: &dyn PhysicalExpr,
130            precedence: u8,
131        ) -> std::fmt::Result {
132            if let Some(child) = expr.as_any().downcast_ref::<BinaryExpr>() {
133                let p = child.op.precedence();
134                if p == 0 || p < precedence {
135                    write!(f, "({child})")?;
136                } else {
137                    write!(f, "{child}")?;
138                }
139            } else {
140                write!(f, "{expr}")?;
141            }
142
143            Ok(())
144        }
145
146        let precedence = self.op.precedence();
147        write_child(f, self.left.as_ref(), precedence)?;
148        write!(f, " {} ", self.op)?;
149        write_child(f, self.right.as_ref(), precedence)
150    }
151}
152
153/// Invoke a boolean kernel on a pair of arrays
154#[inline]
155fn boolean_op(
156    left: &dyn Array,
157    right: &dyn Array,
158    op: impl FnOnce(&BooleanArray, &BooleanArray) -> Result<BooleanArray, ArrowError>,
159) -> Result<Arc<dyn Array + 'static>, ArrowError> {
160    let ll = as_boolean_array(left).expect("boolean_op failed to downcast left array");
161    let rr = as_boolean_array(right).expect("boolean_op failed to downcast right array");
162    op(ll, rr).map(|t| Arc::new(t) as _)
163}
164
165impl PhysicalExpr for BinaryExpr {
166    /// Return a reference to Any that can be used for downcasting
167    fn as_any(&self) -> &dyn Any {
168        self
169    }
170
171    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
172        BinaryTypeCoercer::new(
173            &self.left.data_type(input_schema)?,
174            &self.op,
175            &self.right.data_type(input_schema)?,
176        )
177        .get_result_type()
178    }
179
180    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
181        Ok(self.left.nullable(input_schema)? || self.right.nullable(input_schema)?)
182    }
183
184    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
185        use arrow::compute::kernels::numeric::*;
186
187        // Evaluate left-hand side expression.
188        let lhs = self.left.evaluate(batch)?;
189
190        // Check if we can apply short-circuit evaluation.
191        match check_short_circuit(&lhs, &self.op) {
192            ShortCircuitStrategy::None => {}
193            ShortCircuitStrategy::ReturnLeft => return Ok(lhs),
194            ShortCircuitStrategy::ReturnRight => {
195                let rhs = self.right.evaluate(batch)?;
196                return Ok(rhs);
197            }
198            ShortCircuitStrategy::PreSelection(selection) => {
199                // The function `evaluate_selection` was not called for filtering and calculation,
200                // as it takes into account cases where the selection contains null values.
201                let batch = filter_record_batch(batch, selection)?;
202                let right_ret = self.right.evaluate(&batch)?;
203
204                match &right_ret {
205                    ColumnarValue::Array(array) => {
206                        // When the array on the right is all true or all false, skip the scatter process
207                        let boolean_array = array.as_boolean();
208                        let true_count = boolean_array.true_count();
209                        let length = boolean_array.len();
210                        if true_count == length {
211                            return Ok(lhs);
212                        } else if true_count == 0 && boolean_array.null_count() == 0 {
213                            // If the right-hand array is returned at this point,the lengths will be inconsistent;
214                            // returning a scalar can avoid this issue
215                            return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(
216                                Some(false),
217                            )));
218                        }
219
220                        return pre_selection_scatter(selection, Some(boolean_array));
221                    }
222                    ColumnarValue::Scalar(scalar) => {
223                        if let ScalarValue::Boolean(v) = scalar {
224                            // When the scalar is true or false, skip the scatter process
225                            if let Some(v) = v {
226                                if *v {
227                                    return Ok(lhs);
228                                } else {
229                                    return Ok(right_ret);
230                                }
231                            } else {
232                                return pre_selection_scatter(selection, None);
233                            }
234                        } else {
235                            return internal_err!(
236                                "Expected boolean scalar value, found: {right_ret:?}"
237                            );
238                        }
239                    }
240                }
241            }
242        }
243
244        let rhs = self.right.evaluate(batch)?;
245        let left_data_type = lhs.data_type();
246        let right_data_type = rhs.data_type();
247
248        let schema = batch.schema();
249        let input_schema = schema.as_ref();
250
251        match self.op {
252            Operator::Plus if self.fail_on_overflow => return apply(&lhs, &rhs, add),
253            Operator::Plus => return apply(&lhs, &rhs, add_wrapping),
254            Operator::Minus if self.fail_on_overflow => return apply(&lhs, &rhs, sub),
255            Operator::Minus => return apply(&lhs, &rhs, sub_wrapping),
256            Operator::Multiply if self.fail_on_overflow => return apply(&lhs, &rhs, mul),
257            Operator::Multiply => return apply(&lhs, &rhs, mul_wrapping),
258            Operator::Divide => return apply(&lhs, &rhs, div),
259            Operator::Modulo => return apply(&lhs, &rhs, rem),
260
261            Operator::Eq
262            | Operator::NotEq
263            | Operator::Lt
264            | Operator::Gt
265            | Operator::LtEq
266            | Operator::GtEq
267            | Operator::IsDistinctFrom
268            | Operator::IsNotDistinctFrom
269            | Operator::LikeMatch
270            | Operator::ILikeMatch
271            | Operator::NotLikeMatch
272            | Operator::NotILikeMatch => {
273                return apply_cmp(self.op, &lhs, &rhs);
274            }
275            _ => {}
276        }
277
278        let result_type = self.data_type(input_schema)?;
279
280        // If the left-hand side is an array and the right-hand side is a non-null scalar, try the optimized kernel.
281        if let (ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) = (&lhs, &rhs)
282            && !scalar.is_null()
283            && let Some(result_array) =
284                self.evaluate_array_scalar(array, scalar.clone())?
285        {
286            let final_array = result_array
287                .and_then(|a| to_result_type_array(&self.op, a, &result_type));
288            return final_array.map(ColumnarValue::Array);
289        }
290
291        // if both arrays or both literals - extract arrays and continue execution
292        let (left, right) = (
293            lhs.into_array(batch.num_rows())?,
294            rhs.into_array(batch.num_rows())?,
295        );
296        self.evaluate_with_resolved_args(left, &left_data_type, right, &right_data_type)
297            .map(ColumnarValue::Array)
298    }
299
300    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
301        vec![&self.left, &self.right]
302    }
303
304    fn with_new_children(
305        self: Arc<Self>,
306        children: Vec<Arc<dyn PhysicalExpr>>,
307    ) -> Result<Arc<dyn PhysicalExpr>> {
308        Ok(Arc::new(
309            BinaryExpr::new(Arc::clone(&children[0]), self.op, Arc::clone(&children[1]))
310                .with_fail_on_overflow(self.fail_on_overflow),
311        ))
312    }
313
314    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
315        // Get children intervals:
316        let left_interval = children[0];
317        let right_interval = children[1];
318        // Calculate current node's interval:
319        apply_operator(&self.op, left_interval, right_interval)
320    }
321
322    fn propagate_constraints(
323        &self,
324        interval: &Interval,
325        children: &[&Interval],
326    ) -> Result<Option<Vec<Interval>>> {
327        // Get children intervals.
328        let left_interval = children[0];
329        let right_interval = children[1];
330
331        if self.op.eq(&Operator::And) {
332            if interval.eq(&Interval::TRUE) {
333                // A certainly true logical conjunction can only derive from possibly
334                // true operands. Otherwise, we prove infeasibility.
335                Ok((!left_interval.eq(&Interval::FALSE)
336                    && !right_interval.eq(&Interval::FALSE))
337                .then(|| vec![Interval::TRUE, Interval::TRUE]))
338            } else if interval.eq(&Interval::FALSE) {
339                // If the logical conjunction is certainly false, one of the
340                // operands must be false. However, it's not always possible to
341                // determine which operand is false, leading to different scenarios.
342
343                // If one operand is certainly true and the other one is uncertain,
344                // then the latter must be certainly false.
345                if left_interval.eq(&Interval::TRUE)
346                    && right_interval.eq(&Interval::TRUE_OR_FALSE)
347                {
348                    Ok(Some(vec![Interval::TRUE, Interval::FALSE]))
349                } else if right_interval.eq(&Interval::TRUE)
350                    && left_interval.eq(&Interval::TRUE_OR_FALSE)
351                {
352                    Ok(Some(vec![Interval::FALSE, Interval::TRUE]))
353                }
354                // If both children are uncertain, or if one is certainly false,
355                // we cannot conclusively refine their intervals. In this case,
356                // propagation does not result in any interval changes.
357                else {
358                    Ok(Some(vec![]))
359                }
360            } else {
361                // An uncertain logical conjunction result can not shrink the
362                // end-points of its children.
363                Ok(Some(vec![]))
364            }
365        } else if self.op.eq(&Operator::Or) {
366            if interval.eq(&Interval::FALSE) {
367                // A certainly false logical disjunction can only derive from certainly
368                // false operands. Otherwise, we prove infeasibility.
369                Ok((!left_interval.eq(&Interval::TRUE)
370                    && !right_interval.eq(&Interval::TRUE))
371                .then(|| vec![Interval::FALSE, Interval::FALSE]))
372            } else if interval.eq(&Interval::TRUE) {
373                // If the logical disjunction is certainly true, one of the
374                // operands must be true. However, it's not always possible to
375                // determine which operand is true, leading to different scenarios.
376
377                // If one operand is certainly false and the other one is uncertain,
378                // then the latter must be certainly true.
379                if left_interval.eq(&Interval::FALSE)
380                    && right_interval.eq(&Interval::TRUE_OR_FALSE)
381                {
382                    Ok(Some(vec![Interval::FALSE, Interval::TRUE]))
383                } else if right_interval.eq(&Interval::FALSE)
384                    && left_interval.eq(&Interval::TRUE_OR_FALSE)
385                {
386                    Ok(Some(vec![Interval::TRUE, Interval::FALSE]))
387                }
388                // If both children are uncertain, or if one is certainly true,
389                // we cannot conclusively refine their intervals. In this case,
390                // propagation does not result in any interval changes.
391                else {
392                    Ok(Some(vec![]))
393                }
394            } else {
395                // An uncertain logical disjunction result can not shrink the
396                // end-points of its children.
397                Ok(Some(vec![]))
398            }
399        } else if self.op.supports_propagation() {
400            Ok(
401                propagate_comparison(&self.op, interval, left_interval, right_interval)?
402                    .map(|(left, right)| vec![left, right]),
403            )
404        } else {
405            Ok(
406                propagate_arithmetic(&self.op, interval, left_interval, right_interval)?
407                    .map(|(left, right)| vec![left, right]),
408            )
409        }
410    }
411
412    fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
413        let (left, right) = (children[0], children[1]);
414
415        if self.op.is_numerical_operators() {
416            // We might be able to construct the output statistics more accurately,
417            // without falling back to an unknown distribution, if we are dealing
418            // with Gaussian distributions and numerical operations.
419            if let (Gaussian(left), Gaussian(right)) = (left, right)
420                && let Some(result) = combine_gaussians(&self.op, left, right)?
421            {
422                return Ok(Gaussian(result));
423            }
424        } else if self.op.is_logic_operator() {
425            // If we are dealing with logical operators, we expect (and can only
426            // operate on) Bernoulli distributions.
427            return if let (Bernoulli(left), Bernoulli(right)) = (left, right) {
428                combine_bernoullis(&self.op, left, right).map(Bernoulli)
429            } else {
430                internal_err!(
431                    "Logical operators are only compatible with `Bernoulli` distributions"
432                )
433            };
434        } else if self.op.supports_propagation() {
435            // If we are handling comparison operators, we expect (and can only
436            // operate on) numeric distributions.
437            return create_bernoulli_from_comparison(&self.op, left, right);
438        }
439        // Fall back to an unknown distribution with only summary statistics:
440        new_generic_from_binary_op(&self.op, left, right)
441    }
442
443    /// For each operator, [`BinaryExpr`] has distinct rules.
444    /// TODO: There may be rules specific to some data types and expression ranges.
445    fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
446        let (l_order, l_range) = (children[0].sort_properties, &children[0].range);
447        let (r_order, r_range) = (children[1].sort_properties, &children[1].range);
448        match self.op() {
449            Operator::Plus => Ok(ExprProperties {
450                sort_properties: l_order.add(&r_order),
451                range: l_range.add(r_range)?,
452                preserves_lex_ordering: false,
453            }),
454            Operator::Minus => Ok(ExprProperties {
455                sort_properties: l_order.sub(&r_order),
456                range: l_range.sub(r_range)?,
457                preserves_lex_ordering: false,
458            }),
459            Operator::Gt => Ok(ExprProperties {
460                sort_properties: l_order.gt_or_gteq(&r_order),
461                range: l_range.gt(r_range)?,
462                preserves_lex_ordering: false,
463            }),
464            Operator::GtEq => Ok(ExprProperties {
465                sort_properties: l_order.gt_or_gteq(&r_order),
466                range: l_range.gt_eq(r_range)?,
467                preserves_lex_ordering: false,
468            }),
469            Operator::Lt => Ok(ExprProperties {
470                sort_properties: r_order.gt_or_gteq(&l_order),
471                range: l_range.lt(r_range)?,
472                preserves_lex_ordering: false,
473            }),
474            Operator::LtEq => Ok(ExprProperties {
475                sort_properties: r_order.gt_or_gteq(&l_order),
476                range: l_range.lt_eq(r_range)?,
477                preserves_lex_ordering: false,
478            }),
479            Operator::And => Ok(ExprProperties {
480                sort_properties: r_order.and_or(&l_order),
481                range: l_range.and(r_range)?,
482                preserves_lex_ordering: false,
483            }),
484            Operator::Or => Ok(ExprProperties {
485                sort_properties: r_order.and_or(&l_order),
486                range: l_range.or(r_range)?,
487                preserves_lex_ordering: false,
488            }),
489            _ => Ok(ExprProperties::new_unknown()),
490        }
491    }
492
493    fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
494        fn write_child(
495            f: &mut std::fmt::Formatter,
496            expr: &dyn PhysicalExpr,
497            precedence: u8,
498        ) -> std::fmt::Result {
499            if let Some(child) = expr.as_any().downcast_ref::<BinaryExpr>() {
500                let p = child.op.precedence();
501                if p == 0 || p < precedence {
502                    write!(f, "(")?;
503                    child.fmt_sql(f)?;
504                    write!(f, ")")
505                } else {
506                    child.fmt_sql(f)
507                }
508            } else {
509                expr.fmt_sql(f)
510            }
511        }
512
513        let precedence = self.op.precedence();
514        write_child(f, self.left.as_ref(), precedence)?;
515        write!(f, " {} ", self.op)?;
516        write_child(f, self.right.as_ref(), precedence)
517    }
518}
519
520/// Casts dictionary array to result type for binary numerical operators. Such operators
521/// between array and scalar produce a dictionary array other than primitive array of the
522/// same operators between array and array. This leads to inconsistent result types causing
523/// errors in the following query execution. For such operators between array and scalar,
524/// we cast the dictionary array to primitive array.
525fn to_result_type_array(
526    op: &Operator,
527    array: ArrayRef,
528    result_type: &DataType,
529) -> Result<ArrayRef> {
530    if array.data_type() == result_type {
531        Ok(array)
532    } else if op.is_numerical_operators() {
533        match array.data_type() {
534            DataType::Dictionary(_, value_type) => {
535                if value_type.as_ref() == result_type {
536                    Ok(cast(&array, result_type)?)
537                } else {
538                    internal_err!(
539                        "Incompatible Dictionary value type {value_type} with result type {result_type} of Binary operator {op:?}"
540                    )
541                }
542            }
543            _ => Ok(array),
544        }
545    } else {
546        Ok(array)
547    }
548}
549
550impl BinaryExpr {
551    /// Evaluate the expression of the left input is an array and
552    /// right is literal - use scalar operations
553    fn evaluate_array_scalar(
554        &self,
555        array: &dyn Array,
556        scalar: ScalarValue,
557    ) -> Result<Option<Result<ArrayRef>>> {
558        use Operator::*;
559        let scalar_result = match &self.op {
560            RegexMatch => regex_match_dyn_scalar(array, &scalar, false, false),
561            RegexIMatch => regex_match_dyn_scalar(array, &scalar, false, true),
562            RegexNotMatch => regex_match_dyn_scalar(array, &scalar, true, false),
563            RegexNotIMatch => regex_match_dyn_scalar(array, &scalar, true, true),
564            BitwiseAnd => bitwise_and_dyn_scalar(array, scalar),
565            BitwiseOr => bitwise_or_dyn_scalar(array, scalar),
566            BitwiseXor => bitwise_xor_dyn_scalar(array, scalar),
567            BitwiseShiftRight => bitwise_shift_right_dyn_scalar(array, scalar),
568            BitwiseShiftLeft => bitwise_shift_left_dyn_scalar(array, scalar),
569            // if scalar operation is not supported - fallback to array implementation
570            _ => None,
571        };
572
573        Ok(scalar_result)
574    }
575
576    fn evaluate_with_resolved_args(
577        &self,
578        left: Arc<dyn Array>,
579        left_data_type: &DataType,
580        right: Arc<dyn Array>,
581        right_data_type: &DataType,
582    ) -> Result<ArrayRef> {
583        use Operator::*;
584        match &self.op {
585            IsDistinctFrom | IsNotDistinctFrom | Lt | LtEq | Gt | GtEq | Eq | NotEq
586            | Plus | Minus | Multiply | Divide | Modulo | LikeMatch | ILikeMatch
587            | NotLikeMatch | NotILikeMatch => unreachable!(),
588            And => {
589                if left_data_type == &DataType::Boolean {
590                    Ok(boolean_op(&left, &right, and_kleene)?)
591                } else {
592                    internal_err!(
593                        "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
594                        self.op,
595                        left.data_type(),
596                        right.data_type()
597                    )
598                }
599            }
600            Or => {
601                if left_data_type == &DataType::Boolean {
602                    Ok(boolean_op(&left, &right, or_kleene)?)
603                } else {
604                    internal_err!(
605                        "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
606                        self.op,
607                        left_data_type,
608                        right_data_type
609                    )
610                }
611            }
612            RegexMatch => regex_match_dyn(&left, &right, false, false),
613            RegexIMatch => regex_match_dyn(&left, &right, false, true),
614            RegexNotMatch => regex_match_dyn(&left, &right, true, false),
615            RegexNotIMatch => regex_match_dyn(&left, &right, true, true),
616            BitwiseAnd => bitwise_and_dyn(left, right),
617            BitwiseOr => bitwise_or_dyn(left, right),
618            BitwiseXor => bitwise_xor_dyn(left, right),
619            BitwiseShiftRight => bitwise_shift_right_dyn(left, right),
620            BitwiseShiftLeft => bitwise_shift_left_dyn(left, right),
621            StringConcat => concat_elements(&left, &right),
622            AtArrow | ArrowAt | Arrow | LongArrow | HashArrow | HashLongArrow | AtAt
623            | HashMinus | AtQuestion | Question | QuestionAnd | QuestionPipe
624            | IntegerDivide => {
625                not_impl_err!(
626                    "Binary operator '{:?}' is not supported in the physical expr",
627                    self.op
628                )
629            }
630        }
631    }
632}
633
634enum ShortCircuitStrategy<'a> {
635    None,
636    ReturnLeft,
637    ReturnRight,
638    PreSelection(&'a BooleanArray),
639}
640
641/// Based on the results calculated from the left side of the short-circuit operation,
642/// if the proportion of `true` is less than 0.2 and the current operation is an `and`,
643/// the `RecordBatch` will be filtered in advance.
644const PRE_SELECTION_THRESHOLD: f32 = 0.2;
645
646/// Checks if a logical operator (`AND`/`OR`) can short-circuit evaluation based on the left-hand side (lhs) result.
647///
648/// Short-circuiting occurs under these circumstances:
649/// - For `AND`:
650///    - if LHS is all false => short-circuit → return LHS
651///    - if LHS is all true  => short-circuit → return RHS
652///    - if LHS is mixed and true_count/sum_count <= [`PRE_SELECTION_THRESHOLD`] -> pre-selection
653/// - For `OR`:
654///    - if LHS is all true  => short-circuit → return LHS
655///    - if LHS is all false => short-circuit → return RHS
656/// # Arguments
657/// * `lhs` - The left-hand side (lhs) columnar value (array or scalar)
658/// * `lhs` - The left-hand side (lhs) columnar value (array or scalar)
659/// * `op` - The logical operator (`AND` or `OR`)
660///
661/// # Implementation Notes
662/// 1. Only works with Boolean-typed arguments (other types automatically return `false`)
663/// 2. Handles both scalar values and array values
664/// 3. For arrays, uses optimized bit counting techniques for boolean arrays
665fn check_short_circuit<'a>(
666    lhs: &'a ColumnarValue,
667    op: &Operator,
668) -> ShortCircuitStrategy<'a> {
669    // Quick reject for non-logical operators,and quick judgment when op is and
670    let is_and = match op {
671        Operator::And => true,
672        Operator::Or => false,
673        _ => return ShortCircuitStrategy::None,
674    };
675
676    // Non-boolean types can't be short-circuited
677    if lhs.data_type() != DataType::Boolean {
678        return ShortCircuitStrategy::None;
679    }
680
681    match lhs {
682        ColumnarValue::Array(array) => {
683            // Fast path for arrays - try to downcast to boolean array
684            if let Ok(bool_array) = as_boolean_array(array) {
685                // Arrays with nulls can't be short-circuited
686                if bool_array.null_count() > 0 {
687                    return ShortCircuitStrategy::None;
688                }
689
690                let len = bool_array.len();
691                if len == 0 {
692                    return ShortCircuitStrategy::None;
693                }
694
695                let true_count = bool_array.values().count_set_bits();
696                if is_and {
697                    // For AND, prioritize checking for all-false (short circuit case)
698                    // Uses optimized false_count() method provided by Arrow
699
700                    // Short circuit if all values are false
701                    if true_count == 0 {
702                        return ShortCircuitStrategy::ReturnLeft;
703                    }
704
705                    // If no false values, then all must be true
706                    if true_count == len {
707                        return ShortCircuitStrategy::ReturnRight;
708                    }
709
710                    // determine if we can pre-selection
711                    if true_count as f32 / len as f32 <= PRE_SELECTION_THRESHOLD {
712                        return ShortCircuitStrategy::PreSelection(bool_array);
713                    }
714                } else {
715                    // For OR, prioritize checking for all-true (short circuit case)
716                    // Uses optimized true_count() method provided by Arrow
717
718                    // Short circuit if all values are true
719                    if true_count == len {
720                        return ShortCircuitStrategy::ReturnLeft;
721                    }
722
723                    // If no true values, then all must be false
724                    if true_count == 0 {
725                        return ShortCircuitStrategy::ReturnRight;
726                    }
727                }
728            }
729        }
730        ColumnarValue::Scalar(scalar) => {
731            // Fast path for scalar values
732            if let ScalarValue::Boolean(Some(is_true)) = scalar {
733                // Return Left for:
734                // - AND with false value
735                // - OR with true value
736                if (is_and && !is_true) || (!is_and && *is_true) {
737                    return ShortCircuitStrategy::ReturnLeft;
738                } else {
739                    return ShortCircuitStrategy::ReturnRight;
740                }
741            }
742        }
743    }
744
745    // If we can't short-circuit, indicate that normal evaluation should continue
746    ShortCircuitStrategy::None
747}
748
749/// Creates a new boolean array based on the evaluation of the right expression,
750/// but only for positions where the left_result is true.
751///
752/// This function is used for short-circuit evaluation optimization of logical AND operations:
753/// - When left_result has few true values, we only evaluate the right expression for those positions
754/// - Values are copied from right_array where left_result is true
755/// - All other positions are filled with false values
756///
757/// # Parameters
758/// - `left_result` Boolean array with selection mask (typically from left side of AND)
759/// - `right_result` Result of evaluating right side of expression (only for selected positions)
760///
761/// # Returns
762/// A combined ColumnarValue with values from right_result where left_result is true
763///
764/// # Example
765///  Initial Data: { 1, 2, 3, 4, 5 }
766///  Left Evaluation
767///     (Condition: Equal to 2 or 3)
768///          ↓
769///  Filtered Data: {2, 3}
770///    Left Bitmap: { 0, 1, 1, 0, 0 }
771///          ↓
772///   Right Evaluation
773///     (Condition: Even numbers)
774///          ↓
775///  Right Data: { 2 }
776///    Right Bitmap: { 1, 0 }
777///          ↓
778///   Combine Results
779///  Final Bitmap: { 0, 1, 0, 0, 0 }
780///
781/// # Note
782/// Perhaps it would be better to modify `left_result` directly without creating a copy?
783/// In practice, `left_result` should have only one owner, so making changes should be safe.
784/// However, this is difficult to achieve under the immutable constraints of [`Arc`] and [`BooleanArray`].
785fn pre_selection_scatter(
786    left_result: &BooleanArray,
787    right_result: Option<&BooleanArray>,
788) -> Result<ColumnarValue> {
789    let result_len = left_result.len();
790
791    let mut result_array_builder = BooleanArray::builder(result_len);
792
793    // keep track of current position we have in right boolean array
794    let mut right_array_pos = 0;
795
796    // keep track of how much is filled
797    let mut last_end = 0;
798    // reduce if condition in for_each
799    match right_result {
800        Some(right_result) => {
801            SlicesIterator::new(left_result).for_each(|(start, end)| {
802                // the gap needs to be filled with false
803                if start > last_end {
804                    result_array_builder.append_n(start - last_end, false);
805                }
806
807                // copy values from right array for this slice
808                let len = end - start;
809                right_result
810                    .slice(right_array_pos, len)
811                    .iter()
812                    .for_each(|v| result_array_builder.append_option(v));
813
814                right_array_pos += len;
815                last_end = end;
816            });
817        }
818        None => SlicesIterator::new(left_result).for_each(|(start, end)| {
819            // the gap needs to be filled with false
820            if start > last_end {
821                result_array_builder.append_n(start - last_end, false);
822            }
823
824            // append nulls for this slice derictly
825            let len = end - start;
826            result_array_builder.append_nulls(len);
827
828            last_end = end;
829        }),
830    }
831
832    // Fill any remaining positions with false
833    if last_end < result_len {
834        result_array_builder.append_n(result_len - last_end, false);
835    }
836    let boolean_result = result_array_builder.finish();
837
838    Ok(ColumnarValue::Array(Arc::new(boolean_result)))
839}
840
841fn concat_elements(left: &ArrayRef, right: &ArrayRef) -> Result<ArrayRef> {
842    Ok(match left.data_type() {
843        DataType::Utf8 => Arc::new(concat_elements_utf8(
844            left.as_string::<i32>(),
845            right.as_string::<i32>(),
846        )?),
847        DataType::LargeUtf8 => Arc::new(concat_elements_utf8(
848            left.as_string::<i64>(),
849            right.as_string::<i64>(),
850        )?),
851        DataType::Utf8View => Arc::new(concat_elements_utf8view(
852            left.as_string_view(),
853            right.as_string_view(),
854        )?),
855        other => {
856            return internal_err!(
857                "Data type {other:?} not supported for binary operation 'concat_elements' on string arrays"
858            );
859        }
860    })
861}
862
863/// Create a binary expression whose arguments are correctly coerced.
864/// This function errors if it is not possible to coerce the arguments
865/// to computational types supported by the operator.
866pub fn binary(
867    lhs: Arc<dyn PhysicalExpr>,
868    op: Operator,
869    rhs: Arc<dyn PhysicalExpr>,
870    _input_schema: &Schema,
871) -> Result<Arc<dyn PhysicalExpr>> {
872    Ok(Arc::new(BinaryExpr::new(lhs, op, rhs)))
873}
874
875/// Create a similar to expression
876pub fn similar_to(
877    negated: bool,
878    case_insensitive: bool,
879    expr: Arc<dyn PhysicalExpr>,
880    pattern: Arc<dyn PhysicalExpr>,
881) -> Result<Arc<dyn PhysicalExpr>> {
882    let binary_op = match (negated, case_insensitive) {
883        (false, false) => Operator::RegexMatch,
884        (false, true) => Operator::RegexIMatch,
885        (true, false) => Operator::RegexNotMatch,
886        (true, true) => Operator::RegexNotIMatch,
887    };
888    Ok(Arc::new(BinaryExpr::new(expr, binary_op, pattern)))
889}
890
891#[cfg(test)]
892mod tests {
893    use super::*;
894    use crate::expressions::{Column, Literal, col, lit, try_cast};
895    use datafusion_expr::lit as expr_lit;
896
897    use datafusion_common::plan_datafusion_err;
898    use datafusion_physical_expr_common::physical_expr::fmt_sql;
899
900    use crate::planner::logical2physical;
901    use arrow::array::BooleanArray;
902    use datafusion_expr::col as logical_col;
903    /// Performs a binary operation, applying any type coercion necessary
904    fn binary_op(
905        left: Arc<dyn PhysicalExpr>,
906        op: Operator,
907        right: Arc<dyn PhysicalExpr>,
908        schema: &Schema,
909    ) -> Result<Arc<dyn PhysicalExpr>> {
910        let left_type = left.data_type(schema)?;
911        let right_type = right.data_type(schema)?;
912        let (lhs, rhs) =
913            BinaryTypeCoercer::new(&left_type, &op, &right_type).get_input_types()?;
914
915        let left_expr = try_cast(left, schema, lhs)?;
916        let right_expr = try_cast(right, schema, rhs)?;
917        binary(left_expr, op, right_expr, schema)
918    }
919
920    #[test]
921    fn binary_comparison() -> Result<()> {
922        let schema = Schema::new(vec![
923            Field::new("a", DataType::Int32, false),
924            Field::new("b", DataType::Int32, false),
925        ]);
926        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
927        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
928
929        // expression: "a < b"
930        let lt = binary(
931            col("a", &schema)?,
932            Operator::Lt,
933            col("b", &schema)?,
934            &schema,
935        )?;
936        let batch =
937            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
938
939        let result = lt
940            .evaluate(&batch)?
941            .into_array(batch.num_rows())
942            .expect("Failed to convert to array");
943        assert_eq!(result.len(), 5);
944
945        let expected = [false, false, true, true, true];
946        let result =
947            as_boolean_array(&result).expect("failed to downcast to BooleanArray");
948        for (i, &expected_item) in expected.iter().enumerate().take(5) {
949            assert_eq!(result.value(i), expected_item);
950        }
951
952        Ok(())
953    }
954
955    #[test]
956    fn binary_nested() -> Result<()> {
957        let schema = Schema::new(vec![
958            Field::new("a", DataType::Int32, false),
959            Field::new("b", DataType::Int32, false),
960        ]);
961        let a = Int32Array::from(vec![2, 4, 6, 8, 10]);
962        let b = Int32Array::from(vec![2, 5, 4, 8, 8]);
963
964        // expression: "a < b OR a == b"
965        let expr = binary(
966            binary(
967                col("a", &schema)?,
968                Operator::Lt,
969                col("b", &schema)?,
970                &schema,
971            )?,
972            Operator::Or,
973            binary(
974                col("a", &schema)?,
975                Operator::Eq,
976                col("b", &schema)?,
977                &schema,
978            )?,
979            &schema,
980        )?;
981        let batch =
982            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
983
984        assert_eq!("a@0 < b@1 OR a@0 = b@1", format!("{expr}"));
985
986        let result = expr
987            .evaluate(&batch)?
988            .into_array(batch.num_rows())
989            .expect("Failed to convert to array");
990        assert_eq!(result.len(), 5);
991
992        let expected = [true, true, false, true, false];
993        let result =
994            as_boolean_array(&result).expect("failed to downcast to BooleanArray");
995        for (i, &expected_item) in expected.iter().enumerate().take(5) {
996            assert_eq!(result.value(i), expected_item);
997        }
998
999        Ok(())
1000    }
1001
1002    // runs an end-to-end test of physical type coercion:
1003    // 1. construct a record batch with two columns of type A and B
1004    //  (*_ARRAY is the Rust Arrow array type, and *_TYPE is the DataType of the elements)
1005    // 2. construct a physical expression of A OP B
1006    // 3. evaluate the expression
1007    // 4. verify that the resulting expression is of type C
1008    // 5. verify that the results of evaluation are $VEC
1009    macro_rules! test_coercion {
1010        ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $B_ARRAY:ident, $B_TYPE:expr, $B_VEC:expr, $OP:expr, $C_ARRAY:ident, $C_TYPE:expr, $VEC:expr,) => {{
1011            let schema = Schema::new(vec![
1012                Field::new("a", $A_TYPE, false),
1013                Field::new("b", $B_TYPE, false),
1014            ]);
1015            let a = $A_ARRAY::from($A_VEC);
1016            let b = $B_ARRAY::from($B_VEC);
1017            let (lhs, rhs) =
1018                BinaryTypeCoercer::new(&$A_TYPE, &$OP, &$B_TYPE).get_input_types()?;
1019
1020            let left = try_cast(col("a", &schema)?, &schema, lhs)?;
1021            let right = try_cast(col("b", &schema)?, &schema, rhs)?;
1022
1023            // verify that we can construct the expression
1024            let expression = binary(left, $OP, right, &schema)?;
1025            let batch = RecordBatch::try_new(
1026                Arc::new(schema.clone()),
1027                vec![Arc::new(a), Arc::new(b)],
1028            )?;
1029
1030            // verify that the expression's type is correct
1031            assert_eq!(expression.data_type(&schema)?, $C_TYPE);
1032
1033            // compute
1034            let result = expression
1035                .evaluate(&batch)?
1036                .into_array(batch.num_rows())
1037                .expect("Failed to convert to array");
1038
1039            // verify that the array's data_type is correct
1040            assert_eq!(*result.data_type(), $C_TYPE);
1041
1042            // verify that the data itself is downcastable
1043            let result = result
1044                .as_any()
1045                .downcast_ref::<$C_ARRAY>()
1046                .expect("failed to downcast");
1047            // verify that the result itself is correct
1048            for (i, x) in $VEC.iter().enumerate() {
1049                let v = result.value(i);
1050                assert_eq!(
1051                    v, *x,
1052                    "Unexpected output at position {i}:\n\nActual:\n{v}\n\nExpected:\n{x}"
1053                );
1054            }
1055        }};
1056    }
1057
1058    #[test]
1059    fn test_type_coercion() -> Result<()> {
1060        test_coercion!(
1061            Int32Array,
1062            DataType::Int32,
1063            vec![1i32, 2i32],
1064            UInt32Array,
1065            DataType::UInt32,
1066            vec![1u32, 2u32],
1067            Operator::Plus,
1068            Int64Array,
1069            DataType::Int64,
1070            [2i64, 4i64],
1071        );
1072        test_coercion!(
1073            Int32Array,
1074            DataType::Int32,
1075            vec![1i32],
1076            UInt16Array,
1077            DataType::UInt16,
1078            vec![1u16],
1079            Operator::Plus,
1080            Int32Array,
1081            DataType::Int32,
1082            [2i32],
1083        );
1084        test_coercion!(
1085            Float32Array,
1086            DataType::Float32,
1087            vec![1f32],
1088            UInt16Array,
1089            DataType::UInt16,
1090            vec![1u16],
1091            Operator::Plus,
1092            Float32Array,
1093            DataType::Float32,
1094            [2f32],
1095        );
1096        test_coercion!(
1097            Float32Array,
1098            DataType::Float32,
1099            vec![2f32],
1100            UInt16Array,
1101            DataType::UInt16,
1102            vec![1u16],
1103            Operator::Multiply,
1104            Float32Array,
1105            DataType::Float32,
1106            [2f32],
1107        );
1108        test_coercion!(
1109            StringArray,
1110            DataType::Utf8,
1111            vec!["1994-12-13", "1995-01-26"],
1112            Date32Array,
1113            DataType::Date32,
1114            vec![9112, 9156],
1115            Operator::Eq,
1116            BooleanArray,
1117            DataType::Boolean,
1118            [true, true],
1119        );
1120        test_coercion!(
1121            StringArray,
1122            DataType::Utf8,
1123            vec!["1994-12-13", "1995-01-26"],
1124            Date32Array,
1125            DataType::Date32,
1126            vec![9113, 9154],
1127            Operator::Lt,
1128            BooleanArray,
1129            DataType::Boolean,
1130            [true, false],
1131        );
1132        test_coercion!(
1133            StringArray,
1134            DataType::Utf8,
1135            vec!["1994-12-13T12:34:56", "1995-01-26T01:23:45"],
1136            Date64Array,
1137            DataType::Date64,
1138            vec![787322096000, 791083425000],
1139            Operator::Eq,
1140            BooleanArray,
1141            DataType::Boolean,
1142            [true, true],
1143        );
1144        test_coercion!(
1145            StringArray,
1146            DataType::Utf8,
1147            vec!["1994-12-13T12:34:56", "1995-01-26T01:23:45"],
1148            Date64Array,
1149            DataType::Date64,
1150            vec![787322096001, 791083424999],
1151            Operator::Lt,
1152            BooleanArray,
1153            DataType::Boolean,
1154            [true, false],
1155        );
1156        test_coercion!(
1157            StringViewArray,
1158            DataType::Utf8View,
1159            vec!["abc"; 5],
1160            StringArray,
1161            DataType::Utf8,
1162            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1163            Operator::RegexMatch,
1164            BooleanArray,
1165            DataType::Boolean,
1166            [true, false, true, false, false],
1167        );
1168        test_coercion!(
1169            StringViewArray,
1170            DataType::Utf8View,
1171            vec!["abc"; 5],
1172            StringArray,
1173            DataType::Utf8,
1174            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1175            Operator::RegexIMatch,
1176            BooleanArray,
1177            DataType::Boolean,
1178            [true, true, true, true, false],
1179        );
1180        test_coercion!(
1181            StringArray,
1182            DataType::Utf8,
1183            vec!["abc"; 5],
1184            StringViewArray,
1185            DataType::Utf8View,
1186            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1187            Operator::RegexNotMatch,
1188            BooleanArray,
1189            DataType::Boolean,
1190            [false, true, false, true, true],
1191        );
1192        test_coercion!(
1193            StringArray,
1194            DataType::Utf8,
1195            vec!["abc"; 5],
1196            StringViewArray,
1197            DataType::Utf8View,
1198            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1199            Operator::RegexNotIMatch,
1200            BooleanArray,
1201            DataType::Boolean,
1202            [false, false, false, false, true],
1203        );
1204        test_coercion!(
1205            StringArray,
1206            DataType::Utf8,
1207            vec!["abc"; 5],
1208            StringArray,
1209            DataType::Utf8,
1210            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1211            Operator::RegexMatch,
1212            BooleanArray,
1213            DataType::Boolean,
1214            [true, false, true, false, false],
1215        );
1216        test_coercion!(
1217            StringArray,
1218            DataType::Utf8,
1219            vec!["abc"; 5],
1220            StringArray,
1221            DataType::Utf8,
1222            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1223            Operator::RegexIMatch,
1224            BooleanArray,
1225            DataType::Boolean,
1226            [true, true, true, true, false],
1227        );
1228        test_coercion!(
1229            StringArray,
1230            DataType::Utf8,
1231            vec!["abc"; 5],
1232            StringArray,
1233            DataType::Utf8,
1234            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1235            Operator::RegexNotMatch,
1236            BooleanArray,
1237            DataType::Boolean,
1238            [false, true, false, true, true],
1239        );
1240        test_coercion!(
1241            StringArray,
1242            DataType::Utf8,
1243            vec!["abc"; 5],
1244            StringArray,
1245            DataType::Utf8,
1246            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1247            Operator::RegexNotIMatch,
1248            BooleanArray,
1249            DataType::Boolean,
1250            [false, false, false, false, true],
1251        );
1252        test_coercion!(
1253            LargeStringArray,
1254            DataType::LargeUtf8,
1255            vec!["abc"; 5],
1256            LargeStringArray,
1257            DataType::LargeUtf8,
1258            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1259            Operator::RegexMatch,
1260            BooleanArray,
1261            DataType::Boolean,
1262            [true, false, true, false, false],
1263        );
1264        test_coercion!(
1265            LargeStringArray,
1266            DataType::LargeUtf8,
1267            vec!["abc"; 5],
1268            LargeStringArray,
1269            DataType::LargeUtf8,
1270            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1271            Operator::RegexIMatch,
1272            BooleanArray,
1273            DataType::Boolean,
1274            [true, true, true, true, false],
1275        );
1276        test_coercion!(
1277            LargeStringArray,
1278            DataType::LargeUtf8,
1279            vec!["abc"; 5],
1280            LargeStringArray,
1281            DataType::LargeUtf8,
1282            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1283            Operator::RegexNotMatch,
1284            BooleanArray,
1285            DataType::Boolean,
1286            [false, true, false, true, true],
1287        );
1288        test_coercion!(
1289            LargeStringArray,
1290            DataType::LargeUtf8,
1291            vec!["abc"; 5],
1292            LargeStringArray,
1293            DataType::LargeUtf8,
1294            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1295            Operator::RegexNotIMatch,
1296            BooleanArray,
1297            DataType::Boolean,
1298            [false, false, false, false, true],
1299        );
1300        test_coercion!(
1301            StringArray,
1302            DataType::Utf8,
1303            vec!["abc"; 5],
1304            StringArray,
1305            DataType::Utf8,
1306            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1307            Operator::LikeMatch,
1308            BooleanArray,
1309            DataType::Boolean,
1310            [true, false, false, true, false],
1311        );
1312        test_coercion!(
1313            StringArray,
1314            DataType::Utf8,
1315            vec!["abc"; 5],
1316            StringArray,
1317            DataType::Utf8,
1318            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1319            Operator::ILikeMatch,
1320            BooleanArray,
1321            DataType::Boolean,
1322            [true, true, false, true, true],
1323        );
1324        test_coercion!(
1325            StringArray,
1326            DataType::Utf8,
1327            vec!["abc"; 5],
1328            StringArray,
1329            DataType::Utf8,
1330            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1331            Operator::NotLikeMatch,
1332            BooleanArray,
1333            DataType::Boolean,
1334            [false, true, true, false, true],
1335        );
1336        test_coercion!(
1337            StringArray,
1338            DataType::Utf8,
1339            vec!["abc"; 5],
1340            StringArray,
1341            DataType::Utf8,
1342            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1343            Operator::NotILikeMatch,
1344            BooleanArray,
1345            DataType::Boolean,
1346            [false, false, true, false, false],
1347        );
1348        test_coercion!(
1349            LargeStringArray,
1350            DataType::LargeUtf8,
1351            vec!["abc"; 5],
1352            LargeStringArray,
1353            DataType::LargeUtf8,
1354            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1355            Operator::LikeMatch,
1356            BooleanArray,
1357            DataType::Boolean,
1358            [true, false, false, true, false],
1359        );
1360        test_coercion!(
1361            LargeStringArray,
1362            DataType::LargeUtf8,
1363            vec!["abc"; 5],
1364            LargeStringArray,
1365            DataType::LargeUtf8,
1366            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1367            Operator::ILikeMatch,
1368            BooleanArray,
1369            DataType::Boolean,
1370            [true, true, false, true, true],
1371        );
1372        test_coercion!(
1373            LargeStringArray,
1374            DataType::LargeUtf8,
1375            vec!["abc"; 5],
1376            LargeStringArray,
1377            DataType::LargeUtf8,
1378            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1379            Operator::NotLikeMatch,
1380            BooleanArray,
1381            DataType::Boolean,
1382            [false, true, true, false, true],
1383        );
1384        test_coercion!(
1385            LargeStringArray,
1386            DataType::LargeUtf8,
1387            vec!["abc"; 5],
1388            LargeStringArray,
1389            DataType::LargeUtf8,
1390            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1391            Operator::NotILikeMatch,
1392            BooleanArray,
1393            DataType::Boolean,
1394            [false, false, true, false, false],
1395        );
1396        test_coercion!(
1397            Int16Array,
1398            DataType::Int16,
1399            vec![1i16, 2i16, 3i16],
1400            Int64Array,
1401            DataType::Int64,
1402            vec![10i64, 4i64, 5i64],
1403            Operator::BitwiseAnd,
1404            Int64Array,
1405            DataType::Int64,
1406            [0i64, 0i64, 1i64],
1407        );
1408        test_coercion!(
1409            UInt16Array,
1410            DataType::UInt16,
1411            vec![1u16, 2u16, 3u16],
1412            UInt64Array,
1413            DataType::UInt64,
1414            vec![10u64, 4u64, 5u64],
1415            Operator::BitwiseAnd,
1416            UInt64Array,
1417            DataType::UInt64,
1418            [0u64, 0u64, 1u64],
1419        );
1420        test_coercion!(
1421            Int16Array,
1422            DataType::Int16,
1423            vec![3i16, 2i16, 3i16],
1424            Int64Array,
1425            DataType::Int64,
1426            vec![10i64, 6i64, 5i64],
1427            Operator::BitwiseOr,
1428            Int64Array,
1429            DataType::Int64,
1430            [11i64, 6i64, 7i64],
1431        );
1432        test_coercion!(
1433            UInt16Array,
1434            DataType::UInt16,
1435            vec![1u16, 2u16, 3u16],
1436            UInt64Array,
1437            DataType::UInt64,
1438            vec![10u64, 4u64, 5u64],
1439            Operator::BitwiseOr,
1440            UInt64Array,
1441            DataType::UInt64,
1442            [11u64, 6u64, 7u64],
1443        );
1444        test_coercion!(
1445            Int16Array,
1446            DataType::Int16,
1447            vec![3i16, 2i16, 3i16],
1448            Int64Array,
1449            DataType::Int64,
1450            vec![10i64, 6i64, 5i64],
1451            Operator::BitwiseXor,
1452            Int64Array,
1453            DataType::Int64,
1454            [9i64, 4i64, 6i64],
1455        );
1456        test_coercion!(
1457            UInt16Array,
1458            DataType::UInt16,
1459            vec![3u16, 2u16, 3u16],
1460            UInt64Array,
1461            DataType::UInt64,
1462            vec![10u64, 6u64, 5u64],
1463            Operator::BitwiseXor,
1464            UInt64Array,
1465            DataType::UInt64,
1466            [9u64, 4u64, 6u64],
1467        );
1468        test_coercion!(
1469            Int16Array,
1470            DataType::Int16,
1471            vec![4i16, 27i16, 35i16],
1472            Int64Array,
1473            DataType::Int64,
1474            vec![2i64, 3i64, 4i64],
1475            Operator::BitwiseShiftRight,
1476            Int64Array,
1477            DataType::Int64,
1478            [1i64, 3i64, 2i64],
1479        );
1480        test_coercion!(
1481            UInt16Array,
1482            DataType::UInt16,
1483            vec![4u16, 27u16, 35u16],
1484            UInt64Array,
1485            DataType::UInt64,
1486            vec![2u64, 3u64, 4u64],
1487            Operator::BitwiseShiftRight,
1488            UInt64Array,
1489            DataType::UInt64,
1490            [1u64, 3u64, 2u64],
1491        );
1492        test_coercion!(
1493            Int16Array,
1494            DataType::Int16,
1495            vec![2i16, 3i16, 4i16],
1496            Int64Array,
1497            DataType::Int64,
1498            vec![4i64, 12i64, 7i64],
1499            Operator::BitwiseShiftLeft,
1500            Int64Array,
1501            DataType::Int64,
1502            [32i64, 12288i64, 512i64],
1503        );
1504        test_coercion!(
1505            UInt16Array,
1506            DataType::UInt16,
1507            vec![2u16, 3u16, 4u16],
1508            UInt64Array,
1509            DataType::UInt64,
1510            vec![4u64, 12u64, 7u64],
1511            Operator::BitwiseShiftLeft,
1512            UInt64Array,
1513            DataType::UInt64,
1514            [32u64, 12288u64, 512u64],
1515        );
1516        Ok(())
1517    }
1518
1519    // Note it would be nice to use the same test_coercion macro as
1520    // above, but sadly the type of the values of the dictionary are
1521    // not encoded in the rust type of the DictionaryArray. Thus there
1522    // is no way at the time of this writing to create a dictionary
1523    // array using the `From` trait
1524    #[test]
1525    fn test_dictionary_type_to_array_coercion() -> Result<()> {
1526        // Test string  a string dictionary
1527        let dict_type =
1528            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
1529        let string_type = DataType::Utf8;
1530
1531        // build dictionary
1532        let mut dict_builder = StringDictionaryBuilder::<Int32Type>::new();
1533
1534        dict_builder.append("one")?;
1535        dict_builder.append_null();
1536        dict_builder.append("three")?;
1537        dict_builder.append("four")?;
1538        let dict_array = Arc::new(dict_builder.finish()) as ArrayRef;
1539
1540        let str_array = Arc::new(StringArray::from(vec![
1541            Some("not one"),
1542            Some("two"),
1543            None,
1544            Some("four"),
1545        ])) as ArrayRef;
1546
1547        let schema = Arc::new(Schema::new(vec![
1548            Field::new("a", dict_type.clone(), true),
1549            Field::new("b", string_type.clone(), true),
1550        ]));
1551
1552        // Test 1: a = b
1553        let result = BooleanArray::from(vec![Some(false), None, None, Some(true)]);
1554        apply_logic_op(&schema, &dict_array, &str_array, Operator::Eq, result)?;
1555
1556        // Test 2: now test the other direction
1557        // b = a
1558        let schema = Arc::new(Schema::new(vec![
1559            Field::new("a", string_type, true),
1560            Field::new("b", dict_type, true),
1561        ]));
1562        let result = BooleanArray::from(vec![Some(false), None, None, Some(true)]);
1563        apply_logic_op(&schema, &str_array, &dict_array, Operator::Eq, result)?;
1564
1565        Ok(())
1566    }
1567
1568    #[test]
1569    fn plus_op() -> Result<()> {
1570        let schema = Schema::new(vec![
1571            Field::new("a", DataType::Int32, false),
1572            Field::new("b", DataType::Int32, false),
1573        ]);
1574        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1575        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1576
1577        apply_arithmetic::<Int32Type>(
1578            Arc::new(schema),
1579            vec![Arc::new(a), Arc::new(b)],
1580            Operator::Plus,
1581            Int32Array::from(vec![2, 4, 7, 12, 21]),
1582        )?;
1583
1584        Ok(())
1585    }
1586
1587    #[test]
1588    fn plus_op_dict() -> Result<()> {
1589        let schema = Schema::new(vec![
1590            Field::new(
1591                "a",
1592                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1593                true,
1594            ),
1595            Field::new(
1596                "b",
1597                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1598                true,
1599            ),
1600        ]);
1601
1602        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1603        let keys = Int8Array::from(vec![Some(0), None, Some(1), Some(3), None]);
1604        let a = DictionaryArray::try_new(keys, Arc::new(a))?;
1605
1606        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1607        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
1608        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
1609
1610        apply_arithmetic::<Int32Type>(
1611            Arc::new(schema),
1612            vec![Arc::new(a), Arc::new(b)],
1613            Operator::Plus,
1614            Int32Array::from(vec![Some(2), None, Some(4), Some(8), None]),
1615        )?;
1616
1617        Ok(())
1618    }
1619
1620    #[test]
1621    fn plus_op_dict_decimal() -> Result<()> {
1622        let schema = Schema::new(vec![
1623            Field::new(
1624                "a",
1625                DataType::Dictionary(
1626                    Box::new(DataType::Int8),
1627                    Box::new(DataType::Decimal128(10, 0)),
1628                ),
1629                true,
1630            ),
1631            Field::new(
1632                "b",
1633                DataType::Dictionary(
1634                    Box::new(DataType::Int8),
1635                    Box::new(DataType::Decimal128(10, 0)),
1636                ),
1637                true,
1638            ),
1639        ]);
1640
1641        let value = 123;
1642        let decimal_array = Arc::new(create_decimal_array(
1643            &[
1644                Some(value),
1645                Some(value + 2),
1646                Some(value - 1),
1647                Some(value + 1),
1648            ],
1649            10,
1650            0,
1651        ));
1652
1653        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
1654        let a = DictionaryArray::try_new(keys, decimal_array)?;
1655
1656        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
1657        let decimal_array = Arc::new(create_decimal_array(
1658            &[
1659                Some(value + 1),
1660                Some(value + 3),
1661                Some(value),
1662                Some(value + 2),
1663            ],
1664            10,
1665            0,
1666        ));
1667        let b = DictionaryArray::try_new(keys, decimal_array)?;
1668
1669        apply_arithmetic(
1670            Arc::new(schema),
1671            vec![Arc::new(a), Arc::new(b)],
1672            Operator::Plus,
1673            create_decimal_array(&[Some(247), None, None, Some(247), Some(246)], 11, 0),
1674        )?;
1675
1676        Ok(())
1677    }
1678
1679    #[test]
1680    fn plus_op_scalar() -> Result<()> {
1681        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1682        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1683
1684        apply_arithmetic_scalar(
1685            Arc::new(schema),
1686            vec![Arc::new(a)],
1687            Operator::Plus,
1688            ScalarValue::Int32(Some(1)),
1689            Arc::new(Int32Array::from(vec![2, 3, 4, 5, 6])),
1690        )?;
1691
1692        Ok(())
1693    }
1694
1695    #[test]
1696    fn plus_op_dict_scalar() -> Result<()> {
1697        let schema = Schema::new(vec![Field::new(
1698            "a",
1699            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1700            true,
1701        )]);
1702
1703        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
1704
1705        dict_builder.append(1)?;
1706        dict_builder.append_null();
1707        dict_builder.append(2)?;
1708        dict_builder.append(5)?;
1709
1710        let a = dict_builder.finish();
1711
1712        let expected: PrimitiveArray<Int32Type> =
1713            PrimitiveArray::from(vec![Some(2), None, Some(3), Some(6)]);
1714
1715        apply_arithmetic_scalar(
1716            Arc::new(schema),
1717            vec![Arc::new(a)],
1718            Operator::Plus,
1719            ScalarValue::Dictionary(
1720                Box::new(DataType::Int8),
1721                Box::new(ScalarValue::Int32(Some(1))),
1722            ),
1723            Arc::new(expected),
1724        )?;
1725
1726        Ok(())
1727    }
1728
1729    #[test]
1730    fn plus_op_dict_scalar_decimal() -> Result<()> {
1731        let schema = Schema::new(vec![Field::new(
1732            "a",
1733            DataType::Dictionary(
1734                Box::new(DataType::Int8),
1735                Box::new(DataType::Decimal128(10, 0)),
1736            ),
1737            true,
1738        )]);
1739
1740        let value = 123;
1741        let decimal_array = Arc::new(create_decimal_array(
1742            &[Some(value), None, Some(value - 1), Some(value + 1)],
1743            10,
1744            0,
1745        ));
1746
1747        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
1748        let a = DictionaryArray::try_new(keys, decimal_array)?;
1749
1750        let decimal_array = Arc::new(create_decimal_array(
1751            &[
1752                Some(value + 1),
1753                Some(value),
1754                None,
1755                Some(value + 2),
1756                Some(value + 1),
1757            ],
1758            11,
1759            0,
1760        ));
1761
1762        apply_arithmetic_scalar(
1763            Arc::new(schema),
1764            vec![Arc::new(a)],
1765            Operator::Plus,
1766            ScalarValue::Dictionary(
1767                Box::new(DataType::Int8),
1768                Box::new(ScalarValue::Decimal128(Some(1), 10, 0)),
1769            ),
1770            decimal_array,
1771        )?;
1772
1773        Ok(())
1774    }
1775
1776    #[test]
1777    fn minus_op() -> Result<()> {
1778        let schema = Arc::new(Schema::new(vec![
1779            Field::new("a", DataType::Int32, false),
1780            Field::new("b", DataType::Int32, false),
1781        ]));
1782        let a = Arc::new(Int32Array::from(vec![1, 2, 4, 8, 16]));
1783        let b = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
1784
1785        apply_arithmetic::<Int32Type>(
1786            Arc::clone(&schema),
1787            vec![
1788                Arc::clone(&a) as Arc<dyn Array>,
1789                Arc::clone(&b) as Arc<dyn Array>,
1790            ],
1791            Operator::Minus,
1792            Int32Array::from(vec![0, 0, 1, 4, 11]),
1793        )?;
1794
1795        // should handle have negative values in result (for signed)
1796        apply_arithmetic::<Int32Type>(
1797            schema,
1798            vec![b, a],
1799            Operator::Minus,
1800            Int32Array::from(vec![0, 0, -1, -4, -11]),
1801        )?;
1802
1803        Ok(())
1804    }
1805
1806    #[test]
1807    fn minus_op_dict() -> Result<()> {
1808        let schema = Schema::new(vec![
1809            Field::new(
1810                "a",
1811                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1812                true,
1813            ),
1814            Field::new(
1815                "b",
1816                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1817                true,
1818            ),
1819        ]);
1820
1821        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1822        let keys = Int8Array::from(vec![Some(0), None, Some(1), Some(3), None]);
1823        let a = DictionaryArray::try_new(keys, Arc::new(a))?;
1824
1825        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1826        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
1827        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
1828
1829        apply_arithmetic::<Int32Type>(
1830            Arc::new(schema),
1831            vec![Arc::new(a), Arc::new(b)],
1832            Operator::Minus,
1833            Int32Array::from(vec![Some(0), None, Some(0), Some(0), None]),
1834        )?;
1835
1836        Ok(())
1837    }
1838
1839    #[test]
1840    fn minus_op_dict_decimal() -> Result<()> {
1841        let schema = Schema::new(vec![
1842            Field::new(
1843                "a",
1844                DataType::Dictionary(
1845                    Box::new(DataType::Int8),
1846                    Box::new(DataType::Decimal128(10, 0)),
1847                ),
1848                true,
1849            ),
1850            Field::new(
1851                "b",
1852                DataType::Dictionary(
1853                    Box::new(DataType::Int8),
1854                    Box::new(DataType::Decimal128(10, 0)),
1855                ),
1856                true,
1857            ),
1858        ]);
1859
1860        let value = 123;
1861        let decimal_array = Arc::new(create_decimal_array(
1862            &[
1863                Some(value),
1864                Some(value + 2),
1865                Some(value - 1),
1866                Some(value + 1),
1867            ],
1868            10,
1869            0,
1870        ));
1871
1872        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
1873        let a = DictionaryArray::try_new(keys, decimal_array)?;
1874
1875        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
1876        let decimal_array = Arc::new(create_decimal_array(
1877            &[
1878                Some(value + 1),
1879                Some(value + 3),
1880                Some(value),
1881                Some(value + 2),
1882            ],
1883            10,
1884            0,
1885        ));
1886        let b = DictionaryArray::try_new(keys, decimal_array)?;
1887
1888        apply_arithmetic(
1889            Arc::new(schema),
1890            vec![Arc::new(a), Arc::new(b)],
1891            Operator::Minus,
1892            create_decimal_array(&[Some(-1), None, None, Some(1), Some(0)], 11, 0),
1893        )?;
1894
1895        Ok(())
1896    }
1897
1898    #[test]
1899    fn minus_op_scalar() -> Result<()> {
1900        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1901        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1902
1903        apply_arithmetic_scalar(
1904            Arc::new(schema),
1905            vec![Arc::new(a)],
1906            Operator::Minus,
1907            ScalarValue::Int32(Some(1)),
1908            Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4])),
1909        )?;
1910
1911        Ok(())
1912    }
1913
1914    #[test]
1915    fn minus_op_dict_scalar() -> Result<()> {
1916        let schema = Schema::new(vec![Field::new(
1917            "a",
1918            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1919            true,
1920        )]);
1921
1922        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
1923
1924        dict_builder.append(1)?;
1925        dict_builder.append_null();
1926        dict_builder.append(2)?;
1927        dict_builder.append(5)?;
1928
1929        let a = dict_builder.finish();
1930
1931        let expected: PrimitiveArray<Int32Type> =
1932            PrimitiveArray::from(vec![Some(0), None, Some(1), Some(4)]);
1933
1934        apply_arithmetic_scalar(
1935            Arc::new(schema),
1936            vec![Arc::new(a)],
1937            Operator::Minus,
1938            ScalarValue::Dictionary(
1939                Box::new(DataType::Int8),
1940                Box::new(ScalarValue::Int32(Some(1))),
1941            ),
1942            Arc::new(expected),
1943        )?;
1944
1945        Ok(())
1946    }
1947
1948    #[test]
1949    fn minus_op_dict_scalar_decimal() -> Result<()> {
1950        let schema = Schema::new(vec![Field::new(
1951            "a",
1952            DataType::Dictionary(
1953                Box::new(DataType::Int8),
1954                Box::new(DataType::Decimal128(10, 0)),
1955            ),
1956            true,
1957        )]);
1958
1959        let value = 123;
1960        let decimal_array = Arc::new(create_decimal_array(
1961            &[Some(value), None, Some(value - 1), Some(value + 1)],
1962            10,
1963            0,
1964        ));
1965
1966        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
1967        let a = DictionaryArray::try_new(keys, decimal_array)?;
1968
1969        let decimal_array = Arc::new(create_decimal_array(
1970            &[
1971                Some(value - 1),
1972                Some(value - 2),
1973                None,
1974                Some(value),
1975                Some(value - 1),
1976            ],
1977            11,
1978            0,
1979        ));
1980
1981        apply_arithmetic_scalar(
1982            Arc::new(schema),
1983            vec![Arc::new(a)],
1984            Operator::Minus,
1985            ScalarValue::Dictionary(
1986                Box::new(DataType::Int8),
1987                Box::new(ScalarValue::Decimal128(Some(1), 10, 0)),
1988            ),
1989            decimal_array,
1990        )?;
1991
1992        Ok(())
1993    }
1994
1995    #[test]
1996    fn multiply_op() -> Result<()> {
1997        let schema = Arc::new(Schema::new(vec![
1998            Field::new("a", DataType::Int32, false),
1999            Field::new("b", DataType::Int32, false),
2000        ]));
2001        let a = Arc::new(Int32Array::from(vec![4, 8, 16, 32, 64]));
2002        let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32]));
2003
2004        apply_arithmetic::<Int32Type>(
2005            schema,
2006            vec![a, b],
2007            Operator::Multiply,
2008            Int32Array::from(vec![8, 32, 128, 512, 2048]),
2009        )?;
2010
2011        Ok(())
2012    }
2013
2014    #[test]
2015    fn multiply_op_dict() -> Result<()> {
2016        let schema = Schema::new(vec![
2017            Field::new(
2018                "a",
2019                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2020                true,
2021            ),
2022            Field::new(
2023                "b",
2024                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2025                true,
2026            ),
2027        ]);
2028
2029        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
2030        let keys = Int8Array::from(vec![Some(0), None, Some(1), Some(3), None]);
2031        let a = DictionaryArray::try_new(keys, Arc::new(a))?;
2032
2033        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
2034        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
2035        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
2036
2037        apply_arithmetic::<Int32Type>(
2038            Arc::new(schema),
2039            vec![Arc::new(a), Arc::new(b)],
2040            Operator::Multiply,
2041            Int32Array::from(vec![Some(1), None, Some(4), Some(16), None]),
2042        )?;
2043
2044        Ok(())
2045    }
2046
2047    #[test]
2048    fn multiply_op_dict_decimal() -> Result<()> {
2049        let schema = Schema::new(vec![
2050            Field::new(
2051                "a",
2052                DataType::Dictionary(
2053                    Box::new(DataType::Int8),
2054                    Box::new(DataType::Decimal128(10, 0)),
2055                ),
2056                true,
2057            ),
2058            Field::new(
2059                "b",
2060                DataType::Dictionary(
2061                    Box::new(DataType::Int8),
2062                    Box::new(DataType::Decimal128(10, 0)),
2063                ),
2064                true,
2065            ),
2066        ]);
2067
2068        let value = 123;
2069        let decimal_array = Arc::new(create_decimal_array(
2070            &[
2071                Some(value),
2072                Some(value + 2),
2073                Some(value - 1),
2074                Some(value + 1),
2075            ],
2076            10,
2077            0,
2078        )) as ArrayRef;
2079
2080        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
2081        let a = DictionaryArray::try_new(keys, decimal_array)?;
2082
2083        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
2084        let decimal_array = Arc::new(create_decimal_array(
2085            &[
2086                Some(value + 1),
2087                Some(value + 3),
2088                Some(value),
2089                Some(value + 2),
2090            ],
2091            10,
2092            0,
2093        ));
2094        let b = DictionaryArray::try_new(keys, decimal_array)?;
2095
2096        apply_arithmetic(
2097            Arc::new(schema),
2098            vec![Arc::new(a), Arc::new(b)],
2099            Operator::Multiply,
2100            create_decimal_array(
2101                &[Some(15252), None, None, Some(15252), Some(15129)],
2102                21,
2103                0,
2104            ),
2105        )?;
2106
2107        Ok(())
2108    }
2109
2110    #[test]
2111    fn multiply_op_scalar() -> Result<()> {
2112        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
2113        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
2114
2115        apply_arithmetic_scalar(
2116            Arc::new(schema),
2117            vec![Arc::new(a)],
2118            Operator::Multiply,
2119            ScalarValue::Int32(Some(2)),
2120            Arc::new(Int32Array::from(vec![2, 4, 6, 8, 10])),
2121        )?;
2122
2123        Ok(())
2124    }
2125
2126    #[test]
2127    fn multiply_op_dict_scalar() -> Result<()> {
2128        let schema = Schema::new(vec![Field::new(
2129            "a",
2130            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2131            true,
2132        )]);
2133
2134        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2135
2136        dict_builder.append(1)?;
2137        dict_builder.append_null();
2138        dict_builder.append(2)?;
2139        dict_builder.append(5)?;
2140
2141        let a = dict_builder.finish();
2142
2143        let expected: PrimitiveArray<Int32Type> =
2144            PrimitiveArray::from(vec![Some(2), None, Some(4), Some(10)]);
2145
2146        apply_arithmetic_scalar(
2147            Arc::new(schema),
2148            vec![Arc::new(a)],
2149            Operator::Multiply,
2150            ScalarValue::Dictionary(
2151                Box::new(DataType::Int8),
2152                Box::new(ScalarValue::Int32(Some(2))),
2153            ),
2154            Arc::new(expected),
2155        )?;
2156
2157        Ok(())
2158    }
2159
2160    #[test]
2161    fn multiply_op_dict_scalar_decimal() -> Result<()> {
2162        let schema = Schema::new(vec![Field::new(
2163            "a",
2164            DataType::Dictionary(
2165                Box::new(DataType::Int8),
2166                Box::new(DataType::Decimal128(10, 0)),
2167            ),
2168            true,
2169        )]);
2170
2171        let value = 123;
2172        let decimal_array = Arc::new(create_decimal_array(
2173            &[Some(value), None, Some(value - 1), Some(value + 1)],
2174            10,
2175            0,
2176        ));
2177
2178        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
2179        let a = DictionaryArray::try_new(keys, decimal_array)?;
2180
2181        let decimal_array = Arc::new(create_decimal_array(
2182            &[Some(246), Some(244), None, Some(248), Some(246)],
2183            21,
2184            0,
2185        ));
2186
2187        apply_arithmetic_scalar(
2188            Arc::new(schema),
2189            vec![Arc::new(a)],
2190            Operator::Multiply,
2191            ScalarValue::Dictionary(
2192                Box::new(DataType::Int8),
2193                Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
2194            ),
2195            decimal_array,
2196        )?;
2197
2198        Ok(())
2199    }
2200
2201    #[test]
2202    fn divide_op() -> Result<()> {
2203        let schema = Arc::new(Schema::new(vec![
2204            Field::new("a", DataType::Int32, false),
2205            Field::new("b", DataType::Int32, false),
2206        ]));
2207        let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048]));
2208        let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32]));
2209
2210        apply_arithmetic::<Int32Type>(
2211            schema,
2212            vec![a, b],
2213            Operator::Divide,
2214            Int32Array::from(vec![4, 8, 16, 32, 64]),
2215        )?;
2216
2217        Ok(())
2218    }
2219
2220    #[test]
2221    fn divide_op_dict() -> Result<()> {
2222        let schema = Schema::new(vec![
2223            Field::new(
2224                "a",
2225                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2226                true,
2227            ),
2228            Field::new(
2229                "b",
2230                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2231                true,
2232            ),
2233        ]);
2234
2235        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2236
2237        dict_builder.append(1)?;
2238        dict_builder.append_null();
2239        dict_builder.append(2)?;
2240        dict_builder.append(5)?;
2241        dict_builder.append(0)?;
2242
2243        let a = dict_builder.finish();
2244
2245        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
2246        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
2247        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
2248
2249        apply_arithmetic::<Int32Type>(
2250            Arc::new(schema),
2251            vec![Arc::new(a), Arc::new(b)],
2252            Operator::Divide,
2253            Int32Array::from(vec![Some(1), None, Some(1), Some(1), Some(0)]),
2254        )?;
2255
2256        Ok(())
2257    }
2258
2259    #[test]
2260    fn divide_op_dict_decimal() -> Result<()> {
2261        let schema = Schema::new(vec![
2262            Field::new(
2263                "a",
2264                DataType::Dictionary(
2265                    Box::new(DataType::Int8),
2266                    Box::new(DataType::Decimal128(10, 0)),
2267                ),
2268                true,
2269            ),
2270            Field::new(
2271                "b",
2272                DataType::Dictionary(
2273                    Box::new(DataType::Int8),
2274                    Box::new(DataType::Decimal128(10, 0)),
2275                ),
2276                true,
2277            ),
2278        ]);
2279
2280        let value = 123;
2281        let decimal_array = Arc::new(create_decimal_array(
2282            &[
2283                Some(value),
2284                Some(value + 2),
2285                Some(value - 1),
2286                Some(value + 1),
2287            ],
2288            10,
2289            0,
2290        ));
2291
2292        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
2293        let a = DictionaryArray::try_new(keys, decimal_array)?;
2294
2295        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
2296        let decimal_array = Arc::new(create_decimal_array(
2297            &[
2298                Some(value + 1),
2299                Some(value + 3),
2300                Some(value),
2301                Some(value + 2),
2302            ],
2303            10,
2304            0,
2305        ));
2306        let b = DictionaryArray::try_new(keys, decimal_array)?;
2307
2308        apply_arithmetic(
2309            Arc::new(schema),
2310            vec![Arc::new(a), Arc::new(b)],
2311            Operator::Divide,
2312            create_decimal_array(
2313                &[
2314                    Some(9919), // 0.9919
2315                    None,
2316                    None,
2317                    Some(10081), // 1.0081
2318                    Some(10000), // 1.0
2319                ],
2320                14,
2321                4,
2322            ),
2323        )?;
2324
2325        Ok(())
2326    }
2327
2328    #[test]
2329    fn divide_op_scalar() -> Result<()> {
2330        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
2331        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
2332
2333        apply_arithmetic_scalar(
2334            Arc::new(schema),
2335            vec![Arc::new(a)],
2336            Operator::Divide,
2337            ScalarValue::Int32(Some(2)),
2338            Arc::new(Int32Array::from(vec![0, 1, 1, 2, 2])),
2339        )?;
2340
2341        Ok(())
2342    }
2343
2344    #[test]
2345    fn divide_op_dict_scalar() -> Result<()> {
2346        let schema = Schema::new(vec![Field::new(
2347            "a",
2348            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2349            true,
2350        )]);
2351
2352        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2353
2354        dict_builder.append(1)?;
2355        dict_builder.append_null();
2356        dict_builder.append(2)?;
2357        dict_builder.append(5)?;
2358
2359        let a = dict_builder.finish();
2360
2361        let expected: PrimitiveArray<Int32Type> =
2362            PrimitiveArray::from(vec![Some(0), None, Some(1), Some(2)]);
2363
2364        apply_arithmetic_scalar(
2365            Arc::new(schema),
2366            vec![Arc::new(a)],
2367            Operator::Divide,
2368            ScalarValue::Dictionary(
2369                Box::new(DataType::Int8),
2370                Box::new(ScalarValue::Int32(Some(2))),
2371            ),
2372            Arc::new(expected),
2373        )?;
2374
2375        Ok(())
2376    }
2377
2378    #[test]
2379    fn divide_op_dict_scalar_decimal() -> Result<()> {
2380        let schema = Schema::new(vec![Field::new(
2381            "a",
2382            DataType::Dictionary(
2383                Box::new(DataType::Int8),
2384                Box::new(DataType::Decimal128(10, 0)),
2385            ),
2386            true,
2387        )]);
2388
2389        let value = 123;
2390        let decimal_array = Arc::new(create_decimal_array(
2391            &[Some(value), None, Some(value - 1), Some(value + 1)],
2392            10,
2393            0,
2394        ));
2395
2396        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
2397        let a = DictionaryArray::try_new(keys, decimal_array)?;
2398
2399        let decimal_array = Arc::new(create_decimal_array(
2400            &[Some(615000), Some(610000), None, Some(620000), Some(615000)],
2401            14,
2402            4,
2403        ));
2404
2405        apply_arithmetic_scalar(
2406            Arc::new(schema),
2407            vec![Arc::new(a)],
2408            Operator::Divide,
2409            ScalarValue::Dictionary(
2410                Box::new(DataType::Int8),
2411                Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
2412            ),
2413            decimal_array,
2414        )?;
2415
2416        Ok(())
2417    }
2418
2419    #[test]
2420    fn modulus_op() -> Result<()> {
2421        let schema = Arc::new(Schema::new(vec![
2422            Field::new("a", DataType::Int32, false),
2423            Field::new("b", DataType::Int32, false),
2424        ]));
2425        let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048]));
2426        let b = Arc::new(Int32Array::from(vec![2, 4, 7, 14, 32]));
2427
2428        apply_arithmetic::<Int32Type>(
2429            schema,
2430            vec![a, b],
2431            Operator::Modulo,
2432            Int32Array::from(vec![0, 0, 2, 8, 0]),
2433        )?;
2434
2435        Ok(())
2436    }
2437
2438    #[test]
2439    fn modulus_op_dict() -> Result<()> {
2440        let schema = Schema::new(vec![
2441            Field::new(
2442                "a",
2443                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2444                true,
2445            ),
2446            Field::new(
2447                "b",
2448                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2449                true,
2450            ),
2451        ]);
2452
2453        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2454
2455        dict_builder.append(1)?;
2456        dict_builder.append_null();
2457        dict_builder.append(2)?;
2458        dict_builder.append(5)?;
2459        dict_builder.append(0)?;
2460
2461        let a = dict_builder.finish();
2462
2463        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
2464        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
2465        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
2466
2467        apply_arithmetic::<Int32Type>(
2468            Arc::new(schema),
2469            vec![Arc::new(a), Arc::new(b)],
2470            Operator::Modulo,
2471            Int32Array::from(vec![Some(0), None, Some(0), Some(1), Some(0)]),
2472        )?;
2473
2474        Ok(())
2475    }
2476
2477    #[test]
2478    fn modulus_op_dict_decimal() -> Result<()> {
2479        let schema = Schema::new(vec![
2480            Field::new(
2481                "a",
2482                DataType::Dictionary(
2483                    Box::new(DataType::Int8),
2484                    Box::new(DataType::Decimal128(10, 0)),
2485                ),
2486                true,
2487            ),
2488            Field::new(
2489                "b",
2490                DataType::Dictionary(
2491                    Box::new(DataType::Int8),
2492                    Box::new(DataType::Decimal128(10, 0)),
2493                ),
2494                true,
2495            ),
2496        ]);
2497
2498        let value = 123;
2499        let decimal_array = Arc::new(create_decimal_array(
2500            &[
2501                Some(value),
2502                Some(value + 2),
2503                Some(value - 1),
2504                Some(value + 1),
2505            ],
2506            10,
2507            0,
2508        ));
2509
2510        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
2511        let a = DictionaryArray::try_new(keys, decimal_array)?;
2512
2513        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
2514        let decimal_array = Arc::new(create_decimal_array(
2515            &[
2516                Some(value + 1),
2517                Some(value + 3),
2518                Some(value),
2519                Some(value + 2),
2520            ],
2521            10,
2522            0,
2523        ));
2524        let b = DictionaryArray::try_new(keys, decimal_array)?;
2525
2526        apply_arithmetic(
2527            Arc::new(schema),
2528            vec![Arc::new(a), Arc::new(b)],
2529            Operator::Modulo,
2530            create_decimal_array(&[Some(123), None, None, Some(1), Some(0)], 10, 0),
2531        )?;
2532
2533        Ok(())
2534    }
2535
2536    #[test]
2537    fn modulus_op_scalar() -> Result<()> {
2538        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
2539        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
2540
2541        apply_arithmetic_scalar(
2542            Arc::new(schema),
2543            vec![Arc::new(a)],
2544            Operator::Modulo,
2545            ScalarValue::Int32(Some(2)),
2546            Arc::new(Int32Array::from(vec![1, 0, 1, 0, 1])),
2547        )?;
2548
2549        Ok(())
2550    }
2551
2552    #[test]
2553    fn modules_op_dict_scalar() -> Result<()> {
2554        let schema = Schema::new(vec![Field::new(
2555            "a",
2556            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2557            true,
2558        )]);
2559
2560        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2561
2562        dict_builder.append(1)?;
2563        dict_builder.append_null();
2564        dict_builder.append(2)?;
2565        dict_builder.append(5)?;
2566
2567        let a = dict_builder.finish();
2568
2569        let expected: PrimitiveArray<Int32Type> =
2570            PrimitiveArray::from(vec![Some(1), None, Some(0), Some(1)]);
2571
2572        apply_arithmetic_scalar(
2573            Arc::new(schema),
2574            vec![Arc::new(a)],
2575            Operator::Modulo,
2576            ScalarValue::Dictionary(
2577                Box::new(DataType::Int8),
2578                Box::new(ScalarValue::Int32(Some(2))),
2579            ),
2580            Arc::new(expected),
2581        )?;
2582
2583        Ok(())
2584    }
2585
2586    #[test]
2587    fn modulus_op_dict_scalar_decimal() -> Result<()> {
2588        let schema = Schema::new(vec![Field::new(
2589            "a",
2590            DataType::Dictionary(
2591                Box::new(DataType::Int8),
2592                Box::new(DataType::Decimal128(10, 0)),
2593            ),
2594            true,
2595        )]);
2596
2597        let value = 123;
2598        let decimal_array = Arc::new(create_decimal_array(
2599            &[Some(value), None, Some(value - 1), Some(value + 1)],
2600            10,
2601            0,
2602        ));
2603
2604        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
2605        let a = DictionaryArray::try_new(keys, decimal_array)?;
2606
2607        let decimal_array = Arc::new(create_decimal_array(
2608            &[Some(1), Some(0), None, Some(0), Some(1)],
2609            10,
2610            0,
2611        ));
2612
2613        apply_arithmetic_scalar(
2614            Arc::new(schema),
2615            vec![Arc::new(a)],
2616            Operator::Modulo,
2617            ScalarValue::Dictionary(
2618                Box::new(DataType::Int8),
2619                Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
2620            ),
2621            decimal_array,
2622        )?;
2623
2624        Ok(())
2625    }
2626
2627    fn apply_arithmetic<T: ArrowNumericType>(
2628        schema: SchemaRef,
2629        data: Vec<ArrayRef>,
2630        op: Operator,
2631        expected: PrimitiveArray<T>,
2632    ) -> Result<()> {
2633        let arithmetic_op =
2634            binary_op(col("a", &schema)?, op, col("b", &schema)?, &schema)?;
2635        let batch = RecordBatch::try_new(schema, data)?;
2636        let result = arithmetic_op
2637            .evaluate(&batch)?
2638            .into_array(batch.num_rows())
2639            .expect("Failed to convert to array");
2640
2641        assert_eq!(result.as_ref(), &expected);
2642        Ok(())
2643    }
2644
2645    fn apply_arithmetic_scalar(
2646        schema: SchemaRef,
2647        data: Vec<ArrayRef>,
2648        op: Operator,
2649        literal: ScalarValue,
2650        expected: ArrayRef,
2651    ) -> Result<()> {
2652        let lit = Arc::new(Literal::new(literal));
2653        let arithmetic_op = binary_op(col("a", &schema)?, op, lit, &schema)?;
2654        let batch = RecordBatch::try_new(schema, data)?;
2655        let result = arithmetic_op
2656            .evaluate(&batch)?
2657            .into_array(batch.num_rows())
2658            .expect("Failed to convert to array");
2659
2660        assert_eq!(&result, &expected);
2661        Ok(())
2662    }
2663
2664    fn apply_logic_op(
2665        schema: &SchemaRef,
2666        left: &ArrayRef,
2667        right: &ArrayRef,
2668        op: Operator,
2669        expected: BooleanArray,
2670    ) -> Result<()> {
2671        let op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?;
2672        let data: Vec<ArrayRef> = vec![Arc::clone(left), Arc::clone(right)];
2673        let batch = RecordBatch::try_new(Arc::clone(schema), data)?;
2674        let result = op
2675            .evaluate(&batch)?
2676            .into_array(batch.num_rows())
2677            .expect("Failed to convert to array");
2678
2679        assert_eq!(result.as_ref(), &expected);
2680        Ok(())
2681    }
2682
2683    // Test `scalar <op> arr` produces expected
2684    fn apply_logic_op_scalar_arr(
2685        schema: &SchemaRef,
2686        scalar: &ScalarValue,
2687        arr: &ArrayRef,
2688        op: Operator,
2689        expected: &BooleanArray,
2690    ) -> Result<()> {
2691        let scalar = lit(scalar.clone());
2692        let op = binary_op(scalar, op, col("a", schema)?, schema)?;
2693        let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
2694        let result = op
2695            .evaluate(&batch)?
2696            .into_array(batch.num_rows())
2697            .expect("Failed to convert to array");
2698        assert_eq!(result.as_ref(), expected);
2699
2700        Ok(())
2701    }
2702
2703    // Test `arr <op> scalar` produces expected
2704    fn apply_logic_op_arr_scalar(
2705        schema: &SchemaRef,
2706        arr: &ArrayRef,
2707        scalar: &ScalarValue,
2708        op: Operator,
2709        expected: &BooleanArray,
2710    ) -> Result<()> {
2711        let scalar = lit(scalar.clone());
2712        let op = binary_op(col("a", schema)?, op, scalar, schema)?;
2713        let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
2714        let result = op
2715            .evaluate(&batch)?
2716            .into_array(batch.num_rows())
2717            .expect("Failed to convert to array");
2718        assert_eq!(result.as_ref(), expected);
2719
2720        Ok(())
2721    }
2722
2723    #[test]
2724    fn and_with_nulls_op() -> Result<()> {
2725        let schema = Schema::new(vec![
2726            Field::new("a", DataType::Boolean, true),
2727            Field::new("b", DataType::Boolean, true),
2728        ]);
2729        let a = Arc::new(BooleanArray::from(vec![
2730            Some(true),
2731            Some(false),
2732            None,
2733            Some(true),
2734            Some(false),
2735            None,
2736            Some(true),
2737            Some(false),
2738            None,
2739        ])) as ArrayRef;
2740        let b = Arc::new(BooleanArray::from(vec![
2741            Some(true),
2742            Some(true),
2743            Some(true),
2744            Some(false),
2745            Some(false),
2746            Some(false),
2747            None,
2748            None,
2749            None,
2750        ])) as ArrayRef;
2751
2752        let expected = BooleanArray::from(vec![
2753            Some(true),
2754            Some(false),
2755            None,
2756            Some(false),
2757            Some(false),
2758            Some(false),
2759            None,
2760            Some(false),
2761            None,
2762        ]);
2763        apply_logic_op(&Arc::new(schema), &a, &b, Operator::And, expected)?;
2764
2765        Ok(())
2766    }
2767
2768    #[test]
2769    fn regex_with_nulls() -> Result<()> {
2770        let schema = Schema::new(vec![
2771            Field::new("a", DataType::Utf8, true),
2772            Field::new("b", DataType::Utf8, true),
2773        ]);
2774        let a = Arc::new(StringArray::from(vec![
2775            Some("abc"),
2776            None,
2777            Some("abc"),
2778            None,
2779            Some("abc"),
2780        ])) as ArrayRef;
2781        let b = Arc::new(StringArray::from(vec![
2782            Some("^a"),
2783            Some("^A"),
2784            None,
2785            None,
2786            Some("^(b|c)"),
2787        ])) as ArrayRef;
2788
2789        let regex_expected =
2790            BooleanArray::from(vec![Some(true), None, None, None, Some(false)]);
2791        let regex_not_expected =
2792            BooleanArray::from(vec![Some(false), None, None, None, Some(true)]);
2793        apply_logic_op(
2794            &Arc::new(schema.clone()),
2795            &a,
2796            &b,
2797            Operator::RegexMatch,
2798            regex_expected.clone(),
2799        )?;
2800        apply_logic_op(
2801            &Arc::new(schema.clone()),
2802            &a,
2803            &b,
2804            Operator::RegexIMatch,
2805            regex_expected.clone(),
2806        )?;
2807        apply_logic_op(
2808            &Arc::new(schema.clone()),
2809            &a,
2810            &b,
2811            Operator::RegexNotMatch,
2812            regex_not_expected.clone(),
2813        )?;
2814        apply_logic_op(
2815            &Arc::new(schema),
2816            &a,
2817            &b,
2818            Operator::RegexNotIMatch,
2819            regex_not_expected.clone(),
2820        )?;
2821
2822        let schema = Schema::new(vec![
2823            Field::new("a", DataType::LargeUtf8, true),
2824            Field::new("b", DataType::LargeUtf8, true),
2825        ]);
2826        let a = Arc::new(LargeStringArray::from(vec![
2827            Some("abc"),
2828            None,
2829            Some("abc"),
2830            None,
2831            Some("abc"),
2832        ])) as ArrayRef;
2833        let b = Arc::new(LargeStringArray::from(vec![
2834            Some("^a"),
2835            Some("^A"),
2836            None,
2837            None,
2838            Some("^(b|c)"),
2839        ])) as ArrayRef;
2840
2841        apply_logic_op(
2842            &Arc::new(schema.clone()),
2843            &a,
2844            &b,
2845            Operator::RegexMatch,
2846            regex_expected.clone(),
2847        )?;
2848        apply_logic_op(
2849            &Arc::new(schema.clone()),
2850            &a,
2851            &b,
2852            Operator::RegexIMatch,
2853            regex_expected,
2854        )?;
2855        apply_logic_op(
2856            &Arc::new(schema.clone()),
2857            &a,
2858            &b,
2859            Operator::RegexNotMatch,
2860            regex_not_expected.clone(),
2861        )?;
2862        apply_logic_op(
2863            &Arc::new(schema),
2864            &a,
2865            &b,
2866            Operator::RegexNotIMatch,
2867            regex_not_expected,
2868        )?;
2869
2870        Ok(())
2871    }
2872
2873    #[test]
2874    fn or_with_nulls_op() -> Result<()> {
2875        let schema = Schema::new(vec![
2876            Field::new("a", DataType::Boolean, true),
2877            Field::new("b", DataType::Boolean, true),
2878        ]);
2879        let a = Arc::new(BooleanArray::from(vec![
2880            Some(true),
2881            Some(false),
2882            None,
2883            Some(true),
2884            Some(false),
2885            None,
2886            Some(true),
2887            Some(false),
2888            None,
2889        ])) as ArrayRef;
2890        let b = Arc::new(BooleanArray::from(vec![
2891            Some(true),
2892            Some(true),
2893            Some(true),
2894            Some(false),
2895            Some(false),
2896            Some(false),
2897            None,
2898            None,
2899            None,
2900        ])) as ArrayRef;
2901
2902        let expected = BooleanArray::from(vec![
2903            Some(true),
2904            Some(true),
2905            Some(true),
2906            Some(true),
2907            Some(false),
2908            None,
2909            Some(true),
2910            None,
2911            None,
2912        ]);
2913        apply_logic_op(&Arc::new(schema), &a, &b, Operator::Or, expected)?;
2914
2915        Ok(())
2916    }
2917
2918    /// Returns (schema, a: BooleanArray, b: BooleanArray) with all possible inputs
2919    ///
2920    /// a: [true, true, true,  NULL, NULL, NULL,  false, false, false]
2921    /// b: [true, NULL, false, true, NULL, false, true,  NULL,  false]
2922    fn bool_test_arrays() -> (SchemaRef, ArrayRef, ArrayRef) {
2923        let schema = Schema::new(vec![
2924            Field::new("a", DataType::Boolean, true),
2925            Field::new("b", DataType::Boolean, true),
2926        ]);
2927        let a: BooleanArray = [
2928            Some(true),
2929            Some(true),
2930            Some(true),
2931            None,
2932            None,
2933            None,
2934            Some(false),
2935            Some(false),
2936            Some(false),
2937        ]
2938        .iter()
2939        .collect();
2940        let b: BooleanArray = [
2941            Some(true),
2942            None,
2943            Some(false),
2944            Some(true),
2945            None,
2946            Some(false),
2947            Some(true),
2948            None,
2949            Some(false),
2950        ]
2951        .iter()
2952        .collect();
2953        (Arc::new(schema), Arc::new(a), Arc::new(b))
2954    }
2955
2956    /// Returns (schema, BooleanArray) with [true, NULL, false]
2957    fn scalar_bool_test_array() -> (SchemaRef, ArrayRef) {
2958        let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
2959        let a: BooleanArray = [Some(true), None, Some(false)].iter().collect();
2960        (Arc::new(schema), Arc::new(a))
2961    }
2962
2963    #[test]
2964    fn eq_op_bool() {
2965        let (schema, a, b) = bool_test_arrays();
2966        let expected = [
2967            Some(true),
2968            None,
2969            Some(false),
2970            None,
2971            None,
2972            None,
2973            Some(false),
2974            None,
2975            Some(true),
2976        ]
2977        .iter()
2978        .collect();
2979        apply_logic_op(&schema, &a, &b, Operator::Eq, expected).unwrap();
2980    }
2981
2982    #[test]
2983    fn eq_op_bool_scalar() {
2984        let (schema, a) = scalar_bool_test_array();
2985        let expected = [Some(true), None, Some(false)].iter().collect();
2986        apply_logic_op_scalar_arr(
2987            &schema,
2988            &ScalarValue::from(true),
2989            &a,
2990            Operator::Eq,
2991            &expected,
2992        )
2993        .unwrap();
2994        apply_logic_op_arr_scalar(
2995            &schema,
2996            &a,
2997            &ScalarValue::from(true),
2998            Operator::Eq,
2999            &expected,
3000        )
3001        .unwrap();
3002
3003        let expected = [Some(false), None, Some(true)].iter().collect();
3004        apply_logic_op_scalar_arr(
3005            &schema,
3006            &ScalarValue::from(false),
3007            &a,
3008            Operator::Eq,
3009            &expected,
3010        )
3011        .unwrap();
3012        apply_logic_op_arr_scalar(
3013            &schema,
3014            &a,
3015            &ScalarValue::from(false),
3016            Operator::Eq,
3017            &expected,
3018        )
3019        .unwrap();
3020    }
3021
3022    #[test]
3023    fn neq_op_bool() {
3024        let (schema, a, b) = bool_test_arrays();
3025        let expected = [
3026            Some(false),
3027            None,
3028            Some(true),
3029            None,
3030            None,
3031            None,
3032            Some(true),
3033            None,
3034            Some(false),
3035        ]
3036        .iter()
3037        .collect();
3038        apply_logic_op(&schema, &a, &b, Operator::NotEq, expected).unwrap();
3039    }
3040
3041    #[test]
3042    fn neq_op_bool_scalar() {
3043        let (schema, a) = scalar_bool_test_array();
3044        let expected = [Some(false), None, Some(true)].iter().collect();
3045        apply_logic_op_scalar_arr(
3046            &schema,
3047            &ScalarValue::from(true),
3048            &a,
3049            Operator::NotEq,
3050            &expected,
3051        )
3052        .unwrap();
3053        apply_logic_op_arr_scalar(
3054            &schema,
3055            &a,
3056            &ScalarValue::from(true),
3057            Operator::NotEq,
3058            &expected,
3059        )
3060        .unwrap();
3061
3062        let expected = [Some(true), None, Some(false)].iter().collect();
3063        apply_logic_op_scalar_arr(
3064            &schema,
3065            &ScalarValue::from(false),
3066            &a,
3067            Operator::NotEq,
3068            &expected,
3069        )
3070        .unwrap();
3071        apply_logic_op_arr_scalar(
3072            &schema,
3073            &a,
3074            &ScalarValue::from(false),
3075            Operator::NotEq,
3076            &expected,
3077        )
3078        .unwrap();
3079    }
3080
3081    #[test]
3082    fn lt_op_bool() {
3083        let (schema, a, b) = bool_test_arrays();
3084        let expected = [
3085            Some(false),
3086            None,
3087            Some(false),
3088            None,
3089            None,
3090            None,
3091            Some(true),
3092            None,
3093            Some(false),
3094        ]
3095        .iter()
3096        .collect();
3097        apply_logic_op(&schema, &a, &b, Operator::Lt, expected).unwrap();
3098    }
3099
3100    #[test]
3101    fn lt_op_bool_scalar() {
3102        let (schema, a) = scalar_bool_test_array();
3103        let expected = [Some(false), None, Some(false)].iter().collect();
3104        apply_logic_op_scalar_arr(
3105            &schema,
3106            &ScalarValue::from(true),
3107            &a,
3108            Operator::Lt,
3109            &expected,
3110        )
3111        .unwrap();
3112
3113        let expected = [Some(false), None, Some(true)].iter().collect();
3114        apply_logic_op_arr_scalar(
3115            &schema,
3116            &a,
3117            &ScalarValue::from(true),
3118            Operator::Lt,
3119            &expected,
3120        )
3121        .unwrap();
3122
3123        let expected = [Some(true), None, Some(false)].iter().collect();
3124        apply_logic_op_scalar_arr(
3125            &schema,
3126            &ScalarValue::from(false),
3127            &a,
3128            Operator::Lt,
3129            &expected,
3130        )
3131        .unwrap();
3132
3133        let expected = [Some(false), None, Some(false)].iter().collect();
3134        apply_logic_op_arr_scalar(
3135            &schema,
3136            &a,
3137            &ScalarValue::from(false),
3138            Operator::Lt,
3139            &expected,
3140        )
3141        .unwrap();
3142    }
3143
3144    #[test]
3145    fn lt_eq_op_bool() {
3146        let (schema, a, b) = bool_test_arrays();
3147        let expected = [
3148            Some(true),
3149            None,
3150            Some(false),
3151            None,
3152            None,
3153            None,
3154            Some(true),
3155            None,
3156            Some(true),
3157        ]
3158        .iter()
3159        .collect();
3160        apply_logic_op(&schema, &a, &b, Operator::LtEq, expected).unwrap();
3161    }
3162
3163    #[test]
3164    fn lt_eq_op_bool_scalar() {
3165        let (schema, a) = scalar_bool_test_array();
3166        let expected = [Some(true), None, Some(false)].iter().collect();
3167        apply_logic_op_scalar_arr(
3168            &schema,
3169            &ScalarValue::from(true),
3170            &a,
3171            Operator::LtEq,
3172            &expected,
3173        )
3174        .unwrap();
3175
3176        let expected = [Some(true), None, Some(true)].iter().collect();
3177        apply_logic_op_arr_scalar(
3178            &schema,
3179            &a,
3180            &ScalarValue::from(true),
3181            Operator::LtEq,
3182            &expected,
3183        )
3184        .unwrap();
3185
3186        let expected = [Some(true), None, Some(true)].iter().collect();
3187        apply_logic_op_scalar_arr(
3188            &schema,
3189            &ScalarValue::from(false),
3190            &a,
3191            Operator::LtEq,
3192            &expected,
3193        )
3194        .unwrap();
3195
3196        let expected = [Some(false), None, Some(true)].iter().collect();
3197        apply_logic_op_arr_scalar(
3198            &schema,
3199            &a,
3200            &ScalarValue::from(false),
3201            Operator::LtEq,
3202            &expected,
3203        )
3204        .unwrap();
3205    }
3206
3207    #[test]
3208    fn gt_op_bool() {
3209        let (schema, a, b) = bool_test_arrays();
3210        let expected = [
3211            Some(false),
3212            None,
3213            Some(true),
3214            None,
3215            None,
3216            None,
3217            Some(false),
3218            None,
3219            Some(false),
3220        ]
3221        .iter()
3222        .collect();
3223        apply_logic_op(&schema, &a, &b, Operator::Gt, expected).unwrap();
3224    }
3225
3226    #[test]
3227    fn gt_op_bool_scalar() {
3228        let (schema, a) = scalar_bool_test_array();
3229        let expected = [Some(false), None, Some(true)].iter().collect();
3230        apply_logic_op_scalar_arr(
3231            &schema,
3232            &ScalarValue::from(true),
3233            &a,
3234            Operator::Gt,
3235            &expected,
3236        )
3237        .unwrap();
3238
3239        let expected = [Some(false), None, Some(false)].iter().collect();
3240        apply_logic_op_arr_scalar(
3241            &schema,
3242            &a,
3243            &ScalarValue::from(true),
3244            Operator::Gt,
3245            &expected,
3246        )
3247        .unwrap();
3248
3249        let expected = [Some(false), None, Some(false)].iter().collect();
3250        apply_logic_op_scalar_arr(
3251            &schema,
3252            &ScalarValue::from(false),
3253            &a,
3254            Operator::Gt,
3255            &expected,
3256        )
3257        .unwrap();
3258
3259        let expected = [Some(true), None, Some(false)].iter().collect();
3260        apply_logic_op_arr_scalar(
3261            &schema,
3262            &a,
3263            &ScalarValue::from(false),
3264            Operator::Gt,
3265            &expected,
3266        )
3267        .unwrap();
3268    }
3269
3270    #[test]
3271    fn gt_eq_op_bool() {
3272        let (schema, a, b) = bool_test_arrays();
3273        let expected = [
3274            Some(true),
3275            None,
3276            Some(true),
3277            None,
3278            None,
3279            None,
3280            Some(false),
3281            None,
3282            Some(true),
3283        ]
3284        .iter()
3285        .collect();
3286        apply_logic_op(&schema, &a, &b, Operator::GtEq, expected).unwrap();
3287    }
3288
3289    #[test]
3290    fn gt_eq_op_bool_scalar() {
3291        let (schema, a) = scalar_bool_test_array();
3292        let expected = [Some(true), None, Some(true)].iter().collect();
3293        apply_logic_op_scalar_arr(
3294            &schema,
3295            &ScalarValue::from(true),
3296            &a,
3297            Operator::GtEq,
3298            &expected,
3299        )
3300        .unwrap();
3301
3302        let expected = [Some(true), None, Some(false)].iter().collect();
3303        apply_logic_op_arr_scalar(
3304            &schema,
3305            &a,
3306            &ScalarValue::from(true),
3307            Operator::GtEq,
3308            &expected,
3309        )
3310        .unwrap();
3311
3312        let expected = [Some(false), None, Some(true)].iter().collect();
3313        apply_logic_op_scalar_arr(
3314            &schema,
3315            &ScalarValue::from(false),
3316            &a,
3317            Operator::GtEq,
3318            &expected,
3319        )
3320        .unwrap();
3321
3322        let expected = [Some(true), None, Some(true)].iter().collect();
3323        apply_logic_op_arr_scalar(
3324            &schema,
3325            &a,
3326            &ScalarValue::from(false),
3327            Operator::GtEq,
3328            &expected,
3329        )
3330        .unwrap();
3331    }
3332
3333    #[test]
3334    fn is_distinct_from_op_bool() {
3335        let (schema, a, b) = bool_test_arrays();
3336        let expected = [
3337            Some(false),
3338            Some(true),
3339            Some(true),
3340            Some(true),
3341            Some(false),
3342            Some(true),
3343            Some(true),
3344            Some(true),
3345            Some(false),
3346        ]
3347        .iter()
3348        .collect();
3349        apply_logic_op(&schema, &a, &b, Operator::IsDistinctFrom, expected).unwrap();
3350    }
3351
3352    #[test]
3353    fn is_not_distinct_from_op_bool() {
3354        let (schema, a, b) = bool_test_arrays();
3355        let expected = [
3356            Some(true),
3357            Some(false),
3358            Some(false),
3359            Some(false),
3360            Some(true),
3361            Some(false),
3362            Some(false),
3363            Some(false),
3364            Some(true),
3365        ]
3366        .iter()
3367        .collect();
3368        apply_logic_op(&schema, &a, &b, Operator::IsNotDistinctFrom, expected).unwrap();
3369    }
3370
3371    #[test]
3372    fn relatively_deeply_nested() {
3373        // Reproducer for https://github.com/apache/datafusion/issues/419
3374
3375        // where even relatively shallow binary expressions overflowed
3376        // the stack in debug builds
3377
3378        let input: Vec<_> = vec![1, 2, 3, 4, 5].into_iter().map(Some).collect();
3379        let a: Int32Array = input.iter().collect();
3380
3381        let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(a) as _)]).unwrap();
3382        let schema = batch.schema();
3383
3384        // build a left deep tree ((((a + a) + a) + a ....
3385        let tree_depth: i32 = 100;
3386        let expr = (0..tree_depth)
3387            .map(|_| col("a", schema.as_ref()).unwrap())
3388            .reduce(|l, r| binary(l, Operator::Plus, r, &schema).unwrap())
3389            .unwrap();
3390
3391        let result = expr
3392            .evaluate(&batch)
3393            .expect("evaluation")
3394            .into_array(batch.num_rows())
3395            .expect("Failed to convert to array");
3396
3397        let expected: Int32Array = input
3398            .into_iter()
3399            .map(|i| i.map(|i| i * tree_depth))
3400            .collect();
3401        assert_eq!(result.as_ref(), &expected);
3402    }
3403
3404    fn create_decimal_array(
3405        array: &[Option<i128>],
3406        precision: u8,
3407        scale: i8,
3408    ) -> Decimal128Array {
3409        let mut decimal_builder = Decimal128Builder::with_capacity(array.len());
3410        for value in array.iter().copied() {
3411            decimal_builder.append_option(value)
3412        }
3413        decimal_builder
3414            .finish()
3415            .with_precision_and_scale(precision, scale)
3416            .unwrap()
3417    }
3418
3419    #[test]
3420    fn comparison_dict_decimal_scalar_expr_test() -> Result<()> {
3421        // scalar of decimal compare with dictionary decimal array
3422        let value_i128 = 123;
3423        let decimal_scalar = ScalarValue::Dictionary(
3424            Box::new(DataType::Int8),
3425            Box::new(ScalarValue::Decimal128(Some(value_i128), 25, 3)),
3426        );
3427        let schema = Arc::new(Schema::new(vec![Field::new(
3428            "a",
3429            DataType::Dictionary(
3430                Box::new(DataType::Int8),
3431                Box::new(DataType::Decimal128(25, 3)),
3432            ),
3433            true,
3434        )]));
3435        let decimal_array = Arc::new(create_decimal_array(
3436            &[
3437                Some(value_i128),
3438                None,
3439                Some(value_i128 - 1),
3440                Some(value_i128 + 1),
3441            ],
3442            25,
3443            3,
3444        ));
3445
3446        let keys = Int8Array::from(vec![Some(0), None, Some(2), Some(3)]);
3447        let dictionary =
3448            Arc::new(DictionaryArray::try_new(keys, decimal_array)?) as ArrayRef;
3449
3450        // array = scalar
3451        apply_logic_op_arr_scalar(
3452            &schema,
3453            &dictionary,
3454            &decimal_scalar,
3455            Operator::Eq,
3456            &BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
3457        )
3458        .unwrap();
3459        // array != scalar
3460        apply_logic_op_arr_scalar(
3461            &schema,
3462            &dictionary,
3463            &decimal_scalar,
3464            Operator::NotEq,
3465            &BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
3466        )
3467        .unwrap();
3468        //  array < scalar
3469        apply_logic_op_arr_scalar(
3470            &schema,
3471            &dictionary,
3472            &decimal_scalar,
3473            Operator::Lt,
3474            &BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3475        )
3476        .unwrap();
3477
3478        //  array <= scalar
3479        apply_logic_op_arr_scalar(
3480            &schema,
3481            &dictionary,
3482            &decimal_scalar,
3483            Operator::LtEq,
3484            &BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
3485        )
3486        .unwrap();
3487        // array > scalar
3488        apply_logic_op_arr_scalar(
3489            &schema,
3490            &dictionary,
3491            &decimal_scalar,
3492            Operator::Gt,
3493            &BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
3494        )
3495        .unwrap();
3496
3497        // array >= scalar
3498        apply_logic_op_arr_scalar(
3499            &schema,
3500            &dictionary,
3501            &decimal_scalar,
3502            Operator::GtEq,
3503            &BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3504        )
3505        .unwrap();
3506
3507        Ok(())
3508    }
3509
3510    #[test]
3511    fn comparison_decimal_expr_test() -> Result<()> {
3512        // scalar of decimal compare with decimal array
3513        let value_i128 = 123;
3514        let decimal_scalar = ScalarValue::Decimal128(Some(value_i128), 25, 3);
3515        let schema = Arc::new(Schema::new(vec![Field::new(
3516            "a",
3517            DataType::Decimal128(25, 3),
3518            true,
3519        )]));
3520        let decimal_array = Arc::new(create_decimal_array(
3521            &[
3522                Some(value_i128),
3523                None,
3524                Some(value_i128 - 1),
3525                Some(value_i128 + 1),
3526            ],
3527            25,
3528            3,
3529        )) as ArrayRef;
3530        // array = scalar
3531        apply_logic_op_arr_scalar(
3532            &schema,
3533            &decimal_array,
3534            &decimal_scalar,
3535            Operator::Eq,
3536            &BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
3537        )
3538        .unwrap();
3539        // array != scalar
3540        apply_logic_op_arr_scalar(
3541            &schema,
3542            &decimal_array,
3543            &decimal_scalar,
3544            Operator::NotEq,
3545            &BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
3546        )
3547        .unwrap();
3548        //  array < scalar
3549        apply_logic_op_arr_scalar(
3550            &schema,
3551            &decimal_array,
3552            &decimal_scalar,
3553            Operator::Lt,
3554            &BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3555        )
3556        .unwrap();
3557
3558        //  array <= scalar
3559        apply_logic_op_arr_scalar(
3560            &schema,
3561            &decimal_array,
3562            &decimal_scalar,
3563            Operator::LtEq,
3564            &BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
3565        )
3566        .unwrap();
3567        // array > scalar
3568        apply_logic_op_arr_scalar(
3569            &schema,
3570            &decimal_array,
3571            &decimal_scalar,
3572            Operator::Gt,
3573            &BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
3574        )
3575        .unwrap();
3576
3577        // array >= scalar
3578        apply_logic_op_arr_scalar(
3579            &schema,
3580            &decimal_array,
3581            &decimal_scalar,
3582            Operator::GtEq,
3583            &BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3584        )
3585        .unwrap();
3586
3587        // scalar of different data type with decimal array
3588        let decimal_scalar = ScalarValue::Decimal128(Some(123_456), 10, 3);
3589        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)]));
3590        // scalar == array
3591        apply_logic_op_scalar_arr(
3592            &schema,
3593            &decimal_scalar,
3594            &(Arc::new(Int64Array::from(vec![Some(124), None])) as ArrayRef),
3595            Operator::Eq,
3596            &BooleanArray::from(vec![Some(false), None]),
3597        )
3598        .unwrap();
3599
3600        // array != scalar
3601        apply_logic_op_arr_scalar(
3602            &schema,
3603            &(Arc::new(Int64Array::from(vec![Some(123), None, Some(1)])) as ArrayRef),
3604            &decimal_scalar,
3605            Operator::NotEq,
3606            &BooleanArray::from(vec![Some(true), None, Some(true)]),
3607        )
3608        .unwrap();
3609
3610        // array < scalar
3611        apply_logic_op_arr_scalar(
3612            &schema,
3613            &(Arc::new(Int64Array::from(vec![Some(123), None, Some(124)])) as ArrayRef),
3614            &decimal_scalar,
3615            Operator::Lt,
3616            &BooleanArray::from(vec![Some(true), None, Some(false)]),
3617        )
3618        .unwrap();
3619
3620        // array > scalar
3621        apply_logic_op_arr_scalar(
3622            &schema,
3623            &(Arc::new(Int64Array::from(vec![Some(123), None, Some(124)])) as ArrayRef),
3624            &decimal_scalar,
3625            Operator::Gt,
3626            &BooleanArray::from(vec![Some(false), None, Some(true)]),
3627        )
3628        .unwrap();
3629
3630        let schema =
3631            Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)]));
3632        // array == scalar
3633        apply_logic_op_arr_scalar(
3634            &schema,
3635            &(Arc::new(Float64Array::from(vec![Some(123.456), None, Some(123.457)]))
3636                as ArrayRef),
3637            &decimal_scalar,
3638            Operator::Eq,
3639            &BooleanArray::from(vec![Some(true), None, Some(false)]),
3640        )
3641        .unwrap();
3642
3643        // array <= scalar
3644        apply_logic_op_arr_scalar(
3645            &schema,
3646            &(Arc::new(Float64Array::from(vec![
3647                Some(123.456),
3648                None,
3649                Some(123.457),
3650                Some(123.45),
3651            ])) as ArrayRef),
3652            &decimal_scalar,
3653            Operator::LtEq,
3654            &BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3655        )
3656        .unwrap();
3657        // array >= scalar
3658        apply_logic_op_arr_scalar(
3659            &schema,
3660            &(Arc::new(Float64Array::from(vec![
3661                Some(123.456),
3662                None,
3663                Some(123.457),
3664                Some(123.45),
3665            ])) as ArrayRef),
3666            &decimal_scalar,
3667            Operator::GtEq,
3668            &BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
3669        )
3670        .unwrap();
3671
3672        let value: i128 = 123;
3673        let decimal_array = Arc::new(create_decimal_array(
3674            &[Some(value), None, Some(value - 1), Some(value + 1)],
3675            10,
3676            0,
3677        )) as ArrayRef;
3678
3679        // comparison array op for decimal array
3680        let schema = Arc::new(Schema::new(vec![
3681            Field::new("a", DataType::Decimal128(10, 0), true),
3682            Field::new("b", DataType::Decimal128(10, 0), true),
3683        ]));
3684        let right_decimal_array = Arc::new(create_decimal_array(
3685            &[
3686                Some(value - 1),
3687                Some(value),
3688                Some(value + 1),
3689                Some(value + 1),
3690            ],
3691            10,
3692            0,
3693        )) as ArrayRef;
3694
3695        apply_logic_op(
3696            &schema,
3697            &decimal_array,
3698            &right_decimal_array,
3699            Operator::Eq,
3700            BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
3701        )
3702        .unwrap();
3703
3704        apply_logic_op(
3705            &schema,
3706            &decimal_array,
3707            &right_decimal_array,
3708            Operator::NotEq,
3709            BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
3710        )
3711        .unwrap();
3712
3713        apply_logic_op(
3714            &schema,
3715            &decimal_array,
3716            &right_decimal_array,
3717            Operator::Lt,
3718            BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3719        )
3720        .unwrap();
3721
3722        apply_logic_op(
3723            &schema,
3724            &decimal_array,
3725            &right_decimal_array,
3726            Operator::LtEq,
3727            BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
3728        )
3729        .unwrap();
3730
3731        apply_logic_op(
3732            &schema,
3733            &decimal_array,
3734            &right_decimal_array,
3735            Operator::Gt,
3736            BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
3737        )
3738        .unwrap();
3739
3740        apply_logic_op(
3741            &schema,
3742            &decimal_array,
3743            &right_decimal_array,
3744            Operator::GtEq,
3745            BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3746        )
3747        .unwrap();
3748
3749        // compare decimal array with other array type
3750        let value: i64 = 123;
3751        let schema = Arc::new(Schema::new(vec![
3752            Field::new("a", DataType::Int64, true),
3753            Field::new("b", DataType::Decimal128(10, 0), true),
3754        ]));
3755
3756        let int64_array = Arc::new(Int64Array::from(vec![
3757            Some(value),
3758            Some(value - 1),
3759            Some(value),
3760            Some(value + 1),
3761        ])) as ArrayRef;
3762
3763        // eq: int64array == decimal array
3764        apply_logic_op(
3765            &schema,
3766            &int64_array,
3767            &decimal_array,
3768            Operator::Eq,
3769            BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3770        )
3771        .unwrap();
3772        // neq: int64array != decimal array
3773        apply_logic_op(
3774            &schema,
3775            &int64_array,
3776            &decimal_array,
3777            Operator::NotEq,
3778            BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3779        )
3780        .unwrap();
3781
3782        let schema = Arc::new(Schema::new(vec![
3783            Field::new("a", DataType::Float64, true),
3784            Field::new("b", DataType::Decimal128(10, 2), true),
3785        ]));
3786
3787        let value: i128 = 123;
3788        let decimal_array = Arc::new(create_decimal_array(
3789            &[
3790                Some(value), // 1.23
3791                None,
3792                Some(value - 1), // 1.22
3793                Some(value + 1), // 1.24
3794            ],
3795            10,
3796            2,
3797        )) as ArrayRef;
3798        let float64_array = Arc::new(Float64Array::from(vec![
3799            Some(1.23),
3800            Some(1.22),
3801            Some(1.23),
3802            Some(1.24),
3803        ])) as ArrayRef;
3804        // lt: float64array < decimal array
3805        apply_logic_op(
3806            &schema,
3807            &float64_array,
3808            &decimal_array,
3809            Operator::Lt,
3810            BooleanArray::from(vec![Some(false), None, Some(false), Some(false)]),
3811        )
3812        .unwrap();
3813        // lt_eq: float64array <= decimal array
3814        apply_logic_op(
3815            &schema,
3816            &float64_array,
3817            &decimal_array,
3818            Operator::LtEq,
3819            BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3820        )
3821        .unwrap();
3822        // gt: float64array > decimal array
3823        apply_logic_op(
3824            &schema,
3825            &float64_array,
3826            &decimal_array,
3827            Operator::Gt,
3828            BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3829        )
3830        .unwrap();
3831        apply_logic_op(
3832            &schema,
3833            &float64_array,
3834            &decimal_array,
3835            Operator::GtEq,
3836            BooleanArray::from(vec![Some(true), None, Some(true), Some(true)]),
3837        )
3838        .unwrap();
3839        // is distinct: float64array is distinct decimal array
3840        // TODO: now we do not refactor the `is distinct or is not distinct` rule of coercion.
3841        // traced by https://github.com/apache/datafusion/issues/1590
3842        // the decimal array will be casted to float64array
3843        apply_logic_op(
3844            &schema,
3845            &float64_array,
3846            &decimal_array,
3847            Operator::IsDistinctFrom,
3848            BooleanArray::from(vec![Some(false), Some(true), Some(true), Some(false)]),
3849        )
3850        .unwrap();
3851        // is not distinct
3852        apply_logic_op(
3853            &schema,
3854            &float64_array,
3855            &decimal_array,
3856            Operator::IsNotDistinctFrom,
3857            BooleanArray::from(vec![Some(true), Some(false), Some(false), Some(true)]),
3858        )
3859        .unwrap();
3860
3861        Ok(())
3862    }
3863
3864    fn apply_decimal_arithmetic_op(
3865        schema: &SchemaRef,
3866        left: &ArrayRef,
3867        right: &ArrayRef,
3868        op: Operator,
3869        expected: ArrayRef,
3870    ) -> Result<()> {
3871        let arithmetic_op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?;
3872        let data: Vec<ArrayRef> = vec![Arc::clone(left), Arc::clone(right)];
3873        let batch = RecordBatch::try_new(Arc::clone(schema), data)?;
3874        let result = arithmetic_op
3875            .evaluate(&batch)?
3876            .into_array(batch.num_rows())
3877            .expect("Failed to convert to array");
3878
3879        assert_eq!(result.as_ref(), expected.as_ref());
3880        Ok(())
3881    }
3882
3883    #[test]
3884    fn arithmetic_decimal_expr_test() -> Result<()> {
3885        let schema = Arc::new(Schema::new(vec![
3886            Field::new("a", DataType::Int32, true),
3887            Field::new("b", DataType::Decimal128(10, 2), true),
3888        ]));
3889        let value: i128 = 123;
3890        let decimal_array = Arc::new(create_decimal_array(
3891            &[
3892                Some(value), // 1.23
3893                None,
3894                Some(value - 1), // 1.22
3895                Some(value + 1), // 1.24
3896            ],
3897            10,
3898            2,
3899        )) as ArrayRef;
3900        let int32_array = Arc::new(Int32Array::from(vec![
3901            Some(123),
3902            Some(122),
3903            Some(123),
3904            Some(124),
3905        ])) as ArrayRef;
3906
3907        // add: Int32array add decimal array
3908        let expect = Arc::new(create_decimal_array(
3909            &[Some(12423), None, Some(12422), Some(12524)],
3910            13,
3911            2,
3912        )) as ArrayRef;
3913        apply_decimal_arithmetic_op(
3914            &schema,
3915            &int32_array,
3916            &decimal_array,
3917            Operator::Plus,
3918            expect,
3919        )
3920        .unwrap();
3921
3922        // subtract: decimal array subtract int32 array
3923        let schema = Arc::new(Schema::new(vec![
3924            Field::new("a", DataType::Decimal128(10, 2), true),
3925            Field::new("b", DataType::Int32, true),
3926        ]));
3927        let expect = Arc::new(create_decimal_array(
3928            &[Some(-12177), None, Some(-12178), Some(-12276)],
3929            13,
3930            2,
3931        )) as ArrayRef;
3932        apply_decimal_arithmetic_op(
3933            &schema,
3934            &decimal_array,
3935            &int32_array,
3936            Operator::Minus,
3937            expect,
3938        )
3939        .unwrap();
3940
3941        // multiply: decimal array multiply int32 array
3942        let expect = Arc::new(create_decimal_array(
3943            &[Some(15129), None, Some(15006), Some(15376)],
3944            21,
3945            2,
3946        )) as ArrayRef;
3947        apply_decimal_arithmetic_op(
3948            &schema,
3949            &decimal_array,
3950            &int32_array,
3951            Operator::Multiply,
3952            expect,
3953        )
3954        .unwrap();
3955
3956        // divide: int32 array divide decimal array
3957        let schema = Arc::new(Schema::new(vec![
3958            Field::new("a", DataType::Int32, true),
3959            Field::new("b", DataType::Decimal128(10, 2), true),
3960        ]));
3961        let expect = Arc::new(create_decimal_array(
3962            &[Some(1000000), None, Some(1008196), Some(1000000)],
3963            16,
3964            4,
3965        )) as ArrayRef;
3966        apply_decimal_arithmetic_op(
3967            &schema,
3968            &int32_array,
3969            &decimal_array,
3970            Operator::Divide,
3971            expect,
3972        )
3973        .unwrap();
3974
3975        // modulus: int32 array modulus decimal array
3976        let schema = Arc::new(Schema::new(vec![
3977            Field::new("a", DataType::Int32, true),
3978            Field::new("b", DataType::Decimal128(10, 2), true),
3979        ]));
3980        let expect = Arc::new(create_decimal_array(
3981            &[Some(000), None, Some(100), Some(000)],
3982            10,
3983            2,
3984        )) as ArrayRef;
3985        apply_decimal_arithmetic_op(
3986            &schema,
3987            &int32_array,
3988            &decimal_array,
3989            Operator::Modulo,
3990            expect,
3991        )
3992        .unwrap();
3993
3994        Ok(())
3995    }
3996
3997    #[test]
3998    fn arithmetic_decimal_float_expr_test() -> Result<()> {
3999        let schema = Arc::new(Schema::new(vec![
4000            Field::new("a", DataType::Float64, true),
4001            Field::new("b", DataType::Decimal128(10, 2), true),
4002        ]));
4003        let value: i128 = 123;
4004        let decimal_array = Arc::new(create_decimal_array(
4005            &[
4006                Some(value), // 1.23
4007                None,
4008                Some(value - 1), // 1.22
4009                Some(value + 1), // 1.24
4010            ],
4011            10,
4012            2,
4013        )) as ArrayRef;
4014        let float64_array = Arc::new(Float64Array::from(vec![
4015            Some(123.0),
4016            Some(122.0),
4017            Some(123.0),
4018            Some(124.0),
4019        ])) as ArrayRef;
4020
4021        // add: float64 array add decimal array
4022        let expect = Arc::new(Float64Array::from(vec![
4023            Some(124.23),
4024            None,
4025            Some(124.22),
4026            Some(125.24),
4027        ])) as ArrayRef;
4028        apply_decimal_arithmetic_op(
4029            &schema,
4030            &float64_array,
4031            &decimal_array,
4032            Operator::Plus,
4033            expect,
4034        )
4035        .unwrap();
4036
4037        // subtract: decimal array subtract float64 array
4038        let schema = Arc::new(Schema::new(vec![
4039            Field::new("a", DataType::Float64, true),
4040            Field::new("b", DataType::Decimal128(10, 2), true),
4041        ]));
4042        let expect = Arc::new(Float64Array::from(vec![
4043            Some(121.77),
4044            None,
4045            Some(121.78),
4046            Some(122.76),
4047        ])) as ArrayRef;
4048        apply_decimal_arithmetic_op(
4049            &schema,
4050            &float64_array,
4051            &decimal_array,
4052            Operator::Minus,
4053            expect,
4054        )
4055        .unwrap();
4056
4057        // multiply: decimal array multiply float64 array
4058        let expect = Arc::new(Float64Array::from(vec![
4059            Some(151.29),
4060            None,
4061            Some(150.06),
4062            Some(153.76),
4063        ])) as ArrayRef;
4064        apply_decimal_arithmetic_op(
4065            &schema,
4066            &float64_array,
4067            &decimal_array,
4068            Operator::Multiply,
4069            expect,
4070        )
4071        .unwrap();
4072
4073        // divide: float64 array divide decimal array
4074        let schema = Arc::new(Schema::new(vec![
4075            Field::new("a", DataType::Float64, true),
4076            Field::new("b", DataType::Decimal128(10, 2), true),
4077        ]));
4078        let expect = Arc::new(Float64Array::from(vec![
4079            Some(100.0),
4080            None,
4081            Some(100.81967213114754),
4082            Some(100.0),
4083        ])) as ArrayRef;
4084        apply_decimal_arithmetic_op(
4085            &schema,
4086            &float64_array,
4087            &decimal_array,
4088            Operator::Divide,
4089            expect,
4090        )
4091        .unwrap();
4092
4093        // modulus: float64 array modulus decimal array
4094        let schema = Arc::new(Schema::new(vec![
4095            Field::new("a", DataType::Float64, true),
4096            Field::new("b", DataType::Decimal128(10, 2), true),
4097        ]));
4098        let expect = Arc::new(Float64Array::from(vec![
4099            Some(1.7763568394002505e-15),
4100            None,
4101            Some(1.0000000000000027),
4102            Some(8.881784197001252e-16),
4103        ])) as ArrayRef;
4104        apply_decimal_arithmetic_op(
4105            &schema,
4106            &float64_array,
4107            &decimal_array,
4108            Operator::Modulo,
4109            expect,
4110        )
4111        .unwrap();
4112
4113        Ok(())
4114    }
4115
4116    #[test]
4117    fn arithmetic_divide_zero() -> Result<()> {
4118        // other data type
4119        let schema = Arc::new(Schema::new(vec![
4120            Field::new("a", DataType::Int32, true),
4121            Field::new("b", DataType::Int32, true),
4122        ]));
4123        let a = Arc::new(Int32Array::from(vec![100]));
4124        let b = Arc::new(Int32Array::from(vec![0]));
4125
4126        let err = apply_arithmetic::<Int32Type>(
4127            schema,
4128            vec![a, b],
4129            Operator::Divide,
4130            Int32Array::from(vec![Some(4), Some(8), Some(16), Some(32), Some(64)]),
4131        )
4132        .unwrap_err();
4133
4134        let _expected = plan_datafusion_err!("Divide by zero");
4135
4136        assert!(matches!(err, ref _expected), "{err}");
4137
4138        // decimal
4139        let schema = Arc::new(Schema::new(vec![
4140            Field::new("a", DataType::Decimal128(25, 3), true),
4141            Field::new("b", DataType::Decimal128(25, 3), true),
4142        ]));
4143        let left_decimal_array = Arc::new(create_decimal_array(&[Some(1234567)], 25, 3));
4144        let right_decimal_array = Arc::new(create_decimal_array(&[Some(0)], 25, 3));
4145
4146        let err = apply_arithmetic::<Decimal128Type>(
4147            schema,
4148            vec![left_decimal_array, right_decimal_array],
4149            Operator::Divide,
4150            create_decimal_array(
4151                &[Some(12345670000000000000000000000000000), None],
4152                38,
4153                29,
4154            ),
4155        )
4156        .unwrap_err();
4157
4158        assert!(matches!(err, ref _expected), "{err}");
4159
4160        Ok(())
4161    }
4162
4163    #[test]
4164    fn bitwise_array_test() -> Result<()> {
4165        let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
4166        let right =
4167            Arc::new(Int32Array::from(vec![Some(1), Some(3), Some(7)])) as ArrayRef;
4168        let mut result = bitwise_and_dyn(Arc::clone(&left), Arc::clone(&right))?;
4169        let expected = Int32Array::from(vec![Some(0), None, Some(3)]);
4170        assert_eq!(result.as_ref(), &expected);
4171
4172        result = bitwise_or_dyn(Arc::clone(&left), Arc::clone(&right))?;
4173        let expected = Int32Array::from(vec![Some(13), None, Some(15)]);
4174        assert_eq!(result.as_ref(), &expected);
4175
4176        result = bitwise_xor_dyn(Arc::clone(&left), Arc::clone(&right))?;
4177        let expected = Int32Array::from(vec![Some(13), None, Some(12)]);
4178        assert_eq!(result.as_ref(), &expected);
4179
4180        let left =
4181            Arc::new(UInt32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
4182        let right =
4183            Arc::new(UInt32Array::from(vec![Some(1), Some(3), Some(7)])) as ArrayRef;
4184        let mut result = bitwise_and_dyn(Arc::clone(&left), Arc::clone(&right))?;
4185        let expected = UInt32Array::from(vec![Some(0), None, Some(3)]);
4186        assert_eq!(result.as_ref(), &expected);
4187
4188        result = bitwise_or_dyn(Arc::clone(&left), Arc::clone(&right))?;
4189        let expected = UInt32Array::from(vec![Some(13), None, Some(15)]);
4190        assert_eq!(result.as_ref(), &expected);
4191
4192        result = bitwise_xor_dyn(Arc::clone(&left), Arc::clone(&right))?;
4193        let expected = UInt32Array::from(vec![Some(13), None, Some(12)]);
4194        assert_eq!(result.as_ref(), &expected);
4195
4196        Ok(())
4197    }
4198
4199    #[test]
4200    fn bitwise_shift_array_test() -> Result<()> {
4201        let input = Arc::new(Int32Array::from(vec![Some(2), None, Some(10)])) as ArrayRef;
4202        let modules =
4203            Arc::new(Int32Array::from(vec![Some(2), Some(4), Some(8)])) as ArrayRef;
4204        let mut result =
4205            bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?;
4206
4207        let expected = Int32Array::from(vec![Some(8), None, Some(2560)]);
4208        assert_eq!(result.as_ref(), &expected);
4209
4210        result = bitwise_shift_right_dyn(Arc::clone(&result), Arc::clone(&modules))?;
4211        assert_eq!(result.as_ref(), &input);
4212
4213        let input =
4214            Arc::new(UInt32Array::from(vec![Some(2), None, Some(10)])) as ArrayRef;
4215        let modules =
4216            Arc::new(UInt32Array::from(vec![Some(2), Some(4), Some(8)])) as ArrayRef;
4217        let mut result =
4218            bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?;
4219
4220        let expected = UInt32Array::from(vec![Some(8), None, Some(2560)]);
4221        assert_eq!(result.as_ref(), &expected);
4222
4223        result = bitwise_shift_right_dyn(Arc::clone(&result), Arc::clone(&modules))?;
4224        assert_eq!(result.as_ref(), &input);
4225        Ok(())
4226    }
4227
4228    #[test]
4229    fn bitwise_shift_array_overflow_test() -> Result<()> {
4230        let input = Arc::new(Int32Array::from(vec![Some(2)])) as ArrayRef;
4231        let modules = Arc::new(Int32Array::from(vec![Some(100)])) as ArrayRef;
4232        let result = bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?;
4233
4234        let expected = Int32Array::from(vec![Some(32)]);
4235        assert_eq!(result.as_ref(), &expected);
4236
4237        let input = Arc::new(UInt32Array::from(vec![Some(2)])) as ArrayRef;
4238        let modules = Arc::new(UInt32Array::from(vec![Some(100)])) as ArrayRef;
4239        let result = bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?;
4240
4241        let expected = UInt32Array::from(vec![Some(32)]);
4242        assert_eq!(result.as_ref(), &expected);
4243        Ok(())
4244    }
4245
4246    #[test]
4247    fn bitwise_scalar_test() -> Result<()> {
4248        let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
4249        let right = ScalarValue::from(3i32);
4250        let mut result = bitwise_and_dyn_scalar(&left, right.clone()).unwrap()?;
4251        let expected = Int32Array::from(vec![Some(0), None, Some(3)]);
4252        assert_eq!(result.as_ref(), &expected);
4253
4254        result = bitwise_or_dyn_scalar(&left, right.clone()).unwrap()?;
4255        let expected = Int32Array::from(vec![Some(15), None, Some(11)]);
4256        assert_eq!(result.as_ref(), &expected);
4257
4258        result = bitwise_xor_dyn_scalar(&left, right).unwrap()?;
4259        let expected = Int32Array::from(vec![Some(15), None, Some(8)]);
4260        assert_eq!(result.as_ref(), &expected);
4261
4262        let left =
4263            Arc::new(UInt32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
4264        let right = ScalarValue::from(3u32);
4265        let mut result = bitwise_and_dyn_scalar(&left, right.clone()).unwrap()?;
4266        let expected = UInt32Array::from(vec![Some(0), None, Some(3)]);
4267        assert_eq!(result.as_ref(), &expected);
4268
4269        result = bitwise_or_dyn_scalar(&left, right.clone()).unwrap()?;
4270        let expected = UInt32Array::from(vec![Some(15), None, Some(11)]);
4271        assert_eq!(result.as_ref(), &expected);
4272
4273        result = bitwise_xor_dyn_scalar(&left, right).unwrap()?;
4274        let expected = UInt32Array::from(vec![Some(15), None, Some(8)]);
4275        assert_eq!(result.as_ref(), &expected);
4276        Ok(())
4277    }
4278
4279    #[test]
4280    fn bitwise_shift_scalar_test() -> Result<()> {
4281        let input = Arc::new(Int32Array::from(vec![Some(2), None, Some(4)])) as ArrayRef;
4282        let module = ScalarValue::from(10i32);
4283        let mut result =
4284            bitwise_shift_left_dyn_scalar(&input, module.clone()).unwrap()?;
4285
4286        let expected = Int32Array::from(vec![Some(2048), None, Some(4096)]);
4287        assert_eq!(result.as_ref(), &expected);
4288
4289        result = bitwise_shift_right_dyn_scalar(&result, module).unwrap()?;
4290        assert_eq!(result.as_ref(), &input);
4291
4292        let input = Arc::new(UInt32Array::from(vec![Some(2), None, Some(4)])) as ArrayRef;
4293        let module = ScalarValue::from(10u32);
4294        let mut result =
4295            bitwise_shift_left_dyn_scalar(&input, module.clone()).unwrap()?;
4296
4297        let expected = UInt32Array::from(vec![Some(2048), None, Some(4096)]);
4298        assert_eq!(result.as_ref(), &expected);
4299
4300        result = bitwise_shift_right_dyn_scalar(&result, module).unwrap()?;
4301        assert_eq!(result.as_ref(), &input);
4302        Ok(())
4303    }
4304
4305    #[test]
4306    fn test_display_and_or_combo() {
4307        let expr = BinaryExpr::new(
4308            Arc::new(BinaryExpr::new(
4309                lit(ScalarValue::from(1)),
4310                Operator::And,
4311                lit(ScalarValue::from(2)),
4312            )),
4313            Operator::And,
4314            Arc::new(BinaryExpr::new(
4315                lit(ScalarValue::from(3)),
4316                Operator::And,
4317                lit(ScalarValue::from(4)),
4318            )),
4319        );
4320        assert_eq!(expr.to_string(), "1 AND 2 AND 3 AND 4");
4321
4322        let expr = BinaryExpr::new(
4323            Arc::new(BinaryExpr::new(
4324                lit(ScalarValue::from(1)),
4325                Operator::Or,
4326                lit(ScalarValue::from(2)),
4327            )),
4328            Operator::Or,
4329            Arc::new(BinaryExpr::new(
4330                lit(ScalarValue::from(3)),
4331                Operator::Or,
4332                lit(ScalarValue::from(4)),
4333            )),
4334        );
4335        assert_eq!(expr.to_string(), "1 OR 2 OR 3 OR 4");
4336
4337        let expr = BinaryExpr::new(
4338            Arc::new(BinaryExpr::new(
4339                lit(ScalarValue::from(1)),
4340                Operator::And,
4341                lit(ScalarValue::from(2)),
4342            )),
4343            Operator::Or,
4344            Arc::new(BinaryExpr::new(
4345                lit(ScalarValue::from(3)),
4346                Operator::And,
4347                lit(ScalarValue::from(4)),
4348            )),
4349        );
4350        assert_eq!(expr.to_string(), "1 AND 2 OR 3 AND 4");
4351
4352        let expr = BinaryExpr::new(
4353            Arc::new(BinaryExpr::new(
4354                lit(ScalarValue::from(1)),
4355                Operator::Or,
4356                lit(ScalarValue::from(2)),
4357            )),
4358            Operator::And,
4359            Arc::new(BinaryExpr::new(
4360                lit(ScalarValue::from(3)),
4361                Operator::Or,
4362                lit(ScalarValue::from(4)),
4363            )),
4364        );
4365        assert_eq!(expr.to_string(), "(1 OR 2) AND (3 OR 4)");
4366    }
4367
4368    #[test]
4369    fn test_to_result_type_array() {
4370        let values = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
4371        let keys = Int8Array::from(vec![Some(0), None, Some(2), Some(3)]);
4372        let dictionary =
4373            Arc::new(DictionaryArray::try_new(keys, values).unwrap()) as ArrayRef;
4374
4375        // Casting Dictionary to Int32
4376        let casted = to_result_type_array(
4377            &Operator::Plus,
4378            Arc::clone(&dictionary),
4379            &DataType::Int32,
4380        )
4381        .unwrap();
4382        assert_eq!(
4383            &casted,
4384            &(Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(4)]))
4385                as ArrayRef)
4386        );
4387
4388        // Array has same datatype as result type, no casting
4389        let casted = to_result_type_array(
4390            &Operator::Plus,
4391            Arc::clone(&dictionary),
4392            dictionary.data_type(),
4393        )
4394        .unwrap();
4395        assert_eq!(&casted, &dictionary);
4396
4397        // Not numerical operator, no casting
4398        let casted = to_result_type_array(
4399            &Operator::Eq,
4400            Arc::clone(&dictionary),
4401            &DataType::Int32,
4402        )
4403        .unwrap();
4404        assert_eq!(&casted, &dictionary);
4405    }
4406
4407    #[test]
4408    fn test_add_with_overflow() -> Result<()> {
4409        // create test data
4410        let l = Arc::new(Int32Array::from(vec![1, i32::MAX]));
4411        let r = Arc::new(Int32Array::from(vec![2, 1]));
4412        let schema = Arc::new(Schema::new(vec![
4413            Field::new("l", DataType::Int32, false),
4414            Field::new("r", DataType::Int32, false),
4415        ]));
4416        let batch = RecordBatch::try_new(schema, vec![l, r])?;
4417
4418        // create expression
4419        let expr = BinaryExpr::new(
4420            Arc::new(Column::new("l", 0)),
4421            Operator::Plus,
4422            Arc::new(Column::new("r", 1)),
4423        )
4424        .with_fail_on_overflow(true);
4425
4426        // evaluate expression
4427        let result = expr.evaluate(&batch);
4428        assert!(
4429            result
4430                .err()
4431                .unwrap()
4432                .to_string()
4433                .contains("Overflow happened on: 2147483647 + 1")
4434        );
4435        Ok(())
4436    }
4437
4438    #[test]
4439    fn test_subtract_with_overflow() -> Result<()> {
4440        // create test data
4441        let l = Arc::new(Int32Array::from(vec![1, i32::MIN]));
4442        let r = Arc::new(Int32Array::from(vec![2, 1]));
4443        let schema = Arc::new(Schema::new(vec![
4444            Field::new("l", DataType::Int32, false),
4445            Field::new("r", DataType::Int32, false),
4446        ]));
4447        let batch = RecordBatch::try_new(schema, vec![l, r])?;
4448
4449        // create expression
4450        let expr = BinaryExpr::new(
4451            Arc::new(Column::new("l", 0)),
4452            Operator::Minus,
4453            Arc::new(Column::new("r", 1)),
4454        )
4455        .with_fail_on_overflow(true);
4456
4457        // evaluate expression
4458        let result = expr.evaluate(&batch);
4459        assert!(
4460            result
4461                .err()
4462                .unwrap()
4463                .to_string()
4464                .contains("Overflow happened on: -2147483648 - 1")
4465        );
4466        Ok(())
4467    }
4468
4469    #[test]
4470    fn test_mul_with_overflow() -> Result<()> {
4471        // create test data
4472        let l = Arc::new(Int32Array::from(vec![1, i32::MAX]));
4473        let r = Arc::new(Int32Array::from(vec![2, 2]));
4474        let schema = Arc::new(Schema::new(vec![
4475            Field::new("l", DataType::Int32, false),
4476            Field::new("r", DataType::Int32, false),
4477        ]));
4478        let batch = RecordBatch::try_new(schema, vec![l, r])?;
4479
4480        // create expression
4481        let expr = BinaryExpr::new(
4482            Arc::new(Column::new("l", 0)),
4483            Operator::Multiply,
4484            Arc::new(Column::new("r", 1)),
4485        )
4486        .with_fail_on_overflow(true);
4487
4488        // evaluate expression
4489        let result = expr.evaluate(&batch);
4490        assert!(
4491            result
4492                .err()
4493                .unwrap()
4494                .to_string()
4495                .contains("Overflow happened on: 2147483647 * 2")
4496        );
4497        Ok(())
4498    }
4499
4500    /// Test helper for SIMILAR TO binary operation
4501    fn apply_similar_to(
4502        schema: &SchemaRef,
4503        va: Vec<&str>,
4504        vb: Vec<&str>,
4505        negated: bool,
4506        case_insensitive: bool,
4507        expected: &BooleanArray,
4508    ) -> Result<()> {
4509        let a = StringArray::from(va);
4510        let b = StringArray::from(vb);
4511        let op = similar_to(
4512            negated,
4513            case_insensitive,
4514            col("a", schema)?,
4515            col("b", schema)?,
4516        )?;
4517        let batch =
4518            RecordBatch::try_new(Arc::clone(schema), vec![Arc::new(a), Arc::new(b)])?;
4519        let result = op
4520            .evaluate(&batch)?
4521            .into_array(batch.num_rows())
4522            .expect("Failed to convert to array");
4523        assert_eq!(result.as_ref(), expected);
4524
4525        Ok(())
4526    }
4527
4528    #[test]
4529    fn test_similar_to() {
4530        let schema = Arc::new(Schema::new(vec![
4531            Field::new("a", DataType::Utf8, false),
4532            Field::new("b", DataType::Utf8, false),
4533        ]));
4534
4535        let expected = [Some(true), Some(false)].iter().collect();
4536        // case-sensitive
4537        apply_similar_to(
4538            &schema,
4539            vec!["hello world", "Hello World"],
4540            vec!["hello.*", "hello.*"],
4541            false,
4542            false,
4543            &expected,
4544        )
4545        .unwrap();
4546        // case-insensitive
4547        apply_similar_to(
4548            &schema,
4549            vec!["hello world", "bye"],
4550            vec!["hello.*", "hello.*"],
4551            false,
4552            true,
4553            &expected,
4554        )
4555        .unwrap();
4556    }
4557
4558    pub fn binary_expr(
4559        left: Arc<dyn PhysicalExpr>,
4560        op: Operator,
4561        right: Arc<dyn PhysicalExpr>,
4562        schema: &Schema,
4563    ) -> Result<BinaryExpr> {
4564        Ok(binary_op(left, op, right, schema)?
4565            .as_any()
4566            .downcast_ref::<BinaryExpr>()
4567            .unwrap()
4568            .clone())
4569    }
4570
4571    /// Test for Uniform-Uniform, Unknown-Uniform, Uniform-Unknown and Unknown-Unknown evaluation.
4572    #[test]
4573    fn test_evaluate_statistics_combination_of_range_holders() -> Result<()> {
4574        let schema = &Schema::new(vec![Field::new("a", DataType::Float64, false)]);
4575        let a = Arc::new(Column::new("a", 0)) as _;
4576        let b = lit(ScalarValue::from(12.0));
4577
4578        let left_interval = Interval::make(Some(0.0), Some(12.0))?;
4579        let right_interval = Interval::make(Some(12.0), Some(36.0))?;
4580        let (left_mean, right_mean) = (ScalarValue::from(6.0), ScalarValue::from(24.0));
4581        let (left_med, right_med) = (ScalarValue::from(6.0), ScalarValue::from(24.0));
4582
4583        for children in [
4584            vec![
4585                &Distribution::new_uniform(left_interval.clone())?,
4586                &Distribution::new_uniform(right_interval.clone())?,
4587            ],
4588            vec![
4589                &Distribution::new_generic(
4590                    left_mean.clone(),
4591                    left_med.clone(),
4592                    ScalarValue::Float64(None),
4593                    left_interval.clone(),
4594                )?,
4595                &Distribution::new_uniform(right_interval.clone())?,
4596            ],
4597            vec![
4598                &Distribution::new_uniform(right_interval.clone())?,
4599                &Distribution::new_generic(
4600                    right_mean.clone(),
4601                    right_med.clone(),
4602                    ScalarValue::Float64(None),
4603                    right_interval.clone(),
4604                )?,
4605            ],
4606            vec![
4607                &Distribution::new_generic(
4608                    left_mean.clone(),
4609                    left_med.clone(),
4610                    ScalarValue::Float64(None),
4611                    left_interval.clone(),
4612                )?,
4613                &Distribution::new_generic(
4614                    right_mean.clone(),
4615                    right_med.clone(),
4616                    ScalarValue::Float64(None),
4617                    right_interval.clone(),
4618                )?,
4619            ],
4620        ] {
4621            let ops = vec![
4622                Operator::Plus,
4623                Operator::Minus,
4624                Operator::Multiply,
4625                Operator::Divide,
4626            ];
4627
4628            for op in ops {
4629                let expr = binary_expr(Arc::clone(&a), op, Arc::clone(&b), schema)?;
4630                assert_eq!(
4631                    expr.evaluate_statistics(&children)?,
4632                    new_generic_from_binary_op(&op, children[0], children[1])?
4633                );
4634            }
4635        }
4636        Ok(())
4637    }
4638
4639    #[test]
4640    fn test_evaluate_statistics_bernoulli() -> Result<()> {
4641        let schema = &Schema::new(vec![
4642            Field::new("a", DataType::Int64, false),
4643            Field::new("b", DataType::Int64, false),
4644        ]);
4645        let a = Arc::new(Column::new("a", 0)) as _;
4646        let b = Arc::new(Column::new("b", 1)) as _;
4647        let eq = Arc::new(binary_expr(
4648            Arc::clone(&a),
4649            Operator::Eq,
4650            Arc::clone(&b),
4651            schema,
4652        )?);
4653        let neq = Arc::new(binary_expr(a, Operator::NotEq, b, schema)?);
4654
4655        let left_stat = &Distribution::new_uniform(Interval::make(Some(0), Some(7))?)?;
4656        let right_stat = &Distribution::new_uniform(Interval::make(Some(4), Some(11))?)?;
4657
4658        // Intervals: [0, 7], [4, 11].
4659        // The intersection is [4, 7], so the probability of equality is 4 / 64 = 1 / 16.
4660        assert_eq!(
4661            eq.evaluate_statistics(&[left_stat, right_stat])?,
4662            Distribution::new_bernoulli(ScalarValue::from(1.0 / 16.0))?
4663        );
4664
4665        // The probability of being distinct is 1 - 1 / 16 = 15 / 16.
4666        assert_eq!(
4667            neq.evaluate_statistics(&[left_stat, right_stat])?,
4668            Distribution::new_bernoulli(ScalarValue::from(15.0 / 16.0))?
4669        );
4670
4671        Ok(())
4672    }
4673
4674    #[test]
4675    fn test_propagate_statistics_combination_of_range_holders_arithmetic() -> Result<()> {
4676        let schema = &Schema::new(vec![Field::new("a", DataType::Float64, false)]);
4677        let a = Arc::new(Column::new("a", 0)) as _;
4678        let b = lit(ScalarValue::from(12.0));
4679
4680        let left_interval = Interval::make(Some(0.0), Some(12.0))?;
4681        let right_interval = Interval::make(Some(12.0), Some(36.0))?;
4682
4683        let parent = Distribution::new_uniform(Interval::make(Some(-432.), Some(432.))?)?;
4684        let children = vec![
4685            vec![
4686                Distribution::new_uniform(left_interval.clone())?,
4687                Distribution::new_uniform(right_interval.clone())?,
4688            ],
4689            vec![
4690                Distribution::new_generic(
4691                    ScalarValue::from(6.),
4692                    ScalarValue::from(6.),
4693                    ScalarValue::Float64(None),
4694                    left_interval.clone(),
4695                )?,
4696                Distribution::new_uniform(right_interval.clone())?,
4697            ],
4698            vec![
4699                Distribution::new_uniform(left_interval.clone())?,
4700                Distribution::new_generic(
4701                    ScalarValue::from(12.),
4702                    ScalarValue::from(12.),
4703                    ScalarValue::Float64(None),
4704                    right_interval.clone(),
4705                )?,
4706            ],
4707            vec![
4708                Distribution::new_generic(
4709                    ScalarValue::from(6.),
4710                    ScalarValue::from(6.),
4711                    ScalarValue::Float64(None),
4712                    left_interval.clone(),
4713                )?,
4714                Distribution::new_generic(
4715                    ScalarValue::from(12.),
4716                    ScalarValue::from(12.),
4717                    ScalarValue::Float64(None),
4718                    right_interval.clone(),
4719                )?,
4720            ],
4721        ];
4722
4723        let ops = vec![
4724            Operator::Plus,
4725            Operator::Minus,
4726            Operator::Multiply,
4727            Operator::Divide,
4728        ];
4729
4730        for child_view in children {
4731            let child_refs = child_view.iter().collect::<Vec<_>>();
4732            for op in &ops {
4733                let expr = binary_expr(Arc::clone(&a), *op, Arc::clone(&b), schema)?;
4734                assert_eq!(
4735                    expr.propagate_statistics(&parent, child_refs.as_slice())?,
4736                    Some(child_view.clone())
4737                );
4738            }
4739        }
4740        Ok(())
4741    }
4742
4743    #[test]
4744    fn test_propagate_statistics_combination_of_range_holders_comparison() -> Result<()> {
4745        let schema = &Schema::new(vec![Field::new("a", DataType::Float64, false)]);
4746        let a = Arc::new(Column::new("a", 0)) as _;
4747        let b = lit(ScalarValue::from(12.0));
4748
4749        let left_interval = Interval::make(Some(0.0), Some(12.0))?;
4750        let right_interval = Interval::make(Some(6.0), Some(18.0))?;
4751
4752        let one = ScalarValue::from(1.0);
4753        let parent = Distribution::new_bernoulli(one)?;
4754        let children = vec![
4755            vec![
4756                Distribution::new_uniform(left_interval.clone())?,
4757                Distribution::new_uniform(right_interval.clone())?,
4758            ],
4759            vec![
4760                Distribution::new_generic(
4761                    ScalarValue::from(6.),
4762                    ScalarValue::from(6.),
4763                    ScalarValue::Float64(None),
4764                    left_interval.clone(),
4765                )?,
4766                Distribution::new_uniform(right_interval.clone())?,
4767            ],
4768            vec![
4769                Distribution::new_uniform(left_interval.clone())?,
4770                Distribution::new_generic(
4771                    ScalarValue::from(12.),
4772                    ScalarValue::from(12.),
4773                    ScalarValue::Float64(None),
4774                    right_interval.clone(),
4775                )?,
4776            ],
4777            vec![
4778                Distribution::new_generic(
4779                    ScalarValue::from(6.),
4780                    ScalarValue::from(6.),
4781                    ScalarValue::Float64(None),
4782                    left_interval.clone(),
4783                )?,
4784                Distribution::new_generic(
4785                    ScalarValue::from(12.),
4786                    ScalarValue::from(12.),
4787                    ScalarValue::Float64(None),
4788                    right_interval.clone(),
4789                )?,
4790            ],
4791        ];
4792
4793        let ops = vec![
4794            Operator::Eq,
4795            Operator::Gt,
4796            Operator::GtEq,
4797            Operator::Lt,
4798            Operator::LtEq,
4799        ];
4800
4801        for child_view in children {
4802            let child_refs = child_view.iter().collect::<Vec<_>>();
4803            for op in &ops {
4804                let expr = binary_expr(Arc::clone(&a), *op, Arc::clone(&b), schema)?;
4805                assert!(
4806                    expr.propagate_statistics(&parent, child_refs.as_slice())?
4807                        .is_some()
4808                );
4809            }
4810        }
4811
4812        Ok(())
4813    }
4814
4815    #[test]
4816    fn test_fmt_sql() -> Result<()> {
4817        let schema = Schema::new(vec![
4818            Field::new("a", DataType::Int32, false),
4819            Field::new("b", DataType::Int32, false),
4820        ]);
4821
4822        // Test basic binary expressions
4823        let simple_expr = binary_expr(
4824            col("a", &schema)?,
4825            Operator::Plus,
4826            col("b", &schema)?,
4827            &schema,
4828        )?;
4829        let display_string = simple_expr.to_string();
4830        assert_eq!(display_string, "a@0 + b@1");
4831        let sql_string = fmt_sql(&simple_expr).to_string();
4832        assert_eq!(sql_string, "a + b");
4833
4834        // Test nested expressions with different operator precedence
4835        let nested_expr = binary_expr(
4836            Arc::new(binary_expr(
4837                col("a", &schema)?,
4838                Operator::Plus,
4839                col("b", &schema)?,
4840                &schema,
4841            )?),
4842            Operator::Multiply,
4843            col("b", &schema)?,
4844            &schema,
4845        )?;
4846        let display_string = nested_expr.to_string();
4847        assert_eq!(display_string, "(a@0 + b@1) * b@1");
4848        let sql_string = fmt_sql(&nested_expr).to_string();
4849        assert_eq!(sql_string, "(a + b) * b");
4850
4851        // Test nested expressions with same operator precedence
4852        let nested_same_prec = binary_expr(
4853            Arc::new(binary_expr(
4854                col("a", &schema)?,
4855                Operator::Plus,
4856                col("b", &schema)?,
4857                &schema,
4858            )?),
4859            Operator::Plus,
4860            col("b", &schema)?,
4861            &schema,
4862        )?;
4863        let display_string = nested_same_prec.to_string();
4864        assert_eq!(display_string, "a@0 + b@1 + b@1");
4865        let sql_string = fmt_sql(&nested_same_prec).to_string();
4866        assert_eq!(sql_string, "a + b + b");
4867
4868        // Test with literals
4869        let lit_expr = binary_expr(
4870            col("a", &schema)?,
4871            Operator::Eq,
4872            lit(ScalarValue::Int32(Some(42))),
4873            &schema,
4874        )?;
4875        let display_string = lit_expr.to_string();
4876        assert_eq!(display_string, "a@0 = 42");
4877        let sql_string = fmt_sql(&lit_expr).to_string();
4878        assert_eq!(sql_string, "a = 42");
4879
4880        Ok(())
4881    }
4882
4883    #[test]
4884    fn test_check_short_circuit() {
4885        // Test with non-nullable arrays
4886        let schema = Arc::new(Schema::new(vec![
4887            Field::new("a", DataType::Int32, false),
4888            Field::new("b", DataType::Int32, false),
4889        ]));
4890        let a_array = Int32Array::from(vec![1, 3, 4, 5, 6]);
4891        let b_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
4892        let batch = RecordBatch::try_new(
4893            Arc::clone(&schema),
4894            vec![Arc::new(a_array), Arc::new(b_array)],
4895        )
4896        .unwrap();
4897
4898        // op: AND left: all false
4899        let left_expr = logical2physical(&logical_col("a").eq(expr_lit(2)), &schema);
4900        let left_value = left_expr.evaluate(&batch).unwrap();
4901        assert!(matches!(
4902            check_short_circuit(&left_value, &Operator::And),
4903            ShortCircuitStrategy::ReturnLeft
4904        ));
4905
4906        // op: AND left: not all false
4907        let left_expr = logical2physical(&logical_col("a").eq(expr_lit(3)), &schema);
4908        let left_value = left_expr.evaluate(&batch).unwrap();
4909        let ColumnarValue::Array(array) = &left_value else {
4910            panic!("Expected ColumnarValue::Array");
4911        };
4912        let ShortCircuitStrategy::PreSelection(value) =
4913            check_short_circuit(&left_value, &Operator::And)
4914        else {
4915            panic!("Expected ShortCircuitStrategy::PreSelection");
4916        };
4917        let expected_boolean_arr: Vec<_> =
4918            as_boolean_array(array).unwrap().iter().collect();
4919        let boolean_arr: Vec<_> = value.iter().collect();
4920        assert_eq!(expected_boolean_arr, boolean_arr);
4921
4922        // op: OR left: all true
4923        let left_expr = logical2physical(&logical_col("a").gt(expr_lit(0)), &schema);
4924        let left_value = left_expr.evaluate(&batch).unwrap();
4925        assert!(matches!(
4926            check_short_circuit(&left_value, &Operator::Or),
4927            ShortCircuitStrategy::ReturnLeft
4928        ));
4929
4930        // op: OR left: not all true
4931        let left_expr: Arc<dyn PhysicalExpr> =
4932            logical2physical(&logical_col("a").gt(expr_lit(2)), &schema);
4933        let left_value = left_expr.evaluate(&batch).unwrap();
4934        assert!(matches!(
4935            check_short_circuit(&left_value, &Operator::Or),
4936            ShortCircuitStrategy::None
4937        ));
4938
4939        // Test with nullable arrays and null values
4940        let schema_nullable = Arc::new(Schema::new(vec![
4941            Field::new("c", DataType::Boolean, true),
4942            Field::new("d", DataType::Boolean, true),
4943        ]));
4944
4945        // Create arrays with null values
4946        let c_array = Arc::new(BooleanArray::from(vec![
4947            Some(true),
4948            Some(false),
4949            None,
4950            Some(true),
4951            None,
4952        ])) as ArrayRef;
4953        let d_array = Arc::new(BooleanArray::from(vec![
4954            Some(false),
4955            Some(true),
4956            Some(false),
4957            None,
4958            Some(true),
4959        ])) as ArrayRef;
4960
4961        let batch_nullable = RecordBatch::try_new(
4962            Arc::clone(&schema_nullable),
4963            vec![Arc::clone(&c_array), Arc::clone(&d_array)],
4964        )
4965        .unwrap();
4966
4967        // Case: Mixed values with nulls - shouldn't short-circuit for AND
4968        let mixed_nulls = logical2physical(&logical_col("c"), &schema_nullable);
4969        let mixed_nulls_value = mixed_nulls.evaluate(&batch_nullable).unwrap();
4970        assert!(matches!(
4971            check_short_circuit(&mixed_nulls_value, &Operator::And),
4972            ShortCircuitStrategy::None
4973        ));
4974
4975        // Case: Mixed values with nulls - shouldn't short-circuit for OR
4976        assert!(matches!(
4977            check_short_circuit(&mixed_nulls_value, &Operator::Or),
4978            ShortCircuitStrategy::None
4979        ));
4980
4981        // Test with all nulls
4982        let all_nulls = Arc::new(BooleanArray::from(vec![None, None, None])) as ArrayRef;
4983        let null_batch = RecordBatch::try_new(
4984            Arc::new(Schema::new(vec![Field::new("e", DataType::Boolean, true)])),
4985            vec![all_nulls],
4986        )
4987        .unwrap();
4988
4989        let null_expr = logical2physical(&logical_col("e"), &null_batch.schema());
4990        let null_value = null_expr.evaluate(&null_batch).unwrap();
4991
4992        // All nulls shouldn't short-circuit for AND or OR
4993        assert!(matches!(
4994            check_short_circuit(&null_value, &Operator::And),
4995            ShortCircuitStrategy::None
4996        ));
4997        assert!(matches!(
4998            check_short_circuit(&null_value, &Operator::Or),
4999            ShortCircuitStrategy::None
5000        ));
5001
5002        // Test with scalar values
5003        // Scalar true
5004        let scalar_true = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true)));
5005        assert!(matches!(
5006            check_short_circuit(&scalar_true, &Operator::Or),
5007            ShortCircuitStrategy::ReturnLeft
5008        )); // Should short-circuit OR
5009        assert!(matches!(
5010            check_short_circuit(&scalar_true, &Operator::And),
5011            ShortCircuitStrategy::ReturnRight
5012        )); // Should return the RHS for AND
5013
5014        // Scalar false
5015        let scalar_false = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false)));
5016        assert!(matches!(
5017            check_short_circuit(&scalar_false, &Operator::And),
5018            ShortCircuitStrategy::ReturnLeft
5019        )); // Should short-circuit AND
5020        assert!(matches!(
5021            check_short_circuit(&scalar_false, &Operator::Or),
5022            ShortCircuitStrategy::ReturnRight
5023        )); // Should return the RHS for OR
5024
5025        // Scalar null
5026        let scalar_null = ColumnarValue::Scalar(ScalarValue::Boolean(None));
5027        assert!(matches!(
5028            check_short_circuit(&scalar_null, &Operator::And),
5029            ShortCircuitStrategy::None
5030        ));
5031        assert!(matches!(
5032            check_short_circuit(&scalar_null, &Operator::Or),
5033            ShortCircuitStrategy::None
5034        ));
5035    }
5036
5037    /// Test for [pre_selection_scatter]
5038    /// Since [check_short_circuit] ensures that the left side does not contain null and is neither all_true nor all_false, as well as not being empty,
5039    /// the following tests have been designed:
5040    /// 1. Test sparse left with interleaved true/false
5041    /// 2. Test multiple consecutive true blocks
5042    /// 3. Test multiple consecutive true blocks
5043    /// 4. Test single true at first position
5044    /// 5. Test single true at last position
5045    /// 6. Test nulls in right array
5046    #[test]
5047    fn test_pre_selection_scatter() {
5048        fn create_bool_array(bools: Vec<bool>) -> BooleanArray {
5049            BooleanArray::from(bools.into_iter().map(Some).collect::<Vec<_>>())
5050        }
5051        // Test sparse left with interleaved true/false
5052        {
5053            // Left: [T, F, T, F, T]
5054            // Right: [F, T, F] (values for 3 true positions)
5055            let left = create_bool_array(vec![true, false, true, false, true]);
5056            let right = create_bool_array(vec![false, true, false]);
5057
5058            let result = pre_selection_scatter(&left, Some(&right)).unwrap();
5059            let result_arr = result.into_array(left.len()).unwrap();
5060
5061            let expected = create_bool_array(vec![false, false, true, false, false]);
5062            assert_eq!(&expected, result_arr.as_boolean());
5063        }
5064        // Test multiple consecutive true blocks
5065        {
5066            // Left: [F, T, T, F, T, T, T]
5067            // Right: [T, F, F, T, F]
5068            let left =
5069                create_bool_array(vec![false, true, true, false, true, true, true]);
5070            let right = create_bool_array(vec![true, false, false, true, false]);
5071
5072            let result = pre_selection_scatter(&left, Some(&right)).unwrap();
5073            let result_arr = result.into_array(left.len()).unwrap();
5074
5075            let expected =
5076                create_bool_array(vec![false, true, false, false, false, true, false]);
5077            assert_eq!(&expected, result_arr.as_boolean());
5078        }
5079        // Test single true at first position
5080        {
5081            // Left: [T, F, F]
5082            // Right: [F]
5083            let left = create_bool_array(vec![true, false, false]);
5084            let right = create_bool_array(vec![false]);
5085
5086            let result = pre_selection_scatter(&left, Some(&right)).unwrap();
5087            let result_arr = result.into_array(left.len()).unwrap();
5088
5089            let expected = create_bool_array(vec![false, false, false]);
5090            assert_eq!(&expected, result_arr.as_boolean());
5091        }
5092        // Test single true at last position
5093        {
5094            // Left: [F, F, T]
5095            // Right: [F]
5096            let left = create_bool_array(vec![false, false, true]);
5097            let right = create_bool_array(vec![false]);
5098
5099            let result = pre_selection_scatter(&left, Some(&right)).unwrap();
5100            let result_arr = result.into_array(left.len()).unwrap();
5101
5102            let expected = create_bool_array(vec![false, false, false]);
5103            assert_eq!(&expected, result_arr.as_boolean());
5104        }
5105        // Test nulls in right array
5106        {
5107            // Left: [F, T, F, T]
5108            // Right: [None, Some(false)] (with null at first position)
5109            let left = create_bool_array(vec![false, true, false, true]);
5110            let right = BooleanArray::from(vec![None, Some(false)]);
5111
5112            let result = pre_selection_scatter(&left, Some(&right)).unwrap();
5113            let result_arr = result.into_array(left.len()).unwrap();
5114
5115            let expected = BooleanArray::from(vec![
5116                Some(false),
5117                None, // null from right
5118                Some(false),
5119                Some(false),
5120            ]);
5121            assert_eq!(&expected, result_arr.as_boolean());
5122        }
5123    }
5124
5125    #[test]
5126    fn test_and_true_preselection_returns_lhs() {
5127        let schema =
5128            Arc::new(Schema::new(vec![Field::new("c", DataType::Boolean, false)]));
5129        let c_array = Arc::new(BooleanArray::from(vec![false, true, false, false, false]))
5130            as ArrayRef;
5131        let batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::clone(&c_array)])
5132            .unwrap();
5133
5134        let expr = logical2physical(&logical_col("c").and(expr_lit(true)), &schema);
5135
5136        let result = expr.evaluate(&batch).unwrap();
5137        let ColumnarValue::Array(result_arr) = result else {
5138            panic!("Expected ColumnarValue::Array");
5139        };
5140
5141        let expected: Vec<_> = c_array.as_boolean().iter().collect();
5142        let actual: Vec<_> = result_arr.as_boolean().iter().collect();
5143        assert_eq!(
5144            expected, actual,
5145            "AND with TRUE must equal LHS even with PreSelection"
5146        );
5147    }
5148
5149    #[test]
5150    fn test_evaluate_bounds_int32() {
5151        let schema = Schema::new(vec![
5152            Field::new("a", DataType::Int32, false),
5153            Field::new("b", DataType::Int32, false),
5154        ]);
5155
5156        let a = Arc::new(Column::new("a", 0)) as _;
5157        let b = Arc::new(Column::new("b", 1)) as _;
5158
5159        // Test addition bounds
5160        let add_expr =
5161            binary_expr(Arc::clone(&a), Operator::Plus, Arc::clone(&b), &schema).unwrap();
5162        let add_bounds = add_expr
5163            .evaluate_bounds(&[
5164                &Interval::make(Some(1), Some(10)).unwrap(),
5165                &Interval::make(Some(5), Some(15)).unwrap(),
5166            ])
5167            .unwrap();
5168        assert_eq!(add_bounds, Interval::make(Some(6), Some(25)).unwrap());
5169
5170        // Test subtraction bounds
5171        let sub_expr =
5172            binary_expr(Arc::clone(&a), Operator::Minus, Arc::clone(&b), &schema)
5173                .unwrap();
5174        let sub_bounds = sub_expr
5175            .evaluate_bounds(&[
5176                &Interval::make(Some(1), Some(10)).unwrap(),
5177                &Interval::make(Some(5), Some(15)).unwrap(),
5178            ])
5179            .unwrap();
5180        assert_eq!(sub_bounds, Interval::make(Some(-14), Some(5)).unwrap());
5181
5182        // Test multiplication bounds
5183        let mul_expr =
5184            binary_expr(Arc::clone(&a), Operator::Multiply, Arc::clone(&b), &schema)
5185                .unwrap();
5186        let mul_bounds = mul_expr
5187            .evaluate_bounds(&[
5188                &Interval::make(Some(1), Some(10)).unwrap(),
5189                &Interval::make(Some(5), Some(15)).unwrap(),
5190            ])
5191            .unwrap();
5192        assert_eq!(mul_bounds, Interval::make(Some(5), Some(150)).unwrap());
5193
5194        // Test division bounds
5195        let div_expr =
5196            binary_expr(Arc::clone(&a), Operator::Divide, Arc::clone(&b), &schema)
5197                .unwrap();
5198        let div_bounds = div_expr
5199            .evaluate_bounds(&[
5200                &Interval::make(Some(10), Some(20)).unwrap(),
5201                &Interval::make(Some(2), Some(5)).unwrap(),
5202            ])
5203            .unwrap();
5204        assert_eq!(div_bounds, Interval::make(Some(2), Some(10)).unwrap());
5205    }
5206
5207    #[test]
5208    fn test_evaluate_bounds_bool() {
5209        let schema = Schema::new(vec![
5210            Field::new("a", DataType::Boolean, false),
5211            Field::new("b", DataType::Boolean, false),
5212        ]);
5213
5214        let a = Arc::new(Column::new("a", 0)) as _;
5215        let b = Arc::new(Column::new("b", 1)) as _;
5216
5217        // Test OR bounds
5218        let or_expr =
5219            binary_expr(Arc::clone(&a), Operator::Or, Arc::clone(&b), &schema).unwrap();
5220        let or_bounds = or_expr
5221            .evaluate_bounds(&[
5222                &Interval::make(Some(true), Some(true)).unwrap(),
5223                &Interval::make(Some(false), Some(false)).unwrap(),
5224            ])
5225            .unwrap();
5226        assert_eq!(or_bounds, Interval::make(Some(true), Some(true)).unwrap());
5227
5228        // Test AND bounds
5229        let and_expr =
5230            binary_expr(Arc::clone(&a), Operator::And, Arc::clone(&b), &schema).unwrap();
5231        let and_bounds = and_expr
5232            .evaluate_bounds(&[
5233                &Interval::make(Some(true), Some(true)).unwrap(),
5234                &Interval::make(Some(false), Some(false)).unwrap(),
5235            ])
5236            .unwrap();
5237        assert_eq!(
5238            and_bounds,
5239            Interval::make(Some(false), Some(false)).unwrap()
5240        );
5241    }
5242
5243    #[test]
5244    fn test_evaluate_nested_type() {
5245        let batch_schema = Arc::new(Schema::new(vec![
5246            Field::new(
5247                "a",
5248                DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
5249                true,
5250            ),
5251            Field::new(
5252                "b",
5253                DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
5254                true,
5255            ),
5256        ]));
5257
5258        let mut list_builder_a = ListBuilder::new(Int32Builder::new());
5259
5260        list_builder_a.append_value([Some(1)]);
5261        list_builder_a.append_value([Some(2)]);
5262        list_builder_a.append_value([]);
5263        list_builder_a.append_value([None]);
5264
5265        let list_array_a: ArrayRef = Arc::new(list_builder_a.finish());
5266
5267        let mut list_builder_b = ListBuilder::new(Int32Builder::new());
5268
5269        list_builder_b.append_value([Some(1)]);
5270        list_builder_b.append_value([Some(2)]);
5271        list_builder_b.append_value([]);
5272        list_builder_b.append_value([None]);
5273
5274        let list_array_b: ArrayRef = Arc::new(list_builder_b.finish());
5275
5276        let batch =
5277            RecordBatch::try_new(batch_schema, vec![list_array_a, list_array_b]).unwrap();
5278
5279        let schema = Arc::new(Schema::new(vec![
5280            Field::new(
5281                "a",
5282                DataType::List(Arc::new(Field::new("foo", DataType::Int32, true))),
5283                true,
5284            ),
5285            Field::new(
5286                "b",
5287                DataType::List(Arc::new(Field::new("bar", DataType::Int32, true))),
5288                true,
5289            ),
5290        ]));
5291
5292        let a = Arc::new(Column::new("a", 0)) as _;
5293        let b = Arc::new(Column::new("b", 1)) as _;
5294
5295        let eq_expr =
5296            binary_expr(Arc::clone(&a), Operator::Eq, Arc::clone(&b), &schema).unwrap();
5297
5298        let eq_result = eq_expr.evaluate(&batch).unwrap();
5299        let expected =
5300            BooleanArray::from_iter(vec![Some(true), Some(true), Some(true), Some(true)]);
5301        assert_eq!(eq_result.into_array(4).unwrap().as_boolean(), &expected);
5302    }
5303}