datafusion_physical_expr/expressions/
binary.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18mod kernels;
19
20use crate::expressions::binary::kernels::concat_elements_utf8view;
21use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
22use crate::PhysicalExpr;
23use std::hash::Hash;
24use std::{any::Any, sync::Arc};
25
26use arrow::array::*;
27use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene};
28use arrow::compute::kernels::cmp::*;
29use arrow::compute::kernels::comparison::{regexp_is_match, regexp_is_match_scalar};
30use arrow::compute::kernels::concat_elements::concat_elements_utf8;
31use arrow::compute::{
32    cast, filter_record_batch, ilike, like, nilike, nlike, SlicesIterator,
33};
34use arrow::datatypes::*;
35use arrow::error::ArrowError;
36use datafusion_common::cast::as_boolean_array;
37use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue};
38use datafusion_expr::binary::BinaryTypeCoercer;
39use datafusion_expr::interval_arithmetic::{apply_operator, Interval};
40use datafusion_expr::sort_properties::ExprProperties;
41use datafusion_expr::statistics::Distribution::{Bernoulli, Gaussian};
42use datafusion_expr::statistics::{
43    combine_bernoullis, combine_gaussians, create_bernoulli_from_comparison,
44    new_generic_from_binary_op, Distribution,
45};
46use datafusion_expr::{ColumnarValue, Operator};
47use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested};
48
49use kernels::{
50    bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar,
51    bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, bitwise_shift_right_dyn,
52    bitwise_shift_right_dyn_scalar, bitwise_xor_dyn, bitwise_xor_dyn_scalar,
53};
54
55/// Binary expression
56#[derive(Debug, Clone, Eq)]
57pub struct BinaryExpr {
58    left: Arc<dyn PhysicalExpr>,
59    op: Operator,
60    right: Arc<dyn PhysicalExpr>,
61    /// Specifies whether an error is returned on overflow or not
62    fail_on_overflow: bool,
63}
64
65// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808
66impl PartialEq for BinaryExpr {
67    fn eq(&self, other: &Self) -> bool {
68        self.left.eq(&other.left)
69            && self.op.eq(&other.op)
70            && self.right.eq(&other.right)
71            && self.fail_on_overflow.eq(&other.fail_on_overflow)
72    }
73}
74impl Hash for BinaryExpr {
75    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
76        self.left.hash(state);
77        self.op.hash(state);
78        self.right.hash(state);
79        self.fail_on_overflow.hash(state);
80    }
81}
82
83impl BinaryExpr {
84    /// Create new binary expression
85    pub fn new(
86        left: Arc<dyn PhysicalExpr>,
87        op: Operator,
88        right: Arc<dyn PhysicalExpr>,
89    ) -> Self {
90        Self {
91            left,
92            op,
93            right,
94            fail_on_overflow: false,
95        }
96    }
97
98    /// Create new binary expression with explicit fail_on_overflow value
99    pub fn with_fail_on_overflow(self, fail_on_overflow: bool) -> Self {
100        Self {
101            left: self.left,
102            op: self.op,
103            right: self.right,
104            fail_on_overflow,
105        }
106    }
107
108    /// Get the left side of the binary expression
109    pub fn left(&self) -> &Arc<dyn PhysicalExpr> {
110        &self.left
111    }
112
113    /// Get the right side of the binary expression
114    pub fn right(&self) -> &Arc<dyn PhysicalExpr> {
115        &self.right
116    }
117
118    /// Get the operator for this binary expression
119    pub fn op(&self) -> &Operator {
120        &self.op
121    }
122}
123
124impl std::fmt::Display for BinaryExpr {
125    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
126        // Put parentheses around child binary expressions so that we can see the difference
127        // between `(a OR b) AND c` and `a OR (b AND c)`. We only insert parentheses when needed,
128        // based on operator precedence. For example, `(a AND b) OR c` and `a AND b OR c` are
129        // equivalent and the parentheses are not necessary.
130
131        fn write_child(
132            f: &mut std::fmt::Formatter,
133            expr: &dyn PhysicalExpr,
134            precedence: u8,
135        ) -> std::fmt::Result {
136            if let Some(child) = expr.as_any().downcast_ref::<BinaryExpr>() {
137                let p = child.op.precedence();
138                if p == 0 || p < precedence {
139                    write!(f, "({child})")?;
140                } else {
141                    write!(f, "{child}")?;
142                }
143            } else {
144                write!(f, "{expr}")?;
145            }
146
147            Ok(())
148        }
149
150        let precedence = self.op.precedence();
151        write_child(f, self.left.as_ref(), precedence)?;
152        write!(f, " {} ", self.op)?;
153        write_child(f, self.right.as_ref(), precedence)
154    }
155}
156
157/// Invoke a boolean kernel on a pair of arrays
158#[inline]
159fn boolean_op(
160    left: &dyn Array,
161    right: &dyn Array,
162    op: impl FnOnce(&BooleanArray, &BooleanArray) -> Result<BooleanArray, ArrowError>,
163) -> Result<Arc<(dyn Array + 'static)>, ArrowError> {
164    let ll = as_boolean_array(left).expect("boolean_op failed to downcast left array");
165    let rr = as_boolean_array(right).expect("boolean_op failed to downcast right array");
166    op(ll, rr).map(|t| Arc::new(t) as _)
167}
168
169macro_rules! binary_string_array_flag_op {
170    ($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{
171        match $LEFT.data_type() {
172            DataType::Utf8 => {
173                compute_utf8_flag_op!($LEFT, $RIGHT, $OP, StringArray, $NOT, $FLAG)
174            },
175            DataType::Utf8View => {
176                compute_utf8view_flag_op!($LEFT, $RIGHT, $OP, StringViewArray, $NOT, $FLAG)
177            }
178            DataType::LargeUtf8 => {
179                compute_utf8_flag_op!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG)
180            },
181            other => internal_err!(
182                "Data type {:?} not supported for binary_string_array_flag_op operation '{}' on string array",
183                other, stringify!($OP)
184            ),
185        }
186    }};
187}
188
189/// Invoke a compute kernel on a pair of binary data arrays with flags
190macro_rules! compute_utf8_flag_op {
191    ($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr, $FLAG:expr) => {{
192        let ll = $LEFT
193            .as_any()
194            .downcast_ref::<$ARRAYTYPE>()
195            .expect("compute_utf8_flag_op failed to downcast array");
196        let rr = $RIGHT
197            .as_any()
198            .downcast_ref::<$ARRAYTYPE>()
199            .expect("compute_utf8_flag_op failed to downcast array");
200
201        let flag = if $FLAG {
202            Some($ARRAYTYPE::from(vec!["i"; ll.len()]))
203        } else {
204            None
205        };
206        let mut array = $OP(ll, rr, flag.as_ref())?;
207        if $NOT {
208            array = not(&array).unwrap();
209        }
210        Ok(Arc::new(array))
211    }};
212}
213
214/// Invoke a compute kernel on a pair of binary data arrays with flags
215macro_rules! compute_utf8view_flag_op {
216    ($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr, $FLAG:expr) => {{
217        let ll = $LEFT
218            .as_any()
219            .downcast_ref::<$ARRAYTYPE>()
220            .expect("compute_utf8view_flag_op failed to downcast array");
221        let rr = $RIGHT
222            .as_any()
223            .downcast_ref::<$ARRAYTYPE>()
224            .expect("compute_utf8view_flag_op failed to downcast array");
225
226        let flag = if $FLAG {
227            Some($ARRAYTYPE::from(vec!["i"; ll.len()]))
228        } else {
229            None
230        };
231        let mut array = $OP(ll, rr, flag.as_ref())?;
232        if $NOT {
233            array = not(&array).unwrap();
234        }
235        Ok(Arc::new(array))
236    }};
237}
238
239macro_rules! binary_string_array_flag_op_scalar {
240    ($LEFT:ident, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{
241        // This macro is slightly different from binary_string_array_flag_op because, when comparing with a scalar value,
242        // the query can be optimized in such a way that operands will be dicts, so we need to support it here
243        let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
244            DataType::Utf8 => {
245                compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, StringArray, $NOT, $FLAG)
246            },
247            DataType::Utf8View => {
248                compute_utf8view_flag_op_scalar!($LEFT, $RIGHT, $OP, StringViewArray, $NOT, $FLAG)
249            }
250            DataType::LargeUtf8 => {
251                compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG)
252            },
253            DataType::Dictionary(_, _) => {
254                let values = $LEFT.as_any_dictionary().values();
255
256                match values.data_type() {
257                    DataType::Utf8 => compute_utf8_flag_op_scalar!(values, $RIGHT, $OP, StringArray, $NOT, $FLAG),
258                    DataType::Utf8View => compute_utf8view_flag_op_scalar!(values, $RIGHT, $OP, StringViewArray, $NOT, $FLAG),
259                    DataType::LargeUtf8 => compute_utf8_flag_op_scalar!(values, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG),
260                    other => internal_err!(
261                        "Data type {:?} not supported as a dictionary value type for binary_string_array_flag_op_scalar operation '{}' on string array",
262                        other, stringify!($OP)
263                    ),
264                }.map(
265                    // downcast_dictionary_array duplicates code per possible key type, so we aim to do all prep work before
266                    |evaluated_values| downcast_dictionary_array! {
267                        $LEFT => {
268                            let unpacked_dict = evaluated_values.take_iter($LEFT.keys().iter().map(|opt| opt.map(|v| v as _))).collect::<BooleanArray>();
269                            Arc::new(unpacked_dict) as _
270                        },
271                        _ => unreachable!(),
272                    }
273                )
274            },
275            other => internal_err!(
276                "Data type {:?} not supported for binary_string_array_flag_op_scalar operation '{}' on string array",
277                other, stringify!($OP)
278            ),
279        };
280        Some(result)
281    }};
282}
283
284/// Invoke a compute kernel on a data array and a scalar value with flag
285macro_rules! compute_utf8_flag_op_scalar {
286    ($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr, $FLAG:expr) => {{
287        let ll = $LEFT
288            .as_any()
289            .downcast_ref::<$ARRAYTYPE>()
290            .expect("compute_utf8_flag_op_scalar failed to downcast array");
291
292        let string_value = match $RIGHT.try_as_str() {
293            Some(Some(string_value)) => string_value,
294            // null literal or non string
295            _ => return internal_err!(
296                        "compute_utf8_flag_op_scalar failed to cast literal value {} for operation '{}'",
297                        $RIGHT, stringify!($OP)
298                    )
299        };
300
301        let flag = $FLAG.then_some("i");
302        let mut array =
303            paste::expr! {[<$OP _scalar>]}(ll, &string_value, flag)?;
304        if $NOT {
305            array = not(&array).unwrap();
306        }
307
308        Ok(Arc::new(array))
309    }};
310}
311
312/// Invoke a compute kernel on a data array and a scalar value with flag
313macro_rules! compute_utf8view_flag_op_scalar {
314    ($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr, $FLAG:expr) => {{
315        let ll = $LEFT
316            .as_any()
317            .downcast_ref::<$ARRAYTYPE>()
318            .expect("compute_utf8view_flag_op_scalar failed to downcast array");
319
320        let string_value = match $RIGHT.try_as_str() {
321            Some(Some(string_value)) => string_value,
322            // null literal or non string
323            _ => return internal_err!(
324                        "compute_utf8view_flag_op_scalar failed to cast literal value {} for operation '{}'",
325                        $RIGHT, stringify!($OP)
326                    )
327        };
328
329        let flag = $FLAG.then_some("i");
330        let mut array =
331            paste::expr! {[<$OP _scalar>]}(ll, &string_value, flag)?;
332        if $NOT {
333            array = not(&array).unwrap();
334        }
335
336        Ok(Arc::new(array))
337    }};
338}
339
340impl PhysicalExpr for BinaryExpr {
341    /// Return a reference to Any that can be used for downcasting
342    fn as_any(&self) -> &dyn Any {
343        self
344    }
345
346    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
347        BinaryTypeCoercer::new(
348            &self.left.data_type(input_schema)?,
349            &self.op,
350            &self.right.data_type(input_schema)?,
351        )
352        .get_result_type()
353    }
354
355    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
356        Ok(self.left.nullable(input_schema)? || self.right.nullable(input_schema)?)
357    }
358
359    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
360        use arrow::compute::kernels::numeric::*;
361
362        // Evaluate left-hand side expression.
363        let lhs = self.left.evaluate(batch)?;
364
365        // Check if we can apply short-circuit evaluation.
366        match check_short_circuit(&lhs, &self.op) {
367            ShortCircuitStrategy::None => {}
368            ShortCircuitStrategy::ReturnLeft => return Ok(lhs),
369            ShortCircuitStrategy::ReturnRight => {
370                let rhs = self.right.evaluate(batch)?;
371                return Ok(rhs);
372            }
373            ShortCircuitStrategy::PreSelection(selection) => {
374                // The function `evaluate_selection` was not called for filtering and calculation,
375                // as it takes into account cases where the selection contains null values.
376                let batch = filter_record_batch(batch, selection)?;
377                let right_ret = self.right.evaluate(&batch)?;
378                return pre_selection_scatter(selection, right_ret);
379            }
380        }
381
382        let rhs = self.right.evaluate(batch)?;
383        let left_data_type = lhs.data_type();
384        let right_data_type = rhs.data_type();
385
386        let schema = batch.schema();
387        let input_schema = schema.as_ref();
388
389        if left_data_type.is_nested() {
390            if !left_data_type.equals_datatype(&right_data_type) {
391                return internal_err!("Cannot evaluate binary expression because of type mismatch: left {}, right {} ", left_data_type, right_data_type);
392            }
393            return apply_cmp_for_nested(self.op, &lhs, &rhs);
394        }
395
396        match self.op {
397            Operator::Plus if self.fail_on_overflow => return apply(&lhs, &rhs, add),
398            Operator::Plus => return apply(&lhs, &rhs, add_wrapping),
399            Operator::Minus if self.fail_on_overflow => return apply(&lhs, &rhs, sub),
400            Operator::Minus => return apply(&lhs, &rhs, sub_wrapping),
401            Operator::Multiply if self.fail_on_overflow => return apply(&lhs, &rhs, mul),
402            Operator::Multiply => return apply(&lhs, &rhs, mul_wrapping),
403            Operator::Divide => return apply(&lhs, &rhs, div),
404            Operator::Modulo => return apply(&lhs, &rhs, rem),
405            Operator::Eq => return apply_cmp(&lhs, &rhs, eq),
406            Operator::NotEq => return apply_cmp(&lhs, &rhs, neq),
407            Operator::Lt => return apply_cmp(&lhs, &rhs, lt),
408            Operator::Gt => return apply_cmp(&lhs, &rhs, gt),
409            Operator::LtEq => return apply_cmp(&lhs, &rhs, lt_eq),
410            Operator::GtEq => return apply_cmp(&lhs, &rhs, gt_eq),
411            Operator::IsDistinctFrom => return apply_cmp(&lhs, &rhs, distinct),
412            Operator::IsNotDistinctFrom => return apply_cmp(&lhs, &rhs, not_distinct),
413            Operator::LikeMatch => return apply_cmp(&lhs, &rhs, like),
414            Operator::ILikeMatch => return apply_cmp(&lhs, &rhs, ilike),
415            Operator::NotLikeMatch => return apply_cmp(&lhs, &rhs, nlike),
416            Operator::NotILikeMatch => return apply_cmp(&lhs, &rhs, nilike),
417            _ => {}
418        }
419
420        let result_type = self.data_type(input_schema)?;
421
422        // If the left-hand side is an array and the right-hand side is a non-null scalar, try the optimized kernel.
423        if let (ColumnarValue::Array(array), ColumnarValue::Scalar(ref scalar)) =
424            (&lhs, &rhs)
425        {
426            if !scalar.is_null() {
427                if let Some(result_array) =
428                    self.evaluate_array_scalar(array, scalar.clone())?
429                {
430                    let final_array = result_array
431                        .and_then(|a| to_result_type_array(&self.op, a, &result_type));
432                    return final_array.map(ColumnarValue::Array);
433                }
434            }
435        }
436
437        // if both arrays or both literals - extract arrays and continue execution
438        let (left, right) = (
439            lhs.into_array(batch.num_rows())?,
440            rhs.into_array(batch.num_rows())?,
441        );
442        self.evaluate_with_resolved_args(left, &left_data_type, right, &right_data_type)
443            .map(ColumnarValue::Array)
444    }
445
446    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
447        vec![&self.left, &self.right]
448    }
449
450    fn with_new_children(
451        self: Arc<Self>,
452        children: Vec<Arc<dyn PhysicalExpr>>,
453    ) -> Result<Arc<dyn PhysicalExpr>> {
454        Ok(Arc::new(
455            BinaryExpr::new(Arc::clone(&children[0]), self.op, Arc::clone(&children[1]))
456                .with_fail_on_overflow(self.fail_on_overflow),
457        ))
458    }
459
460    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
461        // Get children intervals:
462        let left_interval = children[0];
463        let right_interval = children[1];
464        // Calculate current node's interval:
465        apply_operator(&self.op, left_interval, right_interval)
466    }
467
468    fn propagate_constraints(
469        &self,
470        interval: &Interval,
471        children: &[&Interval],
472    ) -> Result<Option<Vec<Interval>>> {
473        // Get children intervals.
474        let left_interval = children[0];
475        let right_interval = children[1];
476
477        if self.op.eq(&Operator::And) {
478            if interval.eq(&Interval::CERTAINLY_TRUE) {
479                // A certainly true logical conjunction can only derive from possibly
480                // true operands. Otherwise, we prove infeasibility.
481                Ok((!left_interval.eq(&Interval::CERTAINLY_FALSE)
482                    && !right_interval.eq(&Interval::CERTAINLY_FALSE))
483                .then(|| vec![Interval::CERTAINLY_TRUE, Interval::CERTAINLY_TRUE]))
484            } else if interval.eq(&Interval::CERTAINLY_FALSE) {
485                // If the logical conjunction is certainly false, one of the
486                // operands must be false. However, it's not always possible to
487                // determine which operand is false, leading to different scenarios.
488
489                // If one operand is certainly true and the other one is uncertain,
490                // then the latter must be certainly false.
491                if left_interval.eq(&Interval::CERTAINLY_TRUE)
492                    && right_interval.eq(&Interval::UNCERTAIN)
493                {
494                    Ok(Some(vec![
495                        Interval::CERTAINLY_TRUE,
496                        Interval::CERTAINLY_FALSE,
497                    ]))
498                } else if right_interval.eq(&Interval::CERTAINLY_TRUE)
499                    && left_interval.eq(&Interval::UNCERTAIN)
500                {
501                    Ok(Some(vec![
502                        Interval::CERTAINLY_FALSE,
503                        Interval::CERTAINLY_TRUE,
504                    ]))
505                }
506                // If both children are uncertain, or if one is certainly false,
507                // we cannot conclusively refine their intervals. In this case,
508                // propagation does not result in any interval changes.
509                else {
510                    Ok(Some(vec![]))
511                }
512            } else {
513                // An uncertain logical conjunction result can not shrink the
514                // end-points of its children.
515                Ok(Some(vec![]))
516            }
517        } else if self.op.eq(&Operator::Or) {
518            if interval.eq(&Interval::CERTAINLY_FALSE) {
519                // A certainly false logical disjunction can only derive from certainly
520                // false operands. Otherwise, we prove infeasibility.
521                Ok((!left_interval.eq(&Interval::CERTAINLY_TRUE)
522                    && !right_interval.eq(&Interval::CERTAINLY_TRUE))
523                .then(|| vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_FALSE]))
524            } else if interval.eq(&Interval::CERTAINLY_TRUE) {
525                // If the logical disjunction is certainly true, one of the
526                // operands must be true. However, it's not always possible to
527                // determine which operand is true, leading to different scenarios.
528
529                // If one operand is certainly false and the other one is uncertain,
530                // then the latter must be certainly true.
531                if left_interval.eq(&Interval::CERTAINLY_FALSE)
532                    && right_interval.eq(&Interval::UNCERTAIN)
533                {
534                    Ok(Some(vec![
535                        Interval::CERTAINLY_FALSE,
536                        Interval::CERTAINLY_TRUE,
537                    ]))
538                } else if right_interval.eq(&Interval::CERTAINLY_FALSE)
539                    && left_interval.eq(&Interval::UNCERTAIN)
540                {
541                    Ok(Some(vec![
542                        Interval::CERTAINLY_TRUE,
543                        Interval::CERTAINLY_FALSE,
544                    ]))
545                }
546                // If both children are uncertain, or if one is certainly true,
547                // we cannot conclusively refine their intervals. In this case,
548                // propagation does not result in any interval changes.
549                else {
550                    Ok(Some(vec![]))
551                }
552            } else {
553                // An uncertain logical disjunction result can not shrink the
554                // end-points of its children.
555                Ok(Some(vec![]))
556            }
557        } else if self.op.supports_propagation() {
558            Ok(
559                propagate_comparison(&self.op, interval, left_interval, right_interval)?
560                    .map(|(left, right)| vec![left, right]),
561            )
562        } else {
563            Ok(
564                propagate_arithmetic(&self.op, interval, left_interval, right_interval)?
565                    .map(|(left, right)| vec![left, right]),
566            )
567        }
568    }
569
570    fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
571        let (left, right) = (children[0], children[1]);
572
573        if self.op.is_numerical_operators() {
574            // We might be able to construct the output statistics more accurately,
575            // without falling back to an unknown distribution, if we are dealing
576            // with Gaussian distributions and numerical operations.
577            if let (Gaussian(left), Gaussian(right)) = (left, right) {
578                if let Some(result) = combine_gaussians(&self.op, left, right)? {
579                    return Ok(Gaussian(result));
580                }
581            }
582        } else if self.op.is_logic_operator() {
583            // If we are dealing with logical operators, we expect (and can only
584            // operate on) Bernoulli distributions.
585            return if let (Bernoulli(left), Bernoulli(right)) = (left, right) {
586                combine_bernoullis(&self.op, left, right).map(Bernoulli)
587            } else {
588                internal_err!(
589                    "Logical operators are only compatible with `Bernoulli` distributions"
590                )
591            };
592        } else if self.op.supports_propagation() {
593            // If we are handling comparison operators, we expect (and can only
594            // operate on) numeric distributions.
595            return create_bernoulli_from_comparison(&self.op, left, right);
596        }
597        // Fall back to an unknown distribution with only summary statistics:
598        new_generic_from_binary_op(&self.op, left, right)
599    }
600
601    /// For each operator, [`BinaryExpr`] has distinct rules.
602    /// TODO: There may be rules specific to some data types and expression ranges.
603    fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
604        let (l_order, l_range) = (children[0].sort_properties, &children[0].range);
605        let (r_order, r_range) = (children[1].sort_properties, &children[1].range);
606        match self.op() {
607            Operator::Plus => Ok(ExprProperties {
608                sort_properties: l_order.add(&r_order),
609                range: l_range.add(r_range)?,
610                preserves_lex_ordering: false,
611            }),
612            Operator::Minus => Ok(ExprProperties {
613                sort_properties: l_order.sub(&r_order),
614                range: l_range.sub(r_range)?,
615                preserves_lex_ordering: false,
616            }),
617            Operator::Gt => Ok(ExprProperties {
618                sort_properties: l_order.gt_or_gteq(&r_order),
619                range: l_range.gt(r_range)?,
620                preserves_lex_ordering: false,
621            }),
622            Operator::GtEq => Ok(ExprProperties {
623                sort_properties: l_order.gt_or_gteq(&r_order),
624                range: l_range.gt_eq(r_range)?,
625                preserves_lex_ordering: false,
626            }),
627            Operator::Lt => Ok(ExprProperties {
628                sort_properties: r_order.gt_or_gteq(&l_order),
629                range: l_range.lt(r_range)?,
630                preserves_lex_ordering: false,
631            }),
632            Operator::LtEq => Ok(ExprProperties {
633                sort_properties: r_order.gt_or_gteq(&l_order),
634                range: l_range.lt_eq(r_range)?,
635                preserves_lex_ordering: false,
636            }),
637            Operator::And => Ok(ExprProperties {
638                sort_properties: r_order.and_or(&l_order),
639                range: l_range.and(r_range)?,
640                preserves_lex_ordering: false,
641            }),
642            Operator::Or => Ok(ExprProperties {
643                sort_properties: r_order.and_or(&l_order),
644                range: l_range.or(r_range)?,
645                preserves_lex_ordering: false,
646            }),
647            _ => Ok(ExprProperties::new_unknown()),
648        }
649    }
650
651    fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
652        fn write_child(
653            f: &mut std::fmt::Formatter,
654            expr: &dyn PhysicalExpr,
655            precedence: u8,
656        ) -> std::fmt::Result {
657            if let Some(child) = expr.as_any().downcast_ref::<BinaryExpr>() {
658                let p = child.op.precedence();
659                if p == 0 || p < precedence {
660                    write!(f, "(")?;
661                    child.fmt_sql(f)?;
662                    write!(f, ")")
663                } else {
664                    child.fmt_sql(f)
665                }
666            } else {
667                expr.fmt_sql(f)
668            }
669        }
670
671        let precedence = self.op.precedence();
672        write_child(f, self.left.as_ref(), precedence)?;
673        write!(f, " {} ", self.op)?;
674        write_child(f, self.right.as_ref(), precedence)
675    }
676}
677
678/// Casts dictionary array to result type for binary numerical operators. Such operators
679/// between array and scalar produce a dictionary array other than primitive array of the
680/// same operators between array and array. This leads to inconsistent result types causing
681/// errors in the following query execution. For such operators between array and scalar,
682/// we cast the dictionary array to primitive array.
683fn to_result_type_array(
684    op: &Operator,
685    array: ArrayRef,
686    result_type: &DataType,
687) -> Result<ArrayRef> {
688    if array.data_type() == result_type {
689        Ok(array)
690    } else if op.is_numerical_operators() {
691        match array.data_type() {
692            DataType::Dictionary(_, value_type) => {
693                if value_type.as_ref() == result_type {
694                    Ok(cast(&array, result_type)?)
695                } else {
696                    internal_err!(
697                            "Incompatible Dictionary value type {value_type:?} with result type {result_type:?} of Binary operator {op:?}"
698                        )
699                }
700            }
701            _ => Ok(array),
702        }
703    } else {
704        Ok(array)
705    }
706}
707
708impl BinaryExpr {
709    /// Evaluate the expression of the left input is an array and
710    /// right is literal - use scalar operations
711    fn evaluate_array_scalar(
712        &self,
713        array: &dyn Array,
714        scalar: ScalarValue,
715    ) -> Result<Option<Result<ArrayRef>>> {
716        use Operator::*;
717        let scalar_result = match &self.op {
718            RegexMatch => binary_string_array_flag_op_scalar!(
719                array,
720                scalar,
721                regexp_is_match,
722                false,
723                false
724            ),
725            RegexIMatch => binary_string_array_flag_op_scalar!(
726                array,
727                scalar,
728                regexp_is_match,
729                false,
730                true
731            ),
732            RegexNotMatch => binary_string_array_flag_op_scalar!(
733                array,
734                scalar,
735                regexp_is_match,
736                true,
737                false
738            ),
739            RegexNotIMatch => binary_string_array_flag_op_scalar!(
740                array,
741                scalar,
742                regexp_is_match,
743                true,
744                true
745            ),
746            BitwiseAnd => bitwise_and_dyn_scalar(array, scalar),
747            BitwiseOr => bitwise_or_dyn_scalar(array, scalar),
748            BitwiseXor => bitwise_xor_dyn_scalar(array, scalar),
749            BitwiseShiftRight => bitwise_shift_right_dyn_scalar(array, scalar),
750            BitwiseShiftLeft => bitwise_shift_left_dyn_scalar(array, scalar),
751            // if scalar operation is not supported - fallback to array implementation
752            _ => None,
753        };
754
755        Ok(scalar_result)
756    }
757
758    fn evaluate_with_resolved_args(
759        &self,
760        left: Arc<dyn Array>,
761        left_data_type: &DataType,
762        right: Arc<dyn Array>,
763        right_data_type: &DataType,
764    ) -> Result<ArrayRef> {
765        use Operator::*;
766        match &self.op {
767            IsDistinctFrom | IsNotDistinctFrom | Lt | LtEq | Gt | GtEq | Eq | NotEq
768            | Plus | Minus | Multiply | Divide | Modulo | LikeMatch | ILikeMatch
769            | NotLikeMatch | NotILikeMatch => unreachable!(),
770            And => {
771                if left_data_type == &DataType::Boolean {
772                    Ok(boolean_op(&left, &right, and_kleene)?)
773                } else {
774                    internal_err!(
775                        "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
776                        self.op,
777                        left.data_type(),
778                        right.data_type()
779                    )
780                }
781            }
782            Or => {
783                if left_data_type == &DataType::Boolean {
784                    Ok(boolean_op(&left, &right, or_kleene)?)
785                } else {
786                    internal_err!(
787                        "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
788                        self.op,
789                        left_data_type,
790                        right_data_type
791                    )
792                }
793            }
794            RegexMatch => {
795                binary_string_array_flag_op!(left, right, regexp_is_match, false, false)
796            }
797            RegexIMatch => {
798                binary_string_array_flag_op!(left, right, regexp_is_match, false, true)
799            }
800            RegexNotMatch => {
801                binary_string_array_flag_op!(left, right, regexp_is_match, true, false)
802            }
803            RegexNotIMatch => {
804                binary_string_array_flag_op!(left, right, regexp_is_match, true, true)
805            }
806            BitwiseAnd => bitwise_and_dyn(left, right),
807            BitwiseOr => bitwise_or_dyn(left, right),
808            BitwiseXor => bitwise_xor_dyn(left, right),
809            BitwiseShiftRight => bitwise_shift_right_dyn(left, right),
810            BitwiseShiftLeft => bitwise_shift_left_dyn(left, right),
811            StringConcat => concat_elements(left, right),
812            AtArrow | ArrowAt | Arrow | LongArrow | HashArrow | HashLongArrow | AtAt
813            | HashMinus | AtQuestion | Question | QuestionAnd | QuestionPipe
814            | IntegerDivide => {
815                not_impl_err!(
816                    "Binary operator '{:?}' is not supported in the physical expr",
817                    self.op
818                )
819            }
820        }
821    }
822}
823
824enum ShortCircuitStrategy<'a> {
825    None,
826    ReturnLeft,
827    ReturnRight,
828    PreSelection(&'a BooleanArray),
829}
830
831/// Based on the results calculated from the left side of the short-circuit operation,
832/// if the proportion of `true` is less than 0.2 and the current operation is an `and`,
833/// the `RecordBatch` will be filtered in advance.
834const PRE_SELECTION_THRESHOLD: f32 = 0.2;
835
836/// Checks if a logical operator (`AND`/`OR`) can short-circuit evaluation based on the left-hand side (lhs) result.
837///
838/// Short-circuiting occurs under these circumstances:
839/// - For `AND`:
840///    - if LHS is all false => short-circuit → return LHS
841///    - if LHS is all true  => short-circuit → return RHS
842///    - if LHS is mixed and true_count/sum_count <= [`PRE_SELECTION_THRESHOLD`] -> pre-selection
843/// - For `OR`:
844///    - if LHS is all true  => short-circuit → return LHS
845///    - if LHS is all false => short-circuit → return RHS
846/// # Arguments
847/// * `lhs` - The left-hand side (lhs) columnar value (array or scalar)
848/// * `lhs` - The left-hand side (lhs) columnar value (array or scalar)
849/// * `op` - The logical operator (`AND` or `OR`)
850///
851/// # Implementation Notes
852/// 1. Only works with Boolean-typed arguments (other types automatically return `false`)
853/// 2. Handles both scalar values and array values
854/// 3. For arrays, uses optimized bit counting techniques for boolean arrays
855fn check_short_circuit<'a>(
856    lhs: &'a ColumnarValue,
857    op: &Operator,
858) -> ShortCircuitStrategy<'a> {
859    // Quick reject for non-logical operators,and quick judgment when op is and
860    let is_and = match op {
861        Operator::And => true,
862        Operator::Or => false,
863        _ => return ShortCircuitStrategy::None,
864    };
865
866    // Non-boolean types can't be short-circuited
867    if lhs.data_type() != DataType::Boolean {
868        return ShortCircuitStrategy::None;
869    }
870
871    match lhs {
872        ColumnarValue::Array(array) => {
873            // Fast path for arrays - try to downcast to boolean array
874            if let Ok(bool_array) = as_boolean_array(array) {
875                // Arrays with nulls can't be short-circuited
876                if bool_array.null_count() > 0 {
877                    return ShortCircuitStrategy::None;
878                }
879
880                let len = bool_array.len();
881                if len == 0 {
882                    return ShortCircuitStrategy::None;
883                }
884
885                let true_count = bool_array.values().count_set_bits();
886                if is_and {
887                    // For AND, prioritize checking for all-false (short circuit case)
888                    // Uses optimized false_count() method provided by Arrow
889
890                    // Short circuit if all values are false
891                    if true_count == 0 {
892                        return ShortCircuitStrategy::ReturnLeft;
893                    }
894
895                    // If no false values, then all must be true
896                    if true_count == len {
897                        return ShortCircuitStrategy::ReturnRight;
898                    }
899
900                    // determine if we can pre-selection
901                    if true_count as f32 / len as f32 <= PRE_SELECTION_THRESHOLD {
902                        return ShortCircuitStrategy::PreSelection(bool_array);
903                    }
904                } else {
905                    // For OR, prioritize checking for all-true (short circuit case)
906                    // Uses optimized true_count() method provided by Arrow
907
908                    // Short circuit if all values are true
909                    if true_count == len {
910                        return ShortCircuitStrategy::ReturnLeft;
911                    }
912
913                    // If no true values, then all must be false
914                    if true_count == 0 {
915                        return ShortCircuitStrategy::ReturnRight;
916                    }
917                }
918            }
919        }
920        ColumnarValue::Scalar(scalar) => {
921            // Fast path for scalar values
922            if let ScalarValue::Boolean(Some(is_true)) = scalar {
923                // Return Left for:
924                // - AND with false value
925                // - OR with true value
926                if (is_and && !is_true) || (!is_and && *is_true) {
927                    return ShortCircuitStrategy::ReturnLeft;
928                } else {
929                    return ShortCircuitStrategy::ReturnRight;
930                }
931            }
932        }
933    }
934
935    // If we can't short-circuit, indicate that normal evaluation should continue
936    ShortCircuitStrategy::None
937}
938
939/// Creates a new boolean array based on the evaluation of the right expression,
940/// but only for positions where the left_result is true.
941///
942/// This function is used for short-circuit evaluation optimization of logical AND operations:
943/// - When left_result has few true values, we only evaluate the right expression for those positions
944/// - Values are copied from right_array where left_result is true
945/// - All other positions are filled with false values
946///
947/// # Parameters
948/// - `left_result` Boolean array with selection mask (typically from left side of AND)
949/// - `right_result` Result of evaluating right side of expression (only for selected positions)
950///
951/// # Returns
952/// A combined ColumnarValue with values from right_result where left_result is true
953///
954/// # Example
955///  Initial Data: { 1, 2, 3, 4, 5 }
956///  Left Evaluation
957///     (Condition: Equal to 2 or 3)
958///          ↓
959///  Filtered Data: {2, 3}
960///    Left Bitmap: { 0, 1, 1, 0, 0 }
961///          ↓
962///   Right Evaluation
963///     (Condition: Even numbers)
964///          ↓
965///  Right Data: { 2 }
966///    Right Bitmap: { 1, 0 }
967///          ↓
968///   Combine Results
969///  Final Bitmap: { 0, 1, 0, 0, 0 }
970///
971/// # Note
972/// Perhaps it would be better to modify `left_result` directly without creating a copy?
973/// In practice, `left_result` should have only one owner, so making changes should be safe.
974/// However, this is difficult to achieve under the immutable constraints of [`Arc`] and [`BooleanArray`].
975fn pre_selection_scatter(
976    left_result: &BooleanArray,
977    right_result: ColumnarValue,
978) -> Result<ColumnarValue> {
979    let right_boolean_array = match &right_result {
980        ColumnarValue::Array(array) => array.as_boolean(),
981        ColumnarValue::Scalar(_) => return Ok(right_result),
982    };
983
984    let result_len = left_result.len();
985
986    let mut result_array_builder = BooleanArray::builder(result_len);
987
988    // keep track of current position we have in right boolean array
989    let mut right_array_pos = 0;
990
991    // keep track of how much is filled
992    let mut last_end = 0;
993    SlicesIterator::new(left_result).for_each(|(start, end)| {
994        // the gap needs to be filled with false
995        if start > last_end {
996            result_array_builder.append_n(start - last_end, false);
997        }
998
999        // copy values from right array for this slice
1000        let len = end - start;
1001        right_boolean_array
1002            .slice(right_array_pos, len)
1003            .iter()
1004            .for_each(|v| result_array_builder.append_option(v));
1005
1006        right_array_pos += len;
1007        last_end = end;
1008    });
1009
1010    // Fill any remaining positions with false
1011    if last_end < result_len {
1012        result_array_builder.append_n(result_len - last_end, false);
1013    }
1014    let boolean_result = result_array_builder.finish();
1015
1016    Ok(ColumnarValue::Array(Arc::new(boolean_result)))
1017}
1018
1019fn concat_elements(left: Arc<dyn Array>, right: Arc<dyn Array>) -> Result<ArrayRef> {
1020    Ok(match left.data_type() {
1021        DataType::Utf8 => Arc::new(concat_elements_utf8(
1022            left.as_string::<i32>(),
1023            right.as_string::<i32>(),
1024        )?),
1025        DataType::LargeUtf8 => Arc::new(concat_elements_utf8(
1026            left.as_string::<i64>(),
1027            right.as_string::<i64>(),
1028        )?),
1029        DataType::Utf8View => Arc::new(concat_elements_utf8view(
1030            left.as_string_view(),
1031            right.as_string_view(),
1032        )?),
1033        other => {
1034            return internal_err!(
1035                "Data type {other:?} not supported for binary operation 'concat_elements' on string arrays"
1036            );
1037        }
1038    })
1039}
1040
1041/// Create a binary expression whose arguments are correctly coerced.
1042/// This function errors if it is not possible to coerce the arguments
1043/// to computational types supported by the operator.
1044pub fn binary(
1045    lhs: Arc<dyn PhysicalExpr>,
1046    op: Operator,
1047    rhs: Arc<dyn PhysicalExpr>,
1048    _input_schema: &Schema,
1049) -> Result<Arc<dyn PhysicalExpr>> {
1050    Ok(Arc::new(BinaryExpr::new(lhs, op, rhs)))
1051}
1052
1053/// Create a similar to expression
1054pub fn similar_to(
1055    negated: bool,
1056    case_insensitive: bool,
1057    expr: Arc<dyn PhysicalExpr>,
1058    pattern: Arc<dyn PhysicalExpr>,
1059) -> Result<Arc<dyn PhysicalExpr>> {
1060    let binary_op = match (negated, case_insensitive) {
1061        (false, false) => Operator::RegexMatch,
1062        (false, true) => Operator::RegexIMatch,
1063        (true, false) => Operator::RegexNotMatch,
1064        (true, true) => Operator::RegexNotIMatch,
1065    };
1066    Ok(Arc::new(BinaryExpr::new(expr, binary_op, pattern)))
1067}
1068
1069#[cfg(test)]
1070mod tests {
1071    use super::*;
1072    use crate::expressions::{col, lit, try_cast, Column, Literal};
1073    use datafusion_expr::lit as expr_lit;
1074
1075    use datafusion_common::plan_datafusion_err;
1076    use datafusion_physical_expr_common::physical_expr::fmt_sql;
1077
1078    use crate::planner::logical2physical;
1079    use arrow::array::BooleanArray;
1080    use datafusion_expr::col as logical_col;
1081    /// Performs a binary operation, applying any type coercion necessary
1082    fn binary_op(
1083        left: Arc<dyn PhysicalExpr>,
1084        op: Operator,
1085        right: Arc<dyn PhysicalExpr>,
1086        schema: &Schema,
1087    ) -> Result<Arc<dyn PhysicalExpr>> {
1088        let left_type = left.data_type(schema)?;
1089        let right_type = right.data_type(schema)?;
1090        let (lhs, rhs) =
1091            BinaryTypeCoercer::new(&left_type, &op, &right_type).get_input_types()?;
1092
1093        let left_expr = try_cast(left, schema, lhs)?;
1094        let right_expr = try_cast(right, schema, rhs)?;
1095        binary(left_expr, op, right_expr, schema)
1096    }
1097
1098    #[test]
1099    fn binary_comparison() -> Result<()> {
1100        let schema = Schema::new(vec![
1101            Field::new("a", DataType::Int32, false),
1102            Field::new("b", DataType::Int32, false),
1103        ]);
1104        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1105        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1106
1107        // expression: "a < b"
1108        let lt = binary(
1109            col("a", &schema)?,
1110            Operator::Lt,
1111            col("b", &schema)?,
1112            &schema,
1113        )?;
1114        let batch =
1115            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
1116
1117        let result = lt
1118            .evaluate(&batch)?
1119            .into_array(batch.num_rows())
1120            .expect("Failed to convert to array");
1121        assert_eq!(result.len(), 5);
1122
1123        let expected = [false, false, true, true, true];
1124        let result =
1125            as_boolean_array(&result).expect("failed to downcast to BooleanArray");
1126        for (i, &expected_item) in expected.iter().enumerate().take(5) {
1127            assert_eq!(result.value(i), expected_item);
1128        }
1129
1130        Ok(())
1131    }
1132
1133    #[test]
1134    fn binary_nested() -> Result<()> {
1135        let schema = Schema::new(vec![
1136            Field::new("a", DataType::Int32, false),
1137            Field::new("b", DataType::Int32, false),
1138        ]);
1139        let a = Int32Array::from(vec![2, 4, 6, 8, 10]);
1140        let b = Int32Array::from(vec![2, 5, 4, 8, 8]);
1141
1142        // expression: "a < b OR a == b"
1143        let expr = binary(
1144            binary(
1145                col("a", &schema)?,
1146                Operator::Lt,
1147                col("b", &schema)?,
1148                &schema,
1149            )?,
1150            Operator::Or,
1151            binary(
1152                col("a", &schema)?,
1153                Operator::Eq,
1154                col("b", &schema)?,
1155                &schema,
1156            )?,
1157            &schema,
1158        )?;
1159        let batch =
1160            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
1161
1162        assert_eq!("a@0 < b@1 OR a@0 = b@1", format!("{expr}"));
1163
1164        let result = expr
1165            .evaluate(&batch)?
1166            .into_array(batch.num_rows())
1167            .expect("Failed to convert to array");
1168        assert_eq!(result.len(), 5);
1169
1170        let expected = [true, true, false, true, false];
1171        let result =
1172            as_boolean_array(&result).expect("failed to downcast to BooleanArray");
1173        for (i, &expected_item) in expected.iter().enumerate().take(5) {
1174            assert_eq!(result.value(i), expected_item);
1175        }
1176
1177        Ok(())
1178    }
1179
1180    // runs an end-to-end test of physical type coercion:
1181    // 1. construct a record batch with two columns of type A and B
1182    //  (*_ARRAY is the Rust Arrow array type, and *_TYPE is the DataType of the elements)
1183    // 2. construct a physical expression of A OP B
1184    // 3. evaluate the expression
1185    // 4. verify that the resulting expression is of type C
1186    // 5. verify that the results of evaluation are $VEC
1187    macro_rules! test_coercion {
1188        ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $B_ARRAY:ident, $B_TYPE:expr, $B_VEC:expr, $OP:expr, $C_ARRAY:ident, $C_TYPE:expr, $VEC:expr,) => {{
1189            let schema = Schema::new(vec![
1190                Field::new("a", $A_TYPE, false),
1191                Field::new("b", $B_TYPE, false),
1192            ]);
1193            let a = $A_ARRAY::from($A_VEC);
1194            let b = $B_ARRAY::from($B_VEC);
1195            let (lhs, rhs) = BinaryTypeCoercer::new(&$A_TYPE, &$OP, &$B_TYPE).get_input_types()?;
1196
1197            let left = try_cast(col("a", &schema)?, &schema, lhs)?;
1198            let right = try_cast(col("b", &schema)?, &schema, rhs)?;
1199
1200            // verify that we can construct the expression
1201            let expression = binary(left, $OP, right, &schema)?;
1202            let batch = RecordBatch::try_new(
1203                Arc::new(schema.clone()),
1204                vec![Arc::new(a), Arc::new(b)],
1205            )?;
1206
1207            // verify that the expression's type is correct
1208            assert_eq!(expression.data_type(&schema)?, $C_TYPE);
1209
1210            // compute
1211            let result = expression.evaluate(&batch)?.into_array(batch.num_rows()).expect("Failed to convert to array");
1212
1213            // verify that the array's data_type is correct
1214            assert_eq!(*result.data_type(), $C_TYPE);
1215
1216            // verify that the data itself is downcastable
1217            let result = result
1218                .as_any()
1219                .downcast_ref::<$C_ARRAY>()
1220                .expect("failed to downcast");
1221            // verify that the result itself is correct
1222            for (i, x) in $VEC.iter().enumerate() {
1223                let v = result.value(i);
1224                assert_eq!(
1225                    v,
1226                    *x,
1227                    "Unexpected output at position {i}:\n\nActual:\n{v}\n\nExpected:\n{x}"
1228                );
1229            }
1230        }};
1231    }
1232
1233    #[test]
1234    fn test_type_coercion() -> Result<()> {
1235        test_coercion!(
1236            Int32Array,
1237            DataType::Int32,
1238            vec![1i32, 2i32],
1239            UInt32Array,
1240            DataType::UInt32,
1241            vec![1u32, 2u32],
1242            Operator::Plus,
1243            Int64Array,
1244            DataType::Int64,
1245            [2i64, 4i64],
1246        );
1247        test_coercion!(
1248            Int32Array,
1249            DataType::Int32,
1250            vec![1i32],
1251            UInt16Array,
1252            DataType::UInt16,
1253            vec![1u16],
1254            Operator::Plus,
1255            Int32Array,
1256            DataType::Int32,
1257            [2i32],
1258        );
1259        test_coercion!(
1260            Float32Array,
1261            DataType::Float32,
1262            vec![1f32],
1263            UInt16Array,
1264            DataType::UInt16,
1265            vec![1u16],
1266            Operator::Plus,
1267            Float32Array,
1268            DataType::Float32,
1269            [2f32],
1270        );
1271        test_coercion!(
1272            Float32Array,
1273            DataType::Float32,
1274            vec![2f32],
1275            UInt16Array,
1276            DataType::UInt16,
1277            vec![1u16],
1278            Operator::Multiply,
1279            Float32Array,
1280            DataType::Float32,
1281            [2f32],
1282        );
1283        test_coercion!(
1284            StringArray,
1285            DataType::Utf8,
1286            vec!["1994-12-13", "1995-01-26"],
1287            Date32Array,
1288            DataType::Date32,
1289            vec![9112, 9156],
1290            Operator::Eq,
1291            BooleanArray,
1292            DataType::Boolean,
1293            [true, true],
1294        );
1295        test_coercion!(
1296            StringArray,
1297            DataType::Utf8,
1298            vec!["1994-12-13", "1995-01-26"],
1299            Date32Array,
1300            DataType::Date32,
1301            vec![9113, 9154],
1302            Operator::Lt,
1303            BooleanArray,
1304            DataType::Boolean,
1305            [true, false],
1306        );
1307        test_coercion!(
1308            StringArray,
1309            DataType::Utf8,
1310            vec!["1994-12-13T12:34:56", "1995-01-26T01:23:45"],
1311            Date64Array,
1312            DataType::Date64,
1313            vec![787322096000, 791083425000],
1314            Operator::Eq,
1315            BooleanArray,
1316            DataType::Boolean,
1317            [true, true],
1318        );
1319        test_coercion!(
1320            StringArray,
1321            DataType::Utf8,
1322            vec!["1994-12-13T12:34:56", "1995-01-26T01:23:45"],
1323            Date64Array,
1324            DataType::Date64,
1325            vec![787322096001, 791083424999],
1326            Operator::Lt,
1327            BooleanArray,
1328            DataType::Boolean,
1329            [true, false],
1330        );
1331        test_coercion!(
1332            StringViewArray,
1333            DataType::Utf8View,
1334            vec!["abc"; 5],
1335            StringArray,
1336            DataType::Utf8,
1337            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1338            Operator::RegexMatch,
1339            BooleanArray,
1340            DataType::Boolean,
1341            [true, false, true, false, false],
1342        );
1343        test_coercion!(
1344            StringViewArray,
1345            DataType::Utf8View,
1346            vec!["abc"; 5],
1347            StringArray,
1348            DataType::Utf8,
1349            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1350            Operator::RegexIMatch,
1351            BooleanArray,
1352            DataType::Boolean,
1353            [true, true, true, true, false],
1354        );
1355        test_coercion!(
1356            StringArray,
1357            DataType::Utf8,
1358            vec!["abc"; 5],
1359            StringViewArray,
1360            DataType::Utf8View,
1361            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1362            Operator::RegexNotMatch,
1363            BooleanArray,
1364            DataType::Boolean,
1365            [false, true, false, true, true],
1366        );
1367        test_coercion!(
1368            StringArray,
1369            DataType::Utf8,
1370            vec!["abc"; 5],
1371            StringViewArray,
1372            DataType::Utf8View,
1373            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1374            Operator::RegexNotIMatch,
1375            BooleanArray,
1376            DataType::Boolean,
1377            [false, false, false, false, true],
1378        );
1379        test_coercion!(
1380            StringArray,
1381            DataType::Utf8,
1382            vec!["abc"; 5],
1383            StringArray,
1384            DataType::Utf8,
1385            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1386            Operator::RegexMatch,
1387            BooleanArray,
1388            DataType::Boolean,
1389            [true, false, true, false, false],
1390        );
1391        test_coercion!(
1392            StringArray,
1393            DataType::Utf8,
1394            vec!["abc"; 5],
1395            StringArray,
1396            DataType::Utf8,
1397            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1398            Operator::RegexIMatch,
1399            BooleanArray,
1400            DataType::Boolean,
1401            [true, true, true, true, false],
1402        );
1403        test_coercion!(
1404            StringArray,
1405            DataType::Utf8,
1406            vec!["abc"; 5],
1407            StringArray,
1408            DataType::Utf8,
1409            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1410            Operator::RegexNotMatch,
1411            BooleanArray,
1412            DataType::Boolean,
1413            [false, true, false, true, true],
1414        );
1415        test_coercion!(
1416            StringArray,
1417            DataType::Utf8,
1418            vec!["abc"; 5],
1419            StringArray,
1420            DataType::Utf8,
1421            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1422            Operator::RegexNotIMatch,
1423            BooleanArray,
1424            DataType::Boolean,
1425            [false, false, false, false, true],
1426        );
1427        test_coercion!(
1428            LargeStringArray,
1429            DataType::LargeUtf8,
1430            vec!["abc"; 5],
1431            LargeStringArray,
1432            DataType::LargeUtf8,
1433            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1434            Operator::RegexMatch,
1435            BooleanArray,
1436            DataType::Boolean,
1437            [true, false, true, false, false],
1438        );
1439        test_coercion!(
1440            LargeStringArray,
1441            DataType::LargeUtf8,
1442            vec!["abc"; 5],
1443            LargeStringArray,
1444            DataType::LargeUtf8,
1445            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1446            Operator::RegexIMatch,
1447            BooleanArray,
1448            DataType::Boolean,
1449            [true, true, true, true, false],
1450        );
1451        test_coercion!(
1452            LargeStringArray,
1453            DataType::LargeUtf8,
1454            vec!["abc"; 5],
1455            LargeStringArray,
1456            DataType::LargeUtf8,
1457            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1458            Operator::RegexNotMatch,
1459            BooleanArray,
1460            DataType::Boolean,
1461            [false, true, false, true, true],
1462        );
1463        test_coercion!(
1464            LargeStringArray,
1465            DataType::LargeUtf8,
1466            vec!["abc"; 5],
1467            LargeStringArray,
1468            DataType::LargeUtf8,
1469            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1470            Operator::RegexNotIMatch,
1471            BooleanArray,
1472            DataType::Boolean,
1473            [false, false, false, false, true],
1474        );
1475        test_coercion!(
1476            StringArray,
1477            DataType::Utf8,
1478            vec!["abc"; 5],
1479            StringArray,
1480            DataType::Utf8,
1481            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1482            Operator::LikeMatch,
1483            BooleanArray,
1484            DataType::Boolean,
1485            [true, false, false, true, false],
1486        );
1487        test_coercion!(
1488            StringArray,
1489            DataType::Utf8,
1490            vec!["abc"; 5],
1491            StringArray,
1492            DataType::Utf8,
1493            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1494            Operator::ILikeMatch,
1495            BooleanArray,
1496            DataType::Boolean,
1497            [true, true, false, true, true],
1498        );
1499        test_coercion!(
1500            StringArray,
1501            DataType::Utf8,
1502            vec!["abc"; 5],
1503            StringArray,
1504            DataType::Utf8,
1505            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1506            Operator::NotLikeMatch,
1507            BooleanArray,
1508            DataType::Boolean,
1509            [false, true, true, false, true],
1510        );
1511        test_coercion!(
1512            StringArray,
1513            DataType::Utf8,
1514            vec!["abc"; 5],
1515            StringArray,
1516            DataType::Utf8,
1517            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1518            Operator::NotILikeMatch,
1519            BooleanArray,
1520            DataType::Boolean,
1521            [false, false, true, false, false],
1522        );
1523        test_coercion!(
1524            LargeStringArray,
1525            DataType::LargeUtf8,
1526            vec!["abc"; 5],
1527            LargeStringArray,
1528            DataType::LargeUtf8,
1529            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1530            Operator::LikeMatch,
1531            BooleanArray,
1532            DataType::Boolean,
1533            [true, false, false, true, false],
1534        );
1535        test_coercion!(
1536            LargeStringArray,
1537            DataType::LargeUtf8,
1538            vec!["abc"; 5],
1539            LargeStringArray,
1540            DataType::LargeUtf8,
1541            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1542            Operator::ILikeMatch,
1543            BooleanArray,
1544            DataType::Boolean,
1545            [true, true, false, true, true],
1546        );
1547        test_coercion!(
1548            LargeStringArray,
1549            DataType::LargeUtf8,
1550            vec!["abc"; 5],
1551            LargeStringArray,
1552            DataType::LargeUtf8,
1553            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1554            Operator::NotLikeMatch,
1555            BooleanArray,
1556            DataType::Boolean,
1557            [false, true, true, false, true],
1558        );
1559        test_coercion!(
1560            LargeStringArray,
1561            DataType::LargeUtf8,
1562            vec!["abc"; 5],
1563            LargeStringArray,
1564            DataType::LargeUtf8,
1565            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1566            Operator::NotILikeMatch,
1567            BooleanArray,
1568            DataType::Boolean,
1569            [false, false, true, false, false],
1570        );
1571        test_coercion!(
1572            Int16Array,
1573            DataType::Int16,
1574            vec![1i16, 2i16, 3i16],
1575            Int64Array,
1576            DataType::Int64,
1577            vec![10i64, 4i64, 5i64],
1578            Operator::BitwiseAnd,
1579            Int64Array,
1580            DataType::Int64,
1581            [0i64, 0i64, 1i64],
1582        );
1583        test_coercion!(
1584            UInt16Array,
1585            DataType::UInt16,
1586            vec![1u16, 2u16, 3u16],
1587            UInt64Array,
1588            DataType::UInt64,
1589            vec![10u64, 4u64, 5u64],
1590            Operator::BitwiseAnd,
1591            UInt64Array,
1592            DataType::UInt64,
1593            [0u64, 0u64, 1u64],
1594        );
1595        test_coercion!(
1596            Int16Array,
1597            DataType::Int16,
1598            vec![3i16, 2i16, 3i16],
1599            Int64Array,
1600            DataType::Int64,
1601            vec![10i64, 6i64, 5i64],
1602            Operator::BitwiseOr,
1603            Int64Array,
1604            DataType::Int64,
1605            [11i64, 6i64, 7i64],
1606        );
1607        test_coercion!(
1608            UInt16Array,
1609            DataType::UInt16,
1610            vec![1u16, 2u16, 3u16],
1611            UInt64Array,
1612            DataType::UInt64,
1613            vec![10u64, 4u64, 5u64],
1614            Operator::BitwiseOr,
1615            UInt64Array,
1616            DataType::UInt64,
1617            [11u64, 6u64, 7u64],
1618        );
1619        test_coercion!(
1620            Int16Array,
1621            DataType::Int16,
1622            vec![3i16, 2i16, 3i16],
1623            Int64Array,
1624            DataType::Int64,
1625            vec![10i64, 6i64, 5i64],
1626            Operator::BitwiseXor,
1627            Int64Array,
1628            DataType::Int64,
1629            [9i64, 4i64, 6i64],
1630        );
1631        test_coercion!(
1632            UInt16Array,
1633            DataType::UInt16,
1634            vec![3u16, 2u16, 3u16],
1635            UInt64Array,
1636            DataType::UInt64,
1637            vec![10u64, 6u64, 5u64],
1638            Operator::BitwiseXor,
1639            UInt64Array,
1640            DataType::UInt64,
1641            [9u64, 4u64, 6u64],
1642        );
1643        test_coercion!(
1644            Int16Array,
1645            DataType::Int16,
1646            vec![4i16, 27i16, 35i16],
1647            Int64Array,
1648            DataType::Int64,
1649            vec![2i64, 3i64, 4i64],
1650            Operator::BitwiseShiftRight,
1651            Int64Array,
1652            DataType::Int64,
1653            [1i64, 3i64, 2i64],
1654        );
1655        test_coercion!(
1656            UInt16Array,
1657            DataType::UInt16,
1658            vec![4u16, 27u16, 35u16],
1659            UInt64Array,
1660            DataType::UInt64,
1661            vec![2u64, 3u64, 4u64],
1662            Operator::BitwiseShiftRight,
1663            UInt64Array,
1664            DataType::UInt64,
1665            [1u64, 3u64, 2u64],
1666        );
1667        test_coercion!(
1668            Int16Array,
1669            DataType::Int16,
1670            vec![2i16, 3i16, 4i16],
1671            Int64Array,
1672            DataType::Int64,
1673            vec![4i64, 12i64, 7i64],
1674            Operator::BitwiseShiftLeft,
1675            Int64Array,
1676            DataType::Int64,
1677            [32i64, 12288i64, 512i64],
1678        );
1679        test_coercion!(
1680            UInt16Array,
1681            DataType::UInt16,
1682            vec![2u16, 3u16, 4u16],
1683            UInt64Array,
1684            DataType::UInt64,
1685            vec![4u64, 12u64, 7u64],
1686            Operator::BitwiseShiftLeft,
1687            UInt64Array,
1688            DataType::UInt64,
1689            [32u64, 12288u64, 512u64],
1690        );
1691        Ok(())
1692    }
1693
1694    // Note it would be nice to use the same test_coercion macro as
1695    // above, but sadly the type of the values of the dictionary are
1696    // not encoded in the rust type of the DictionaryArray. Thus there
1697    // is no way at the time of this writing to create a dictionary
1698    // array using the `From` trait
1699    #[test]
1700    fn test_dictionary_type_to_array_coercion() -> Result<()> {
1701        // Test string  a string dictionary
1702        let dict_type =
1703            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
1704        let string_type = DataType::Utf8;
1705
1706        // build dictionary
1707        let mut dict_builder = StringDictionaryBuilder::<Int32Type>::new();
1708
1709        dict_builder.append("one")?;
1710        dict_builder.append_null();
1711        dict_builder.append("three")?;
1712        dict_builder.append("four")?;
1713        let dict_array = Arc::new(dict_builder.finish()) as ArrayRef;
1714
1715        let str_array = Arc::new(StringArray::from(vec![
1716            Some("not one"),
1717            Some("two"),
1718            None,
1719            Some("four"),
1720        ])) as ArrayRef;
1721
1722        let schema = Arc::new(Schema::new(vec![
1723            Field::new("a", dict_type.clone(), true),
1724            Field::new("b", string_type.clone(), true),
1725        ]));
1726
1727        // Test 1: a = b
1728        let result = BooleanArray::from(vec![Some(false), None, None, Some(true)]);
1729        apply_logic_op(&schema, &dict_array, &str_array, Operator::Eq, result)?;
1730
1731        // Test 2: now test the other direction
1732        // b = a
1733        let schema = Arc::new(Schema::new(vec![
1734            Field::new("a", string_type, true),
1735            Field::new("b", dict_type, true),
1736        ]));
1737        let result = BooleanArray::from(vec![Some(false), None, None, Some(true)]);
1738        apply_logic_op(&schema, &str_array, &dict_array, Operator::Eq, result)?;
1739
1740        Ok(())
1741    }
1742
1743    #[test]
1744    fn plus_op() -> Result<()> {
1745        let schema = Schema::new(vec![
1746            Field::new("a", DataType::Int32, false),
1747            Field::new("b", DataType::Int32, false),
1748        ]);
1749        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1750        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1751
1752        apply_arithmetic::<Int32Type>(
1753            Arc::new(schema),
1754            vec![Arc::new(a), Arc::new(b)],
1755            Operator::Plus,
1756            Int32Array::from(vec![2, 4, 7, 12, 21]),
1757        )?;
1758
1759        Ok(())
1760    }
1761
1762    #[test]
1763    fn plus_op_dict() -> Result<()> {
1764        let schema = Schema::new(vec![
1765            Field::new(
1766                "a",
1767                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1768                true,
1769            ),
1770            Field::new(
1771                "b",
1772                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1773                true,
1774            ),
1775        ]);
1776
1777        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1778        let keys = Int8Array::from(vec![Some(0), None, Some(1), Some(3), None]);
1779        let a = DictionaryArray::try_new(keys, Arc::new(a))?;
1780
1781        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1782        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
1783        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
1784
1785        apply_arithmetic::<Int32Type>(
1786            Arc::new(schema),
1787            vec![Arc::new(a), Arc::new(b)],
1788            Operator::Plus,
1789            Int32Array::from(vec![Some(2), None, Some(4), Some(8), None]),
1790        )?;
1791
1792        Ok(())
1793    }
1794
1795    #[test]
1796    fn plus_op_dict_decimal() -> Result<()> {
1797        let schema = Schema::new(vec![
1798            Field::new(
1799                "a",
1800                DataType::Dictionary(
1801                    Box::new(DataType::Int8),
1802                    Box::new(DataType::Decimal128(10, 0)),
1803                ),
1804                true,
1805            ),
1806            Field::new(
1807                "b",
1808                DataType::Dictionary(
1809                    Box::new(DataType::Int8),
1810                    Box::new(DataType::Decimal128(10, 0)),
1811                ),
1812                true,
1813            ),
1814        ]);
1815
1816        let value = 123;
1817        let decimal_array = Arc::new(create_decimal_array(
1818            &[
1819                Some(value),
1820                Some(value + 2),
1821                Some(value - 1),
1822                Some(value + 1),
1823            ],
1824            10,
1825            0,
1826        ));
1827
1828        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
1829        let a = DictionaryArray::try_new(keys, decimal_array)?;
1830
1831        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
1832        let decimal_array = Arc::new(create_decimal_array(
1833            &[
1834                Some(value + 1),
1835                Some(value + 3),
1836                Some(value),
1837                Some(value + 2),
1838            ],
1839            10,
1840            0,
1841        ));
1842        let b = DictionaryArray::try_new(keys, decimal_array)?;
1843
1844        apply_arithmetic(
1845            Arc::new(schema),
1846            vec![Arc::new(a), Arc::new(b)],
1847            Operator::Plus,
1848            create_decimal_array(&[Some(247), None, None, Some(247), Some(246)], 11, 0),
1849        )?;
1850
1851        Ok(())
1852    }
1853
1854    #[test]
1855    fn plus_op_scalar() -> Result<()> {
1856        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1857        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1858
1859        apply_arithmetic_scalar(
1860            Arc::new(schema),
1861            vec![Arc::new(a)],
1862            Operator::Plus,
1863            ScalarValue::Int32(Some(1)),
1864            Arc::new(Int32Array::from(vec![2, 3, 4, 5, 6])),
1865        )?;
1866
1867        Ok(())
1868    }
1869
1870    #[test]
1871    fn plus_op_dict_scalar() -> Result<()> {
1872        let schema = Schema::new(vec![Field::new(
1873            "a",
1874            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1875            true,
1876        )]);
1877
1878        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
1879
1880        dict_builder.append(1)?;
1881        dict_builder.append_null();
1882        dict_builder.append(2)?;
1883        dict_builder.append(5)?;
1884
1885        let a = dict_builder.finish();
1886
1887        let expected: PrimitiveArray<Int32Type> =
1888            PrimitiveArray::from(vec![Some(2), None, Some(3), Some(6)]);
1889
1890        apply_arithmetic_scalar(
1891            Arc::new(schema),
1892            vec![Arc::new(a)],
1893            Operator::Plus,
1894            ScalarValue::Dictionary(
1895                Box::new(DataType::Int8),
1896                Box::new(ScalarValue::Int32(Some(1))),
1897            ),
1898            Arc::new(expected),
1899        )?;
1900
1901        Ok(())
1902    }
1903
1904    #[test]
1905    fn plus_op_dict_scalar_decimal() -> Result<()> {
1906        let schema = Schema::new(vec![Field::new(
1907            "a",
1908            DataType::Dictionary(
1909                Box::new(DataType::Int8),
1910                Box::new(DataType::Decimal128(10, 0)),
1911            ),
1912            true,
1913        )]);
1914
1915        let value = 123;
1916        let decimal_array = Arc::new(create_decimal_array(
1917            &[Some(value), None, Some(value - 1), Some(value + 1)],
1918            10,
1919            0,
1920        ));
1921
1922        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
1923        let a = DictionaryArray::try_new(keys, decimal_array)?;
1924
1925        let decimal_array = Arc::new(create_decimal_array(
1926            &[
1927                Some(value + 1),
1928                Some(value),
1929                None,
1930                Some(value + 2),
1931                Some(value + 1),
1932            ],
1933            11,
1934            0,
1935        ));
1936
1937        apply_arithmetic_scalar(
1938            Arc::new(schema),
1939            vec![Arc::new(a)],
1940            Operator::Plus,
1941            ScalarValue::Dictionary(
1942                Box::new(DataType::Int8),
1943                Box::new(ScalarValue::Decimal128(Some(1), 10, 0)),
1944            ),
1945            decimal_array,
1946        )?;
1947
1948        Ok(())
1949    }
1950
1951    #[test]
1952    fn minus_op() -> Result<()> {
1953        let schema = Arc::new(Schema::new(vec![
1954            Field::new("a", DataType::Int32, false),
1955            Field::new("b", DataType::Int32, false),
1956        ]));
1957        let a = Arc::new(Int32Array::from(vec![1, 2, 4, 8, 16]));
1958        let b = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
1959
1960        apply_arithmetic::<Int32Type>(
1961            Arc::clone(&schema),
1962            vec![
1963                Arc::clone(&a) as Arc<dyn Array>,
1964                Arc::clone(&b) as Arc<dyn Array>,
1965            ],
1966            Operator::Minus,
1967            Int32Array::from(vec![0, 0, 1, 4, 11]),
1968        )?;
1969
1970        // should handle have negative values in result (for signed)
1971        apply_arithmetic::<Int32Type>(
1972            schema,
1973            vec![b, a],
1974            Operator::Minus,
1975            Int32Array::from(vec![0, 0, -1, -4, -11]),
1976        )?;
1977
1978        Ok(())
1979    }
1980
1981    #[test]
1982    fn minus_op_dict() -> Result<()> {
1983        let schema = Schema::new(vec![
1984            Field::new(
1985                "a",
1986                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1987                true,
1988            ),
1989            Field::new(
1990                "b",
1991                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1992                true,
1993            ),
1994        ]);
1995
1996        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1997        let keys = Int8Array::from(vec![Some(0), None, Some(1), Some(3), None]);
1998        let a = DictionaryArray::try_new(keys, Arc::new(a))?;
1999
2000        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
2001        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
2002        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
2003
2004        apply_arithmetic::<Int32Type>(
2005            Arc::new(schema),
2006            vec![Arc::new(a), Arc::new(b)],
2007            Operator::Minus,
2008            Int32Array::from(vec![Some(0), None, Some(0), Some(0), None]),
2009        )?;
2010
2011        Ok(())
2012    }
2013
2014    #[test]
2015    fn minus_op_dict_decimal() -> Result<()> {
2016        let schema = Schema::new(vec![
2017            Field::new(
2018                "a",
2019                DataType::Dictionary(
2020                    Box::new(DataType::Int8),
2021                    Box::new(DataType::Decimal128(10, 0)),
2022                ),
2023                true,
2024            ),
2025            Field::new(
2026                "b",
2027                DataType::Dictionary(
2028                    Box::new(DataType::Int8),
2029                    Box::new(DataType::Decimal128(10, 0)),
2030                ),
2031                true,
2032            ),
2033        ]);
2034
2035        let value = 123;
2036        let decimal_array = Arc::new(create_decimal_array(
2037            &[
2038                Some(value),
2039                Some(value + 2),
2040                Some(value - 1),
2041                Some(value + 1),
2042            ],
2043            10,
2044            0,
2045        ));
2046
2047        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
2048        let a = DictionaryArray::try_new(keys, decimal_array)?;
2049
2050        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
2051        let decimal_array = Arc::new(create_decimal_array(
2052            &[
2053                Some(value + 1),
2054                Some(value + 3),
2055                Some(value),
2056                Some(value + 2),
2057            ],
2058            10,
2059            0,
2060        ));
2061        let b = DictionaryArray::try_new(keys, decimal_array)?;
2062
2063        apply_arithmetic(
2064            Arc::new(schema),
2065            vec![Arc::new(a), Arc::new(b)],
2066            Operator::Minus,
2067            create_decimal_array(&[Some(-1), None, None, Some(1), Some(0)], 11, 0),
2068        )?;
2069
2070        Ok(())
2071    }
2072
2073    #[test]
2074    fn minus_op_scalar() -> Result<()> {
2075        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
2076        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
2077
2078        apply_arithmetic_scalar(
2079            Arc::new(schema),
2080            vec![Arc::new(a)],
2081            Operator::Minus,
2082            ScalarValue::Int32(Some(1)),
2083            Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4])),
2084        )?;
2085
2086        Ok(())
2087    }
2088
2089    #[test]
2090    fn minus_op_dict_scalar() -> Result<()> {
2091        let schema = Schema::new(vec![Field::new(
2092            "a",
2093            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2094            true,
2095        )]);
2096
2097        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2098
2099        dict_builder.append(1)?;
2100        dict_builder.append_null();
2101        dict_builder.append(2)?;
2102        dict_builder.append(5)?;
2103
2104        let a = dict_builder.finish();
2105
2106        let expected: PrimitiveArray<Int32Type> =
2107            PrimitiveArray::from(vec![Some(0), None, Some(1), Some(4)]);
2108
2109        apply_arithmetic_scalar(
2110            Arc::new(schema),
2111            vec![Arc::new(a)],
2112            Operator::Minus,
2113            ScalarValue::Dictionary(
2114                Box::new(DataType::Int8),
2115                Box::new(ScalarValue::Int32(Some(1))),
2116            ),
2117            Arc::new(expected),
2118        )?;
2119
2120        Ok(())
2121    }
2122
2123    #[test]
2124    fn minus_op_dict_scalar_decimal() -> Result<()> {
2125        let schema = Schema::new(vec![Field::new(
2126            "a",
2127            DataType::Dictionary(
2128                Box::new(DataType::Int8),
2129                Box::new(DataType::Decimal128(10, 0)),
2130            ),
2131            true,
2132        )]);
2133
2134        let value = 123;
2135        let decimal_array = Arc::new(create_decimal_array(
2136            &[Some(value), None, Some(value - 1), Some(value + 1)],
2137            10,
2138            0,
2139        ));
2140
2141        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
2142        let a = DictionaryArray::try_new(keys, decimal_array)?;
2143
2144        let decimal_array = Arc::new(create_decimal_array(
2145            &[
2146                Some(value - 1),
2147                Some(value - 2),
2148                None,
2149                Some(value),
2150                Some(value - 1),
2151            ],
2152            11,
2153            0,
2154        ));
2155
2156        apply_arithmetic_scalar(
2157            Arc::new(schema),
2158            vec![Arc::new(a)],
2159            Operator::Minus,
2160            ScalarValue::Dictionary(
2161                Box::new(DataType::Int8),
2162                Box::new(ScalarValue::Decimal128(Some(1), 10, 0)),
2163            ),
2164            decimal_array,
2165        )?;
2166
2167        Ok(())
2168    }
2169
2170    #[test]
2171    fn multiply_op() -> Result<()> {
2172        let schema = Arc::new(Schema::new(vec![
2173            Field::new("a", DataType::Int32, false),
2174            Field::new("b", DataType::Int32, false),
2175        ]));
2176        let a = Arc::new(Int32Array::from(vec![4, 8, 16, 32, 64]));
2177        let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32]));
2178
2179        apply_arithmetic::<Int32Type>(
2180            schema,
2181            vec![a, b],
2182            Operator::Multiply,
2183            Int32Array::from(vec![8, 32, 128, 512, 2048]),
2184        )?;
2185
2186        Ok(())
2187    }
2188
2189    #[test]
2190    fn multiply_op_dict() -> Result<()> {
2191        let schema = Schema::new(vec![
2192            Field::new(
2193                "a",
2194                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2195                true,
2196            ),
2197            Field::new(
2198                "b",
2199                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2200                true,
2201            ),
2202        ]);
2203
2204        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
2205        let keys = Int8Array::from(vec![Some(0), None, Some(1), Some(3), None]);
2206        let a = DictionaryArray::try_new(keys, Arc::new(a))?;
2207
2208        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
2209        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
2210        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
2211
2212        apply_arithmetic::<Int32Type>(
2213            Arc::new(schema),
2214            vec![Arc::new(a), Arc::new(b)],
2215            Operator::Multiply,
2216            Int32Array::from(vec![Some(1), None, Some(4), Some(16), None]),
2217        )?;
2218
2219        Ok(())
2220    }
2221
2222    #[test]
2223    fn multiply_op_dict_decimal() -> Result<()> {
2224        let schema = Schema::new(vec![
2225            Field::new(
2226                "a",
2227                DataType::Dictionary(
2228                    Box::new(DataType::Int8),
2229                    Box::new(DataType::Decimal128(10, 0)),
2230                ),
2231                true,
2232            ),
2233            Field::new(
2234                "b",
2235                DataType::Dictionary(
2236                    Box::new(DataType::Int8),
2237                    Box::new(DataType::Decimal128(10, 0)),
2238                ),
2239                true,
2240            ),
2241        ]);
2242
2243        let value = 123;
2244        let decimal_array = Arc::new(create_decimal_array(
2245            &[
2246                Some(value),
2247                Some(value + 2),
2248                Some(value - 1),
2249                Some(value + 1),
2250            ],
2251            10,
2252            0,
2253        )) as ArrayRef;
2254
2255        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
2256        let a = DictionaryArray::try_new(keys, decimal_array)?;
2257
2258        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
2259        let decimal_array = Arc::new(create_decimal_array(
2260            &[
2261                Some(value + 1),
2262                Some(value + 3),
2263                Some(value),
2264                Some(value + 2),
2265            ],
2266            10,
2267            0,
2268        ));
2269        let b = DictionaryArray::try_new(keys, decimal_array)?;
2270
2271        apply_arithmetic(
2272            Arc::new(schema),
2273            vec![Arc::new(a), Arc::new(b)],
2274            Operator::Multiply,
2275            create_decimal_array(
2276                &[Some(15252), None, None, Some(15252), Some(15129)],
2277                21,
2278                0,
2279            ),
2280        )?;
2281
2282        Ok(())
2283    }
2284
2285    #[test]
2286    fn multiply_op_scalar() -> Result<()> {
2287        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
2288        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
2289
2290        apply_arithmetic_scalar(
2291            Arc::new(schema),
2292            vec![Arc::new(a)],
2293            Operator::Multiply,
2294            ScalarValue::Int32(Some(2)),
2295            Arc::new(Int32Array::from(vec![2, 4, 6, 8, 10])),
2296        )?;
2297
2298        Ok(())
2299    }
2300
2301    #[test]
2302    fn multiply_op_dict_scalar() -> Result<()> {
2303        let schema = Schema::new(vec![Field::new(
2304            "a",
2305            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2306            true,
2307        )]);
2308
2309        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2310
2311        dict_builder.append(1)?;
2312        dict_builder.append_null();
2313        dict_builder.append(2)?;
2314        dict_builder.append(5)?;
2315
2316        let a = dict_builder.finish();
2317
2318        let expected: PrimitiveArray<Int32Type> =
2319            PrimitiveArray::from(vec![Some(2), None, Some(4), Some(10)]);
2320
2321        apply_arithmetic_scalar(
2322            Arc::new(schema),
2323            vec![Arc::new(a)],
2324            Operator::Multiply,
2325            ScalarValue::Dictionary(
2326                Box::new(DataType::Int8),
2327                Box::new(ScalarValue::Int32(Some(2))),
2328            ),
2329            Arc::new(expected),
2330        )?;
2331
2332        Ok(())
2333    }
2334
2335    #[test]
2336    fn multiply_op_dict_scalar_decimal() -> Result<()> {
2337        let schema = Schema::new(vec![Field::new(
2338            "a",
2339            DataType::Dictionary(
2340                Box::new(DataType::Int8),
2341                Box::new(DataType::Decimal128(10, 0)),
2342            ),
2343            true,
2344        )]);
2345
2346        let value = 123;
2347        let decimal_array = Arc::new(create_decimal_array(
2348            &[Some(value), None, Some(value - 1), Some(value + 1)],
2349            10,
2350            0,
2351        ));
2352
2353        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
2354        let a = DictionaryArray::try_new(keys, decimal_array)?;
2355
2356        let decimal_array = Arc::new(create_decimal_array(
2357            &[Some(246), Some(244), None, Some(248), Some(246)],
2358            21,
2359            0,
2360        ));
2361
2362        apply_arithmetic_scalar(
2363            Arc::new(schema),
2364            vec![Arc::new(a)],
2365            Operator::Multiply,
2366            ScalarValue::Dictionary(
2367                Box::new(DataType::Int8),
2368                Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
2369            ),
2370            decimal_array,
2371        )?;
2372
2373        Ok(())
2374    }
2375
2376    #[test]
2377    fn divide_op() -> Result<()> {
2378        let schema = Arc::new(Schema::new(vec![
2379            Field::new("a", DataType::Int32, false),
2380            Field::new("b", DataType::Int32, false),
2381        ]));
2382        let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048]));
2383        let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32]));
2384
2385        apply_arithmetic::<Int32Type>(
2386            schema,
2387            vec![a, b],
2388            Operator::Divide,
2389            Int32Array::from(vec![4, 8, 16, 32, 64]),
2390        )?;
2391
2392        Ok(())
2393    }
2394
2395    #[test]
2396    fn divide_op_dict() -> Result<()> {
2397        let schema = Schema::new(vec![
2398            Field::new(
2399                "a",
2400                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2401                true,
2402            ),
2403            Field::new(
2404                "b",
2405                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2406                true,
2407            ),
2408        ]);
2409
2410        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2411
2412        dict_builder.append(1)?;
2413        dict_builder.append_null();
2414        dict_builder.append(2)?;
2415        dict_builder.append(5)?;
2416        dict_builder.append(0)?;
2417
2418        let a = dict_builder.finish();
2419
2420        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
2421        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
2422        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
2423
2424        apply_arithmetic::<Int32Type>(
2425            Arc::new(schema),
2426            vec![Arc::new(a), Arc::new(b)],
2427            Operator::Divide,
2428            Int32Array::from(vec![Some(1), None, Some(1), Some(1), Some(0)]),
2429        )?;
2430
2431        Ok(())
2432    }
2433
2434    #[test]
2435    fn divide_op_dict_decimal() -> Result<()> {
2436        let schema = Schema::new(vec![
2437            Field::new(
2438                "a",
2439                DataType::Dictionary(
2440                    Box::new(DataType::Int8),
2441                    Box::new(DataType::Decimal128(10, 0)),
2442                ),
2443                true,
2444            ),
2445            Field::new(
2446                "b",
2447                DataType::Dictionary(
2448                    Box::new(DataType::Int8),
2449                    Box::new(DataType::Decimal128(10, 0)),
2450                ),
2451                true,
2452            ),
2453        ]);
2454
2455        let value = 123;
2456        let decimal_array = Arc::new(create_decimal_array(
2457            &[
2458                Some(value),
2459                Some(value + 2),
2460                Some(value - 1),
2461                Some(value + 1),
2462            ],
2463            10,
2464            0,
2465        ));
2466
2467        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
2468        let a = DictionaryArray::try_new(keys, decimal_array)?;
2469
2470        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
2471        let decimal_array = Arc::new(create_decimal_array(
2472            &[
2473                Some(value + 1),
2474                Some(value + 3),
2475                Some(value),
2476                Some(value + 2),
2477            ],
2478            10,
2479            0,
2480        ));
2481        let b = DictionaryArray::try_new(keys, decimal_array)?;
2482
2483        apply_arithmetic(
2484            Arc::new(schema),
2485            vec![Arc::new(a), Arc::new(b)],
2486            Operator::Divide,
2487            create_decimal_array(
2488                &[
2489                    Some(9919), // 0.9919
2490                    None,
2491                    None,
2492                    Some(10081), // 1.0081
2493                    Some(10000), // 1.0
2494                ],
2495                14,
2496                4,
2497            ),
2498        )?;
2499
2500        Ok(())
2501    }
2502
2503    #[test]
2504    fn divide_op_scalar() -> Result<()> {
2505        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
2506        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
2507
2508        apply_arithmetic_scalar(
2509            Arc::new(schema),
2510            vec![Arc::new(a)],
2511            Operator::Divide,
2512            ScalarValue::Int32(Some(2)),
2513            Arc::new(Int32Array::from(vec![0, 1, 1, 2, 2])),
2514        )?;
2515
2516        Ok(())
2517    }
2518
2519    #[test]
2520    fn divide_op_dict_scalar() -> Result<()> {
2521        let schema = Schema::new(vec![Field::new(
2522            "a",
2523            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2524            true,
2525        )]);
2526
2527        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2528
2529        dict_builder.append(1)?;
2530        dict_builder.append_null();
2531        dict_builder.append(2)?;
2532        dict_builder.append(5)?;
2533
2534        let a = dict_builder.finish();
2535
2536        let expected: PrimitiveArray<Int32Type> =
2537            PrimitiveArray::from(vec![Some(0), None, Some(1), Some(2)]);
2538
2539        apply_arithmetic_scalar(
2540            Arc::new(schema),
2541            vec![Arc::new(a)],
2542            Operator::Divide,
2543            ScalarValue::Dictionary(
2544                Box::new(DataType::Int8),
2545                Box::new(ScalarValue::Int32(Some(2))),
2546            ),
2547            Arc::new(expected),
2548        )?;
2549
2550        Ok(())
2551    }
2552
2553    #[test]
2554    fn divide_op_dict_scalar_decimal() -> Result<()> {
2555        let schema = Schema::new(vec![Field::new(
2556            "a",
2557            DataType::Dictionary(
2558                Box::new(DataType::Int8),
2559                Box::new(DataType::Decimal128(10, 0)),
2560            ),
2561            true,
2562        )]);
2563
2564        let value = 123;
2565        let decimal_array = Arc::new(create_decimal_array(
2566            &[Some(value), None, Some(value - 1), Some(value + 1)],
2567            10,
2568            0,
2569        ));
2570
2571        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
2572        let a = DictionaryArray::try_new(keys, decimal_array)?;
2573
2574        let decimal_array = Arc::new(create_decimal_array(
2575            &[Some(615000), Some(610000), None, Some(620000), Some(615000)],
2576            14,
2577            4,
2578        ));
2579
2580        apply_arithmetic_scalar(
2581            Arc::new(schema),
2582            vec![Arc::new(a)],
2583            Operator::Divide,
2584            ScalarValue::Dictionary(
2585                Box::new(DataType::Int8),
2586                Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
2587            ),
2588            decimal_array,
2589        )?;
2590
2591        Ok(())
2592    }
2593
2594    #[test]
2595    fn modulus_op() -> Result<()> {
2596        let schema = Arc::new(Schema::new(vec![
2597            Field::new("a", DataType::Int32, false),
2598            Field::new("b", DataType::Int32, false),
2599        ]));
2600        let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048]));
2601        let b = Arc::new(Int32Array::from(vec![2, 4, 7, 14, 32]));
2602
2603        apply_arithmetic::<Int32Type>(
2604            schema,
2605            vec![a, b],
2606            Operator::Modulo,
2607            Int32Array::from(vec![0, 0, 2, 8, 0]),
2608        )?;
2609
2610        Ok(())
2611    }
2612
2613    #[test]
2614    fn modulus_op_dict() -> Result<()> {
2615        let schema = Schema::new(vec![
2616            Field::new(
2617                "a",
2618                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2619                true,
2620            ),
2621            Field::new(
2622                "b",
2623                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2624                true,
2625            ),
2626        ]);
2627
2628        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2629
2630        dict_builder.append(1)?;
2631        dict_builder.append_null();
2632        dict_builder.append(2)?;
2633        dict_builder.append(5)?;
2634        dict_builder.append(0)?;
2635
2636        let a = dict_builder.finish();
2637
2638        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
2639        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
2640        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
2641
2642        apply_arithmetic::<Int32Type>(
2643            Arc::new(schema),
2644            vec![Arc::new(a), Arc::new(b)],
2645            Operator::Modulo,
2646            Int32Array::from(vec![Some(0), None, Some(0), Some(1), Some(0)]),
2647        )?;
2648
2649        Ok(())
2650    }
2651
2652    #[test]
2653    fn modulus_op_dict_decimal() -> Result<()> {
2654        let schema = Schema::new(vec![
2655            Field::new(
2656                "a",
2657                DataType::Dictionary(
2658                    Box::new(DataType::Int8),
2659                    Box::new(DataType::Decimal128(10, 0)),
2660                ),
2661                true,
2662            ),
2663            Field::new(
2664                "b",
2665                DataType::Dictionary(
2666                    Box::new(DataType::Int8),
2667                    Box::new(DataType::Decimal128(10, 0)),
2668                ),
2669                true,
2670            ),
2671        ]);
2672
2673        let value = 123;
2674        let decimal_array = Arc::new(create_decimal_array(
2675            &[
2676                Some(value),
2677                Some(value + 2),
2678                Some(value - 1),
2679                Some(value + 1),
2680            ],
2681            10,
2682            0,
2683        ));
2684
2685        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
2686        let a = DictionaryArray::try_new(keys, decimal_array)?;
2687
2688        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
2689        let decimal_array = Arc::new(create_decimal_array(
2690            &[
2691                Some(value + 1),
2692                Some(value + 3),
2693                Some(value),
2694                Some(value + 2),
2695            ],
2696            10,
2697            0,
2698        ));
2699        let b = DictionaryArray::try_new(keys, decimal_array)?;
2700
2701        apply_arithmetic(
2702            Arc::new(schema),
2703            vec![Arc::new(a), Arc::new(b)],
2704            Operator::Modulo,
2705            create_decimal_array(&[Some(123), None, None, Some(1), Some(0)], 10, 0),
2706        )?;
2707
2708        Ok(())
2709    }
2710
2711    #[test]
2712    fn modulus_op_scalar() -> Result<()> {
2713        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
2714        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
2715
2716        apply_arithmetic_scalar(
2717            Arc::new(schema),
2718            vec![Arc::new(a)],
2719            Operator::Modulo,
2720            ScalarValue::Int32(Some(2)),
2721            Arc::new(Int32Array::from(vec![1, 0, 1, 0, 1])),
2722        )?;
2723
2724        Ok(())
2725    }
2726
2727    #[test]
2728    fn modules_op_dict_scalar() -> Result<()> {
2729        let schema = Schema::new(vec![Field::new(
2730            "a",
2731            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2732            true,
2733        )]);
2734
2735        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2736
2737        dict_builder.append(1)?;
2738        dict_builder.append_null();
2739        dict_builder.append(2)?;
2740        dict_builder.append(5)?;
2741
2742        let a = dict_builder.finish();
2743
2744        let expected: PrimitiveArray<Int32Type> =
2745            PrimitiveArray::from(vec![Some(1), None, Some(0), Some(1)]);
2746
2747        apply_arithmetic_scalar(
2748            Arc::new(schema),
2749            vec![Arc::new(a)],
2750            Operator::Modulo,
2751            ScalarValue::Dictionary(
2752                Box::new(DataType::Int8),
2753                Box::new(ScalarValue::Int32(Some(2))),
2754            ),
2755            Arc::new(expected),
2756        )?;
2757
2758        Ok(())
2759    }
2760
2761    #[test]
2762    fn modulus_op_dict_scalar_decimal() -> Result<()> {
2763        let schema = Schema::new(vec![Field::new(
2764            "a",
2765            DataType::Dictionary(
2766                Box::new(DataType::Int8),
2767                Box::new(DataType::Decimal128(10, 0)),
2768            ),
2769            true,
2770        )]);
2771
2772        let value = 123;
2773        let decimal_array = Arc::new(create_decimal_array(
2774            &[Some(value), None, Some(value - 1), Some(value + 1)],
2775            10,
2776            0,
2777        ));
2778
2779        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
2780        let a = DictionaryArray::try_new(keys, decimal_array)?;
2781
2782        let decimal_array = Arc::new(create_decimal_array(
2783            &[Some(1), Some(0), None, Some(0), Some(1)],
2784            10,
2785            0,
2786        ));
2787
2788        apply_arithmetic_scalar(
2789            Arc::new(schema),
2790            vec![Arc::new(a)],
2791            Operator::Modulo,
2792            ScalarValue::Dictionary(
2793                Box::new(DataType::Int8),
2794                Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
2795            ),
2796            decimal_array,
2797        )?;
2798
2799        Ok(())
2800    }
2801
2802    fn apply_arithmetic<T: ArrowNumericType>(
2803        schema: SchemaRef,
2804        data: Vec<ArrayRef>,
2805        op: Operator,
2806        expected: PrimitiveArray<T>,
2807    ) -> Result<()> {
2808        let arithmetic_op =
2809            binary_op(col("a", &schema)?, op, col("b", &schema)?, &schema)?;
2810        let batch = RecordBatch::try_new(schema, data)?;
2811        let result = arithmetic_op
2812            .evaluate(&batch)?
2813            .into_array(batch.num_rows())
2814            .expect("Failed to convert to array");
2815
2816        assert_eq!(result.as_ref(), &expected);
2817        Ok(())
2818    }
2819
2820    fn apply_arithmetic_scalar(
2821        schema: SchemaRef,
2822        data: Vec<ArrayRef>,
2823        op: Operator,
2824        literal: ScalarValue,
2825        expected: ArrayRef,
2826    ) -> Result<()> {
2827        let lit = Arc::new(Literal::new(literal));
2828        let arithmetic_op = binary_op(col("a", &schema)?, op, lit, &schema)?;
2829        let batch = RecordBatch::try_new(schema, data)?;
2830        let result = arithmetic_op
2831            .evaluate(&batch)?
2832            .into_array(batch.num_rows())
2833            .expect("Failed to convert to array");
2834
2835        assert_eq!(&result, &expected);
2836        Ok(())
2837    }
2838
2839    fn apply_logic_op(
2840        schema: &SchemaRef,
2841        left: &ArrayRef,
2842        right: &ArrayRef,
2843        op: Operator,
2844        expected: BooleanArray,
2845    ) -> Result<()> {
2846        let op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?;
2847        let data: Vec<ArrayRef> = vec![Arc::clone(left), Arc::clone(right)];
2848        let batch = RecordBatch::try_new(Arc::clone(schema), data)?;
2849        let result = op
2850            .evaluate(&batch)?
2851            .into_array(batch.num_rows())
2852            .expect("Failed to convert to array");
2853
2854        assert_eq!(result.as_ref(), &expected);
2855        Ok(())
2856    }
2857
2858    // Test `scalar <op> arr` produces expected
2859    fn apply_logic_op_scalar_arr(
2860        schema: &SchemaRef,
2861        scalar: &ScalarValue,
2862        arr: &ArrayRef,
2863        op: Operator,
2864        expected: &BooleanArray,
2865    ) -> Result<()> {
2866        let scalar = lit(scalar.clone());
2867        let op = binary_op(scalar, op, col("a", schema)?, schema)?;
2868        let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
2869        let result = op
2870            .evaluate(&batch)?
2871            .into_array(batch.num_rows())
2872            .expect("Failed to convert to array");
2873        assert_eq!(result.as_ref(), expected);
2874
2875        Ok(())
2876    }
2877
2878    // Test `arr <op> scalar` produces expected
2879    fn apply_logic_op_arr_scalar(
2880        schema: &SchemaRef,
2881        arr: &ArrayRef,
2882        scalar: &ScalarValue,
2883        op: Operator,
2884        expected: &BooleanArray,
2885    ) -> Result<()> {
2886        let scalar = lit(scalar.clone());
2887        let op = binary_op(col("a", schema)?, op, scalar, schema)?;
2888        let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
2889        let result = op
2890            .evaluate(&batch)?
2891            .into_array(batch.num_rows())
2892            .expect("Failed to convert to array");
2893        assert_eq!(result.as_ref(), expected);
2894
2895        Ok(())
2896    }
2897
2898    #[test]
2899    fn and_with_nulls_op() -> Result<()> {
2900        let schema = Schema::new(vec![
2901            Field::new("a", DataType::Boolean, true),
2902            Field::new("b", DataType::Boolean, true),
2903        ]);
2904        let a = Arc::new(BooleanArray::from(vec![
2905            Some(true),
2906            Some(false),
2907            None,
2908            Some(true),
2909            Some(false),
2910            None,
2911            Some(true),
2912            Some(false),
2913            None,
2914        ])) as ArrayRef;
2915        let b = Arc::new(BooleanArray::from(vec![
2916            Some(true),
2917            Some(true),
2918            Some(true),
2919            Some(false),
2920            Some(false),
2921            Some(false),
2922            None,
2923            None,
2924            None,
2925        ])) as ArrayRef;
2926
2927        let expected = BooleanArray::from(vec![
2928            Some(true),
2929            Some(false),
2930            None,
2931            Some(false),
2932            Some(false),
2933            Some(false),
2934            None,
2935            Some(false),
2936            None,
2937        ]);
2938        apply_logic_op(&Arc::new(schema), &a, &b, Operator::And, expected)?;
2939
2940        Ok(())
2941    }
2942
2943    #[test]
2944    fn regex_with_nulls() -> Result<()> {
2945        let schema = Schema::new(vec![
2946            Field::new("a", DataType::Utf8, true),
2947            Field::new("b", DataType::Utf8, true),
2948        ]);
2949        let a = Arc::new(StringArray::from(vec![
2950            Some("abc"),
2951            None,
2952            Some("abc"),
2953            None,
2954            Some("abc"),
2955        ])) as ArrayRef;
2956        let b = Arc::new(StringArray::from(vec![
2957            Some("^a"),
2958            Some("^A"),
2959            None,
2960            None,
2961            Some("^(b|c)"),
2962        ])) as ArrayRef;
2963
2964        let regex_expected =
2965            BooleanArray::from(vec![Some(true), None, None, None, Some(false)]);
2966        let regex_not_expected =
2967            BooleanArray::from(vec![Some(false), None, None, None, Some(true)]);
2968        apply_logic_op(
2969            &Arc::new(schema.clone()),
2970            &a,
2971            &b,
2972            Operator::RegexMatch,
2973            regex_expected.clone(),
2974        )?;
2975        apply_logic_op(
2976            &Arc::new(schema.clone()),
2977            &a,
2978            &b,
2979            Operator::RegexIMatch,
2980            regex_expected.clone(),
2981        )?;
2982        apply_logic_op(
2983            &Arc::new(schema.clone()),
2984            &a,
2985            &b,
2986            Operator::RegexNotMatch,
2987            regex_not_expected.clone(),
2988        )?;
2989        apply_logic_op(
2990            &Arc::new(schema),
2991            &a,
2992            &b,
2993            Operator::RegexNotIMatch,
2994            regex_not_expected.clone(),
2995        )?;
2996
2997        let schema = Schema::new(vec![
2998            Field::new("a", DataType::LargeUtf8, true),
2999            Field::new("b", DataType::LargeUtf8, true),
3000        ]);
3001        let a = Arc::new(LargeStringArray::from(vec![
3002            Some("abc"),
3003            None,
3004            Some("abc"),
3005            None,
3006            Some("abc"),
3007        ])) as ArrayRef;
3008        let b = Arc::new(LargeStringArray::from(vec![
3009            Some("^a"),
3010            Some("^A"),
3011            None,
3012            None,
3013            Some("^(b|c)"),
3014        ])) as ArrayRef;
3015
3016        apply_logic_op(
3017            &Arc::new(schema.clone()),
3018            &a,
3019            &b,
3020            Operator::RegexMatch,
3021            regex_expected.clone(),
3022        )?;
3023        apply_logic_op(
3024            &Arc::new(schema.clone()),
3025            &a,
3026            &b,
3027            Operator::RegexIMatch,
3028            regex_expected,
3029        )?;
3030        apply_logic_op(
3031            &Arc::new(schema.clone()),
3032            &a,
3033            &b,
3034            Operator::RegexNotMatch,
3035            regex_not_expected.clone(),
3036        )?;
3037        apply_logic_op(
3038            &Arc::new(schema),
3039            &a,
3040            &b,
3041            Operator::RegexNotIMatch,
3042            regex_not_expected,
3043        )?;
3044
3045        Ok(())
3046    }
3047
3048    #[test]
3049    fn or_with_nulls_op() -> Result<()> {
3050        let schema = Schema::new(vec![
3051            Field::new("a", DataType::Boolean, true),
3052            Field::new("b", DataType::Boolean, true),
3053        ]);
3054        let a = Arc::new(BooleanArray::from(vec![
3055            Some(true),
3056            Some(false),
3057            None,
3058            Some(true),
3059            Some(false),
3060            None,
3061            Some(true),
3062            Some(false),
3063            None,
3064        ])) as ArrayRef;
3065        let b = Arc::new(BooleanArray::from(vec![
3066            Some(true),
3067            Some(true),
3068            Some(true),
3069            Some(false),
3070            Some(false),
3071            Some(false),
3072            None,
3073            None,
3074            None,
3075        ])) as ArrayRef;
3076
3077        let expected = BooleanArray::from(vec![
3078            Some(true),
3079            Some(true),
3080            Some(true),
3081            Some(true),
3082            Some(false),
3083            None,
3084            Some(true),
3085            None,
3086            None,
3087        ]);
3088        apply_logic_op(&Arc::new(schema), &a, &b, Operator::Or, expected)?;
3089
3090        Ok(())
3091    }
3092
3093    /// Returns (schema, a: BooleanArray, b: BooleanArray) with all possible inputs
3094    ///
3095    /// a: [true, true, true,  NULL, NULL, NULL,  false, false, false]
3096    /// b: [true, NULL, false, true, NULL, false, true,  NULL,  false]
3097    fn bool_test_arrays() -> (SchemaRef, ArrayRef, ArrayRef) {
3098        let schema = Schema::new(vec![
3099            Field::new("a", DataType::Boolean, true),
3100            Field::new("b", DataType::Boolean, true),
3101        ]);
3102        let a: BooleanArray = [
3103            Some(true),
3104            Some(true),
3105            Some(true),
3106            None,
3107            None,
3108            None,
3109            Some(false),
3110            Some(false),
3111            Some(false),
3112        ]
3113        .iter()
3114        .collect();
3115        let b: BooleanArray = [
3116            Some(true),
3117            None,
3118            Some(false),
3119            Some(true),
3120            None,
3121            Some(false),
3122            Some(true),
3123            None,
3124            Some(false),
3125        ]
3126        .iter()
3127        .collect();
3128        (Arc::new(schema), Arc::new(a), Arc::new(b))
3129    }
3130
3131    /// Returns (schema, BooleanArray) with [true, NULL, false]
3132    fn scalar_bool_test_array() -> (SchemaRef, ArrayRef) {
3133        let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
3134        let a: BooleanArray = [Some(true), None, Some(false)].iter().collect();
3135        (Arc::new(schema), Arc::new(a))
3136    }
3137
3138    #[test]
3139    fn eq_op_bool() {
3140        let (schema, a, b) = bool_test_arrays();
3141        let expected = [
3142            Some(true),
3143            None,
3144            Some(false),
3145            None,
3146            None,
3147            None,
3148            Some(false),
3149            None,
3150            Some(true),
3151        ]
3152        .iter()
3153        .collect();
3154        apply_logic_op(&schema, &a, &b, Operator::Eq, expected).unwrap();
3155    }
3156
3157    #[test]
3158    fn eq_op_bool_scalar() {
3159        let (schema, a) = scalar_bool_test_array();
3160        let expected = [Some(true), None, Some(false)].iter().collect();
3161        apply_logic_op_scalar_arr(
3162            &schema,
3163            &ScalarValue::from(true),
3164            &a,
3165            Operator::Eq,
3166            &expected,
3167        )
3168        .unwrap();
3169        apply_logic_op_arr_scalar(
3170            &schema,
3171            &a,
3172            &ScalarValue::from(true),
3173            Operator::Eq,
3174            &expected,
3175        )
3176        .unwrap();
3177
3178        let expected = [Some(false), None, Some(true)].iter().collect();
3179        apply_logic_op_scalar_arr(
3180            &schema,
3181            &ScalarValue::from(false),
3182            &a,
3183            Operator::Eq,
3184            &expected,
3185        )
3186        .unwrap();
3187        apply_logic_op_arr_scalar(
3188            &schema,
3189            &a,
3190            &ScalarValue::from(false),
3191            Operator::Eq,
3192            &expected,
3193        )
3194        .unwrap();
3195    }
3196
3197    #[test]
3198    fn neq_op_bool() {
3199        let (schema, a, b) = bool_test_arrays();
3200        let expected = [
3201            Some(false),
3202            None,
3203            Some(true),
3204            None,
3205            None,
3206            None,
3207            Some(true),
3208            None,
3209            Some(false),
3210        ]
3211        .iter()
3212        .collect();
3213        apply_logic_op(&schema, &a, &b, Operator::NotEq, expected).unwrap();
3214    }
3215
3216    #[test]
3217    fn neq_op_bool_scalar() {
3218        let (schema, a) = scalar_bool_test_array();
3219        let expected = [Some(false), None, Some(true)].iter().collect();
3220        apply_logic_op_scalar_arr(
3221            &schema,
3222            &ScalarValue::from(true),
3223            &a,
3224            Operator::NotEq,
3225            &expected,
3226        )
3227        .unwrap();
3228        apply_logic_op_arr_scalar(
3229            &schema,
3230            &a,
3231            &ScalarValue::from(true),
3232            Operator::NotEq,
3233            &expected,
3234        )
3235        .unwrap();
3236
3237        let expected = [Some(true), None, Some(false)].iter().collect();
3238        apply_logic_op_scalar_arr(
3239            &schema,
3240            &ScalarValue::from(false),
3241            &a,
3242            Operator::NotEq,
3243            &expected,
3244        )
3245        .unwrap();
3246        apply_logic_op_arr_scalar(
3247            &schema,
3248            &a,
3249            &ScalarValue::from(false),
3250            Operator::NotEq,
3251            &expected,
3252        )
3253        .unwrap();
3254    }
3255
3256    #[test]
3257    fn lt_op_bool() {
3258        let (schema, a, b) = bool_test_arrays();
3259        let expected = [
3260            Some(false),
3261            None,
3262            Some(false),
3263            None,
3264            None,
3265            None,
3266            Some(true),
3267            None,
3268            Some(false),
3269        ]
3270        .iter()
3271        .collect();
3272        apply_logic_op(&schema, &a, &b, Operator::Lt, expected).unwrap();
3273    }
3274
3275    #[test]
3276    fn lt_op_bool_scalar() {
3277        let (schema, a) = scalar_bool_test_array();
3278        let expected = [Some(false), None, Some(false)].iter().collect();
3279        apply_logic_op_scalar_arr(
3280            &schema,
3281            &ScalarValue::from(true),
3282            &a,
3283            Operator::Lt,
3284            &expected,
3285        )
3286        .unwrap();
3287
3288        let expected = [Some(false), None, Some(true)].iter().collect();
3289        apply_logic_op_arr_scalar(
3290            &schema,
3291            &a,
3292            &ScalarValue::from(true),
3293            Operator::Lt,
3294            &expected,
3295        )
3296        .unwrap();
3297
3298        let expected = [Some(true), None, Some(false)].iter().collect();
3299        apply_logic_op_scalar_arr(
3300            &schema,
3301            &ScalarValue::from(false),
3302            &a,
3303            Operator::Lt,
3304            &expected,
3305        )
3306        .unwrap();
3307
3308        let expected = [Some(false), None, Some(false)].iter().collect();
3309        apply_logic_op_arr_scalar(
3310            &schema,
3311            &a,
3312            &ScalarValue::from(false),
3313            Operator::Lt,
3314            &expected,
3315        )
3316        .unwrap();
3317    }
3318
3319    #[test]
3320    fn lt_eq_op_bool() {
3321        let (schema, a, b) = bool_test_arrays();
3322        let expected = [
3323            Some(true),
3324            None,
3325            Some(false),
3326            None,
3327            None,
3328            None,
3329            Some(true),
3330            None,
3331            Some(true),
3332        ]
3333        .iter()
3334        .collect();
3335        apply_logic_op(&schema, &a, &b, Operator::LtEq, expected).unwrap();
3336    }
3337
3338    #[test]
3339    fn lt_eq_op_bool_scalar() {
3340        let (schema, a) = scalar_bool_test_array();
3341        let expected = [Some(true), None, Some(false)].iter().collect();
3342        apply_logic_op_scalar_arr(
3343            &schema,
3344            &ScalarValue::from(true),
3345            &a,
3346            Operator::LtEq,
3347            &expected,
3348        )
3349        .unwrap();
3350
3351        let expected = [Some(true), None, Some(true)].iter().collect();
3352        apply_logic_op_arr_scalar(
3353            &schema,
3354            &a,
3355            &ScalarValue::from(true),
3356            Operator::LtEq,
3357            &expected,
3358        )
3359        .unwrap();
3360
3361        let expected = [Some(true), None, Some(true)].iter().collect();
3362        apply_logic_op_scalar_arr(
3363            &schema,
3364            &ScalarValue::from(false),
3365            &a,
3366            Operator::LtEq,
3367            &expected,
3368        )
3369        .unwrap();
3370
3371        let expected = [Some(false), None, Some(true)].iter().collect();
3372        apply_logic_op_arr_scalar(
3373            &schema,
3374            &a,
3375            &ScalarValue::from(false),
3376            Operator::LtEq,
3377            &expected,
3378        )
3379        .unwrap();
3380    }
3381
3382    #[test]
3383    fn gt_op_bool() {
3384        let (schema, a, b) = bool_test_arrays();
3385        let expected = [
3386            Some(false),
3387            None,
3388            Some(true),
3389            None,
3390            None,
3391            None,
3392            Some(false),
3393            None,
3394            Some(false),
3395        ]
3396        .iter()
3397        .collect();
3398        apply_logic_op(&schema, &a, &b, Operator::Gt, expected).unwrap();
3399    }
3400
3401    #[test]
3402    fn gt_op_bool_scalar() {
3403        let (schema, a) = scalar_bool_test_array();
3404        let expected = [Some(false), None, Some(true)].iter().collect();
3405        apply_logic_op_scalar_arr(
3406            &schema,
3407            &ScalarValue::from(true),
3408            &a,
3409            Operator::Gt,
3410            &expected,
3411        )
3412        .unwrap();
3413
3414        let expected = [Some(false), None, Some(false)].iter().collect();
3415        apply_logic_op_arr_scalar(
3416            &schema,
3417            &a,
3418            &ScalarValue::from(true),
3419            Operator::Gt,
3420            &expected,
3421        )
3422        .unwrap();
3423
3424        let expected = [Some(false), None, Some(false)].iter().collect();
3425        apply_logic_op_scalar_arr(
3426            &schema,
3427            &ScalarValue::from(false),
3428            &a,
3429            Operator::Gt,
3430            &expected,
3431        )
3432        .unwrap();
3433
3434        let expected = [Some(true), None, Some(false)].iter().collect();
3435        apply_logic_op_arr_scalar(
3436            &schema,
3437            &a,
3438            &ScalarValue::from(false),
3439            Operator::Gt,
3440            &expected,
3441        )
3442        .unwrap();
3443    }
3444
3445    #[test]
3446    fn gt_eq_op_bool() {
3447        let (schema, a, b) = bool_test_arrays();
3448        let expected = [
3449            Some(true),
3450            None,
3451            Some(true),
3452            None,
3453            None,
3454            None,
3455            Some(false),
3456            None,
3457            Some(true),
3458        ]
3459        .iter()
3460        .collect();
3461        apply_logic_op(&schema, &a, &b, Operator::GtEq, expected).unwrap();
3462    }
3463
3464    #[test]
3465    fn gt_eq_op_bool_scalar() {
3466        let (schema, a) = scalar_bool_test_array();
3467        let expected = [Some(true), None, Some(true)].iter().collect();
3468        apply_logic_op_scalar_arr(
3469            &schema,
3470            &ScalarValue::from(true),
3471            &a,
3472            Operator::GtEq,
3473            &expected,
3474        )
3475        .unwrap();
3476
3477        let expected = [Some(true), None, Some(false)].iter().collect();
3478        apply_logic_op_arr_scalar(
3479            &schema,
3480            &a,
3481            &ScalarValue::from(true),
3482            Operator::GtEq,
3483            &expected,
3484        )
3485        .unwrap();
3486
3487        let expected = [Some(false), None, Some(true)].iter().collect();
3488        apply_logic_op_scalar_arr(
3489            &schema,
3490            &ScalarValue::from(false),
3491            &a,
3492            Operator::GtEq,
3493            &expected,
3494        )
3495        .unwrap();
3496
3497        let expected = [Some(true), None, Some(true)].iter().collect();
3498        apply_logic_op_arr_scalar(
3499            &schema,
3500            &a,
3501            &ScalarValue::from(false),
3502            Operator::GtEq,
3503            &expected,
3504        )
3505        .unwrap();
3506    }
3507
3508    #[test]
3509    fn is_distinct_from_op_bool() {
3510        let (schema, a, b) = bool_test_arrays();
3511        let expected = [
3512            Some(false),
3513            Some(true),
3514            Some(true),
3515            Some(true),
3516            Some(false),
3517            Some(true),
3518            Some(true),
3519            Some(true),
3520            Some(false),
3521        ]
3522        .iter()
3523        .collect();
3524        apply_logic_op(&schema, &a, &b, Operator::IsDistinctFrom, expected).unwrap();
3525    }
3526
3527    #[test]
3528    fn is_not_distinct_from_op_bool() {
3529        let (schema, a, b) = bool_test_arrays();
3530        let expected = [
3531            Some(true),
3532            Some(false),
3533            Some(false),
3534            Some(false),
3535            Some(true),
3536            Some(false),
3537            Some(false),
3538            Some(false),
3539            Some(true),
3540        ]
3541        .iter()
3542        .collect();
3543        apply_logic_op(&schema, &a, &b, Operator::IsNotDistinctFrom, expected).unwrap();
3544    }
3545
3546    #[test]
3547    fn relatively_deeply_nested() {
3548        // Reproducer for https://github.com/apache/datafusion/issues/419
3549
3550        // where even relatively shallow binary expressions overflowed
3551        // the stack in debug builds
3552
3553        let input: Vec<_> = vec![1, 2, 3, 4, 5].into_iter().map(Some).collect();
3554        let a: Int32Array = input.iter().collect();
3555
3556        let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(a) as _)]).unwrap();
3557        let schema = batch.schema();
3558
3559        // build a left deep tree ((((a + a) + a) + a ....
3560        let tree_depth: i32 = 100;
3561        let expr = (0..tree_depth)
3562            .map(|_| col("a", schema.as_ref()).unwrap())
3563            .reduce(|l, r| binary(l, Operator::Plus, r, &schema).unwrap())
3564            .unwrap();
3565
3566        let result = expr
3567            .evaluate(&batch)
3568            .expect("evaluation")
3569            .into_array(batch.num_rows())
3570            .expect("Failed to convert to array");
3571
3572        let expected: Int32Array = input
3573            .into_iter()
3574            .map(|i| i.map(|i| i * tree_depth))
3575            .collect();
3576        assert_eq!(result.as_ref(), &expected);
3577    }
3578
3579    fn create_decimal_array(
3580        array: &[Option<i128>],
3581        precision: u8,
3582        scale: i8,
3583    ) -> Decimal128Array {
3584        let mut decimal_builder = Decimal128Builder::with_capacity(array.len());
3585        for value in array.iter().copied() {
3586            decimal_builder.append_option(value)
3587        }
3588        decimal_builder
3589            .finish()
3590            .with_precision_and_scale(precision, scale)
3591            .unwrap()
3592    }
3593
3594    #[test]
3595    fn comparison_dict_decimal_scalar_expr_test() -> Result<()> {
3596        // scalar of decimal compare with dictionary decimal array
3597        let value_i128 = 123;
3598        let decimal_scalar = ScalarValue::Dictionary(
3599            Box::new(DataType::Int8),
3600            Box::new(ScalarValue::Decimal128(Some(value_i128), 25, 3)),
3601        );
3602        let schema = Arc::new(Schema::new(vec![Field::new(
3603            "a",
3604            DataType::Dictionary(
3605                Box::new(DataType::Int8),
3606                Box::new(DataType::Decimal128(25, 3)),
3607            ),
3608            true,
3609        )]));
3610        let decimal_array = Arc::new(create_decimal_array(
3611            &[
3612                Some(value_i128),
3613                None,
3614                Some(value_i128 - 1),
3615                Some(value_i128 + 1),
3616            ],
3617            25,
3618            3,
3619        ));
3620
3621        let keys = Int8Array::from(vec![Some(0), None, Some(2), Some(3)]);
3622        let dictionary =
3623            Arc::new(DictionaryArray::try_new(keys, decimal_array)?) as ArrayRef;
3624
3625        // array = scalar
3626        apply_logic_op_arr_scalar(
3627            &schema,
3628            &dictionary,
3629            &decimal_scalar,
3630            Operator::Eq,
3631            &BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
3632        )
3633        .unwrap();
3634        // array != scalar
3635        apply_logic_op_arr_scalar(
3636            &schema,
3637            &dictionary,
3638            &decimal_scalar,
3639            Operator::NotEq,
3640            &BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
3641        )
3642        .unwrap();
3643        //  array < scalar
3644        apply_logic_op_arr_scalar(
3645            &schema,
3646            &dictionary,
3647            &decimal_scalar,
3648            Operator::Lt,
3649            &BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3650        )
3651        .unwrap();
3652
3653        //  array <= scalar
3654        apply_logic_op_arr_scalar(
3655            &schema,
3656            &dictionary,
3657            &decimal_scalar,
3658            Operator::LtEq,
3659            &BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
3660        )
3661        .unwrap();
3662        // array > scalar
3663        apply_logic_op_arr_scalar(
3664            &schema,
3665            &dictionary,
3666            &decimal_scalar,
3667            Operator::Gt,
3668            &BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
3669        )
3670        .unwrap();
3671
3672        // array >= scalar
3673        apply_logic_op_arr_scalar(
3674            &schema,
3675            &dictionary,
3676            &decimal_scalar,
3677            Operator::GtEq,
3678            &BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3679        )
3680        .unwrap();
3681
3682        Ok(())
3683    }
3684
3685    #[test]
3686    fn comparison_decimal_expr_test() -> Result<()> {
3687        // scalar of decimal compare with decimal array
3688        let value_i128 = 123;
3689        let decimal_scalar = ScalarValue::Decimal128(Some(value_i128), 25, 3);
3690        let schema = Arc::new(Schema::new(vec![Field::new(
3691            "a",
3692            DataType::Decimal128(25, 3),
3693            true,
3694        )]));
3695        let decimal_array = Arc::new(create_decimal_array(
3696            &[
3697                Some(value_i128),
3698                None,
3699                Some(value_i128 - 1),
3700                Some(value_i128 + 1),
3701            ],
3702            25,
3703            3,
3704        )) as ArrayRef;
3705        // array = scalar
3706        apply_logic_op_arr_scalar(
3707            &schema,
3708            &decimal_array,
3709            &decimal_scalar,
3710            Operator::Eq,
3711            &BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
3712        )
3713        .unwrap();
3714        // array != scalar
3715        apply_logic_op_arr_scalar(
3716            &schema,
3717            &decimal_array,
3718            &decimal_scalar,
3719            Operator::NotEq,
3720            &BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
3721        )
3722        .unwrap();
3723        //  array < scalar
3724        apply_logic_op_arr_scalar(
3725            &schema,
3726            &decimal_array,
3727            &decimal_scalar,
3728            Operator::Lt,
3729            &BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3730        )
3731        .unwrap();
3732
3733        //  array <= scalar
3734        apply_logic_op_arr_scalar(
3735            &schema,
3736            &decimal_array,
3737            &decimal_scalar,
3738            Operator::LtEq,
3739            &BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
3740        )
3741        .unwrap();
3742        // array > scalar
3743        apply_logic_op_arr_scalar(
3744            &schema,
3745            &decimal_array,
3746            &decimal_scalar,
3747            Operator::Gt,
3748            &BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
3749        )
3750        .unwrap();
3751
3752        // array >= scalar
3753        apply_logic_op_arr_scalar(
3754            &schema,
3755            &decimal_array,
3756            &decimal_scalar,
3757            Operator::GtEq,
3758            &BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3759        )
3760        .unwrap();
3761
3762        // scalar of different data type with decimal array
3763        let decimal_scalar = ScalarValue::Decimal128(Some(123_456), 10, 3);
3764        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)]));
3765        // scalar == array
3766        apply_logic_op_scalar_arr(
3767            &schema,
3768            &decimal_scalar,
3769            &(Arc::new(Int64Array::from(vec![Some(124), None])) as ArrayRef),
3770            Operator::Eq,
3771            &BooleanArray::from(vec![Some(false), None]),
3772        )
3773        .unwrap();
3774
3775        // array != scalar
3776        apply_logic_op_arr_scalar(
3777            &schema,
3778            &(Arc::new(Int64Array::from(vec![Some(123), None, Some(1)])) as ArrayRef),
3779            &decimal_scalar,
3780            Operator::NotEq,
3781            &BooleanArray::from(vec![Some(true), None, Some(true)]),
3782        )
3783        .unwrap();
3784
3785        // array < scalar
3786        apply_logic_op_arr_scalar(
3787            &schema,
3788            &(Arc::new(Int64Array::from(vec![Some(123), None, Some(124)])) as ArrayRef),
3789            &decimal_scalar,
3790            Operator::Lt,
3791            &BooleanArray::from(vec![Some(true), None, Some(false)]),
3792        )
3793        .unwrap();
3794
3795        // array > scalar
3796        apply_logic_op_arr_scalar(
3797            &schema,
3798            &(Arc::new(Int64Array::from(vec![Some(123), None, Some(124)])) as ArrayRef),
3799            &decimal_scalar,
3800            Operator::Gt,
3801            &BooleanArray::from(vec![Some(false), None, Some(true)]),
3802        )
3803        .unwrap();
3804
3805        let schema =
3806            Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)]));
3807        // array == scalar
3808        apply_logic_op_arr_scalar(
3809            &schema,
3810            &(Arc::new(Float64Array::from(vec![Some(123.456), None, Some(123.457)]))
3811                as ArrayRef),
3812            &decimal_scalar,
3813            Operator::Eq,
3814            &BooleanArray::from(vec![Some(true), None, Some(false)]),
3815        )
3816        .unwrap();
3817
3818        // array <= scalar
3819        apply_logic_op_arr_scalar(
3820            &schema,
3821            &(Arc::new(Float64Array::from(vec![
3822                Some(123.456),
3823                None,
3824                Some(123.457),
3825                Some(123.45),
3826            ])) as ArrayRef),
3827            &decimal_scalar,
3828            Operator::LtEq,
3829            &BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3830        )
3831        .unwrap();
3832        // array >= scalar
3833        apply_logic_op_arr_scalar(
3834            &schema,
3835            &(Arc::new(Float64Array::from(vec![
3836                Some(123.456),
3837                None,
3838                Some(123.457),
3839                Some(123.45),
3840            ])) as ArrayRef),
3841            &decimal_scalar,
3842            Operator::GtEq,
3843            &BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
3844        )
3845        .unwrap();
3846
3847        let value: i128 = 123;
3848        let decimal_array = Arc::new(create_decimal_array(
3849            &[Some(value), None, Some(value - 1), Some(value + 1)],
3850            10,
3851            0,
3852        )) as ArrayRef;
3853
3854        // comparison array op for decimal array
3855        let schema = Arc::new(Schema::new(vec![
3856            Field::new("a", DataType::Decimal128(10, 0), true),
3857            Field::new("b", DataType::Decimal128(10, 0), true),
3858        ]));
3859        let right_decimal_array = Arc::new(create_decimal_array(
3860            &[
3861                Some(value - 1),
3862                Some(value),
3863                Some(value + 1),
3864                Some(value + 1),
3865            ],
3866            10,
3867            0,
3868        )) as ArrayRef;
3869
3870        apply_logic_op(
3871            &schema,
3872            &decimal_array,
3873            &right_decimal_array,
3874            Operator::Eq,
3875            BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
3876        )
3877        .unwrap();
3878
3879        apply_logic_op(
3880            &schema,
3881            &decimal_array,
3882            &right_decimal_array,
3883            Operator::NotEq,
3884            BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
3885        )
3886        .unwrap();
3887
3888        apply_logic_op(
3889            &schema,
3890            &decimal_array,
3891            &right_decimal_array,
3892            Operator::Lt,
3893            BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3894        )
3895        .unwrap();
3896
3897        apply_logic_op(
3898            &schema,
3899            &decimal_array,
3900            &right_decimal_array,
3901            Operator::LtEq,
3902            BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
3903        )
3904        .unwrap();
3905
3906        apply_logic_op(
3907            &schema,
3908            &decimal_array,
3909            &right_decimal_array,
3910            Operator::Gt,
3911            BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
3912        )
3913        .unwrap();
3914
3915        apply_logic_op(
3916            &schema,
3917            &decimal_array,
3918            &right_decimal_array,
3919            Operator::GtEq,
3920            BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3921        )
3922        .unwrap();
3923
3924        // compare decimal array with other array type
3925        let value: i64 = 123;
3926        let schema = Arc::new(Schema::new(vec![
3927            Field::new("a", DataType::Int64, true),
3928            Field::new("b", DataType::Decimal128(10, 0), true),
3929        ]));
3930
3931        let int64_array = Arc::new(Int64Array::from(vec![
3932            Some(value),
3933            Some(value - 1),
3934            Some(value),
3935            Some(value + 1),
3936        ])) as ArrayRef;
3937
3938        // eq: int64array == decimal array
3939        apply_logic_op(
3940            &schema,
3941            &int64_array,
3942            &decimal_array,
3943            Operator::Eq,
3944            BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3945        )
3946        .unwrap();
3947        // neq: int64array != decimal array
3948        apply_logic_op(
3949            &schema,
3950            &int64_array,
3951            &decimal_array,
3952            Operator::NotEq,
3953            BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3954        )
3955        .unwrap();
3956
3957        let schema = Arc::new(Schema::new(vec![
3958            Field::new("a", DataType::Float64, true),
3959            Field::new("b", DataType::Decimal128(10, 2), true),
3960        ]));
3961
3962        let value: i128 = 123;
3963        let decimal_array = Arc::new(create_decimal_array(
3964            &[
3965                Some(value), // 1.23
3966                None,
3967                Some(value - 1), // 1.22
3968                Some(value + 1), // 1.24
3969            ],
3970            10,
3971            2,
3972        )) as ArrayRef;
3973        let float64_array = Arc::new(Float64Array::from(vec![
3974            Some(1.23),
3975            Some(1.22),
3976            Some(1.23),
3977            Some(1.24),
3978        ])) as ArrayRef;
3979        // lt: float64array < decimal array
3980        apply_logic_op(
3981            &schema,
3982            &float64_array,
3983            &decimal_array,
3984            Operator::Lt,
3985            BooleanArray::from(vec![Some(false), None, Some(false), Some(false)]),
3986        )
3987        .unwrap();
3988        // lt_eq: float64array <= decimal array
3989        apply_logic_op(
3990            &schema,
3991            &float64_array,
3992            &decimal_array,
3993            Operator::LtEq,
3994            BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3995        )
3996        .unwrap();
3997        // gt: float64array > decimal array
3998        apply_logic_op(
3999            &schema,
4000            &float64_array,
4001            &decimal_array,
4002            Operator::Gt,
4003            BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
4004        )
4005        .unwrap();
4006        apply_logic_op(
4007            &schema,
4008            &float64_array,
4009            &decimal_array,
4010            Operator::GtEq,
4011            BooleanArray::from(vec![Some(true), None, Some(true), Some(true)]),
4012        )
4013        .unwrap();
4014        // is distinct: float64array is distinct decimal array
4015        // TODO: now we do not refactor the `is distinct or is not distinct` rule of coercion.
4016        // traced by https://github.com/apache/datafusion/issues/1590
4017        // the decimal array will be casted to float64array
4018        apply_logic_op(
4019            &schema,
4020            &float64_array,
4021            &decimal_array,
4022            Operator::IsDistinctFrom,
4023            BooleanArray::from(vec![Some(false), Some(true), Some(true), Some(false)]),
4024        )
4025        .unwrap();
4026        // is not distinct
4027        apply_logic_op(
4028            &schema,
4029            &float64_array,
4030            &decimal_array,
4031            Operator::IsNotDistinctFrom,
4032            BooleanArray::from(vec![Some(true), Some(false), Some(false), Some(true)]),
4033        )
4034        .unwrap();
4035
4036        Ok(())
4037    }
4038
4039    fn apply_decimal_arithmetic_op(
4040        schema: &SchemaRef,
4041        left: &ArrayRef,
4042        right: &ArrayRef,
4043        op: Operator,
4044        expected: ArrayRef,
4045    ) -> Result<()> {
4046        let arithmetic_op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?;
4047        let data: Vec<ArrayRef> = vec![Arc::clone(left), Arc::clone(right)];
4048        let batch = RecordBatch::try_new(Arc::clone(schema), data)?;
4049        let result = arithmetic_op
4050            .evaluate(&batch)?
4051            .into_array(batch.num_rows())
4052            .expect("Failed to convert to array");
4053
4054        assert_eq!(result.as_ref(), expected.as_ref());
4055        Ok(())
4056    }
4057
4058    #[test]
4059    fn arithmetic_decimal_expr_test() -> Result<()> {
4060        let schema = Arc::new(Schema::new(vec![
4061            Field::new("a", DataType::Int32, true),
4062            Field::new("b", DataType::Decimal128(10, 2), true),
4063        ]));
4064        let value: i128 = 123;
4065        let decimal_array = Arc::new(create_decimal_array(
4066            &[
4067                Some(value), // 1.23
4068                None,
4069                Some(value - 1), // 1.22
4070                Some(value + 1), // 1.24
4071            ],
4072            10,
4073            2,
4074        )) as ArrayRef;
4075        let int32_array = Arc::new(Int32Array::from(vec![
4076            Some(123),
4077            Some(122),
4078            Some(123),
4079            Some(124),
4080        ])) as ArrayRef;
4081
4082        // add: Int32array add decimal array
4083        let expect = Arc::new(create_decimal_array(
4084            &[Some(12423), None, Some(12422), Some(12524)],
4085            13,
4086            2,
4087        )) as ArrayRef;
4088        apply_decimal_arithmetic_op(
4089            &schema,
4090            &int32_array,
4091            &decimal_array,
4092            Operator::Plus,
4093            expect,
4094        )
4095        .unwrap();
4096
4097        // subtract: decimal array subtract int32 array
4098        let schema = Arc::new(Schema::new(vec![
4099            Field::new("a", DataType::Decimal128(10, 2), true),
4100            Field::new("b", DataType::Int32, true),
4101        ]));
4102        let expect = Arc::new(create_decimal_array(
4103            &[Some(-12177), None, Some(-12178), Some(-12276)],
4104            13,
4105            2,
4106        )) as ArrayRef;
4107        apply_decimal_arithmetic_op(
4108            &schema,
4109            &decimal_array,
4110            &int32_array,
4111            Operator::Minus,
4112            expect,
4113        )
4114        .unwrap();
4115
4116        // multiply: decimal array multiply int32 array
4117        let expect = Arc::new(create_decimal_array(
4118            &[Some(15129), None, Some(15006), Some(15376)],
4119            21,
4120            2,
4121        )) as ArrayRef;
4122        apply_decimal_arithmetic_op(
4123            &schema,
4124            &decimal_array,
4125            &int32_array,
4126            Operator::Multiply,
4127            expect,
4128        )
4129        .unwrap();
4130
4131        // divide: int32 array divide decimal array
4132        let schema = Arc::new(Schema::new(vec![
4133            Field::new("a", DataType::Int32, true),
4134            Field::new("b", DataType::Decimal128(10, 2), true),
4135        ]));
4136        let expect = Arc::new(create_decimal_array(
4137            &[Some(1000000), None, Some(1008196), Some(1000000)],
4138            16,
4139            4,
4140        )) as ArrayRef;
4141        apply_decimal_arithmetic_op(
4142            &schema,
4143            &int32_array,
4144            &decimal_array,
4145            Operator::Divide,
4146            expect,
4147        )
4148        .unwrap();
4149
4150        // modulus: int32 array modulus decimal array
4151        let schema = Arc::new(Schema::new(vec![
4152            Field::new("a", DataType::Int32, true),
4153            Field::new("b", DataType::Decimal128(10, 2), true),
4154        ]));
4155        let expect = Arc::new(create_decimal_array(
4156            &[Some(000), None, Some(100), Some(000)],
4157            10,
4158            2,
4159        )) as ArrayRef;
4160        apply_decimal_arithmetic_op(
4161            &schema,
4162            &int32_array,
4163            &decimal_array,
4164            Operator::Modulo,
4165            expect,
4166        )
4167        .unwrap();
4168
4169        Ok(())
4170    }
4171
4172    #[test]
4173    fn arithmetic_decimal_float_expr_test() -> Result<()> {
4174        let schema = Arc::new(Schema::new(vec![
4175            Field::new("a", DataType::Float64, true),
4176            Field::new("b", DataType::Decimal128(10, 2), true),
4177        ]));
4178        let value: i128 = 123;
4179        let decimal_array = Arc::new(create_decimal_array(
4180            &[
4181                Some(value), // 1.23
4182                None,
4183                Some(value - 1), // 1.22
4184                Some(value + 1), // 1.24
4185            ],
4186            10,
4187            2,
4188        )) as ArrayRef;
4189        let float64_array = Arc::new(Float64Array::from(vec![
4190            Some(123.0),
4191            Some(122.0),
4192            Some(123.0),
4193            Some(124.0),
4194        ])) as ArrayRef;
4195
4196        // add: float64 array add decimal array
4197        let expect = Arc::new(Float64Array::from(vec![
4198            Some(124.23),
4199            None,
4200            Some(124.22),
4201            Some(125.24),
4202        ])) as ArrayRef;
4203        apply_decimal_arithmetic_op(
4204            &schema,
4205            &float64_array,
4206            &decimal_array,
4207            Operator::Plus,
4208            expect,
4209        )
4210        .unwrap();
4211
4212        // subtract: decimal array subtract float64 array
4213        let schema = Arc::new(Schema::new(vec![
4214            Field::new("a", DataType::Float64, true),
4215            Field::new("b", DataType::Decimal128(10, 2), true),
4216        ]));
4217        let expect = Arc::new(Float64Array::from(vec![
4218            Some(121.77),
4219            None,
4220            Some(121.78),
4221            Some(122.76),
4222        ])) as ArrayRef;
4223        apply_decimal_arithmetic_op(
4224            &schema,
4225            &float64_array,
4226            &decimal_array,
4227            Operator::Minus,
4228            expect,
4229        )
4230        .unwrap();
4231
4232        // multiply: decimal array multiply float64 array
4233        let expect = Arc::new(Float64Array::from(vec![
4234            Some(151.29),
4235            None,
4236            Some(150.06),
4237            Some(153.76),
4238        ])) as ArrayRef;
4239        apply_decimal_arithmetic_op(
4240            &schema,
4241            &float64_array,
4242            &decimal_array,
4243            Operator::Multiply,
4244            expect,
4245        )
4246        .unwrap();
4247
4248        // divide: float64 array divide decimal array
4249        let schema = Arc::new(Schema::new(vec![
4250            Field::new("a", DataType::Float64, true),
4251            Field::new("b", DataType::Decimal128(10, 2), true),
4252        ]));
4253        let expect = Arc::new(Float64Array::from(vec![
4254            Some(100.0),
4255            None,
4256            Some(100.81967213114754),
4257            Some(100.0),
4258        ])) as ArrayRef;
4259        apply_decimal_arithmetic_op(
4260            &schema,
4261            &float64_array,
4262            &decimal_array,
4263            Operator::Divide,
4264            expect,
4265        )
4266        .unwrap();
4267
4268        // modulus: float64 array modulus decimal array
4269        let schema = Arc::new(Schema::new(vec![
4270            Field::new("a", DataType::Float64, true),
4271            Field::new("b", DataType::Decimal128(10, 2), true),
4272        ]));
4273        let expect = Arc::new(Float64Array::from(vec![
4274            Some(1.7763568394002505e-15),
4275            None,
4276            Some(1.0000000000000027),
4277            Some(8.881784197001252e-16),
4278        ])) as ArrayRef;
4279        apply_decimal_arithmetic_op(
4280            &schema,
4281            &float64_array,
4282            &decimal_array,
4283            Operator::Modulo,
4284            expect,
4285        )
4286        .unwrap();
4287
4288        Ok(())
4289    }
4290
4291    #[test]
4292    fn arithmetic_divide_zero() -> Result<()> {
4293        // other data type
4294        let schema = Arc::new(Schema::new(vec![
4295            Field::new("a", DataType::Int32, true),
4296            Field::new("b", DataType::Int32, true),
4297        ]));
4298        let a = Arc::new(Int32Array::from(vec![100]));
4299        let b = Arc::new(Int32Array::from(vec![0]));
4300
4301        let err = apply_arithmetic::<Int32Type>(
4302            schema,
4303            vec![a, b],
4304            Operator::Divide,
4305            Int32Array::from(vec![Some(4), Some(8), Some(16), Some(32), Some(64)]),
4306        )
4307        .unwrap_err();
4308
4309        let _expected = plan_datafusion_err!("Divide by zero");
4310
4311        assert!(matches!(err, ref _expected), "{err}");
4312
4313        // decimal
4314        let schema = Arc::new(Schema::new(vec![
4315            Field::new("a", DataType::Decimal128(25, 3), true),
4316            Field::new("b", DataType::Decimal128(25, 3), true),
4317        ]));
4318        let left_decimal_array = Arc::new(create_decimal_array(&[Some(1234567)], 25, 3));
4319        let right_decimal_array = Arc::new(create_decimal_array(&[Some(0)], 25, 3));
4320
4321        let err = apply_arithmetic::<Decimal128Type>(
4322            schema,
4323            vec![left_decimal_array, right_decimal_array],
4324            Operator::Divide,
4325            create_decimal_array(
4326                &[Some(12345670000000000000000000000000000), None],
4327                38,
4328                29,
4329            ),
4330        )
4331        .unwrap_err();
4332
4333        assert!(matches!(err, ref _expected), "{err}");
4334
4335        Ok(())
4336    }
4337
4338    #[test]
4339    fn bitwise_array_test() -> Result<()> {
4340        let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
4341        let right =
4342            Arc::new(Int32Array::from(vec![Some(1), Some(3), Some(7)])) as ArrayRef;
4343        let mut result = bitwise_and_dyn(Arc::clone(&left), Arc::clone(&right))?;
4344        let expected = Int32Array::from(vec![Some(0), None, Some(3)]);
4345        assert_eq!(result.as_ref(), &expected);
4346
4347        result = bitwise_or_dyn(Arc::clone(&left), Arc::clone(&right))?;
4348        let expected = Int32Array::from(vec![Some(13), None, Some(15)]);
4349        assert_eq!(result.as_ref(), &expected);
4350
4351        result = bitwise_xor_dyn(Arc::clone(&left), Arc::clone(&right))?;
4352        let expected = Int32Array::from(vec![Some(13), None, Some(12)]);
4353        assert_eq!(result.as_ref(), &expected);
4354
4355        let left =
4356            Arc::new(UInt32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
4357        let right =
4358            Arc::new(UInt32Array::from(vec![Some(1), Some(3), Some(7)])) as ArrayRef;
4359        let mut result = bitwise_and_dyn(Arc::clone(&left), Arc::clone(&right))?;
4360        let expected = UInt32Array::from(vec![Some(0), None, Some(3)]);
4361        assert_eq!(result.as_ref(), &expected);
4362
4363        result = bitwise_or_dyn(Arc::clone(&left), Arc::clone(&right))?;
4364        let expected = UInt32Array::from(vec![Some(13), None, Some(15)]);
4365        assert_eq!(result.as_ref(), &expected);
4366
4367        result = bitwise_xor_dyn(Arc::clone(&left), Arc::clone(&right))?;
4368        let expected = UInt32Array::from(vec![Some(13), None, Some(12)]);
4369        assert_eq!(result.as_ref(), &expected);
4370
4371        Ok(())
4372    }
4373
4374    #[test]
4375    fn bitwise_shift_array_test() -> Result<()> {
4376        let input = Arc::new(Int32Array::from(vec![Some(2), None, Some(10)])) as ArrayRef;
4377        let modules =
4378            Arc::new(Int32Array::from(vec![Some(2), Some(4), Some(8)])) as ArrayRef;
4379        let mut result =
4380            bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?;
4381
4382        let expected = Int32Array::from(vec![Some(8), None, Some(2560)]);
4383        assert_eq!(result.as_ref(), &expected);
4384
4385        result = bitwise_shift_right_dyn(Arc::clone(&result), Arc::clone(&modules))?;
4386        assert_eq!(result.as_ref(), &input);
4387
4388        let input =
4389            Arc::new(UInt32Array::from(vec![Some(2), None, Some(10)])) as ArrayRef;
4390        let modules =
4391            Arc::new(UInt32Array::from(vec![Some(2), Some(4), Some(8)])) as ArrayRef;
4392        let mut result =
4393            bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?;
4394
4395        let expected = UInt32Array::from(vec![Some(8), None, Some(2560)]);
4396        assert_eq!(result.as_ref(), &expected);
4397
4398        result = bitwise_shift_right_dyn(Arc::clone(&result), Arc::clone(&modules))?;
4399        assert_eq!(result.as_ref(), &input);
4400        Ok(())
4401    }
4402
4403    #[test]
4404    fn bitwise_shift_array_overflow_test() -> Result<()> {
4405        let input = Arc::new(Int32Array::from(vec![Some(2)])) as ArrayRef;
4406        let modules = Arc::new(Int32Array::from(vec![Some(100)])) as ArrayRef;
4407        let result = bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?;
4408
4409        let expected = Int32Array::from(vec![Some(32)]);
4410        assert_eq!(result.as_ref(), &expected);
4411
4412        let input = Arc::new(UInt32Array::from(vec![Some(2)])) as ArrayRef;
4413        let modules = Arc::new(UInt32Array::from(vec![Some(100)])) as ArrayRef;
4414        let result = bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?;
4415
4416        let expected = UInt32Array::from(vec![Some(32)]);
4417        assert_eq!(result.as_ref(), &expected);
4418        Ok(())
4419    }
4420
4421    #[test]
4422    fn bitwise_scalar_test() -> Result<()> {
4423        let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
4424        let right = ScalarValue::from(3i32);
4425        let mut result = bitwise_and_dyn_scalar(&left, right.clone()).unwrap()?;
4426        let expected = Int32Array::from(vec![Some(0), None, Some(3)]);
4427        assert_eq!(result.as_ref(), &expected);
4428
4429        result = bitwise_or_dyn_scalar(&left, right.clone()).unwrap()?;
4430        let expected = Int32Array::from(vec![Some(15), None, Some(11)]);
4431        assert_eq!(result.as_ref(), &expected);
4432
4433        result = bitwise_xor_dyn_scalar(&left, right).unwrap()?;
4434        let expected = Int32Array::from(vec![Some(15), None, Some(8)]);
4435        assert_eq!(result.as_ref(), &expected);
4436
4437        let left =
4438            Arc::new(UInt32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
4439        let right = ScalarValue::from(3u32);
4440        let mut result = bitwise_and_dyn_scalar(&left, right.clone()).unwrap()?;
4441        let expected = UInt32Array::from(vec![Some(0), None, Some(3)]);
4442        assert_eq!(result.as_ref(), &expected);
4443
4444        result = bitwise_or_dyn_scalar(&left, right.clone()).unwrap()?;
4445        let expected = UInt32Array::from(vec![Some(15), None, Some(11)]);
4446        assert_eq!(result.as_ref(), &expected);
4447
4448        result = bitwise_xor_dyn_scalar(&left, right).unwrap()?;
4449        let expected = UInt32Array::from(vec![Some(15), None, Some(8)]);
4450        assert_eq!(result.as_ref(), &expected);
4451        Ok(())
4452    }
4453
4454    #[test]
4455    fn bitwise_shift_scalar_test() -> Result<()> {
4456        let input = Arc::new(Int32Array::from(vec![Some(2), None, Some(4)])) as ArrayRef;
4457        let module = ScalarValue::from(10i32);
4458        let mut result =
4459            bitwise_shift_left_dyn_scalar(&input, module.clone()).unwrap()?;
4460
4461        let expected = Int32Array::from(vec![Some(2048), None, Some(4096)]);
4462        assert_eq!(result.as_ref(), &expected);
4463
4464        result = bitwise_shift_right_dyn_scalar(&result, module).unwrap()?;
4465        assert_eq!(result.as_ref(), &input);
4466
4467        let input = Arc::new(UInt32Array::from(vec![Some(2), None, Some(4)])) as ArrayRef;
4468        let module = ScalarValue::from(10u32);
4469        let mut result =
4470            bitwise_shift_left_dyn_scalar(&input, module.clone()).unwrap()?;
4471
4472        let expected = UInt32Array::from(vec![Some(2048), None, Some(4096)]);
4473        assert_eq!(result.as_ref(), &expected);
4474
4475        result = bitwise_shift_right_dyn_scalar(&result, module).unwrap()?;
4476        assert_eq!(result.as_ref(), &input);
4477        Ok(())
4478    }
4479
4480    #[test]
4481    fn test_display_and_or_combo() {
4482        let expr = BinaryExpr::new(
4483            Arc::new(BinaryExpr::new(
4484                lit(ScalarValue::from(1)),
4485                Operator::And,
4486                lit(ScalarValue::from(2)),
4487            )),
4488            Operator::And,
4489            Arc::new(BinaryExpr::new(
4490                lit(ScalarValue::from(3)),
4491                Operator::And,
4492                lit(ScalarValue::from(4)),
4493            )),
4494        );
4495        assert_eq!(expr.to_string(), "1 AND 2 AND 3 AND 4");
4496
4497        let expr = BinaryExpr::new(
4498            Arc::new(BinaryExpr::new(
4499                lit(ScalarValue::from(1)),
4500                Operator::Or,
4501                lit(ScalarValue::from(2)),
4502            )),
4503            Operator::Or,
4504            Arc::new(BinaryExpr::new(
4505                lit(ScalarValue::from(3)),
4506                Operator::Or,
4507                lit(ScalarValue::from(4)),
4508            )),
4509        );
4510        assert_eq!(expr.to_string(), "1 OR 2 OR 3 OR 4");
4511
4512        let expr = BinaryExpr::new(
4513            Arc::new(BinaryExpr::new(
4514                lit(ScalarValue::from(1)),
4515                Operator::And,
4516                lit(ScalarValue::from(2)),
4517            )),
4518            Operator::Or,
4519            Arc::new(BinaryExpr::new(
4520                lit(ScalarValue::from(3)),
4521                Operator::And,
4522                lit(ScalarValue::from(4)),
4523            )),
4524        );
4525        assert_eq!(expr.to_string(), "1 AND 2 OR 3 AND 4");
4526
4527        let expr = BinaryExpr::new(
4528            Arc::new(BinaryExpr::new(
4529                lit(ScalarValue::from(1)),
4530                Operator::Or,
4531                lit(ScalarValue::from(2)),
4532            )),
4533            Operator::And,
4534            Arc::new(BinaryExpr::new(
4535                lit(ScalarValue::from(3)),
4536                Operator::Or,
4537                lit(ScalarValue::from(4)),
4538            )),
4539        );
4540        assert_eq!(expr.to_string(), "(1 OR 2) AND (3 OR 4)");
4541    }
4542
4543    #[test]
4544    fn test_to_result_type_array() {
4545        let values = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
4546        let keys = Int8Array::from(vec![Some(0), None, Some(2), Some(3)]);
4547        let dictionary =
4548            Arc::new(DictionaryArray::try_new(keys, values).unwrap()) as ArrayRef;
4549
4550        // Casting Dictionary to Int32
4551        let casted = to_result_type_array(
4552            &Operator::Plus,
4553            Arc::clone(&dictionary),
4554            &DataType::Int32,
4555        )
4556        .unwrap();
4557        assert_eq!(
4558            &casted,
4559            &(Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(4)]))
4560                as ArrayRef)
4561        );
4562
4563        // Array has same datatype as result type, no casting
4564        let casted = to_result_type_array(
4565            &Operator::Plus,
4566            Arc::clone(&dictionary),
4567            dictionary.data_type(),
4568        )
4569        .unwrap();
4570        assert_eq!(&casted, &dictionary);
4571
4572        // Not numerical operator, no casting
4573        let casted = to_result_type_array(
4574            &Operator::Eq,
4575            Arc::clone(&dictionary),
4576            &DataType::Int32,
4577        )
4578        .unwrap();
4579        assert_eq!(&casted, &dictionary);
4580    }
4581
4582    #[test]
4583    fn test_add_with_overflow() -> Result<()> {
4584        // create test data
4585        let l = Arc::new(Int32Array::from(vec![1, i32::MAX]));
4586        let r = Arc::new(Int32Array::from(vec![2, 1]));
4587        let schema = Arc::new(Schema::new(vec![
4588            Field::new("l", DataType::Int32, false),
4589            Field::new("r", DataType::Int32, false),
4590        ]));
4591        let batch = RecordBatch::try_new(schema, vec![l, r])?;
4592
4593        // create expression
4594        let expr = BinaryExpr::new(
4595            Arc::new(Column::new("l", 0)),
4596            Operator::Plus,
4597            Arc::new(Column::new("r", 1)),
4598        )
4599        .with_fail_on_overflow(true);
4600
4601        // evaluate expression
4602        let result = expr.evaluate(&batch);
4603        assert!(result
4604            .err()
4605            .unwrap()
4606            .to_string()
4607            .contains("Overflow happened on: 2147483647 + 1"));
4608        Ok(())
4609    }
4610
4611    #[test]
4612    fn test_subtract_with_overflow() -> Result<()> {
4613        // create test data
4614        let l = Arc::new(Int32Array::from(vec![1, i32::MIN]));
4615        let r = Arc::new(Int32Array::from(vec![2, 1]));
4616        let schema = Arc::new(Schema::new(vec![
4617            Field::new("l", DataType::Int32, false),
4618            Field::new("r", DataType::Int32, false),
4619        ]));
4620        let batch = RecordBatch::try_new(schema, vec![l, r])?;
4621
4622        // create expression
4623        let expr = BinaryExpr::new(
4624            Arc::new(Column::new("l", 0)),
4625            Operator::Minus,
4626            Arc::new(Column::new("r", 1)),
4627        )
4628        .with_fail_on_overflow(true);
4629
4630        // evaluate expression
4631        let result = expr.evaluate(&batch);
4632        assert!(result
4633            .err()
4634            .unwrap()
4635            .to_string()
4636            .contains("Overflow happened on: -2147483648 - 1"));
4637        Ok(())
4638    }
4639
4640    #[test]
4641    fn test_mul_with_overflow() -> Result<()> {
4642        // create test data
4643        let l = Arc::new(Int32Array::from(vec![1, i32::MAX]));
4644        let r = Arc::new(Int32Array::from(vec![2, 2]));
4645        let schema = Arc::new(Schema::new(vec![
4646            Field::new("l", DataType::Int32, false),
4647            Field::new("r", DataType::Int32, false),
4648        ]));
4649        let batch = RecordBatch::try_new(schema, vec![l, r])?;
4650
4651        // create expression
4652        let expr = BinaryExpr::new(
4653            Arc::new(Column::new("l", 0)),
4654            Operator::Multiply,
4655            Arc::new(Column::new("r", 1)),
4656        )
4657        .with_fail_on_overflow(true);
4658
4659        // evaluate expression
4660        let result = expr.evaluate(&batch);
4661        assert!(result
4662            .err()
4663            .unwrap()
4664            .to_string()
4665            .contains("Overflow happened on: 2147483647 * 2"));
4666        Ok(())
4667    }
4668
4669    /// Test helper for SIMILAR TO binary operation
4670    fn apply_similar_to(
4671        schema: &SchemaRef,
4672        va: Vec<&str>,
4673        vb: Vec<&str>,
4674        negated: bool,
4675        case_insensitive: bool,
4676        expected: &BooleanArray,
4677    ) -> Result<()> {
4678        let a = StringArray::from(va);
4679        let b = StringArray::from(vb);
4680        let op = similar_to(
4681            negated,
4682            case_insensitive,
4683            col("a", schema)?,
4684            col("b", schema)?,
4685        )?;
4686        let batch =
4687            RecordBatch::try_new(Arc::clone(schema), vec![Arc::new(a), Arc::new(b)])?;
4688        let result = op
4689            .evaluate(&batch)?
4690            .into_array(batch.num_rows())
4691            .expect("Failed to convert to array");
4692        assert_eq!(result.as_ref(), expected);
4693
4694        Ok(())
4695    }
4696
4697    #[test]
4698    fn test_similar_to() {
4699        let schema = Arc::new(Schema::new(vec![
4700            Field::new("a", DataType::Utf8, false),
4701            Field::new("b", DataType::Utf8, false),
4702        ]));
4703
4704        let expected = [Some(true), Some(false)].iter().collect();
4705        // case-sensitive
4706        apply_similar_to(
4707            &schema,
4708            vec!["hello world", "Hello World"],
4709            vec!["hello.*", "hello.*"],
4710            false,
4711            false,
4712            &expected,
4713        )
4714        .unwrap();
4715        // case-insensitive
4716        apply_similar_to(
4717            &schema,
4718            vec!["hello world", "bye"],
4719            vec!["hello.*", "hello.*"],
4720            false,
4721            true,
4722            &expected,
4723        )
4724        .unwrap();
4725    }
4726
4727    pub fn binary_expr(
4728        left: Arc<dyn PhysicalExpr>,
4729        op: Operator,
4730        right: Arc<dyn PhysicalExpr>,
4731        schema: &Schema,
4732    ) -> Result<BinaryExpr> {
4733        Ok(binary_op(left, op, right, schema)?
4734            .as_any()
4735            .downcast_ref::<BinaryExpr>()
4736            .unwrap()
4737            .clone())
4738    }
4739
4740    /// Test for Uniform-Uniform, Unknown-Uniform, Uniform-Unknown and Unknown-Unknown evaluation.
4741    #[test]
4742    fn test_evaluate_statistics_combination_of_range_holders() -> Result<()> {
4743        let schema = &Schema::new(vec![Field::new("a", DataType::Float64, false)]);
4744        let a = Arc::new(Column::new("a", 0)) as _;
4745        let b = lit(ScalarValue::from(12.0));
4746
4747        let left_interval = Interval::make(Some(0.0), Some(12.0))?;
4748        let right_interval = Interval::make(Some(12.0), Some(36.0))?;
4749        let (left_mean, right_mean) = (ScalarValue::from(6.0), ScalarValue::from(24.0));
4750        let (left_med, right_med) = (ScalarValue::from(6.0), ScalarValue::from(24.0));
4751
4752        for children in [
4753            vec![
4754                &Distribution::new_uniform(left_interval.clone())?,
4755                &Distribution::new_uniform(right_interval.clone())?,
4756            ],
4757            vec![
4758                &Distribution::new_generic(
4759                    left_mean.clone(),
4760                    left_med.clone(),
4761                    ScalarValue::Float64(None),
4762                    left_interval.clone(),
4763                )?,
4764                &Distribution::new_uniform(right_interval.clone())?,
4765            ],
4766            vec![
4767                &Distribution::new_uniform(right_interval.clone())?,
4768                &Distribution::new_generic(
4769                    right_mean.clone(),
4770                    right_med.clone(),
4771                    ScalarValue::Float64(None),
4772                    right_interval.clone(),
4773                )?,
4774            ],
4775            vec![
4776                &Distribution::new_generic(
4777                    left_mean.clone(),
4778                    left_med.clone(),
4779                    ScalarValue::Float64(None),
4780                    left_interval.clone(),
4781                )?,
4782                &Distribution::new_generic(
4783                    right_mean.clone(),
4784                    right_med.clone(),
4785                    ScalarValue::Float64(None),
4786                    right_interval.clone(),
4787                )?,
4788            ],
4789        ] {
4790            let ops = vec![
4791                Operator::Plus,
4792                Operator::Minus,
4793                Operator::Multiply,
4794                Operator::Divide,
4795            ];
4796
4797            for op in ops {
4798                let expr = binary_expr(Arc::clone(&a), op, Arc::clone(&b), schema)?;
4799                assert_eq!(
4800                    expr.evaluate_statistics(&children)?,
4801                    new_generic_from_binary_op(&op, children[0], children[1])?
4802                );
4803            }
4804        }
4805        Ok(())
4806    }
4807
4808    #[test]
4809    fn test_evaluate_statistics_bernoulli() -> Result<()> {
4810        let schema = &Schema::new(vec![
4811            Field::new("a", DataType::Int64, false),
4812            Field::new("b", DataType::Int64, false),
4813        ]);
4814        let a = Arc::new(Column::new("a", 0)) as _;
4815        let b = Arc::new(Column::new("b", 1)) as _;
4816        let eq = Arc::new(binary_expr(
4817            Arc::clone(&a),
4818            Operator::Eq,
4819            Arc::clone(&b),
4820            schema,
4821        )?);
4822        let neq = Arc::new(binary_expr(a, Operator::NotEq, b, schema)?);
4823
4824        let left_stat = &Distribution::new_uniform(Interval::make(Some(0), Some(7))?)?;
4825        let right_stat = &Distribution::new_uniform(Interval::make(Some(4), Some(11))?)?;
4826
4827        // Intervals: [0, 7], [4, 11].
4828        // The intersection is [4, 7], so the probability of equality is 4 / 64 = 1 / 16.
4829        assert_eq!(
4830            eq.evaluate_statistics(&[left_stat, right_stat])?,
4831            Distribution::new_bernoulli(ScalarValue::from(1.0 / 16.0))?
4832        );
4833
4834        // The probability of being distinct is 1 - 1 / 16 = 15 / 16.
4835        assert_eq!(
4836            neq.evaluate_statistics(&[left_stat, right_stat])?,
4837            Distribution::new_bernoulli(ScalarValue::from(15.0 / 16.0))?
4838        );
4839
4840        Ok(())
4841    }
4842
4843    #[test]
4844    fn test_propagate_statistics_combination_of_range_holders_arithmetic() -> Result<()> {
4845        let schema = &Schema::new(vec![Field::new("a", DataType::Float64, false)]);
4846        let a = Arc::new(Column::new("a", 0)) as _;
4847        let b = lit(ScalarValue::from(12.0));
4848
4849        let left_interval = Interval::make(Some(0.0), Some(12.0))?;
4850        let right_interval = Interval::make(Some(12.0), Some(36.0))?;
4851
4852        let parent = Distribution::new_uniform(Interval::make(Some(-432.), Some(432.))?)?;
4853        let children = vec![
4854            vec![
4855                Distribution::new_uniform(left_interval.clone())?,
4856                Distribution::new_uniform(right_interval.clone())?,
4857            ],
4858            vec![
4859                Distribution::new_generic(
4860                    ScalarValue::from(6.),
4861                    ScalarValue::from(6.),
4862                    ScalarValue::Float64(None),
4863                    left_interval.clone(),
4864                )?,
4865                Distribution::new_uniform(right_interval.clone())?,
4866            ],
4867            vec![
4868                Distribution::new_uniform(left_interval.clone())?,
4869                Distribution::new_generic(
4870                    ScalarValue::from(12.),
4871                    ScalarValue::from(12.),
4872                    ScalarValue::Float64(None),
4873                    right_interval.clone(),
4874                )?,
4875            ],
4876            vec![
4877                Distribution::new_generic(
4878                    ScalarValue::from(6.),
4879                    ScalarValue::from(6.),
4880                    ScalarValue::Float64(None),
4881                    left_interval.clone(),
4882                )?,
4883                Distribution::new_generic(
4884                    ScalarValue::from(12.),
4885                    ScalarValue::from(12.),
4886                    ScalarValue::Float64(None),
4887                    right_interval.clone(),
4888                )?,
4889            ],
4890        ];
4891
4892        let ops = vec![
4893            Operator::Plus,
4894            Operator::Minus,
4895            Operator::Multiply,
4896            Operator::Divide,
4897        ];
4898
4899        for child_view in children {
4900            let child_refs = child_view.iter().collect::<Vec<_>>();
4901            for op in &ops {
4902                let expr = binary_expr(Arc::clone(&a), *op, Arc::clone(&b), schema)?;
4903                assert_eq!(
4904                    expr.propagate_statistics(&parent, child_refs.as_slice())?,
4905                    Some(child_view.clone())
4906                );
4907            }
4908        }
4909        Ok(())
4910    }
4911
4912    #[test]
4913    fn test_propagate_statistics_combination_of_range_holders_comparison() -> Result<()> {
4914        let schema = &Schema::new(vec![Field::new("a", DataType::Float64, false)]);
4915        let a = Arc::new(Column::new("a", 0)) as _;
4916        let b = lit(ScalarValue::from(12.0));
4917
4918        let left_interval = Interval::make(Some(0.0), Some(12.0))?;
4919        let right_interval = Interval::make(Some(6.0), Some(18.0))?;
4920
4921        let one = ScalarValue::from(1.0);
4922        let parent = Distribution::new_bernoulli(one)?;
4923        let children = vec![
4924            vec![
4925                Distribution::new_uniform(left_interval.clone())?,
4926                Distribution::new_uniform(right_interval.clone())?,
4927            ],
4928            vec![
4929                Distribution::new_generic(
4930                    ScalarValue::from(6.),
4931                    ScalarValue::from(6.),
4932                    ScalarValue::Float64(None),
4933                    left_interval.clone(),
4934                )?,
4935                Distribution::new_uniform(right_interval.clone())?,
4936            ],
4937            vec![
4938                Distribution::new_uniform(left_interval.clone())?,
4939                Distribution::new_generic(
4940                    ScalarValue::from(12.),
4941                    ScalarValue::from(12.),
4942                    ScalarValue::Float64(None),
4943                    right_interval.clone(),
4944                )?,
4945            ],
4946            vec![
4947                Distribution::new_generic(
4948                    ScalarValue::from(6.),
4949                    ScalarValue::from(6.),
4950                    ScalarValue::Float64(None),
4951                    left_interval.clone(),
4952                )?,
4953                Distribution::new_generic(
4954                    ScalarValue::from(12.),
4955                    ScalarValue::from(12.),
4956                    ScalarValue::Float64(None),
4957                    right_interval.clone(),
4958                )?,
4959            ],
4960        ];
4961
4962        let ops = vec![
4963            Operator::Eq,
4964            Operator::Gt,
4965            Operator::GtEq,
4966            Operator::Lt,
4967            Operator::LtEq,
4968        ];
4969
4970        for child_view in children {
4971            let child_refs = child_view.iter().collect::<Vec<_>>();
4972            for op in &ops {
4973                let expr = binary_expr(Arc::clone(&a), *op, Arc::clone(&b), schema)?;
4974                assert!(expr
4975                    .propagate_statistics(&parent, child_refs.as_slice())?
4976                    .is_some());
4977            }
4978        }
4979
4980        Ok(())
4981    }
4982
4983    #[test]
4984    fn test_fmt_sql() -> Result<()> {
4985        let schema = Schema::new(vec![
4986            Field::new("a", DataType::Int32, false),
4987            Field::new("b", DataType::Int32, false),
4988        ]);
4989
4990        // Test basic binary expressions
4991        let simple_expr = binary_expr(
4992            col("a", &schema)?,
4993            Operator::Plus,
4994            col("b", &schema)?,
4995            &schema,
4996        )?;
4997        let display_string = simple_expr.to_string();
4998        assert_eq!(display_string, "a@0 + b@1");
4999        let sql_string = fmt_sql(&simple_expr).to_string();
5000        assert_eq!(sql_string, "a + b");
5001
5002        // Test nested expressions with different operator precedence
5003        let nested_expr = binary_expr(
5004            Arc::new(binary_expr(
5005                col("a", &schema)?,
5006                Operator::Plus,
5007                col("b", &schema)?,
5008                &schema,
5009            )?),
5010            Operator::Multiply,
5011            col("b", &schema)?,
5012            &schema,
5013        )?;
5014        let display_string = nested_expr.to_string();
5015        assert_eq!(display_string, "(a@0 + b@1) * b@1");
5016        let sql_string = fmt_sql(&nested_expr).to_string();
5017        assert_eq!(sql_string, "(a + b) * b");
5018
5019        // Test nested expressions with same operator precedence
5020        let nested_same_prec = binary_expr(
5021            Arc::new(binary_expr(
5022                col("a", &schema)?,
5023                Operator::Plus,
5024                col("b", &schema)?,
5025                &schema,
5026            )?),
5027            Operator::Plus,
5028            col("b", &schema)?,
5029            &schema,
5030        )?;
5031        let display_string = nested_same_prec.to_string();
5032        assert_eq!(display_string, "a@0 + b@1 + b@1");
5033        let sql_string = fmt_sql(&nested_same_prec).to_string();
5034        assert_eq!(sql_string, "a + b + b");
5035
5036        // Test with literals
5037        let lit_expr = binary_expr(
5038            col("a", &schema)?,
5039            Operator::Eq,
5040            lit(ScalarValue::Int32(Some(42))),
5041            &schema,
5042        )?;
5043        let display_string = lit_expr.to_string();
5044        assert_eq!(display_string, "a@0 = 42");
5045        let sql_string = fmt_sql(&lit_expr).to_string();
5046        assert_eq!(sql_string, "a = 42");
5047
5048        Ok(())
5049    }
5050
5051    #[test]
5052    fn test_check_short_circuit() {
5053        // Test with non-nullable arrays
5054        let schema = Arc::new(Schema::new(vec![
5055            Field::new("a", DataType::Int32, false),
5056            Field::new("b", DataType::Int32, false),
5057        ]));
5058        let a_array = Int32Array::from(vec![1, 3, 4, 5, 6]);
5059        let b_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
5060        let batch = RecordBatch::try_new(
5061            Arc::clone(&schema),
5062            vec![Arc::new(a_array), Arc::new(b_array)],
5063        )
5064        .unwrap();
5065
5066        // op: AND left: all false
5067        let left_expr = logical2physical(&logical_col("a").eq(expr_lit(2)), &schema);
5068        let left_value = left_expr.evaluate(&batch).unwrap();
5069        assert!(matches!(
5070            check_short_circuit(&left_value, &Operator::And),
5071            ShortCircuitStrategy::ReturnLeft
5072        ));
5073
5074        // op: AND left: not all false
5075        let left_expr = logical2physical(&logical_col("a").eq(expr_lit(3)), &schema);
5076        let left_value = left_expr.evaluate(&batch).unwrap();
5077        let ColumnarValue::Array(array) = &left_value else {
5078            panic!("Expected ColumnarValue::Array");
5079        };
5080        let ShortCircuitStrategy::PreSelection(value) =
5081            check_short_circuit(&left_value, &Operator::And)
5082        else {
5083            panic!("Expected ShortCircuitStrategy::PreSelection");
5084        };
5085        let expected_boolean_arr: Vec<_> =
5086            as_boolean_array(array).unwrap().iter().collect();
5087        let boolean_arr: Vec<_> = value.iter().collect();
5088        assert_eq!(expected_boolean_arr, boolean_arr);
5089
5090        // op: OR left: all true
5091        let left_expr = logical2physical(&logical_col("a").gt(expr_lit(0)), &schema);
5092        let left_value = left_expr.evaluate(&batch).unwrap();
5093        assert!(matches!(
5094            check_short_circuit(&left_value, &Operator::Or),
5095            ShortCircuitStrategy::ReturnLeft
5096        ));
5097
5098        // op: OR left: not all true
5099        let left_expr: Arc<dyn PhysicalExpr> =
5100            logical2physical(&logical_col("a").gt(expr_lit(2)), &schema);
5101        let left_value = left_expr.evaluate(&batch).unwrap();
5102        assert!(matches!(
5103            check_short_circuit(&left_value, &Operator::Or),
5104            ShortCircuitStrategy::None
5105        ));
5106
5107        // Test with nullable arrays and null values
5108        let schema_nullable = Arc::new(Schema::new(vec![
5109            Field::new("c", DataType::Boolean, true),
5110            Field::new("d", DataType::Boolean, true),
5111        ]));
5112
5113        // Create arrays with null values
5114        let c_array = Arc::new(BooleanArray::from(vec![
5115            Some(true),
5116            Some(false),
5117            None,
5118            Some(true),
5119            None,
5120        ])) as ArrayRef;
5121        let d_array = Arc::new(BooleanArray::from(vec![
5122            Some(false),
5123            Some(true),
5124            Some(false),
5125            None,
5126            Some(true),
5127        ])) as ArrayRef;
5128
5129        let batch_nullable = RecordBatch::try_new(
5130            Arc::clone(&schema_nullable),
5131            vec![Arc::clone(&c_array), Arc::clone(&d_array)],
5132        )
5133        .unwrap();
5134
5135        // Case: Mixed values with nulls - shouldn't short-circuit for AND
5136        let mixed_nulls = logical2physical(&logical_col("c"), &schema_nullable);
5137        let mixed_nulls_value = mixed_nulls.evaluate(&batch_nullable).unwrap();
5138        assert!(matches!(
5139            check_short_circuit(&mixed_nulls_value, &Operator::And),
5140            ShortCircuitStrategy::None
5141        ));
5142
5143        // Case: Mixed values with nulls - shouldn't short-circuit for OR
5144        assert!(matches!(
5145            check_short_circuit(&mixed_nulls_value, &Operator::Or),
5146            ShortCircuitStrategy::None
5147        ));
5148
5149        // Test with all nulls
5150        let all_nulls = Arc::new(BooleanArray::from(vec![None, None, None])) as ArrayRef;
5151        let null_batch = RecordBatch::try_new(
5152            Arc::new(Schema::new(vec![Field::new("e", DataType::Boolean, true)])),
5153            vec![all_nulls],
5154        )
5155        .unwrap();
5156
5157        let null_expr = logical2physical(&logical_col("e"), &null_batch.schema());
5158        let null_value = null_expr.evaluate(&null_batch).unwrap();
5159
5160        // All nulls shouldn't short-circuit for AND or OR
5161        assert!(matches!(
5162            check_short_circuit(&null_value, &Operator::And),
5163            ShortCircuitStrategy::None
5164        ));
5165        assert!(matches!(
5166            check_short_circuit(&null_value, &Operator::Or),
5167            ShortCircuitStrategy::None
5168        ));
5169
5170        // Test with scalar values
5171        // Scalar true
5172        let scalar_true = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true)));
5173        assert!(matches!(
5174            check_short_circuit(&scalar_true, &Operator::Or),
5175            ShortCircuitStrategy::ReturnLeft
5176        )); // Should short-circuit OR
5177        assert!(matches!(
5178            check_short_circuit(&scalar_true, &Operator::And),
5179            ShortCircuitStrategy::ReturnRight
5180        )); // Should return the RHS for AND
5181
5182        // Scalar false
5183        let scalar_false = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false)));
5184        assert!(matches!(
5185            check_short_circuit(&scalar_false, &Operator::And),
5186            ShortCircuitStrategy::ReturnLeft
5187        )); // Should short-circuit AND
5188        assert!(matches!(
5189            check_short_circuit(&scalar_false, &Operator::Or),
5190            ShortCircuitStrategy::ReturnRight
5191        )); // Should return the RHS for OR
5192
5193        // Scalar null
5194        let scalar_null = ColumnarValue::Scalar(ScalarValue::Boolean(None));
5195        assert!(matches!(
5196            check_short_circuit(&scalar_null, &Operator::And),
5197            ShortCircuitStrategy::None
5198        ));
5199        assert!(matches!(
5200            check_short_circuit(&scalar_null, &Operator::Or),
5201            ShortCircuitStrategy::None
5202        ));
5203    }
5204
5205    /// Test for [pre_selection_scatter]
5206    /// Since [check_short_circuit] ensures that the left side does not contain null and is neither all_true nor all_false, as well as not being empty,
5207    /// the following tests have been designed:
5208    /// 1. Test sparse left with interleaved true/false
5209    /// 2. Test multiple consecutive true blocks
5210    /// 3. Test multiple consecutive true blocks
5211    /// 4. Test single true at first position
5212    /// 5. Test single true at last position
5213    /// 6. Test nulls in right array
5214    /// 7. Test scalar right handling
5215    #[test]
5216    fn test_pre_selection_scatter() {
5217        fn create_bool_array(bools: Vec<bool>) -> BooleanArray {
5218            BooleanArray::from(bools.into_iter().map(Some).collect::<Vec<_>>())
5219        }
5220        // Test sparse left with interleaved true/false
5221        {
5222            // Left: [T, F, T, F, T]
5223            // Right: [F, T, F] (values for 3 true positions)
5224            let left = create_bool_array(vec![true, false, true, false, true]);
5225            let right = ColumnarValue::Array(Arc::new(create_bool_array(vec![
5226                false, true, false,
5227            ])));
5228
5229            let result = pre_selection_scatter(&left, right).unwrap();
5230            let result_arr = result.into_array(left.len()).unwrap();
5231
5232            let expected = create_bool_array(vec![false, false, true, false, false]);
5233            assert_eq!(&expected, result_arr.as_boolean());
5234        }
5235        // Test multiple consecutive true blocks
5236        {
5237            // Left: [F, T, T, F, T, T, T]
5238            // Right: [T, F, F, T, F]
5239            let left =
5240                create_bool_array(vec![false, true, true, false, true, true, true]);
5241            let right = ColumnarValue::Array(Arc::new(create_bool_array(vec![
5242                true, false, false, true, false,
5243            ])));
5244
5245            let result = pre_selection_scatter(&left, right).unwrap();
5246            let result_arr = result.into_array(left.len()).unwrap();
5247
5248            let expected =
5249                create_bool_array(vec![false, true, false, false, false, true, false]);
5250            assert_eq!(&expected, result_arr.as_boolean());
5251        }
5252        // Test single true at first position
5253        {
5254            // Left: [T, F, F]
5255            // Right: [F]
5256            let left = create_bool_array(vec![true, false, false]);
5257            let right = ColumnarValue::Array(Arc::new(create_bool_array(vec![false])));
5258
5259            let result = pre_selection_scatter(&left, right).unwrap();
5260            let result_arr = result.into_array(left.len()).unwrap();
5261
5262            let expected = create_bool_array(vec![false, false, false]);
5263            assert_eq!(&expected, result_arr.as_boolean());
5264        }
5265        // Test single true at last position
5266        {
5267            // Left: [F, F, T]
5268            // Right: [F]
5269            let left = create_bool_array(vec![false, false, true]);
5270            let right = ColumnarValue::Array(Arc::new(create_bool_array(vec![false])));
5271
5272            let result = pre_selection_scatter(&left, right).unwrap();
5273            let result_arr = result.into_array(left.len()).unwrap();
5274
5275            let expected = create_bool_array(vec![false, false, false]);
5276            assert_eq!(&expected, result_arr.as_boolean());
5277        }
5278        // Test nulls in right array
5279        {
5280            // Left: [F, T, F, T]
5281            // Right: [None, Some(false)] (with null at first position)
5282            let left = create_bool_array(vec![false, true, false, true]);
5283            let right_arr = BooleanArray::from(vec![None, Some(false)]);
5284            let right = ColumnarValue::Array(Arc::new(right_arr));
5285
5286            let result = pre_selection_scatter(&left, right).unwrap();
5287            let result_arr = result.into_array(left.len()).unwrap();
5288
5289            let expected = BooleanArray::from(vec![
5290                Some(false),
5291                None, // null from right
5292                Some(false),
5293                Some(false),
5294            ]);
5295            assert_eq!(&expected, result_arr.as_boolean());
5296        }
5297        // Test scalar right handling
5298        {
5299            // Left: [T, F, T]
5300            // Right: Scalar true
5301            let left = create_bool_array(vec![true, false, true]);
5302            let right = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true)));
5303
5304            let result = pre_selection_scatter(&left, right).unwrap();
5305            assert!(matches!(result, ColumnarValue::Scalar(_)));
5306        }
5307    }
5308
5309    #[test]
5310    fn test_evaluate_bounds_int32() {
5311        let schema = Schema::new(vec![
5312            Field::new("a", DataType::Int32, false),
5313            Field::new("b", DataType::Int32, false),
5314        ]);
5315
5316        let a = Arc::new(Column::new("a", 0)) as _;
5317        let b = Arc::new(Column::new("b", 1)) as _;
5318
5319        // Test addition bounds
5320        let add_expr =
5321            binary_expr(Arc::clone(&a), Operator::Plus, Arc::clone(&b), &schema).unwrap();
5322        let add_bounds = add_expr
5323            .evaluate_bounds(&[
5324                &Interval::make(Some(1), Some(10)).unwrap(),
5325                &Interval::make(Some(5), Some(15)).unwrap(),
5326            ])
5327            .unwrap();
5328        assert_eq!(add_bounds, Interval::make(Some(6), Some(25)).unwrap());
5329
5330        // Test subtraction bounds
5331        let sub_expr =
5332            binary_expr(Arc::clone(&a), Operator::Minus, Arc::clone(&b), &schema)
5333                .unwrap();
5334        let sub_bounds = sub_expr
5335            .evaluate_bounds(&[
5336                &Interval::make(Some(1), Some(10)).unwrap(),
5337                &Interval::make(Some(5), Some(15)).unwrap(),
5338            ])
5339            .unwrap();
5340        assert_eq!(sub_bounds, Interval::make(Some(-14), Some(5)).unwrap());
5341
5342        // Test multiplication bounds
5343        let mul_expr =
5344            binary_expr(Arc::clone(&a), Operator::Multiply, Arc::clone(&b), &schema)
5345                .unwrap();
5346        let mul_bounds = mul_expr
5347            .evaluate_bounds(&[
5348                &Interval::make(Some(1), Some(10)).unwrap(),
5349                &Interval::make(Some(5), Some(15)).unwrap(),
5350            ])
5351            .unwrap();
5352        assert_eq!(mul_bounds, Interval::make(Some(5), Some(150)).unwrap());
5353
5354        // Test division bounds
5355        let div_expr =
5356            binary_expr(Arc::clone(&a), Operator::Divide, Arc::clone(&b), &schema)
5357                .unwrap();
5358        let div_bounds = div_expr
5359            .evaluate_bounds(&[
5360                &Interval::make(Some(10), Some(20)).unwrap(),
5361                &Interval::make(Some(2), Some(5)).unwrap(),
5362            ])
5363            .unwrap();
5364        assert_eq!(div_bounds, Interval::make(Some(2), Some(10)).unwrap());
5365    }
5366
5367    #[test]
5368    fn test_evaluate_bounds_bool() {
5369        let schema = Schema::new(vec![
5370            Field::new("a", DataType::Boolean, false),
5371            Field::new("b", DataType::Boolean, false),
5372        ]);
5373
5374        let a = Arc::new(Column::new("a", 0)) as _;
5375        let b = Arc::new(Column::new("b", 1)) as _;
5376
5377        // Test OR bounds
5378        let or_expr =
5379            binary_expr(Arc::clone(&a), Operator::Or, Arc::clone(&b), &schema).unwrap();
5380        let or_bounds = or_expr
5381            .evaluate_bounds(&[
5382                &Interval::make(Some(true), Some(true)).unwrap(),
5383                &Interval::make(Some(false), Some(false)).unwrap(),
5384            ])
5385            .unwrap();
5386        assert_eq!(or_bounds, Interval::make(Some(true), Some(true)).unwrap());
5387
5388        // Test AND bounds
5389        let and_expr =
5390            binary_expr(Arc::clone(&a), Operator::And, Arc::clone(&b), &schema).unwrap();
5391        let and_bounds = and_expr
5392            .evaluate_bounds(&[
5393                &Interval::make(Some(true), Some(true)).unwrap(),
5394                &Interval::make(Some(false), Some(false)).unwrap(),
5395            ])
5396            .unwrap();
5397        assert_eq!(
5398            and_bounds,
5399            Interval::make(Some(false), Some(false)).unwrap()
5400        );
5401    }
5402
5403    #[test]
5404    fn test_evaluate_nested_type() {
5405        let batch_schema = Arc::new(Schema::new(vec![
5406            Field::new(
5407                "a",
5408                DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
5409                true,
5410            ),
5411            Field::new(
5412                "b",
5413                DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
5414                true,
5415            ),
5416        ]));
5417
5418        let mut list_builder_a = ListBuilder::new(Int32Builder::new());
5419
5420        list_builder_a.append_value([Some(1)]);
5421        list_builder_a.append_value([Some(2)]);
5422        list_builder_a.append_value([]);
5423        list_builder_a.append_value([None]);
5424
5425        let list_array_a: ArrayRef = Arc::new(list_builder_a.finish());
5426
5427        let mut list_builder_b = ListBuilder::new(Int32Builder::new());
5428
5429        list_builder_b.append_value([Some(1)]);
5430        list_builder_b.append_value([Some(2)]);
5431        list_builder_b.append_value([]);
5432        list_builder_b.append_value([None]);
5433
5434        let list_array_b: ArrayRef = Arc::new(list_builder_b.finish());
5435
5436        let batch =
5437            RecordBatch::try_new(batch_schema, vec![list_array_a, list_array_b]).unwrap();
5438
5439        let schema = Arc::new(Schema::new(vec![
5440            Field::new(
5441                "a",
5442                DataType::List(Arc::new(Field::new("foo", DataType::Int32, true))),
5443                true,
5444            ),
5445            Field::new(
5446                "b",
5447                DataType::List(Arc::new(Field::new("bar", DataType::Int32, true))),
5448                true,
5449            ),
5450        ]));
5451
5452        let a = Arc::new(Column::new("a", 0)) as _;
5453        let b = Arc::new(Column::new("b", 1)) as _;
5454
5455        let eq_expr =
5456            binary_expr(Arc::clone(&a), Operator::Eq, Arc::clone(&b), &schema).unwrap();
5457
5458        let eq_result = eq_expr.evaluate(&batch).unwrap();
5459        let expected =
5460            BooleanArray::from_iter(vec![Some(true), Some(true), Some(true), Some(true)]);
5461        assert_eq!(eq_result.into_array(4).unwrap().as_boolean(), &expected);
5462    }
5463}