Skip to main content

datafusion_physical_expr/expressions/
binary.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18mod kernels;
19
20use crate::PhysicalExpr;
21use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
22use std::hash::Hash;
23use std::{any::Any, sync::Arc};
24
25use arrow::array::*;
26use arrow::compute::kernels::boolean::{and_kleene, or_kleene};
27use arrow::compute::kernels::concat_elements::concat_elements_utf8;
28use arrow::compute::{SlicesIterator, cast, filter_record_batch};
29use arrow::datatypes::*;
30use arrow::error::ArrowError;
31use datafusion_common::cast::as_boolean_array;
32use datafusion_common::{Result, ScalarValue, internal_err, not_impl_err};
33
34use datafusion_expr::binary::BinaryTypeCoercer;
35use datafusion_expr::interval_arithmetic::{Interval, apply_operator};
36use datafusion_expr::sort_properties::ExprProperties;
37use datafusion_expr::statistics::Distribution::{Bernoulli, Gaussian};
38use datafusion_expr::statistics::{
39    Distribution, combine_bernoullis, combine_gaussians,
40    create_bernoulli_from_comparison, new_generic_from_binary_op,
41};
42use datafusion_expr::{ColumnarValue, Operator};
43use datafusion_physical_expr_common::datum::{apply, apply_cmp};
44
45use kernels::{
46    bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar,
47    bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, bitwise_shift_right_dyn,
48    bitwise_shift_right_dyn_scalar, bitwise_xor_dyn, bitwise_xor_dyn_scalar,
49    concat_elements_utf8view, regex_match_dyn, regex_match_dyn_scalar,
50};
51
52/// Binary expression
53#[derive(Debug, Clone, Eq)]
54pub struct BinaryExpr {
55    left: Arc<dyn PhysicalExpr>,
56    op: Operator,
57    right: Arc<dyn PhysicalExpr>,
58    /// Specifies whether an error is returned on overflow or not
59    fail_on_overflow: bool,
60}
61
62// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808
63impl PartialEq for BinaryExpr {
64    fn eq(&self, other: &Self) -> bool {
65        self.left.eq(&other.left)
66            && self.op.eq(&other.op)
67            && self.right.eq(&other.right)
68            && self.fail_on_overflow.eq(&other.fail_on_overflow)
69    }
70}
71impl Hash for BinaryExpr {
72    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
73        self.left.hash(state);
74        self.op.hash(state);
75        self.right.hash(state);
76        self.fail_on_overflow.hash(state);
77    }
78}
79
80impl BinaryExpr {
81    /// Create new binary expression
82    pub fn new(
83        left: Arc<dyn PhysicalExpr>,
84        op: Operator,
85        right: Arc<dyn PhysicalExpr>,
86    ) -> Self {
87        Self {
88            left,
89            op,
90            right,
91            fail_on_overflow: false,
92        }
93    }
94
95    /// Create new binary expression with explicit fail_on_overflow value
96    pub fn with_fail_on_overflow(self, fail_on_overflow: bool) -> Self {
97        Self {
98            left: self.left,
99            op: self.op,
100            right: self.right,
101            fail_on_overflow,
102        }
103    }
104
105    /// Get the left side of the binary expression
106    pub fn left(&self) -> &Arc<dyn PhysicalExpr> {
107        &self.left
108    }
109
110    /// Get the right side of the binary expression
111    pub fn right(&self) -> &Arc<dyn PhysicalExpr> {
112        &self.right
113    }
114
115    /// Get the operator for this binary expression
116    pub fn op(&self) -> &Operator {
117        &self.op
118    }
119}
120
121impl std::fmt::Display for BinaryExpr {
122    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
123        // Put parentheses around child binary expressions so that we can see the difference
124        // between `(a OR b) AND c` and `a OR (b AND c)`. We only insert parentheses when needed,
125        // based on operator precedence. For example, `(a AND b) OR c` and `a AND b OR c` are
126        // equivalent and the parentheses are not necessary.
127
128        fn write_child(
129            f: &mut std::fmt::Formatter,
130            expr: &dyn PhysicalExpr,
131            precedence: u8,
132        ) -> std::fmt::Result {
133            if let Some(child) = expr.as_any().downcast_ref::<BinaryExpr>() {
134                let p = child.op.precedence();
135                if p == 0 || p < precedence {
136                    write!(f, "({child})")?;
137                } else {
138                    write!(f, "{child}")?;
139                }
140            } else {
141                write!(f, "{expr}")?;
142            }
143
144            Ok(())
145        }
146
147        let precedence = self.op.precedence();
148        write_child(f, self.left.as_ref(), precedence)?;
149        write!(f, " {} ", self.op)?;
150        write_child(f, self.right.as_ref(), precedence)
151    }
152}
153
154/// Invoke a boolean kernel on a pair of arrays
155#[inline]
156fn boolean_op(
157    left: &dyn Array,
158    right: &dyn Array,
159    op: impl FnOnce(&BooleanArray, &BooleanArray) -> Result<BooleanArray, ArrowError>,
160) -> Result<Arc<dyn Array + 'static>, ArrowError> {
161    let ll = as_boolean_array(left).expect("boolean_op failed to downcast left array");
162    let rr = as_boolean_array(right).expect("boolean_op failed to downcast right array");
163    op(ll, rr).map(|t| Arc::new(t) as _)
164}
165
166/// Returns true if both operands are Date types (Date32 or Date64)
167/// Used to detect Date - Date operations which should return Int64 (days difference)
168fn is_date_minus_date(lhs: &DataType, rhs: &DataType) -> bool {
169    matches!(
170        (lhs, rhs),
171        (DataType::Date32, DataType::Date32) | (DataType::Date64, DataType::Date64)
172    )
173}
174
175/// Computes the difference between two dates and returns the result as Int64 (days)
176/// This aligns with PostgreSQL, DuckDB, and MySQL behavior where date - date returns an integer
177///
178/// Implementation: Uses Arrow's sub_wrapping to get Duration, then converts to Int64 days
179fn apply_date_subtraction(
180    lhs: &ColumnarValue,
181    rhs: &ColumnarValue,
182) -> Result<ColumnarValue> {
183    use arrow::compute::kernels::numeric::sub_wrapping;
184
185    // Use Arrow's sub_wrapping to compute the Duration result
186    let duration_result = apply(lhs, rhs, sub_wrapping)?;
187
188    // Convert Duration to Int64 (days)
189    match duration_result {
190        ColumnarValue::Array(array) => {
191            let int64_array = duration_to_days(&array)?;
192            Ok(ColumnarValue::Array(int64_array))
193        }
194        ColumnarValue::Scalar(scalar) => {
195            // Convert scalar Duration to Int64 days
196            let array = scalar.to_array_of_size(1)?;
197            let int64_array = duration_to_days(&array)?;
198            let int64_scalar = ScalarValue::try_from_array(int64_array.as_ref(), 0)?;
199            Ok(ColumnarValue::Scalar(int64_scalar))
200        }
201    }
202}
203
204/// Converts a Duration array to Int64 days
205/// Handles different Duration time units (Second, Millisecond, Microsecond, Nanosecond)
206fn duration_to_days(array: &ArrayRef) -> Result<ArrayRef> {
207    use datafusion_common::cast::{
208        as_duration_microsecond_array, as_duration_millisecond_array,
209        as_duration_nanosecond_array, as_duration_second_array,
210    };
211
212    const SECONDS_PER_DAY: i64 = 86_400;
213    const MILLIS_PER_DAY: i64 = 86_400_000;
214    const MICROS_PER_DAY: i64 = 86_400_000_000;
215    const NANOS_PER_DAY: i64 = 86_400_000_000_000;
216
217    match array.data_type() {
218        DataType::Duration(TimeUnit::Second) => {
219            let duration_array = as_duration_second_array(array)?;
220            let result: Int64Array = duration_array
221                .iter()
222                .map(|v| v.map(|val| val / SECONDS_PER_DAY))
223                .collect();
224            Ok(Arc::new(result))
225        }
226        DataType::Duration(TimeUnit::Millisecond) => {
227            let duration_array = as_duration_millisecond_array(array)?;
228            let result: Int64Array = duration_array
229                .iter()
230                .map(|v| v.map(|val| val / MILLIS_PER_DAY))
231                .collect();
232            Ok(Arc::new(result))
233        }
234        DataType::Duration(TimeUnit::Microsecond) => {
235            let duration_array = as_duration_microsecond_array(array)?;
236            let result: Int64Array = duration_array
237                .iter()
238                .map(|v| v.map(|val| val / MICROS_PER_DAY))
239                .collect();
240            Ok(Arc::new(result))
241        }
242        DataType::Duration(TimeUnit::Nanosecond) => {
243            let duration_array = as_duration_nanosecond_array(array)?;
244            let result: Int64Array = duration_array
245                .iter()
246                .map(|v| v.map(|val| val / NANOS_PER_DAY))
247                .collect();
248            Ok(Arc::new(result))
249        }
250        other => internal_err!("duration_to_days expected Duration type, got: {}", other),
251    }
252}
253
254impl PhysicalExpr for BinaryExpr {
255    /// Return a reference to Any that can be used for downcasting
256    fn as_any(&self) -> &dyn Any {
257        self
258    }
259
260    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
261        BinaryTypeCoercer::new(
262            &self.left.data_type(input_schema)?,
263            &self.op,
264            &self.right.data_type(input_schema)?,
265        )
266        .get_result_type()
267    }
268
269    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
270        Ok(self.left.nullable(input_schema)? || self.right.nullable(input_schema)?)
271    }
272
273    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
274        use arrow::compute::kernels::numeric::*;
275
276        // Evaluate left-hand side expression.
277        let lhs = self.left.evaluate(batch)?;
278
279        // Check if we can apply short-circuit evaluation.
280        match check_short_circuit(&lhs, &self.op) {
281            ShortCircuitStrategy::None => {}
282            ShortCircuitStrategy::ReturnLeft => return Ok(lhs),
283            ShortCircuitStrategy::ReturnRight => {
284                let rhs = self.right.evaluate(batch)?;
285                return Ok(rhs);
286            }
287            ShortCircuitStrategy::PreSelection(selection) => {
288                // The function `evaluate_selection` was not called for filtering and calculation,
289                // as it takes into account cases where the selection contains null values.
290                let batch = filter_record_batch(batch, selection)?;
291                let right_ret = self.right.evaluate(&batch)?;
292
293                match &right_ret {
294                    ColumnarValue::Array(array) => {
295                        // When the array on the right is all true or all false, skip the scatter process
296                        let boolean_array = array.as_boolean();
297                        let true_count = boolean_array.true_count();
298                        let length = boolean_array.len();
299                        if true_count == length {
300                            return Ok(lhs);
301                        } else if true_count == 0 && boolean_array.null_count() == 0 {
302                            // If the right-hand array is returned at this point,the lengths will be inconsistent;
303                            // returning a scalar can avoid this issue
304                            return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(
305                                Some(false),
306                            )));
307                        }
308
309                        return pre_selection_scatter(selection, Some(boolean_array));
310                    }
311                    ColumnarValue::Scalar(scalar) => {
312                        if let ScalarValue::Boolean(v) = scalar {
313                            // When the scalar is true or false, skip the scatter process
314                            if let Some(v) = v {
315                                if *v {
316                                    return Ok(lhs);
317                                } else {
318                                    return Ok(right_ret);
319                                }
320                            } else {
321                                return pre_selection_scatter(selection, None);
322                            }
323                        } else {
324                            return internal_err!(
325                                "Expected boolean scalar value, found: {right_ret:?}"
326                            );
327                        }
328                    }
329                }
330            }
331        }
332
333        let rhs = self.right.evaluate(batch)?;
334        let left_data_type = lhs.data_type();
335        let right_data_type = rhs.data_type();
336
337        let schema = batch.schema();
338        let input_schema = schema.as_ref();
339
340        match self.op {
341            Operator::Plus if self.fail_on_overflow => return apply(&lhs, &rhs, add),
342            Operator::Plus => return apply(&lhs, &rhs, add_wrapping),
343            // Special case: Date - Date returns Int64 (days difference)
344            // This aligns with PostgreSQL, DuckDB, and MySQL behavior
345            Operator::Minus if is_date_minus_date(&left_data_type, &right_data_type) => {
346                return apply_date_subtraction(&lhs, &rhs);
347            }
348            Operator::Minus if self.fail_on_overflow => return apply(&lhs, &rhs, sub),
349            Operator::Minus => return apply(&lhs, &rhs, sub_wrapping),
350            Operator::Multiply if self.fail_on_overflow => return apply(&lhs, &rhs, mul),
351            Operator::Multiply => return apply(&lhs, &rhs, mul_wrapping),
352            Operator::Divide => return apply(&lhs, &rhs, div),
353            Operator::Modulo => return apply(&lhs, &rhs, rem),
354
355            Operator::Eq
356            | Operator::NotEq
357            | Operator::Lt
358            | Operator::Gt
359            | Operator::LtEq
360            | Operator::GtEq
361            | Operator::IsDistinctFrom
362            | Operator::IsNotDistinctFrom
363            | Operator::LikeMatch
364            | Operator::ILikeMatch
365            | Operator::NotLikeMatch
366            | Operator::NotILikeMatch => {
367                return apply_cmp(self.op, &lhs, &rhs);
368            }
369            _ => {}
370        }
371
372        let result_type = self.data_type(input_schema)?;
373
374        // If the left-hand side is an array and the right-hand side is a non-null scalar, try the optimized kernel.
375        if let (ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) = (&lhs, &rhs)
376            && !scalar.is_null()
377            && let Some(result_array) =
378                self.evaluate_array_scalar(array, scalar.clone())?
379        {
380            let final_array = result_array
381                .and_then(|a| to_result_type_array(&self.op, a, &result_type));
382            return final_array.map(ColumnarValue::Array);
383        }
384
385        // if both arrays or both literals - extract arrays and continue execution
386        let (left, right) = (
387            lhs.into_array(batch.num_rows())?,
388            rhs.into_array(batch.num_rows())?,
389        );
390        self.evaluate_with_resolved_args(left, &left_data_type, right, &right_data_type)
391            .map(ColumnarValue::Array)
392    }
393
394    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
395        vec![&self.left, &self.right]
396    }
397
398    fn with_new_children(
399        self: Arc<Self>,
400        children: Vec<Arc<dyn PhysicalExpr>>,
401    ) -> Result<Arc<dyn PhysicalExpr>> {
402        Ok(Arc::new(
403            BinaryExpr::new(Arc::clone(&children[0]), self.op, Arc::clone(&children[1]))
404                .with_fail_on_overflow(self.fail_on_overflow),
405        ))
406    }
407
408    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
409        // Get children intervals:
410        let left_interval = children[0];
411        let right_interval = children[1];
412        // Calculate current node's interval:
413        apply_operator(&self.op, left_interval, right_interval)
414    }
415
416    fn propagate_constraints(
417        &self,
418        interval: &Interval,
419        children: &[&Interval],
420    ) -> Result<Option<Vec<Interval>>> {
421        // Get children intervals.
422        let left_interval = children[0];
423        let right_interval = children[1];
424
425        if self.op.eq(&Operator::And) {
426            if interval.eq(&Interval::TRUE) {
427                // A certainly true logical conjunction can only derive from possibly
428                // true operands. Otherwise, we prove infeasibility.
429                Ok((!left_interval.eq(&Interval::FALSE)
430                    && !right_interval.eq(&Interval::FALSE))
431                .then(|| vec![Interval::TRUE, Interval::TRUE]))
432            } else if interval.eq(&Interval::FALSE) {
433                // If the logical conjunction is certainly false, one of the
434                // operands must be false. However, it's not always possible to
435                // determine which operand is false, leading to different scenarios.
436
437                // If one operand is certainly true and the other one is uncertain,
438                // then the latter must be certainly false.
439                if left_interval.eq(&Interval::TRUE)
440                    && right_interval.eq(&Interval::TRUE_OR_FALSE)
441                {
442                    Ok(Some(vec![Interval::TRUE, Interval::FALSE]))
443                } else if right_interval.eq(&Interval::TRUE)
444                    && left_interval.eq(&Interval::TRUE_OR_FALSE)
445                {
446                    Ok(Some(vec![Interval::FALSE, Interval::TRUE]))
447                }
448                // If both children are uncertain, or if one is certainly false,
449                // we cannot conclusively refine their intervals. In this case,
450                // propagation does not result in any interval changes.
451                else {
452                    Ok(Some(vec![]))
453                }
454            } else {
455                // An uncertain logical conjunction result can not shrink the
456                // end-points of its children.
457                Ok(Some(vec![]))
458            }
459        } else if self.op.eq(&Operator::Or) {
460            if interval.eq(&Interval::FALSE) {
461                // A certainly false logical disjunction can only derive from certainly
462                // false operands. Otherwise, we prove infeasibility.
463                Ok((!left_interval.eq(&Interval::TRUE)
464                    && !right_interval.eq(&Interval::TRUE))
465                .then(|| vec![Interval::FALSE, Interval::FALSE]))
466            } else if interval.eq(&Interval::TRUE) {
467                // If the logical disjunction is certainly true, one of the
468                // operands must be true. However, it's not always possible to
469                // determine which operand is true, leading to different scenarios.
470
471                // If one operand is certainly false and the other one is uncertain,
472                // then the latter must be certainly true.
473                if left_interval.eq(&Interval::FALSE)
474                    && right_interval.eq(&Interval::TRUE_OR_FALSE)
475                {
476                    Ok(Some(vec![Interval::FALSE, Interval::TRUE]))
477                } else if right_interval.eq(&Interval::FALSE)
478                    && left_interval.eq(&Interval::TRUE_OR_FALSE)
479                {
480                    Ok(Some(vec![Interval::TRUE, Interval::FALSE]))
481                }
482                // If both children are uncertain, or if one is certainly true,
483                // we cannot conclusively refine their intervals. In this case,
484                // propagation does not result in any interval changes.
485                else {
486                    Ok(Some(vec![]))
487                }
488            } else {
489                // An uncertain logical disjunction result can not shrink the
490                // end-points of its children.
491                Ok(Some(vec![]))
492            }
493        } else if self.op.supports_propagation() {
494            Ok(
495                propagate_comparison(&self.op, interval, left_interval, right_interval)?
496                    .map(|(left, right)| vec![left, right]),
497            )
498        } else {
499            Ok(
500                propagate_arithmetic(&self.op, interval, left_interval, right_interval)?
501                    .map(|(left, right)| vec![left, right]),
502            )
503        }
504    }
505
506    fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
507        let (left, right) = (children[0], children[1]);
508
509        if self.op.is_numerical_operators() {
510            // We might be able to construct the output statistics more accurately,
511            // without falling back to an unknown distribution, if we are dealing
512            // with Gaussian distributions and numerical operations.
513            if let (Gaussian(left), Gaussian(right)) = (left, right)
514                && let Some(result) = combine_gaussians(&self.op, left, right)?
515            {
516                return Ok(Gaussian(result));
517            }
518        } else if self.op.is_logic_operator() {
519            // If we are dealing with logical operators, we expect (and can only
520            // operate on) Bernoulli distributions.
521            return if let (Bernoulli(left), Bernoulli(right)) = (left, right) {
522                combine_bernoullis(&self.op, left, right).map(Bernoulli)
523            } else {
524                internal_err!(
525                    "Logical operators are only compatible with `Bernoulli` distributions"
526                )
527            };
528        } else if self.op.supports_propagation() {
529            // If we are handling comparison operators, we expect (and can only
530            // operate on) numeric distributions.
531            return create_bernoulli_from_comparison(&self.op, left, right);
532        }
533        // Fall back to an unknown distribution with only summary statistics:
534        new_generic_from_binary_op(&self.op, left, right)
535    }
536
537    /// For each operator, [`BinaryExpr`] has distinct rules.
538    /// TODO: There may be rules specific to some data types and expression ranges.
539    fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
540        let (l_order, l_range) = (children[0].sort_properties, &children[0].range);
541        let (r_order, r_range) = (children[1].sort_properties, &children[1].range);
542        match self.op() {
543            Operator::Plus => Ok(ExprProperties {
544                sort_properties: l_order.add(&r_order),
545                range: l_range.add(r_range)?,
546                preserves_lex_ordering: false,
547            }),
548            Operator::Minus => Ok(ExprProperties {
549                sort_properties: l_order.sub(&r_order),
550                range: l_range.sub(r_range)?,
551                preserves_lex_ordering: false,
552            }),
553            Operator::Gt => Ok(ExprProperties {
554                sort_properties: l_order.gt_or_gteq(&r_order),
555                range: l_range.gt(r_range)?,
556                preserves_lex_ordering: false,
557            }),
558            Operator::GtEq => Ok(ExprProperties {
559                sort_properties: l_order.gt_or_gteq(&r_order),
560                range: l_range.gt_eq(r_range)?,
561                preserves_lex_ordering: false,
562            }),
563            Operator::Lt => Ok(ExprProperties {
564                sort_properties: r_order.gt_or_gteq(&l_order),
565                range: l_range.lt(r_range)?,
566                preserves_lex_ordering: false,
567            }),
568            Operator::LtEq => Ok(ExprProperties {
569                sort_properties: r_order.gt_or_gteq(&l_order),
570                range: l_range.lt_eq(r_range)?,
571                preserves_lex_ordering: false,
572            }),
573            Operator::And => Ok(ExprProperties {
574                sort_properties: r_order.and_or(&l_order),
575                range: l_range.and(r_range)?,
576                preserves_lex_ordering: false,
577            }),
578            Operator::Or => Ok(ExprProperties {
579                sort_properties: r_order.and_or(&l_order),
580                range: l_range.or(r_range)?,
581                preserves_lex_ordering: false,
582            }),
583            _ => Ok(ExprProperties::new_unknown()),
584        }
585    }
586
587    fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
588        fn write_child(
589            f: &mut std::fmt::Formatter,
590            expr: &dyn PhysicalExpr,
591            precedence: u8,
592        ) -> std::fmt::Result {
593            if let Some(child) = expr.as_any().downcast_ref::<BinaryExpr>() {
594                let p = child.op.precedence();
595                if p == 0 || p < precedence {
596                    write!(f, "(")?;
597                    child.fmt_sql(f)?;
598                    write!(f, ")")
599                } else {
600                    child.fmt_sql(f)
601                }
602            } else {
603                expr.fmt_sql(f)
604            }
605        }
606
607        let precedence = self.op.precedence();
608        write_child(f, self.left.as_ref(), precedence)?;
609        write!(f, " {} ", self.op)?;
610        write_child(f, self.right.as_ref(), precedence)
611    }
612}
613
614/// Casts dictionary array to result type for binary numerical operators. Such operators
615/// between array and scalar produce a dictionary array other than primitive array of the
616/// same operators between array and array. This leads to inconsistent result types causing
617/// errors in the following query execution. For such operators between array and scalar,
618/// we cast the dictionary array to primitive array.
619fn to_result_type_array(
620    op: &Operator,
621    array: ArrayRef,
622    result_type: &DataType,
623) -> Result<ArrayRef> {
624    if array.data_type() == result_type {
625        Ok(array)
626    } else if op.is_numerical_operators() {
627        match array.data_type() {
628            DataType::Dictionary(_, value_type) => {
629                if value_type.as_ref() == result_type {
630                    Ok(cast(&array, result_type)?)
631                } else {
632                    internal_err!(
633                        "Incompatible Dictionary value type {value_type} with result type {result_type} of Binary operator {op:?}"
634                    )
635                }
636            }
637            _ => Ok(array),
638        }
639    } else {
640        Ok(array)
641    }
642}
643
644impl BinaryExpr {
645    /// Evaluate the expression of the left input is an array and
646    /// right is literal - use scalar operations
647    fn evaluate_array_scalar(
648        &self,
649        array: &dyn Array,
650        scalar: ScalarValue,
651    ) -> Result<Option<Result<ArrayRef>>> {
652        use Operator::*;
653        let scalar_result = match &self.op {
654            RegexMatch => regex_match_dyn_scalar(array, &scalar, false, false),
655            RegexIMatch => regex_match_dyn_scalar(array, &scalar, false, true),
656            RegexNotMatch => regex_match_dyn_scalar(array, &scalar, true, false),
657            RegexNotIMatch => regex_match_dyn_scalar(array, &scalar, true, true),
658            BitwiseAnd => bitwise_and_dyn_scalar(array, scalar),
659            BitwiseOr => bitwise_or_dyn_scalar(array, scalar),
660            BitwiseXor => bitwise_xor_dyn_scalar(array, scalar),
661            BitwiseShiftRight => bitwise_shift_right_dyn_scalar(array, scalar),
662            BitwiseShiftLeft => bitwise_shift_left_dyn_scalar(array, scalar),
663            // if scalar operation is not supported - fallback to array implementation
664            _ => None,
665        };
666
667        Ok(scalar_result)
668    }
669
670    fn evaluate_with_resolved_args(
671        &self,
672        left: Arc<dyn Array>,
673        left_data_type: &DataType,
674        right: Arc<dyn Array>,
675        right_data_type: &DataType,
676    ) -> Result<ArrayRef> {
677        use Operator::*;
678        match &self.op {
679            IsDistinctFrom | IsNotDistinctFrom | Lt | LtEq | Gt | GtEq | Eq | NotEq
680            | Plus | Minus | Multiply | Divide | Modulo | LikeMatch | ILikeMatch
681            | NotLikeMatch | NotILikeMatch => unreachable!(),
682            And => {
683                if left_data_type == &DataType::Boolean {
684                    Ok(boolean_op(&left, &right, and_kleene)?)
685                } else {
686                    internal_err!(
687                        "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
688                        self.op,
689                        left.data_type(),
690                        right.data_type()
691                    )
692                }
693            }
694            Or => {
695                if left_data_type == &DataType::Boolean {
696                    Ok(boolean_op(&left, &right, or_kleene)?)
697                } else {
698                    internal_err!(
699                        "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
700                        self.op,
701                        left_data_type,
702                        right_data_type
703                    )
704                }
705            }
706            RegexMatch => regex_match_dyn(&left, &right, false, false),
707            RegexIMatch => regex_match_dyn(&left, &right, false, true),
708            RegexNotMatch => regex_match_dyn(&left, &right, true, false),
709            RegexNotIMatch => regex_match_dyn(&left, &right, true, true),
710            BitwiseAnd => bitwise_and_dyn(left, right),
711            BitwiseOr => bitwise_or_dyn(left, right),
712            BitwiseXor => bitwise_xor_dyn(left, right),
713            BitwiseShiftRight => bitwise_shift_right_dyn(left, right),
714            BitwiseShiftLeft => bitwise_shift_left_dyn(left, right),
715            StringConcat => concat_elements(&left, &right),
716            AtArrow | ArrowAt | Arrow | LongArrow | HashArrow | HashLongArrow | AtAt
717            | HashMinus | AtQuestion | Question | QuestionAnd | QuestionPipe
718            | IntegerDivide | Colon => {
719                not_impl_err!(
720                    "Binary operator '{:?}' is not supported in the physical expr",
721                    self.op
722                )
723            }
724        }
725    }
726}
727
728enum ShortCircuitStrategy<'a> {
729    None,
730    ReturnLeft,
731    ReturnRight,
732    PreSelection(&'a BooleanArray),
733}
734
735/// Based on the results calculated from the left side of the short-circuit operation,
736/// if the proportion of `true` is less than 0.2 and the current operation is an `and`,
737/// the `RecordBatch` will be filtered in advance.
738const PRE_SELECTION_THRESHOLD: f32 = 0.2;
739
740/// Checks if a logical operator (`AND`/`OR`) can short-circuit evaluation based on the left-hand side (lhs) result.
741///
742/// Short-circuiting occurs under these circumstances:
743/// - For `AND`:
744///    - if LHS is all false => short-circuit → return LHS
745///    - if LHS is all true  => short-circuit → return RHS
746///    - if LHS is mixed and true_count/sum_count <= [`PRE_SELECTION_THRESHOLD`] -> pre-selection
747/// - For `OR`:
748///    - if LHS is all true  => short-circuit → return LHS
749///    - if LHS is all false => short-circuit → return RHS
750/// # Arguments
751/// * `lhs` - The left-hand side (lhs) columnar value (array or scalar)
752/// * `lhs` - The left-hand side (lhs) columnar value (array or scalar)
753/// * `op` - The logical operator (`AND` or `OR`)
754///
755/// # Implementation Notes
756/// 1. Only works with Boolean-typed arguments (other types automatically return `false`)
757/// 2. Handles both scalar values and array values
758/// 3. For arrays, uses optimized bit counting techniques for boolean arrays
759fn check_short_circuit<'a>(
760    lhs: &'a ColumnarValue,
761    op: &Operator,
762) -> ShortCircuitStrategy<'a> {
763    // Quick reject for non-logical operators,and quick judgment when op is and
764    let is_and = match op {
765        Operator::And => true,
766        Operator::Or => false,
767        _ => return ShortCircuitStrategy::None,
768    };
769
770    // Non-boolean types can't be short-circuited
771    if lhs.data_type() != DataType::Boolean {
772        return ShortCircuitStrategy::None;
773    }
774
775    match lhs {
776        ColumnarValue::Array(array) => {
777            // Fast path for arrays - try to downcast to boolean array
778            if let Ok(bool_array) = as_boolean_array(array) {
779                // Arrays with nulls can't be short-circuited
780                if bool_array.null_count() > 0 {
781                    return ShortCircuitStrategy::None;
782                }
783
784                let len = bool_array.len();
785                if len == 0 {
786                    return ShortCircuitStrategy::None;
787                }
788
789                let true_count = bool_array.values().count_set_bits();
790                if is_and {
791                    // For AND, prioritize checking for all-false (short circuit case)
792                    // Uses optimized false_count() method provided by Arrow
793
794                    // Short circuit if all values are false
795                    if true_count == 0 {
796                        return ShortCircuitStrategy::ReturnLeft;
797                    }
798
799                    // If no false values, then all must be true
800                    if true_count == len {
801                        return ShortCircuitStrategy::ReturnRight;
802                    }
803
804                    // determine if we can pre-selection
805                    if true_count as f32 / len as f32 <= PRE_SELECTION_THRESHOLD {
806                        return ShortCircuitStrategy::PreSelection(bool_array);
807                    }
808                } else {
809                    // For OR, prioritize checking for all-true (short circuit case)
810                    // Uses optimized true_count() method provided by Arrow
811
812                    // Short circuit if all values are true
813                    if true_count == len {
814                        return ShortCircuitStrategy::ReturnLeft;
815                    }
816
817                    // If no true values, then all must be false
818                    if true_count == 0 {
819                        return ShortCircuitStrategy::ReturnRight;
820                    }
821                }
822            }
823        }
824        ColumnarValue::Scalar(scalar) => {
825            // Fast path for scalar values
826            if let ScalarValue::Boolean(Some(is_true)) = scalar {
827                // Return Left for:
828                // - AND with false value
829                // - OR with true value
830                if (is_and && !is_true) || (!is_and && *is_true) {
831                    return ShortCircuitStrategy::ReturnLeft;
832                } else {
833                    return ShortCircuitStrategy::ReturnRight;
834                }
835            }
836        }
837    }
838
839    // If we can't short-circuit, indicate that normal evaluation should continue
840    ShortCircuitStrategy::None
841}
842
843/// Creates a new boolean array based on the evaluation of the right expression,
844/// but only for positions where the left_result is true.
845///
846/// This function is used for short-circuit evaluation optimization of logical AND operations:
847/// - When left_result has few true values, we only evaluate the right expression for those positions
848/// - Values are copied from right_array where left_result is true
849/// - All other positions are filled with false values
850///
851/// # Parameters
852/// - `left_result` Boolean array with selection mask (typically from left side of AND)
853/// - `right_result` Result of evaluating right side of expression (only for selected positions)
854///
855/// # Returns
856/// A combined ColumnarValue with values from right_result where left_result is true
857///
858/// # Example
859///  Initial Data: { 1, 2, 3, 4, 5 }
860///  Left Evaluation
861///     (Condition: Equal to 2 or 3)
862///          ↓
863///  Filtered Data: {2, 3}
864///    Left Bitmap: { 0, 1, 1, 0, 0 }
865///          ↓
866///   Right Evaluation
867///     (Condition: Even numbers)
868///          ↓
869///  Right Data: { 2 }
870///    Right Bitmap: { 1, 0 }
871///          ↓
872///   Combine Results
873///  Final Bitmap: { 0, 1, 0, 0, 0 }
874///
875/// # Note
876/// Perhaps it would be better to modify `left_result` directly without creating a copy?
877/// In practice, `left_result` should have only one owner, so making changes should be safe.
878/// However, this is difficult to achieve under the immutable constraints of [`Arc`] and [`BooleanArray`].
879fn pre_selection_scatter(
880    left_result: &BooleanArray,
881    right_result: Option<&BooleanArray>,
882) -> Result<ColumnarValue> {
883    let result_len = left_result.len();
884
885    let mut result_array_builder = BooleanArray::builder(result_len);
886
887    // keep track of current position we have in right boolean array
888    let mut right_array_pos = 0;
889
890    // keep track of how much is filled
891    let mut last_end = 0;
892    // reduce if condition in for_each
893    match right_result {
894        Some(right_result) => {
895            SlicesIterator::new(left_result).for_each(|(start, end)| {
896                // the gap needs to be filled with false
897                if start > last_end {
898                    result_array_builder.append_n(start - last_end, false);
899                }
900
901                // copy values from right array for this slice
902                let len = end - start;
903                right_result
904                    .slice(right_array_pos, len)
905                    .iter()
906                    .for_each(|v| result_array_builder.append_option(v));
907
908                right_array_pos += len;
909                last_end = end;
910            });
911        }
912        None => SlicesIterator::new(left_result).for_each(|(start, end)| {
913            // the gap needs to be filled with false
914            if start > last_end {
915                result_array_builder.append_n(start - last_end, false);
916            }
917
918            // append nulls for this slice derictly
919            let len = end - start;
920            result_array_builder.append_nulls(len);
921
922            last_end = end;
923        }),
924    }
925
926    // Fill any remaining positions with false
927    if last_end < result_len {
928        result_array_builder.append_n(result_len - last_end, false);
929    }
930    let boolean_result = result_array_builder.finish();
931
932    Ok(ColumnarValue::Array(Arc::new(boolean_result)))
933}
934
935fn concat_elements(left: &ArrayRef, right: &ArrayRef) -> Result<ArrayRef> {
936    Ok(match left.data_type() {
937        DataType::Utf8 => Arc::new(concat_elements_utf8(
938            left.as_string::<i32>(),
939            right.as_string::<i32>(),
940        )?),
941        DataType::LargeUtf8 => Arc::new(concat_elements_utf8(
942            left.as_string::<i64>(),
943            right.as_string::<i64>(),
944        )?),
945        DataType::Utf8View => Arc::new(concat_elements_utf8view(
946            left.as_string_view(),
947            right.as_string_view(),
948        )?),
949        other => {
950            return internal_err!(
951                "Data type {other:?} not supported for binary operation 'concat_elements' on string arrays"
952            );
953        }
954    })
955}
956
957/// Create a binary expression whose arguments are correctly coerced.
958/// This function errors if it is not possible to coerce the arguments
959/// to computational types supported by the operator.
960pub fn binary(
961    lhs: Arc<dyn PhysicalExpr>,
962    op: Operator,
963    rhs: Arc<dyn PhysicalExpr>,
964    _input_schema: &Schema,
965) -> Result<Arc<dyn PhysicalExpr>> {
966    Ok(Arc::new(BinaryExpr::new(lhs, op, rhs)))
967}
968
969/// Create a similar to expression
970pub fn similar_to(
971    negated: bool,
972    case_insensitive: bool,
973    expr: Arc<dyn PhysicalExpr>,
974    pattern: Arc<dyn PhysicalExpr>,
975) -> Result<Arc<dyn PhysicalExpr>> {
976    let binary_op = match (negated, case_insensitive) {
977        (false, false) => Operator::RegexMatch,
978        (false, true) => Operator::RegexIMatch,
979        (true, false) => Operator::RegexNotMatch,
980        (true, true) => Operator::RegexNotIMatch,
981    };
982    Ok(Arc::new(BinaryExpr::new(expr, binary_op, pattern)))
983}
984
985#[cfg(test)]
986mod tests {
987    use super::*;
988    use crate::expressions::{Column, Literal, col, lit, try_cast};
989    use datafusion_expr::lit as expr_lit;
990
991    use datafusion_common::plan_datafusion_err;
992    use datafusion_physical_expr_common::physical_expr::fmt_sql;
993
994    use crate::planner::logical2physical;
995    use arrow::array::BooleanArray;
996    use datafusion_expr::col as logical_col;
997    /// Performs a binary operation, applying any type coercion necessary
998    fn binary_op(
999        left: Arc<dyn PhysicalExpr>,
1000        op: Operator,
1001        right: Arc<dyn PhysicalExpr>,
1002        schema: &Schema,
1003    ) -> Result<Arc<dyn PhysicalExpr>> {
1004        let left_type = left.data_type(schema)?;
1005        let right_type = right.data_type(schema)?;
1006        let (lhs, rhs) =
1007            BinaryTypeCoercer::new(&left_type, &op, &right_type).get_input_types()?;
1008
1009        let left_expr = try_cast(left, schema, lhs)?;
1010        let right_expr = try_cast(right, schema, rhs)?;
1011        binary(left_expr, op, right_expr, schema)
1012    }
1013
1014    #[test]
1015    fn binary_comparison() -> Result<()> {
1016        let schema = Schema::new(vec![
1017            Field::new("a", DataType::Int32, false),
1018            Field::new("b", DataType::Int32, false),
1019        ]);
1020        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1021        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1022
1023        // expression: "a < b"
1024        let lt = binary(
1025            col("a", &schema)?,
1026            Operator::Lt,
1027            col("b", &schema)?,
1028            &schema,
1029        )?;
1030        let batch =
1031            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
1032
1033        let result = lt
1034            .evaluate(&batch)?
1035            .into_array(batch.num_rows())
1036            .expect("Failed to convert to array");
1037        assert_eq!(result.len(), 5);
1038
1039        let expected = [false, false, true, true, true];
1040        let result =
1041            as_boolean_array(&result).expect("failed to downcast to BooleanArray");
1042        for (i, &expected_item) in expected.iter().enumerate().take(5) {
1043            assert_eq!(result.value(i), expected_item);
1044        }
1045
1046        Ok(())
1047    }
1048
1049    #[test]
1050    fn binary_nested() -> Result<()> {
1051        let schema = Schema::new(vec![
1052            Field::new("a", DataType::Int32, false),
1053            Field::new("b", DataType::Int32, false),
1054        ]);
1055        let a = Int32Array::from(vec![2, 4, 6, 8, 10]);
1056        let b = Int32Array::from(vec![2, 5, 4, 8, 8]);
1057
1058        // expression: "a < b OR a == b"
1059        let expr = binary(
1060            binary(
1061                col("a", &schema)?,
1062                Operator::Lt,
1063                col("b", &schema)?,
1064                &schema,
1065            )?,
1066            Operator::Or,
1067            binary(
1068                col("a", &schema)?,
1069                Operator::Eq,
1070                col("b", &schema)?,
1071                &schema,
1072            )?,
1073            &schema,
1074        )?;
1075        let batch =
1076            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
1077
1078        assert_eq!("a@0 < b@1 OR a@0 = b@1", format!("{expr}"));
1079
1080        let result = expr
1081            .evaluate(&batch)?
1082            .into_array(batch.num_rows())
1083            .expect("Failed to convert to array");
1084        assert_eq!(result.len(), 5);
1085
1086        let expected = [true, true, false, true, false];
1087        let result =
1088            as_boolean_array(&result).expect("failed to downcast to BooleanArray");
1089        for (i, &expected_item) in expected.iter().enumerate().take(5) {
1090            assert_eq!(result.value(i), expected_item);
1091        }
1092
1093        Ok(())
1094    }
1095
1096    // runs an end-to-end test of physical type coercion:
1097    // 1. construct a record batch with two columns of type A and B
1098    //  (*_ARRAY is the Rust Arrow array type, and *_TYPE is the DataType of the elements)
1099    // 2. construct a physical expression of A OP B
1100    // 3. evaluate the expression
1101    // 4. verify that the resulting expression is of type C
1102    // 5. verify that the results of evaluation are $VEC
1103    macro_rules! test_coercion {
1104        ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $B_ARRAY:ident, $B_TYPE:expr, $B_VEC:expr, $OP:expr, $C_ARRAY:ident, $C_TYPE:expr, $VEC:expr,) => {{
1105            let schema = Schema::new(vec![
1106                Field::new("a", $A_TYPE, false),
1107                Field::new("b", $B_TYPE, false),
1108            ]);
1109            let a = $A_ARRAY::from($A_VEC);
1110            let b = $B_ARRAY::from($B_VEC);
1111            let (lhs, rhs) =
1112                BinaryTypeCoercer::new(&$A_TYPE, &$OP, &$B_TYPE).get_input_types()?;
1113
1114            let left = try_cast(col("a", &schema)?, &schema, lhs)?;
1115            let right = try_cast(col("b", &schema)?, &schema, rhs)?;
1116
1117            // verify that we can construct the expression
1118            let expression = binary(left, $OP, right, &schema)?;
1119            let batch = RecordBatch::try_new(
1120                Arc::new(schema.clone()),
1121                vec![Arc::new(a), Arc::new(b)],
1122            )?;
1123
1124            // verify that the expression's type is correct
1125            assert_eq!(expression.data_type(&schema)?, $C_TYPE);
1126
1127            // compute
1128            let result = expression
1129                .evaluate(&batch)?
1130                .into_array(batch.num_rows())
1131                .expect("Failed to convert to array");
1132
1133            // verify that the array's data_type is correct
1134            assert_eq!(*result.data_type(), $C_TYPE);
1135
1136            // verify that the data itself is downcastable
1137            let result = result
1138                .as_any()
1139                .downcast_ref::<$C_ARRAY>()
1140                .expect("failed to downcast");
1141            // verify that the result itself is correct
1142            for (i, x) in $VEC.iter().enumerate() {
1143                let v = result.value(i);
1144                assert_eq!(
1145                    v, *x,
1146                    "Unexpected output at position {i}:\n\nActual:\n{v}\n\nExpected:\n{x}"
1147                );
1148            }
1149        }};
1150    }
1151
1152    #[test]
1153    fn test_type_coercion() -> Result<()> {
1154        test_coercion!(
1155            Int32Array,
1156            DataType::Int32,
1157            vec![1i32, 2i32],
1158            UInt32Array,
1159            DataType::UInt32,
1160            vec![1u32, 2u32],
1161            Operator::Plus,
1162            Int64Array,
1163            DataType::Int64,
1164            [2i64, 4i64],
1165        );
1166        test_coercion!(
1167            Int32Array,
1168            DataType::Int32,
1169            vec![1i32],
1170            UInt16Array,
1171            DataType::UInt16,
1172            vec![1u16],
1173            Operator::Plus,
1174            Int32Array,
1175            DataType::Int32,
1176            [2i32],
1177        );
1178        test_coercion!(
1179            Float32Array,
1180            DataType::Float32,
1181            vec![1f32],
1182            UInt16Array,
1183            DataType::UInt16,
1184            vec![1u16],
1185            Operator::Plus,
1186            Float32Array,
1187            DataType::Float32,
1188            [2f32],
1189        );
1190        test_coercion!(
1191            Float32Array,
1192            DataType::Float32,
1193            vec![2f32],
1194            UInt16Array,
1195            DataType::UInt16,
1196            vec![1u16],
1197            Operator::Multiply,
1198            Float32Array,
1199            DataType::Float32,
1200            [2f32],
1201        );
1202        test_coercion!(
1203            StringArray,
1204            DataType::Utf8,
1205            vec!["1994-12-13", "1995-01-26"],
1206            Date32Array,
1207            DataType::Date32,
1208            vec![9112, 9156],
1209            Operator::Eq,
1210            BooleanArray,
1211            DataType::Boolean,
1212            [true, true],
1213        );
1214        test_coercion!(
1215            StringArray,
1216            DataType::Utf8,
1217            vec!["1994-12-13", "1995-01-26"],
1218            Date32Array,
1219            DataType::Date32,
1220            vec![9113, 9154],
1221            Operator::Lt,
1222            BooleanArray,
1223            DataType::Boolean,
1224            [true, false],
1225        );
1226        test_coercion!(
1227            StringArray,
1228            DataType::Utf8,
1229            vec!["1994-12-13T12:34:56", "1995-01-26T01:23:45"],
1230            Date64Array,
1231            DataType::Date64,
1232            vec![787322096000, 791083425000],
1233            Operator::Eq,
1234            BooleanArray,
1235            DataType::Boolean,
1236            [true, true],
1237        );
1238        test_coercion!(
1239            StringArray,
1240            DataType::Utf8,
1241            vec!["1994-12-13T12:34:56", "1995-01-26T01:23:45"],
1242            Date64Array,
1243            DataType::Date64,
1244            vec![787322096001, 791083424999],
1245            Operator::Lt,
1246            BooleanArray,
1247            DataType::Boolean,
1248            [true, false],
1249        );
1250        test_coercion!(
1251            StringViewArray,
1252            DataType::Utf8View,
1253            vec!["abc"; 5],
1254            StringArray,
1255            DataType::Utf8,
1256            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1257            Operator::RegexMatch,
1258            BooleanArray,
1259            DataType::Boolean,
1260            [true, false, true, false, false],
1261        );
1262        test_coercion!(
1263            StringViewArray,
1264            DataType::Utf8View,
1265            vec!["abc"; 5],
1266            StringArray,
1267            DataType::Utf8,
1268            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1269            Operator::RegexIMatch,
1270            BooleanArray,
1271            DataType::Boolean,
1272            [true, true, true, true, false],
1273        );
1274        test_coercion!(
1275            StringArray,
1276            DataType::Utf8,
1277            vec!["abc"; 5],
1278            StringViewArray,
1279            DataType::Utf8View,
1280            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1281            Operator::RegexNotMatch,
1282            BooleanArray,
1283            DataType::Boolean,
1284            [false, true, false, true, true],
1285        );
1286        test_coercion!(
1287            StringArray,
1288            DataType::Utf8,
1289            vec!["abc"; 5],
1290            StringViewArray,
1291            DataType::Utf8View,
1292            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1293            Operator::RegexNotIMatch,
1294            BooleanArray,
1295            DataType::Boolean,
1296            [false, false, false, false, true],
1297        );
1298        test_coercion!(
1299            StringArray,
1300            DataType::Utf8,
1301            vec!["abc"; 5],
1302            StringArray,
1303            DataType::Utf8,
1304            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1305            Operator::RegexMatch,
1306            BooleanArray,
1307            DataType::Boolean,
1308            [true, false, true, false, false],
1309        );
1310        test_coercion!(
1311            StringArray,
1312            DataType::Utf8,
1313            vec!["abc"; 5],
1314            StringArray,
1315            DataType::Utf8,
1316            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1317            Operator::RegexIMatch,
1318            BooleanArray,
1319            DataType::Boolean,
1320            [true, true, true, true, false],
1321        );
1322        test_coercion!(
1323            StringArray,
1324            DataType::Utf8,
1325            vec!["abc"; 5],
1326            StringArray,
1327            DataType::Utf8,
1328            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1329            Operator::RegexNotMatch,
1330            BooleanArray,
1331            DataType::Boolean,
1332            [false, true, false, true, true],
1333        );
1334        test_coercion!(
1335            StringArray,
1336            DataType::Utf8,
1337            vec!["abc"; 5],
1338            StringArray,
1339            DataType::Utf8,
1340            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1341            Operator::RegexNotIMatch,
1342            BooleanArray,
1343            DataType::Boolean,
1344            [false, false, false, false, true],
1345        );
1346        test_coercion!(
1347            LargeStringArray,
1348            DataType::LargeUtf8,
1349            vec!["abc"; 5],
1350            LargeStringArray,
1351            DataType::LargeUtf8,
1352            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1353            Operator::RegexMatch,
1354            BooleanArray,
1355            DataType::Boolean,
1356            [true, false, true, false, false],
1357        );
1358        test_coercion!(
1359            LargeStringArray,
1360            DataType::LargeUtf8,
1361            vec!["abc"; 5],
1362            LargeStringArray,
1363            DataType::LargeUtf8,
1364            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1365            Operator::RegexIMatch,
1366            BooleanArray,
1367            DataType::Boolean,
1368            [true, true, true, true, false],
1369        );
1370        test_coercion!(
1371            LargeStringArray,
1372            DataType::LargeUtf8,
1373            vec!["abc"; 5],
1374            LargeStringArray,
1375            DataType::LargeUtf8,
1376            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1377            Operator::RegexNotMatch,
1378            BooleanArray,
1379            DataType::Boolean,
1380            [false, true, false, true, true],
1381        );
1382        test_coercion!(
1383            LargeStringArray,
1384            DataType::LargeUtf8,
1385            vec!["abc"; 5],
1386            LargeStringArray,
1387            DataType::LargeUtf8,
1388            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1389            Operator::RegexNotIMatch,
1390            BooleanArray,
1391            DataType::Boolean,
1392            [false, false, false, false, true],
1393        );
1394        test_coercion!(
1395            StringArray,
1396            DataType::Utf8,
1397            vec!["abc"; 5],
1398            StringArray,
1399            DataType::Utf8,
1400            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1401            Operator::LikeMatch,
1402            BooleanArray,
1403            DataType::Boolean,
1404            [true, false, false, true, false],
1405        );
1406        test_coercion!(
1407            StringArray,
1408            DataType::Utf8,
1409            vec!["abc"; 5],
1410            StringArray,
1411            DataType::Utf8,
1412            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1413            Operator::ILikeMatch,
1414            BooleanArray,
1415            DataType::Boolean,
1416            [true, true, false, true, true],
1417        );
1418        test_coercion!(
1419            StringArray,
1420            DataType::Utf8,
1421            vec!["abc"; 5],
1422            StringArray,
1423            DataType::Utf8,
1424            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1425            Operator::NotLikeMatch,
1426            BooleanArray,
1427            DataType::Boolean,
1428            [false, true, true, false, true],
1429        );
1430        test_coercion!(
1431            StringArray,
1432            DataType::Utf8,
1433            vec!["abc"; 5],
1434            StringArray,
1435            DataType::Utf8,
1436            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1437            Operator::NotILikeMatch,
1438            BooleanArray,
1439            DataType::Boolean,
1440            [false, false, true, false, false],
1441        );
1442        test_coercion!(
1443            LargeStringArray,
1444            DataType::LargeUtf8,
1445            vec!["abc"; 5],
1446            LargeStringArray,
1447            DataType::LargeUtf8,
1448            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1449            Operator::LikeMatch,
1450            BooleanArray,
1451            DataType::Boolean,
1452            [true, false, false, true, false],
1453        );
1454        test_coercion!(
1455            LargeStringArray,
1456            DataType::LargeUtf8,
1457            vec!["abc"; 5],
1458            LargeStringArray,
1459            DataType::LargeUtf8,
1460            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1461            Operator::ILikeMatch,
1462            BooleanArray,
1463            DataType::Boolean,
1464            [true, true, false, true, true],
1465        );
1466        test_coercion!(
1467            LargeStringArray,
1468            DataType::LargeUtf8,
1469            vec!["abc"; 5],
1470            LargeStringArray,
1471            DataType::LargeUtf8,
1472            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1473            Operator::NotLikeMatch,
1474            BooleanArray,
1475            DataType::Boolean,
1476            [false, true, true, false, true],
1477        );
1478        test_coercion!(
1479            LargeStringArray,
1480            DataType::LargeUtf8,
1481            vec!["abc"; 5],
1482            LargeStringArray,
1483            DataType::LargeUtf8,
1484            vec!["a__", "A%BC", "A_BC", "abc", "a%C"],
1485            Operator::NotILikeMatch,
1486            BooleanArray,
1487            DataType::Boolean,
1488            [false, false, true, false, false],
1489        );
1490        test_coercion!(
1491            Int16Array,
1492            DataType::Int16,
1493            vec![1i16, 2i16, 3i16],
1494            Int64Array,
1495            DataType::Int64,
1496            vec![10i64, 4i64, 5i64],
1497            Operator::BitwiseAnd,
1498            Int64Array,
1499            DataType::Int64,
1500            [0i64, 0i64, 1i64],
1501        );
1502        test_coercion!(
1503            UInt16Array,
1504            DataType::UInt16,
1505            vec![1u16, 2u16, 3u16],
1506            UInt64Array,
1507            DataType::UInt64,
1508            vec![10u64, 4u64, 5u64],
1509            Operator::BitwiseAnd,
1510            UInt64Array,
1511            DataType::UInt64,
1512            [0u64, 0u64, 1u64],
1513        );
1514        test_coercion!(
1515            Int16Array,
1516            DataType::Int16,
1517            vec![3i16, 2i16, 3i16],
1518            Int64Array,
1519            DataType::Int64,
1520            vec![10i64, 6i64, 5i64],
1521            Operator::BitwiseOr,
1522            Int64Array,
1523            DataType::Int64,
1524            [11i64, 6i64, 7i64],
1525        );
1526        test_coercion!(
1527            UInt16Array,
1528            DataType::UInt16,
1529            vec![1u16, 2u16, 3u16],
1530            UInt64Array,
1531            DataType::UInt64,
1532            vec![10u64, 4u64, 5u64],
1533            Operator::BitwiseOr,
1534            UInt64Array,
1535            DataType::UInt64,
1536            [11u64, 6u64, 7u64],
1537        );
1538        test_coercion!(
1539            Int16Array,
1540            DataType::Int16,
1541            vec![3i16, 2i16, 3i16],
1542            Int64Array,
1543            DataType::Int64,
1544            vec![10i64, 6i64, 5i64],
1545            Operator::BitwiseXor,
1546            Int64Array,
1547            DataType::Int64,
1548            [9i64, 4i64, 6i64],
1549        );
1550        test_coercion!(
1551            UInt16Array,
1552            DataType::UInt16,
1553            vec![3u16, 2u16, 3u16],
1554            UInt64Array,
1555            DataType::UInt64,
1556            vec![10u64, 6u64, 5u64],
1557            Operator::BitwiseXor,
1558            UInt64Array,
1559            DataType::UInt64,
1560            [9u64, 4u64, 6u64],
1561        );
1562        test_coercion!(
1563            Int16Array,
1564            DataType::Int16,
1565            vec![4i16, 27i16, 35i16],
1566            Int64Array,
1567            DataType::Int64,
1568            vec![2i64, 3i64, 4i64],
1569            Operator::BitwiseShiftRight,
1570            Int64Array,
1571            DataType::Int64,
1572            [1i64, 3i64, 2i64],
1573        );
1574        test_coercion!(
1575            UInt16Array,
1576            DataType::UInt16,
1577            vec![4u16, 27u16, 35u16],
1578            UInt64Array,
1579            DataType::UInt64,
1580            vec![2u64, 3u64, 4u64],
1581            Operator::BitwiseShiftRight,
1582            UInt64Array,
1583            DataType::UInt64,
1584            [1u64, 3u64, 2u64],
1585        );
1586        test_coercion!(
1587            Int16Array,
1588            DataType::Int16,
1589            vec![2i16, 3i16, 4i16],
1590            Int64Array,
1591            DataType::Int64,
1592            vec![4i64, 12i64, 7i64],
1593            Operator::BitwiseShiftLeft,
1594            Int64Array,
1595            DataType::Int64,
1596            [32i64, 12288i64, 512i64],
1597        );
1598        test_coercion!(
1599            UInt16Array,
1600            DataType::UInt16,
1601            vec![2u16, 3u16, 4u16],
1602            UInt64Array,
1603            DataType::UInt64,
1604            vec![4u64, 12u64, 7u64],
1605            Operator::BitwiseShiftLeft,
1606            UInt64Array,
1607            DataType::UInt64,
1608            [32u64, 12288u64, 512u64],
1609        );
1610        Ok(())
1611    }
1612
1613    // Note it would be nice to use the same test_coercion macro as
1614    // above, but sadly the type of the values of the dictionary are
1615    // not encoded in the rust type of the DictionaryArray. Thus there
1616    // is no way at the time of this writing to create a dictionary
1617    // array using the `From` trait
1618    #[test]
1619    fn test_dictionary_type_to_array_coercion() -> Result<()> {
1620        // Test string  a string dictionary
1621        let dict_type =
1622            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
1623        let string_type = DataType::Utf8;
1624
1625        // build dictionary
1626        let mut dict_builder = StringDictionaryBuilder::<Int32Type>::new();
1627
1628        dict_builder.append("one")?;
1629        dict_builder.append_null();
1630        dict_builder.append("three")?;
1631        dict_builder.append("four")?;
1632        let dict_array = Arc::new(dict_builder.finish()) as ArrayRef;
1633
1634        let str_array = Arc::new(StringArray::from(vec![
1635            Some("not one"),
1636            Some("two"),
1637            None,
1638            Some("four"),
1639        ])) as ArrayRef;
1640
1641        let schema = Arc::new(Schema::new(vec![
1642            Field::new("a", dict_type.clone(), true),
1643            Field::new("b", string_type.clone(), true),
1644        ]));
1645
1646        // Test 1: a = b
1647        let result = BooleanArray::from(vec![Some(false), None, None, Some(true)]);
1648        apply_logic_op(&schema, &dict_array, &str_array, Operator::Eq, result)?;
1649
1650        // Test 2: now test the other direction
1651        // b = a
1652        let schema = Arc::new(Schema::new(vec![
1653            Field::new("a", string_type, true),
1654            Field::new("b", dict_type, true),
1655        ]));
1656        let result = BooleanArray::from(vec![Some(false), None, None, Some(true)]);
1657        apply_logic_op(&schema, &str_array, &dict_array, Operator::Eq, result)?;
1658
1659        Ok(())
1660    }
1661
1662    #[test]
1663    fn plus_op() -> Result<()> {
1664        let schema = Schema::new(vec![
1665            Field::new("a", DataType::Int32, false),
1666            Field::new("b", DataType::Int32, false),
1667        ]);
1668        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1669        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1670
1671        apply_arithmetic::<Int32Type>(
1672            Arc::new(schema),
1673            vec![Arc::new(a), Arc::new(b)],
1674            Operator::Plus,
1675            Int32Array::from(vec![2, 4, 7, 12, 21]),
1676        )?;
1677
1678        Ok(())
1679    }
1680
1681    #[test]
1682    fn plus_op_dict() -> Result<()> {
1683        let schema = Schema::new(vec![
1684            Field::new(
1685                "a",
1686                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1687                true,
1688            ),
1689            Field::new(
1690                "b",
1691                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1692                true,
1693            ),
1694        ]);
1695
1696        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1697        let keys = Int8Array::from(vec![Some(0), None, Some(1), Some(3), None]);
1698        let a = DictionaryArray::try_new(keys, Arc::new(a))?;
1699
1700        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1701        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
1702        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
1703
1704        apply_arithmetic::<Int32Type>(
1705            Arc::new(schema),
1706            vec![Arc::new(a), Arc::new(b)],
1707            Operator::Plus,
1708            Int32Array::from(vec![Some(2), None, Some(4), Some(8), None]),
1709        )?;
1710
1711        Ok(())
1712    }
1713
1714    #[test]
1715    fn plus_op_dict_decimal() -> Result<()> {
1716        let schema = Schema::new(vec![
1717            Field::new(
1718                "a",
1719                DataType::Dictionary(
1720                    Box::new(DataType::Int8),
1721                    Box::new(DataType::Decimal128(10, 0)),
1722                ),
1723                true,
1724            ),
1725            Field::new(
1726                "b",
1727                DataType::Dictionary(
1728                    Box::new(DataType::Int8),
1729                    Box::new(DataType::Decimal128(10, 0)),
1730                ),
1731                true,
1732            ),
1733        ]);
1734
1735        let value = 123;
1736        let decimal_array = Arc::new(create_decimal_array(
1737            &[
1738                Some(value),
1739                Some(value + 2),
1740                Some(value - 1),
1741                Some(value + 1),
1742            ],
1743            10,
1744            0,
1745        ));
1746
1747        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
1748        let a = DictionaryArray::try_new(keys, decimal_array)?;
1749
1750        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
1751        let decimal_array = Arc::new(create_decimal_array(
1752            &[
1753                Some(value + 1),
1754                Some(value + 3),
1755                Some(value),
1756                Some(value + 2),
1757            ],
1758            10,
1759            0,
1760        ));
1761        let b = DictionaryArray::try_new(keys, decimal_array)?;
1762
1763        apply_arithmetic(
1764            Arc::new(schema),
1765            vec![Arc::new(a), Arc::new(b)],
1766            Operator::Plus,
1767            create_decimal_array(&[Some(247), None, None, Some(247), Some(246)], 11, 0),
1768        )?;
1769
1770        Ok(())
1771    }
1772
1773    #[test]
1774    fn plus_op_scalar() -> Result<()> {
1775        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1776        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1777
1778        apply_arithmetic_scalar(
1779            Arc::new(schema),
1780            vec![Arc::new(a)],
1781            Operator::Plus,
1782            ScalarValue::Int32(Some(1)),
1783            Arc::new(Int32Array::from(vec![2, 3, 4, 5, 6])),
1784        )?;
1785
1786        Ok(())
1787    }
1788
1789    #[test]
1790    fn plus_op_dict_scalar() -> Result<()> {
1791        let schema = Schema::new(vec![Field::new(
1792            "a",
1793            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1794            true,
1795        )]);
1796
1797        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
1798
1799        dict_builder.append(1)?;
1800        dict_builder.append_null();
1801        dict_builder.append(2)?;
1802        dict_builder.append(5)?;
1803
1804        let a = dict_builder.finish();
1805
1806        let expected: PrimitiveArray<Int32Type> =
1807            PrimitiveArray::from(vec![Some(2), None, Some(3), Some(6)]);
1808
1809        apply_arithmetic_scalar(
1810            Arc::new(schema),
1811            vec![Arc::new(a)],
1812            Operator::Plus,
1813            ScalarValue::Dictionary(
1814                Box::new(DataType::Int8),
1815                Box::new(ScalarValue::Int32(Some(1))),
1816            ),
1817            Arc::new(expected),
1818        )?;
1819
1820        Ok(())
1821    }
1822
1823    #[test]
1824    fn plus_op_dict_scalar_decimal() -> Result<()> {
1825        let schema = Schema::new(vec![Field::new(
1826            "a",
1827            DataType::Dictionary(
1828                Box::new(DataType::Int8),
1829                Box::new(DataType::Decimal128(10, 0)),
1830            ),
1831            true,
1832        )]);
1833
1834        let value = 123;
1835        let decimal_array = Arc::new(create_decimal_array(
1836            &[Some(value), None, Some(value - 1), Some(value + 1)],
1837            10,
1838            0,
1839        ));
1840
1841        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
1842        let a = DictionaryArray::try_new(keys, decimal_array)?;
1843
1844        let decimal_array = Arc::new(create_decimal_array(
1845            &[
1846                Some(value + 1),
1847                Some(value),
1848                None,
1849                Some(value + 2),
1850                Some(value + 1),
1851            ],
1852            11,
1853            0,
1854        ));
1855
1856        apply_arithmetic_scalar(
1857            Arc::new(schema),
1858            vec![Arc::new(a)],
1859            Operator::Plus,
1860            ScalarValue::Dictionary(
1861                Box::new(DataType::Int8),
1862                Box::new(ScalarValue::Decimal128(Some(1), 10, 0)),
1863            ),
1864            decimal_array,
1865        )?;
1866
1867        Ok(())
1868    }
1869
1870    #[test]
1871    fn minus_op() -> Result<()> {
1872        let schema = Arc::new(Schema::new(vec![
1873            Field::new("a", DataType::Int32, false),
1874            Field::new("b", DataType::Int32, false),
1875        ]));
1876        let a = Arc::new(Int32Array::from(vec![1, 2, 4, 8, 16]));
1877        let b = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
1878
1879        apply_arithmetic::<Int32Type>(
1880            Arc::clone(&schema),
1881            vec![
1882                Arc::clone(&a) as Arc<dyn Array>,
1883                Arc::clone(&b) as Arc<dyn Array>,
1884            ],
1885            Operator::Minus,
1886            Int32Array::from(vec![0, 0, 1, 4, 11]),
1887        )?;
1888
1889        // should handle have negative values in result (for signed)
1890        apply_arithmetic::<Int32Type>(
1891            schema,
1892            vec![b, a],
1893            Operator::Minus,
1894            Int32Array::from(vec![0, 0, -1, -4, -11]),
1895        )?;
1896
1897        Ok(())
1898    }
1899
1900    #[test]
1901    fn minus_op_dict() -> Result<()> {
1902        let schema = Schema::new(vec![
1903            Field::new(
1904                "a",
1905                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1906                true,
1907            ),
1908            Field::new(
1909                "b",
1910                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
1911                true,
1912            ),
1913        ]);
1914
1915        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1916        let keys = Int8Array::from(vec![Some(0), None, Some(1), Some(3), None]);
1917        let a = DictionaryArray::try_new(keys, Arc::new(a))?;
1918
1919        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1920        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
1921        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
1922
1923        apply_arithmetic::<Int32Type>(
1924            Arc::new(schema),
1925            vec![Arc::new(a), Arc::new(b)],
1926            Operator::Minus,
1927            Int32Array::from(vec![Some(0), None, Some(0), Some(0), None]),
1928        )?;
1929
1930        Ok(())
1931    }
1932
1933    #[test]
1934    fn minus_op_dict_decimal() -> Result<()> {
1935        let schema = Schema::new(vec![
1936            Field::new(
1937                "a",
1938                DataType::Dictionary(
1939                    Box::new(DataType::Int8),
1940                    Box::new(DataType::Decimal128(10, 0)),
1941                ),
1942                true,
1943            ),
1944            Field::new(
1945                "b",
1946                DataType::Dictionary(
1947                    Box::new(DataType::Int8),
1948                    Box::new(DataType::Decimal128(10, 0)),
1949                ),
1950                true,
1951            ),
1952        ]);
1953
1954        let value = 123;
1955        let decimal_array = Arc::new(create_decimal_array(
1956            &[
1957                Some(value),
1958                Some(value + 2),
1959                Some(value - 1),
1960                Some(value + 1),
1961            ],
1962            10,
1963            0,
1964        ));
1965
1966        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
1967        let a = DictionaryArray::try_new(keys, decimal_array)?;
1968
1969        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
1970        let decimal_array = Arc::new(create_decimal_array(
1971            &[
1972                Some(value + 1),
1973                Some(value + 3),
1974                Some(value),
1975                Some(value + 2),
1976            ],
1977            10,
1978            0,
1979        ));
1980        let b = DictionaryArray::try_new(keys, decimal_array)?;
1981
1982        apply_arithmetic(
1983            Arc::new(schema),
1984            vec![Arc::new(a), Arc::new(b)],
1985            Operator::Minus,
1986            create_decimal_array(&[Some(-1), None, None, Some(1), Some(0)], 11, 0),
1987        )?;
1988
1989        Ok(())
1990    }
1991
1992    #[test]
1993    fn minus_op_scalar() -> Result<()> {
1994        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1995        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1996
1997        apply_arithmetic_scalar(
1998            Arc::new(schema),
1999            vec![Arc::new(a)],
2000            Operator::Minus,
2001            ScalarValue::Int32(Some(1)),
2002            Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4])),
2003        )?;
2004
2005        Ok(())
2006    }
2007
2008    #[test]
2009    fn minus_op_dict_scalar() -> Result<()> {
2010        let schema = Schema::new(vec![Field::new(
2011            "a",
2012            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2013            true,
2014        )]);
2015
2016        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2017
2018        dict_builder.append(1)?;
2019        dict_builder.append_null();
2020        dict_builder.append(2)?;
2021        dict_builder.append(5)?;
2022
2023        let a = dict_builder.finish();
2024
2025        let expected: PrimitiveArray<Int32Type> =
2026            PrimitiveArray::from(vec![Some(0), None, Some(1), Some(4)]);
2027
2028        apply_arithmetic_scalar(
2029            Arc::new(schema),
2030            vec![Arc::new(a)],
2031            Operator::Minus,
2032            ScalarValue::Dictionary(
2033                Box::new(DataType::Int8),
2034                Box::new(ScalarValue::Int32(Some(1))),
2035            ),
2036            Arc::new(expected),
2037        )?;
2038
2039        Ok(())
2040    }
2041
2042    #[test]
2043    fn minus_op_dict_scalar_decimal() -> Result<()> {
2044        let schema = Schema::new(vec![Field::new(
2045            "a",
2046            DataType::Dictionary(
2047                Box::new(DataType::Int8),
2048                Box::new(DataType::Decimal128(10, 0)),
2049            ),
2050            true,
2051        )]);
2052
2053        let value = 123;
2054        let decimal_array = Arc::new(create_decimal_array(
2055            &[Some(value), None, Some(value - 1), Some(value + 1)],
2056            10,
2057            0,
2058        ));
2059
2060        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
2061        let a = DictionaryArray::try_new(keys, decimal_array)?;
2062
2063        let decimal_array = Arc::new(create_decimal_array(
2064            &[
2065                Some(value - 1),
2066                Some(value - 2),
2067                None,
2068                Some(value),
2069                Some(value - 1),
2070            ],
2071            11,
2072            0,
2073        ));
2074
2075        apply_arithmetic_scalar(
2076            Arc::new(schema),
2077            vec![Arc::new(a)],
2078            Operator::Minus,
2079            ScalarValue::Dictionary(
2080                Box::new(DataType::Int8),
2081                Box::new(ScalarValue::Decimal128(Some(1), 10, 0)),
2082            ),
2083            decimal_array,
2084        )?;
2085
2086        Ok(())
2087    }
2088
2089    #[test]
2090    fn multiply_op() -> Result<()> {
2091        let schema = Arc::new(Schema::new(vec![
2092            Field::new("a", DataType::Int32, false),
2093            Field::new("b", DataType::Int32, false),
2094        ]));
2095        let a = Arc::new(Int32Array::from(vec![4, 8, 16, 32, 64]));
2096        let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32]));
2097
2098        apply_arithmetic::<Int32Type>(
2099            schema,
2100            vec![a, b],
2101            Operator::Multiply,
2102            Int32Array::from(vec![8, 32, 128, 512, 2048]),
2103        )?;
2104
2105        Ok(())
2106    }
2107
2108    #[test]
2109    fn multiply_op_dict() -> Result<()> {
2110        let schema = Schema::new(vec![
2111            Field::new(
2112                "a",
2113                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2114                true,
2115            ),
2116            Field::new(
2117                "b",
2118                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2119                true,
2120            ),
2121        ]);
2122
2123        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
2124        let keys = Int8Array::from(vec![Some(0), None, Some(1), Some(3), None]);
2125        let a = DictionaryArray::try_new(keys, Arc::new(a))?;
2126
2127        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
2128        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
2129        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
2130
2131        apply_arithmetic::<Int32Type>(
2132            Arc::new(schema),
2133            vec![Arc::new(a), Arc::new(b)],
2134            Operator::Multiply,
2135            Int32Array::from(vec![Some(1), None, Some(4), Some(16), None]),
2136        )?;
2137
2138        Ok(())
2139    }
2140
2141    #[test]
2142    fn multiply_op_dict_decimal() -> Result<()> {
2143        let schema = Schema::new(vec![
2144            Field::new(
2145                "a",
2146                DataType::Dictionary(
2147                    Box::new(DataType::Int8),
2148                    Box::new(DataType::Decimal128(10, 0)),
2149                ),
2150                true,
2151            ),
2152            Field::new(
2153                "b",
2154                DataType::Dictionary(
2155                    Box::new(DataType::Int8),
2156                    Box::new(DataType::Decimal128(10, 0)),
2157                ),
2158                true,
2159            ),
2160        ]);
2161
2162        let value = 123;
2163        let decimal_array = Arc::new(create_decimal_array(
2164            &[
2165                Some(value),
2166                Some(value + 2),
2167                Some(value - 1),
2168                Some(value + 1),
2169            ],
2170            10,
2171            0,
2172        )) as ArrayRef;
2173
2174        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
2175        let a = DictionaryArray::try_new(keys, decimal_array)?;
2176
2177        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
2178        let decimal_array = Arc::new(create_decimal_array(
2179            &[
2180                Some(value + 1),
2181                Some(value + 3),
2182                Some(value),
2183                Some(value + 2),
2184            ],
2185            10,
2186            0,
2187        ));
2188        let b = DictionaryArray::try_new(keys, decimal_array)?;
2189
2190        apply_arithmetic(
2191            Arc::new(schema),
2192            vec![Arc::new(a), Arc::new(b)],
2193            Operator::Multiply,
2194            create_decimal_array(
2195                &[Some(15252), None, None, Some(15252), Some(15129)],
2196                21,
2197                0,
2198            ),
2199        )?;
2200
2201        Ok(())
2202    }
2203
2204    #[test]
2205    fn multiply_op_scalar() -> Result<()> {
2206        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
2207        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
2208
2209        apply_arithmetic_scalar(
2210            Arc::new(schema),
2211            vec![Arc::new(a)],
2212            Operator::Multiply,
2213            ScalarValue::Int32(Some(2)),
2214            Arc::new(Int32Array::from(vec![2, 4, 6, 8, 10])),
2215        )?;
2216
2217        Ok(())
2218    }
2219
2220    #[test]
2221    fn multiply_op_dict_scalar() -> Result<()> {
2222        let schema = Schema::new(vec![Field::new(
2223            "a",
2224            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2225            true,
2226        )]);
2227
2228        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2229
2230        dict_builder.append(1)?;
2231        dict_builder.append_null();
2232        dict_builder.append(2)?;
2233        dict_builder.append(5)?;
2234
2235        let a = dict_builder.finish();
2236
2237        let expected: PrimitiveArray<Int32Type> =
2238            PrimitiveArray::from(vec![Some(2), None, Some(4), Some(10)]);
2239
2240        apply_arithmetic_scalar(
2241            Arc::new(schema),
2242            vec![Arc::new(a)],
2243            Operator::Multiply,
2244            ScalarValue::Dictionary(
2245                Box::new(DataType::Int8),
2246                Box::new(ScalarValue::Int32(Some(2))),
2247            ),
2248            Arc::new(expected),
2249        )?;
2250
2251        Ok(())
2252    }
2253
2254    #[test]
2255    fn multiply_op_dict_scalar_decimal() -> Result<()> {
2256        let schema = Schema::new(vec![Field::new(
2257            "a",
2258            DataType::Dictionary(
2259                Box::new(DataType::Int8),
2260                Box::new(DataType::Decimal128(10, 0)),
2261            ),
2262            true,
2263        )]);
2264
2265        let value = 123;
2266        let decimal_array = Arc::new(create_decimal_array(
2267            &[Some(value), None, Some(value - 1), Some(value + 1)],
2268            10,
2269            0,
2270        ));
2271
2272        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
2273        let a = DictionaryArray::try_new(keys, decimal_array)?;
2274
2275        let decimal_array = Arc::new(create_decimal_array(
2276            &[Some(246), Some(244), None, Some(248), Some(246)],
2277            21,
2278            0,
2279        ));
2280
2281        apply_arithmetic_scalar(
2282            Arc::new(schema),
2283            vec![Arc::new(a)],
2284            Operator::Multiply,
2285            ScalarValue::Dictionary(
2286                Box::new(DataType::Int8),
2287                Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
2288            ),
2289            decimal_array,
2290        )?;
2291
2292        Ok(())
2293    }
2294
2295    #[test]
2296    fn divide_op() -> Result<()> {
2297        let schema = Arc::new(Schema::new(vec![
2298            Field::new("a", DataType::Int32, false),
2299            Field::new("b", DataType::Int32, false),
2300        ]));
2301        let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048]));
2302        let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32]));
2303
2304        apply_arithmetic::<Int32Type>(
2305            schema,
2306            vec![a, b],
2307            Operator::Divide,
2308            Int32Array::from(vec![4, 8, 16, 32, 64]),
2309        )?;
2310
2311        Ok(())
2312    }
2313
2314    #[test]
2315    fn divide_op_dict() -> Result<()> {
2316        let schema = Schema::new(vec![
2317            Field::new(
2318                "a",
2319                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2320                true,
2321            ),
2322            Field::new(
2323                "b",
2324                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2325                true,
2326            ),
2327        ]);
2328
2329        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2330
2331        dict_builder.append(1)?;
2332        dict_builder.append_null();
2333        dict_builder.append(2)?;
2334        dict_builder.append(5)?;
2335        dict_builder.append(0)?;
2336
2337        let a = dict_builder.finish();
2338
2339        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
2340        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
2341        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
2342
2343        apply_arithmetic::<Int32Type>(
2344            Arc::new(schema),
2345            vec![Arc::new(a), Arc::new(b)],
2346            Operator::Divide,
2347            Int32Array::from(vec![Some(1), None, Some(1), Some(1), Some(0)]),
2348        )?;
2349
2350        Ok(())
2351    }
2352
2353    #[test]
2354    fn divide_op_dict_decimal() -> Result<()> {
2355        let schema = Schema::new(vec![
2356            Field::new(
2357                "a",
2358                DataType::Dictionary(
2359                    Box::new(DataType::Int8),
2360                    Box::new(DataType::Decimal128(10, 0)),
2361                ),
2362                true,
2363            ),
2364            Field::new(
2365                "b",
2366                DataType::Dictionary(
2367                    Box::new(DataType::Int8),
2368                    Box::new(DataType::Decimal128(10, 0)),
2369                ),
2370                true,
2371            ),
2372        ]);
2373
2374        let value = 123;
2375        let decimal_array = Arc::new(create_decimal_array(
2376            &[
2377                Some(value),
2378                Some(value + 2),
2379                Some(value - 1),
2380                Some(value + 1),
2381            ],
2382            10,
2383            0,
2384        ));
2385
2386        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
2387        let a = DictionaryArray::try_new(keys, decimal_array)?;
2388
2389        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
2390        let decimal_array = Arc::new(create_decimal_array(
2391            &[
2392                Some(value + 1),
2393                Some(value + 3),
2394                Some(value),
2395                Some(value + 2),
2396            ],
2397            10,
2398            0,
2399        ));
2400        let b = DictionaryArray::try_new(keys, decimal_array)?;
2401
2402        apply_arithmetic(
2403            Arc::new(schema),
2404            vec![Arc::new(a), Arc::new(b)],
2405            Operator::Divide,
2406            create_decimal_array(
2407                &[
2408                    Some(9919), // 0.9919
2409                    None,
2410                    None,
2411                    Some(10081), // 1.0081
2412                    Some(10000), // 1.0
2413                ],
2414                14,
2415                4,
2416            ),
2417        )?;
2418
2419        Ok(())
2420    }
2421
2422    #[test]
2423    fn divide_op_scalar() -> Result<()> {
2424        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
2425        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
2426
2427        apply_arithmetic_scalar(
2428            Arc::new(schema),
2429            vec![Arc::new(a)],
2430            Operator::Divide,
2431            ScalarValue::Int32(Some(2)),
2432            Arc::new(Int32Array::from(vec![0, 1, 1, 2, 2])),
2433        )?;
2434
2435        Ok(())
2436    }
2437
2438    #[test]
2439    fn divide_op_dict_scalar() -> Result<()> {
2440        let schema = Schema::new(vec![Field::new(
2441            "a",
2442            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2443            true,
2444        )]);
2445
2446        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2447
2448        dict_builder.append(1)?;
2449        dict_builder.append_null();
2450        dict_builder.append(2)?;
2451        dict_builder.append(5)?;
2452
2453        let a = dict_builder.finish();
2454
2455        let expected: PrimitiveArray<Int32Type> =
2456            PrimitiveArray::from(vec![Some(0), None, Some(1), Some(2)]);
2457
2458        apply_arithmetic_scalar(
2459            Arc::new(schema),
2460            vec![Arc::new(a)],
2461            Operator::Divide,
2462            ScalarValue::Dictionary(
2463                Box::new(DataType::Int8),
2464                Box::new(ScalarValue::Int32(Some(2))),
2465            ),
2466            Arc::new(expected),
2467        )?;
2468
2469        Ok(())
2470    }
2471
2472    #[test]
2473    fn divide_op_dict_scalar_decimal() -> Result<()> {
2474        let schema = Schema::new(vec![Field::new(
2475            "a",
2476            DataType::Dictionary(
2477                Box::new(DataType::Int8),
2478                Box::new(DataType::Decimal128(10, 0)),
2479            ),
2480            true,
2481        )]);
2482
2483        let value = 123;
2484        let decimal_array = Arc::new(create_decimal_array(
2485            &[Some(value), None, Some(value - 1), Some(value + 1)],
2486            10,
2487            0,
2488        ));
2489
2490        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
2491        let a = DictionaryArray::try_new(keys, decimal_array)?;
2492
2493        let decimal_array = Arc::new(create_decimal_array(
2494            &[Some(615000), Some(610000), None, Some(620000), Some(615000)],
2495            14,
2496            4,
2497        ));
2498
2499        apply_arithmetic_scalar(
2500            Arc::new(schema),
2501            vec![Arc::new(a)],
2502            Operator::Divide,
2503            ScalarValue::Dictionary(
2504                Box::new(DataType::Int8),
2505                Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
2506            ),
2507            decimal_array,
2508        )?;
2509
2510        Ok(())
2511    }
2512
2513    #[test]
2514    fn modulus_op() -> Result<()> {
2515        let schema = Arc::new(Schema::new(vec![
2516            Field::new("a", DataType::Int32, false),
2517            Field::new("b", DataType::Int32, false),
2518        ]));
2519        let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048]));
2520        let b = Arc::new(Int32Array::from(vec![2, 4, 7, 14, 32]));
2521
2522        apply_arithmetic::<Int32Type>(
2523            schema,
2524            vec![a, b],
2525            Operator::Modulo,
2526            Int32Array::from(vec![0, 0, 2, 8, 0]),
2527        )?;
2528
2529        Ok(())
2530    }
2531
2532    #[test]
2533    fn modulus_op_dict() -> Result<()> {
2534        let schema = Schema::new(vec![
2535            Field::new(
2536                "a",
2537                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2538                true,
2539            ),
2540            Field::new(
2541                "b",
2542                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2543                true,
2544            ),
2545        ]);
2546
2547        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2548
2549        dict_builder.append(1)?;
2550        dict_builder.append_null();
2551        dict_builder.append(2)?;
2552        dict_builder.append(5)?;
2553        dict_builder.append(0)?;
2554
2555        let a = dict_builder.finish();
2556
2557        let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
2558        let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
2559        let b = DictionaryArray::try_new(keys, Arc::new(b))?;
2560
2561        apply_arithmetic::<Int32Type>(
2562            Arc::new(schema),
2563            vec![Arc::new(a), Arc::new(b)],
2564            Operator::Modulo,
2565            Int32Array::from(vec![Some(0), None, Some(0), Some(1), Some(0)]),
2566        )?;
2567
2568        Ok(())
2569    }
2570
2571    #[test]
2572    fn modulus_op_dict_decimal() -> Result<()> {
2573        let schema = Schema::new(vec![
2574            Field::new(
2575                "a",
2576                DataType::Dictionary(
2577                    Box::new(DataType::Int8),
2578                    Box::new(DataType::Decimal128(10, 0)),
2579                ),
2580                true,
2581            ),
2582            Field::new(
2583                "b",
2584                DataType::Dictionary(
2585                    Box::new(DataType::Int8),
2586                    Box::new(DataType::Decimal128(10, 0)),
2587                ),
2588                true,
2589            ),
2590        ]);
2591
2592        let value = 123;
2593        let decimal_array = Arc::new(create_decimal_array(
2594            &[
2595                Some(value),
2596                Some(value + 2),
2597                Some(value - 1),
2598                Some(value + 1),
2599            ],
2600            10,
2601            0,
2602        ));
2603
2604        let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
2605        let a = DictionaryArray::try_new(keys, decimal_array)?;
2606
2607        let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
2608        let decimal_array = Arc::new(create_decimal_array(
2609            &[
2610                Some(value + 1),
2611                Some(value + 3),
2612                Some(value),
2613                Some(value + 2),
2614            ],
2615            10,
2616            0,
2617        ));
2618        let b = DictionaryArray::try_new(keys, decimal_array)?;
2619
2620        apply_arithmetic(
2621            Arc::new(schema),
2622            vec![Arc::new(a), Arc::new(b)],
2623            Operator::Modulo,
2624            create_decimal_array(&[Some(123), None, None, Some(1), Some(0)], 10, 0),
2625        )?;
2626
2627        Ok(())
2628    }
2629
2630    #[test]
2631    fn modulus_op_scalar() -> Result<()> {
2632        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
2633        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
2634
2635        apply_arithmetic_scalar(
2636            Arc::new(schema),
2637            vec![Arc::new(a)],
2638            Operator::Modulo,
2639            ScalarValue::Int32(Some(2)),
2640            Arc::new(Int32Array::from(vec![1, 0, 1, 0, 1])),
2641        )?;
2642
2643        Ok(())
2644    }
2645
2646    #[test]
2647    fn modules_op_dict_scalar() -> Result<()> {
2648        let schema = Schema::new(vec![Field::new(
2649            "a",
2650            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
2651            true,
2652        )]);
2653
2654        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
2655
2656        dict_builder.append(1)?;
2657        dict_builder.append_null();
2658        dict_builder.append(2)?;
2659        dict_builder.append(5)?;
2660
2661        let a = dict_builder.finish();
2662
2663        let expected: PrimitiveArray<Int32Type> =
2664            PrimitiveArray::from(vec![Some(1), None, Some(0), Some(1)]);
2665
2666        apply_arithmetic_scalar(
2667            Arc::new(schema),
2668            vec![Arc::new(a)],
2669            Operator::Modulo,
2670            ScalarValue::Dictionary(
2671                Box::new(DataType::Int8),
2672                Box::new(ScalarValue::Int32(Some(2))),
2673            ),
2674            Arc::new(expected),
2675        )?;
2676
2677        Ok(())
2678    }
2679
2680    #[test]
2681    fn modulus_op_dict_scalar_decimal() -> Result<()> {
2682        let schema = Schema::new(vec![Field::new(
2683            "a",
2684            DataType::Dictionary(
2685                Box::new(DataType::Int8),
2686                Box::new(DataType::Decimal128(10, 0)),
2687            ),
2688            true,
2689        )]);
2690
2691        let value = 123;
2692        let decimal_array = Arc::new(create_decimal_array(
2693            &[Some(value), None, Some(value - 1), Some(value + 1)],
2694            10,
2695            0,
2696        ));
2697
2698        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
2699        let a = DictionaryArray::try_new(keys, decimal_array)?;
2700
2701        let decimal_array = Arc::new(create_decimal_array(
2702            &[Some(1), Some(0), None, Some(0), Some(1)],
2703            10,
2704            0,
2705        ));
2706
2707        apply_arithmetic_scalar(
2708            Arc::new(schema),
2709            vec![Arc::new(a)],
2710            Operator::Modulo,
2711            ScalarValue::Dictionary(
2712                Box::new(DataType::Int8),
2713                Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
2714            ),
2715            decimal_array,
2716        )?;
2717
2718        Ok(())
2719    }
2720
2721    fn apply_arithmetic<T: ArrowNumericType>(
2722        schema: SchemaRef,
2723        data: Vec<ArrayRef>,
2724        op: Operator,
2725        expected: PrimitiveArray<T>,
2726    ) -> Result<()> {
2727        let arithmetic_op =
2728            binary_op(col("a", &schema)?, op, col("b", &schema)?, &schema)?;
2729        let batch = RecordBatch::try_new(schema, data)?;
2730        let result = arithmetic_op
2731            .evaluate(&batch)?
2732            .into_array(batch.num_rows())
2733            .expect("Failed to convert to array");
2734
2735        assert_eq!(result.as_ref(), &expected);
2736        Ok(())
2737    }
2738
2739    fn apply_arithmetic_scalar(
2740        schema: SchemaRef,
2741        data: Vec<ArrayRef>,
2742        op: Operator,
2743        literal: ScalarValue,
2744        expected: ArrayRef,
2745    ) -> Result<()> {
2746        let lit = Arc::new(Literal::new(literal));
2747        let arithmetic_op = binary_op(col("a", &schema)?, op, lit, &schema)?;
2748        let batch = RecordBatch::try_new(schema, data)?;
2749        let result = arithmetic_op
2750            .evaluate(&batch)?
2751            .into_array(batch.num_rows())
2752            .expect("Failed to convert to array");
2753
2754        assert_eq!(&result, &expected);
2755        Ok(())
2756    }
2757
2758    fn apply_logic_op(
2759        schema: &SchemaRef,
2760        left: &ArrayRef,
2761        right: &ArrayRef,
2762        op: Operator,
2763        expected: BooleanArray,
2764    ) -> Result<()> {
2765        let op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?;
2766        let data: Vec<ArrayRef> = vec![Arc::clone(left), Arc::clone(right)];
2767        let batch = RecordBatch::try_new(Arc::clone(schema), data)?;
2768        let result = op
2769            .evaluate(&batch)?
2770            .into_array(batch.num_rows())
2771            .expect("Failed to convert to array");
2772
2773        assert_eq!(result.as_ref(), &expected);
2774        Ok(())
2775    }
2776
2777    // Test `scalar <op> arr` produces expected
2778    fn apply_logic_op_scalar_arr(
2779        schema: &SchemaRef,
2780        scalar: &ScalarValue,
2781        arr: &ArrayRef,
2782        op: Operator,
2783        expected: &BooleanArray,
2784    ) -> Result<()> {
2785        let scalar = lit(scalar.clone());
2786        let op = binary_op(scalar, op, col("a", schema)?, schema)?;
2787        let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
2788        let result = op
2789            .evaluate(&batch)?
2790            .into_array(batch.num_rows())
2791            .expect("Failed to convert to array");
2792        assert_eq!(result.as_ref(), expected);
2793
2794        Ok(())
2795    }
2796
2797    // Test `arr <op> scalar` produces expected
2798    fn apply_logic_op_arr_scalar(
2799        schema: &SchemaRef,
2800        arr: &ArrayRef,
2801        scalar: &ScalarValue,
2802        op: Operator,
2803        expected: &BooleanArray,
2804    ) -> Result<()> {
2805        let scalar = lit(scalar.clone());
2806        let op = binary_op(col("a", schema)?, op, scalar, schema)?;
2807        let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
2808        let result = op
2809            .evaluate(&batch)?
2810            .into_array(batch.num_rows())
2811            .expect("Failed to convert to array");
2812        assert_eq!(result.as_ref(), expected);
2813
2814        Ok(())
2815    }
2816
2817    #[test]
2818    fn and_with_nulls_op() -> Result<()> {
2819        let schema = Schema::new(vec![
2820            Field::new("a", DataType::Boolean, true),
2821            Field::new("b", DataType::Boolean, true),
2822        ]);
2823        let a = Arc::new(BooleanArray::from(vec![
2824            Some(true),
2825            Some(false),
2826            None,
2827            Some(true),
2828            Some(false),
2829            None,
2830            Some(true),
2831            Some(false),
2832            None,
2833        ])) as ArrayRef;
2834        let b = Arc::new(BooleanArray::from(vec![
2835            Some(true),
2836            Some(true),
2837            Some(true),
2838            Some(false),
2839            Some(false),
2840            Some(false),
2841            None,
2842            None,
2843            None,
2844        ])) as ArrayRef;
2845
2846        let expected = BooleanArray::from(vec![
2847            Some(true),
2848            Some(false),
2849            None,
2850            Some(false),
2851            Some(false),
2852            Some(false),
2853            None,
2854            Some(false),
2855            None,
2856        ]);
2857        apply_logic_op(&Arc::new(schema), &a, &b, Operator::And, expected)?;
2858
2859        Ok(())
2860    }
2861
2862    #[test]
2863    fn regex_with_nulls() -> Result<()> {
2864        let schema = Schema::new(vec![
2865            Field::new("a", DataType::Utf8, true),
2866            Field::new("b", DataType::Utf8, true),
2867        ]);
2868        let a = Arc::new(StringArray::from(vec![
2869            Some("abc"),
2870            None,
2871            Some("abc"),
2872            None,
2873            Some("abc"),
2874        ])) as ArrayRef;
2875        let b = Arc::new(StringArray::from(vec![
2876            Some("^a"),
2877            Some("^A"),
2878            None,
2879            None,
2880            Some("^(b|c)"),
2881        ])) as ArrayRef;
2882
2883        let regex_expected =
2884            BooleanArray::from(vec![Some(true), None, None, None, Some(false)]);
2885        let regex_not_expected =
2886            BooleanArray::from(vec![Some(false), None, None, None, Some(true)]);
2887        apply_logic_op(
2888            &Arc::new(schema.clone()),
2889            &a,
2890            &b,
2891            Operator::RegexMatch,
2892            regex_expected.clone(),
2893        )?;
2894        apply_logic_op(
2895            &Arc::new(schema.clone()),
2896            &a,
2897            &b,
2898            Operator::RegexIMatch,
2899            regex_expected.clone(),
2900        )?;
2901        apply_logic_op(
2902            &Arc::new(schema.clone()),
2903            &a,
2904            &b,
2905            Operator::RegexNotMatch,
2906            regex_not_expected.clone(),
2907        )?;
2908        apply_logic_op(
2909            &Arc::new(schema),
2910            &a,
2911            &b,
2912            Operator::RegexNotIMatch,
2913            regex_not_expected.clone(),
2914        )?;
2915
2916        let schema = Schema::new(vec![
2917            Field::new("a", DataType::LargeUtf8, true),
2918            Field::new("b", DataType::LargeUtf8, true),
2919        ]);
2920        let a = Arc::new(LargeStringArray::from(vec![
2921            Some("abc"),
2922            None,
2923            Some("abc"),
2924            None,
2925            Some("abc"),
2926        ])) as ArrayRef;
2927        let b = Arc::new(LargeStringArray::from(vec![
2928            Some("^a"),
2929            Some("^A"),
2930            None,
2931            None,
2932            Some("^(b|c)"),
2933        ])) as ArrayRef;
2934
2935        apply_logic_op(
2936            &Arc::new(schema.clone()),
2937            &a,
2938            &b,
2939            Operator::RegexMatch,
2940            regex_expected.clone(),
2941        )?;
2942        apply_logic_op(
2943            &Arc::new(schema.clone()),
2944            &a,
2945            &b,
2946            Operator::RegexIMatch,
2947            regex_expected,
2948        )?;
2949        apply_logic_op(
2950            &Arc::new(schema.clone()),
2951            &a,
2952            &b,
2953            Operator::RegexNotMatch,
2954            regex_not_expected.clone(),
2955        )?;
2956        apply_logic_op(
2957            &Arc::new(schema),
2958            &a,
2959            &b,
2960            Operator::RegexNotIMatch,
2961            regex_not_expected,
2962        )?;
2963
2964        Ok(())
2965    }
2966
2967    #[test]
2968    fn or_with_nulls_op() -> Result<()> {
2969        let schema = Schema::new(vec![
2970            Field::new("a", DataType::Boolean, true),
2971            Field::new("b", DataType::Boolean, true),
2972        ]);
2973        let a = Arc::new(BooleanArray::from(vec![
2974            Some(true),
2975            Some(false),
2976            None,
2977            Some(true),
2978            Some(false),
2979            None,
2980            Some(true),
2981            Some(false),
2982            None,
2983        ])) as ArrayRef;
2984        let b = Arc::new(BooleanArray::from(vec![
2985            Some(true),
2986            Some(true),
2987            Some(true),
2988            Some(false),
2989            Some(false),
2990            Some(false),
2991            None,
2992            None,
2993            None,
2994        ])) as ArrayRef;
2995
2996        let expected = BooleanArray::from(vec![
2997            Some(true),
2998            Some(true),
2999            Some(true),
3000            Some(true),
3001            Some(false),
3002            None,
3003            Some(true),
3004            None,
3005            None,
3006        ]);
3007        apply_logic_op(&Arc::new(schema), &a, &b, Operator::Or, expected)?;
3008
3009        Ok(())
3010    }
3011
3012    /// Returns (schema, a: BooleanArray, b: BooleanArray) with all possible inputs
3013    ///
3014    /// a: [true, true, true,  NULL, NULL, NULL,  false, false, false]
3015    /// b: [true, NULL, false, true, NULL, false, true,  NULL,  false]
3016    fn bool_test_arrays() -> (SchemaRef, ArrayRef, ArrayRef) {
3017        let schema = Schema::new(vec![
3018            Field::new("a", DataType::Boolean, true),
3019            Field::new("b", DataType::Boolean, true),
3020        ]);
3021        let a: BooleanArray = [
3022            Some(true),
3023            Some(true),
3024            Some(true),
3025            None,
3026            None,
3027            None,
3028            Some(false),
3029            Some(false),
3030            Some(false),
3031        ]
3032        .iter()
3033        .collect();
3034        let b: BooleanArray = [
3035            Some(true),
3036            None,
3037            Some(false),
3038            Some(true),
3039            None,
3040            Some(false),
3041            Some(true),
3042            None,
3043            Some(false),
3044        ]
3045        .iter()
3046        .collect();
3047        (Arc::new(schema), Arc::new(a), Arc::new(b))
3048    }
3049
3050    /// Returns (schema, BooleanArray) with [true, NULL, false]
3051    fn scalar_bool_test_array() -> (SchemaRef, ArrayRef) {
3052        let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
3053        let a: BooleanArray = [Some(true), None, Some(false)].iter().collect();
3054        (Arc::new(schema), Arc::new(a))
3055    }
3056
3057    #[test]
3058    fn eq_op_bool() {
3059        let (schema, a, b) = bool_test_arrays();
3060        let expected = [
3061            Some(true),
3062            None,
3063            Some(false),
3064            None,
3065            None,
3066            None,
3067            Some(false),
3068            None,
3069            Some(true),
3070        ]
3071        .iter()
3072        .collect();
3073        apply_logic_op(&schema, &a, &b, Operator::Eq, expected).unwrap();
3074    }
3075
3076    #[test]
3077    fn eq_op_bool_scalar() {
3078        let (schema, a) = scalar_bool_test_array();
3079        let expected = [Some(true), None, Some(false)].iter().collect();
3080        apply_logic_op_scalar_arr(
3081            &schema,
3082            &ScalarValue::from(true),
3083            &a,
3084            Operator::Eq,
3085            &expected,
3086        )
3087        .unwrap();
3088        apply_logic_op_arr_scalar(
3089            &schema,
3090            &a,
3091            &ScalarValue::from(true),
3092            Operator::Eq,
3093            &expected,
3094        )
3095        .unwrap();
3096
3097        let expected = [Some(false), None, Some(true)].iter().collect();
3098        apply_logic_op_scalar_arr(
3099            &schema,
3100            &ScalarValue::from(false),
3101            &a,
3102            Operator::Eq,
3103            &expected,
3104        )
3105        .unwrap();
3106        apply_logic_op_arr_scalar(
3107            &schema,
3108            &a,
3109            &ScalarValue::from(false),
3110            Operator::Eq,
3111            &expected,
3112        )
3113        .unwrap();
3114    }
3115
3116    #[test]
3117    fn neq_op_bool() {
3118        let (schema, a, b) = bool_test_arrays();
3119        let expected = [
3120            Some(false),
3121            None,
3122            Some(true),
3123            None,
3124            None,
3125            None,
3126            Some(true),
3127            None,
3128            Some(false),
3129        ]
3130        .iter()
3131        .collect();
3132        apply_logic_op(&schema, &a, &b, Operator::NotEq, expected).unwrap();
3133    }
3134
3135    #[test]
3136    fn neq_op_bool_scalar() {
3137        let (schema, a) = scalar_bool_test_array();
3138        let expected = [Some(false), None, Some(true)].iter().collect();
3139        apply_logic_op_scalar_arr(
3140            &schema,
3141            &ScalarValue::from(true),
3142            &a,
3143            Operator::NotEq,
3144            &expected,
3145        )
3146        .unwrap();
3147        apply_logic_op_arr_scalar(
3148            &schema,
3149            &a,
3150            &ScalarValue::from(true),
3151            Operator::NotEq,
3152            &expected,
3153        )
3154        .unwrap();
3155
3156        let expected = [Some(true), None, Some(false)].iter().collect();
3157        apply_logic_op_scalar_arr(
3158            &schema,
3159            &ScalarValue::from(false),
3160            &a,
3161            Operator::NotEq,
3162            &expected,
3163        )
3164        .unwrap();
3165        apply_logic_op_arr_scalar(
3166            &schema,
3167            &a,
3168            &ScalarValue::from(false),
3169            Operator::NotEq,
3170            &expected,
3171        )
3172        .unwrap();
3173    }
3174
3175    #[test]
3176    fn lt_op_bool() {
3177        let (schema, a, b) = bool_test_arrays();
3178        let expected = [
3179            Some(false),
3180            None,
3181            Some(false),
3182            None,
3183            None,
3184            None,
3185            Some(true),
3186            None,
3187            Some(false),
3188        ]
3189        .iter()
3190        .collect();
3191        apply_logic_op(&schema, &a, &b, Operator::Lt, expected).unwrap();
3192    }
3193
3194    #[test]
3195    fn lt_op_bool_scalar() {
3196        let (schema, a) = scalar_bool_test_array();
3197        let expected = [Some(false), None, Some(false)].iter().collect();
3198        apply_logic_op_scalar_arr(
3199            &schema,
3200            &ScalarValue::from(true),
3201            &a,
3202            Operator::Lt,
3203            &expected,
3204        )
3205        .unwrap();
3206
3207        let expected = [Some(false), None, Some(true)].iter().collect();
3208        apply_logic_op_arr_scalar(
3209            &schema,
3210            &a,
3211            &ScalarValue::from(true),
3212            Operator::Lt,
3213            &expected,
3214        )
3215        .unwrap();
3216
3217        let expected = [Some(true), None, Some(false)].iter().collect();
3218        apply_logic_op_scalar_arr(
3219            &schema,
3220            &ScalarValue::from(false),
3221            &a,
3222            Operator::Lt,
3223            &expected,
3224        )
3225        .unwrap();
3226
3227        let expected = [Some(false), None, Some(false)].iter().collect();
3228        apply_logic_op_arr_scalar(
3229            &schema,
3230            &a,
3231            &ScalarValue::from(false),
3232            Operator::Lt,
3233            &expected,
3234        )
3235        .unwrap();
3236    }
3237
3238    #[test]
3239    fn lt_eq_op_bool() {
3240        let (schema, a, b) = bool_test_arrays();
3241        let expected = [
3242            Some(true),
3243            None,
3244            Some(false),
3245            None,
3246            None,
3247            None,
3248            Some(true),
3249            None,
3250            Some(true),
3251        ]
3252        .iter()
3253        .collect();
3254        apply_logic_op(&schema, &a, &b, Operator::LtEq, expected).unwrap();
3255    }
3256
3257    #[test]
3258    fn lt_eq_op_bool_scalar() {
3259        let (schema, a) = scalar_bool_test_array();
3260        let expected = [Some(true), None, Some(false)].iter().collect();
3261        apply_logic_op_scalar_arr(
3262            &schema,
3263            &ScalarValue::from(true),
3264            &a,
3265            Operator::LtEq,
3266            &expected,
3267        )
3268        .unwrap();
3269
3270        let expected = [Some(true), None, Some(true)].iter().collect();
3271        apply_logic_op_arr_scalar(
3272            &schema,
3273            &a,
3274            &ScalarValue::from(true),
3275            Operator::LtEq,
3276            &expected,
3277        )
3278        .unwrap();
3279
3280        let expected = [Some(true), None, Some(true)].iter().collect();
3281        apply_logic_op_scalar_arr(
3282            &schema,
3283            &ScalarValue::from(false),
3284            &a,
3285            Operator::LtEq,
3286            &expected,
3287        )
3288        .unwrap();
3289
3290        let expected = [Some(false), None, Some(true)].iter().collect();
3291        apply_logic_op_arr_scalar(
3292            &schema,
3293            &a,
3294            &ScalarValue::from(false),
3295            Operator::LtEq,
3296            &expected,
3297        )
3298        .unwrap();
3299    }
3300
3301    #[test]
3302    fn gt_op_bool() {
3303        let (schema, a, b) = bool_test_arrays();
3304        let expected = [
3305            Some(false),
3306            None,
3307            Some(true),
3308            None,
3309            None,
3310            None,
3311            Some(false),
3312            None,
3313            Some(false),
3314        ]
3315        .iter()
3316        .collect();
3317        apply_logic_op(&schema, &a, &b, Operator::Gt, expected).unwrap();
3318    }
3319
3320    #[test]
3321    fn gt_op_bool_scalar() {
3322        let (schema, a) = scalar_bool_test_array();
3323        let expected = [Some(false), None, Some(true)].iter().collect();
3324        apply_logic_op_scalar_arr(
3325            &schema,
3326            &ScalarValue::from(true),
3327            &a,
3328            Operator::Gt,
3329            &expected,
3330        )
3331        .unwrap();
3332
3333        let expected = [Some(false), None, Some(false)].iter().collect();
3334        apply_logic_op_arr_scalar(
3335            &schema,
3336            &a,
3337            &ScalarValue::from(true),
3338            Operator::Gt,
3339            &expected,
3340        )
3341        .unwrap();
3342
3343        let expected = [Some(false), None, Some(false)].iter().collect();
3344        apply_logic_op_scalar_arr(
3345            &schema,
3346            &ScalarValue::from(false),
3347            &a,
3348            Operator::Gt,
3349            &expected,
3350        )
3351        .unwrap();
3352
3353        let expected = [Some(true), None, Some(false)].iter().collect();
3354        apply_logic_op_arr_scalar(
3355            &schema,
3356            &a,
3357            &ScalarValue::from(false),
3358            Operator::Gt,
3359            &expected,
3360        )
3361        .unwrap();
3362    }
3363
3364    #[test]
3365    fn gt_eq_op_bool() {
3366        let (schema, a, b) = bool_test_arrays();
3367        let expected = [
3368            Some(true),
3369            None,
3370            Some(true),
3371            None,
3372            None,
3373            None,
3374            Some(false),
3375            None,
3376            Some(true),
3377        ]
3378        .iter()
3379        .collect();
3380        apply_logic_op(&schema, &a, &b, Operator::GtEq, expected).unwrap();
3381    }
3382
3383    #[test]
3384    fn gt_eq_op_bool_scalar() {
3385        let (schema, a) = scalar_bool_test_array();
3386        let expected = [Some(true), None, Some(true)].iter().collect();
3387        apply_logic_op_scalar_arr(
3388            &schema,
3389            &ScalarValue::from(true),
3390            &a,
3391            Operator::GtEq,
3392            &expected,
3393        )
3394        .unwrap();
3395
3396        let expected = [Some(true), None, Some(false)].iter().collect();
3397        apply_logic_op_arr_scalar(
3398            &schema,
3399            &a,
3400            &ScalarValue::from(true),
3401            Operator::GtEq,
3402            &expected,
3403        )
3404        .unwrap();
3405
3406        let expected = [Some(false), None, Some(true)].iter().collect();
3407        apply_logic_op_scalar_arr(
3408            &schema,
3409            &ScalarValue::from(false),
3410            &a,
3411            Operator::GtEq,
3412            &expected,
3413        )
3414        .unwrap();
3415
3416        let expected = [Some(true), None, Some(true)].iter().collect();
3417        apply_logic_op_arr_scalar(
3418            &schema,
3419            &a,
3420            &ScalarValue::from(false),
3421            Operator::GtEq,
3422            &expected,
3423        )
3424        .unwrap();
3425    }
3426
3427    #[test]
3428    fn is_distinct_from_op_bool() {
3429        let (schema, a, b) = bool_test_arrays();
3430        let expected = [
3431            Some(false),
3432            Some(true),
3433            Some(true),
3434            Some(true),
3435            Some(false),
3436            Some(true),
3437            Some(true),
3438            Some(true),
3439            Some(false),
3440        ]
3441        .iter()
3442        .collect();
3443        apply_logic_op(&schema, &a, &b, Operator::IsDistinctFrom, expected).unwrap();
3444    }
3445
3446    #[test]
3447    fn is_not_distinct_from_op_bool() {
3448        let (schema, a, b) = bool_test_arrays();
3449        let expected = [
3450            Some(true),
3451            Some(false),
3452            Some(false),
3453            Some(false),
3454            Some(true),
3455            Some(false),
3456            Some(false),
3457            Some(false),
3458            Some(true),
3459        ]
3460        .iter()
3461        .collect();
3462        apply_logic_op(&schema, &a, &b, Operator::IsNotDistinctFrom, expected).unwrap();
3463    }
3464
3465    #[test]
3466    fn relatively_deeply_nested() {
3467        // Reproducer for https://github.com/apache/datafusion/issues/419
3468
3469        // where even relatively shallow binary expressions overflowed
3470        // the stack in debug builds
3471
3472        let input: Vec<_> = vec![1, 2, 3, 4, 5].into_iter().map(Some).collect();
3473        let a: Int32Array = input.iter().collect();
3474
3475        let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(a) as _)]).unwrap();
3476        let schema = batch.schema();
3477
3478        // build a left deep tree ((((a + a) + a) + a ....
3479        let tree_depth: i32 = 100;
3480        let expr = (0..tree_depth)
3481            .map(|_| col("a", schema.as_ref()).unwrap())
3482            .reduce(|l, r| binary(l, Operator::Plus, r, &schema).unwrap())
3483            .unwrap();
3484
3485        let result = expr
3486            .evaluate(&batch)
3487            .expect("evaluation")
3488            .into_array(batch.num_rows())
3489            .expect("Failed to convert to array");
3490
3491        let expected: Int32Array = input
3492            .into_iter()
3493            .map(|i| i.map(|i| i * tree_depth))
3494            .collect();
3495        assert_eq!(result.as_ref(), &expected);
3496    }
3497
3498    fn create_decimal_array(
3499        array: &[Option<i128>],
3500        precision: u8,
3501        scale: i8,
3502    ) -> Decimal128Array {
3503        let mut decimal_builder = Decimal128Builder::with_capacity(array.len());
3504        for value in array.iter().copied() {
3505            decimal_builder.append_option(value)
3506        }
3507        decimal_builder
3508            .finish()
3509            .with_precision_and_scale(precision, scale)
3510            .unwrap()
3511    }
3512
3513    #[test]
3514    fn comparison_dict_decimal_scalar_expr_test() -> Result<()> {
3515        // scalar of decimal compare with dictionary decimal array
3516        let value_i128 = 123;
3517        let decimal_scalar = ScalarValue::Dictionary(
3518            Box::new(DataType::Int8),
3519            Box::new(ScalarValue::Decimal128(Some(value_i128), 25, 3)),
3520        );
3521        let schema = Arc::new(Schema::new(vec![Field::new(
3522            "a",
3523            DataType::Dictionary(
3524                Box::new(DataType::Int8),
3525                Box::new(DataType::Decimal128(25, 3)),
3526            ),
3527            true,
3528        )]));
3529        let decimal_array = Arc::new(create_decimal_array(
3530            &[
3531                Some(value_i128),
3532                None,
3533                Some(value_i128 - 1),
3534                Some(value_i128 + 1),
3535            ],
3536            25,
3537            3,
3538        ));
3539
3540        let keys = Int8Array::from(vec![Some(0), None, Some(2), Some(3)]);
3541        let dictionary =
3542            Arc::new(DictionaryArray::try_new(keys, decimal_array)?) as ArrayRef;
3543
3544        // array = scalar
3545        apply_logic_op_arr_scalar(
3546            &schema,
3547            &dictionary,
3548            &decimal_scalar,
3549            Operator::Eq,
3550            &BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
3551        )
3552        .unwrap();
3553        // array != scalar
3554        apply_logic_op_arr_scalar(
3555            &schema,
3556            &dictionary,
3557            &decimal_scalar,
3558            Operator::NotEq,
3559            &BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
3560        )
3561        .unwrap();
3562        //  array < scalar
3563        apply_logic_op_arr_scalar(
3564            &schema,
3565            &dictionary,
3566            &decimal_scalar,
3567            Operator::Lt,
3568            &BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3569        )
3570        .unwrap();
3571
3572        //  array <= scalar
3573        apply_logic_op_arr_scalar(
3574            &schema,
3575            &dictionary,
3576            &decimal_scalar,
3577            Operator::LtEq,
3578            &BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
3579        )
3580        .unwrap();
3581        // array > scalar
3582        apply_logic_op_arr_scalar(
3583            &schema,
3584            &dictionary,
3585            &decimal_scalar,
3586            Operator::Gt,
3587            &BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
3588        )
3589        .unwrap();
3590
3591        // array >= scalar
3592        apply_logic_op_arr_scalar(
3593            &schema,
3594            &dictionary,
3595            &decimal_scalar,
3596            Operator::GtEq,
3597            &BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3598        )
3599        .unwrap();
3600
3601        Ok(())
3602    }
3603
3604    #[test]
3605    fn comparison_decimal_expr_test() -> Result<()> {
3606        // scalar of decimal compare with decimal array
3607        let value_i128 = 123;
3608        let decimal_scalar = ScalarValue::Decimal128(Some(value_i128), 25, 3);
3609        let schema = Arc::new(Schema::new(vec![Field::new(
3610            "a",
3611            DataType::Decimal128(25, 3),
3612            true,
3613        )]));
3614        let decimal_array = Arc::new(create_decimal_array(
3615            &[
3616                Some(value_i128),
3617                None,
3618                Some(value_i128 - 1),
3619                Some(value_i128 + 1),
3620            ],
3621            25,
3622            3,
3623        )) as ArrayRef;
3624        // array = scalar
3625        apply_logic_op_arr_scalar(
3626            &schema,
3627            &decimal_array,
3628            &decimal_scalar,
3629            Operator::Eq,
3630            &BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
3631        )
3632        .unwrap();
3633        // array != scalar
3634        apply_logic_op_arr_scalar(
3635            &schema,
3636            &decimal_array,
3637            &decimal_scalar,
3638            Operator::NotEq,
3639            &BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
3640        )
3641        .unwrap();
3642        //  array < scalar
3643        apply_logic_op_arr_scalar(
3644            &schema,
3645            &decimal_array,
3646            &decimal_scalar,
3647            Operator::Lt,
3648            &BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3649        )
3650        .unwrap();
3651
3652        //  array <= scalar
3653        apply_logic_op_arr_scalar(
3654            &schema,
3655            &decimal_array,
3656            &decimal_scalar,
3657            Operator::LtEq,
3658            &BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
3659        )
3660        .unwrap();
3661        // array > scalar
3662        apply_logic_op_arr_scalar(
3663            &schema,
3664            &decimal_array,
3665            &decimal_scalar,
3666            Operator::Gt,
3667            &BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
3668        )
3669        .unwrap();
3670
3671        // array >= scalar
3672        apply_logic_op_arr_scalar(
3673            &schema,
3674            &decimal_array,
3675            &decimal_scalar,
3676            Operator::GtEq,
3677            &BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3678        )
3679        .unwrap();
3680
3681        // scalar of different data type with decimal array
3682        let decimal_scalar = ScalarValue::Decimal128(Some(123_456), 10, 3);
3683        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)]));
3684        // scalar == array
3685        apply_logic_op_scalar_arr(
3686            &schema,
3687            &decimal_scalar,
3688            &(Arc::new(Int64Array::from(vec![Some(124), None])) as ArrayRef),
3689            Operator::Eq,
3690            &BooleanArray::from(vec![Some(false), None]),
3691        )
3692        .unwrap();
3693
3694        // array != scalar
3695        apply_logic_op_arr_scalar(
3696            &schema,
3697            &(Arc::new(Int64Array::from(vec![Some(123), None, Some(1)])) as ArrayRef),
3698            &decimal_scalar,
3699            Operator::NotEq,
3700            &BooleanArray::from(vec![Some(true), None, Some(true)]),
3701        )
3702        .unwrap();
3703
3704        // array < scalar
3705        apply_logic_op_arr_scalar(
3706            &schema,
3707            &(Arc::new(Int64Array::from(vec![Some(123), None, Some(124)])) as ArrayRef),
3708            &decimal_scalar,
3709            Operator::Lt,
3710            &BooleanArray::from(vec![Some(true), None, Some(false)]),
3711        )
3712        .unwrap();
3713
3714        // array > scalar
3715        apply_logic_op_arr_scalar(
3716            &schema,
3717            &(Arc::new(Int64Array::from(vec![Some(123), None, Some(124)])) as ArrayRef),
3718            &decimal_scalar,
3719            Operator::Gt,
3720            &BooleanArray::from(vec![Some(false), None, Some(true)]),
3721        )
3722        .unwrap();
3723
3724        let schema =
3725            Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)]));
3726        // array == scalar
3727        apply_logic_op_arr_scalar(
3728            &schema,
3729            &(Arc::new(Float64Array::from(vec![Some(123.456), None, Some(123.457)]))
3730                as ArrayRef),
3731            &decimal_scalar,
3732            Operator::Eq,
3733            &BooleanArray::from(vec![Some(true), None, Some(false)]),
3734        )
3735        .unwrap();
3736
3737        // array <= scalar
3738        apply_logic_op_arr_scalar(
3739            &schema,
3740            &(Arc::new(Float64Array::from(vec![
3741                Some(123.456),
3742                None,
3743                Some(123.457),
3744                Some(123.45),
3745            ])) as ArrayRef),
3746            &decimal_scalar,
3747            Operator::LtEq,
3748            &BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3749        )
3750        .unwrap();
3751        // array >= scalar
3752        apply_logic_op_arr_scalar(
3753            &schema,
3754            &(Arc::new(Float64Array::from(vec![
3755                Some(123.456),
3756                None,
3757                Some(123.457),
3758                Some(123.45),
3759            ])) as ArrayRef),
3760            &decimal_scalar,
3761            Operator::GtEq,
3762            &BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
3763        )
3764        .unwrap();
3765
3766        let value: i128 = 123;
3767        let decimal_array = Arc::new(create_decimal_array(
3768            &[Some(value), None, Some(value - 1), Some(value + 1)],
3769            10,
3770            0,
3771        )) as ArrayRef;
3772
3773        // comparison array op for decimal array
3774        let schema = Arc::new(Schema::new(vec![
3775            Field::new("a", DataType::Decimal128(10, 0), true),
3776            Field::new("b", DataType::Decimal128(10, 0), true),
3777        ]));
3778        let right_decimal_array = Arc::new(create_decimal_array(
3779            &[
3780                Some(value - 1),
3781                Some(value),
3782                Some(value + 1),
3783                Some(value + 1),
3784            ],
3785            10,
3786            0,
3787        )) as ArrayRef;
3788
3789        apply_logic_op(
3790            &schema,
3791            &decimal_array,
3792            &right_decimal_array,
3793            Operator::Eq,
3794            BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
3795        )
3796        .unwrap();
3797
3798        apply_logic_op(
3799            &schema,
3800            &decimal_array,
3801            &right_decimal_array,
3802            Operator::NotEq,
3803            BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
3804        )
3805        .unwrap();
3806
3807        apply_logic_op(
3808            &schema,
3809            &decimal_array,
3810            &right_decimal_array,
3811            Operator::Lt,
3812            BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3813        )
3814        .unwrap();
3815
3816        apply_logic_op(
3817            &schema,
3818            &decimal_array,
3819            &right_decimal_array,
3820            Operator::LtEq,
3821            BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
3822        )
3823        .unwrap();
3824
3825        apply_logic_op(
3826            &schema,
3827            &decimal_array,
3828            &right_decimal_array,
3829            Operator::Gt,
3830            BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
3831        )
3832        .unwrap();
3833
3834        apply_logic_op(
3835            &schema,
3836            &decimal_array,
3837            &right_decimal_array,
3838            Operator::GtEq,
3839            BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3840        )
3841        .unwrap();
3842
3843        // compare decimal array with other array type
3844        let value: i64 = 123;
3845        let schema = Arc::new(Schema::new(vec![
3846            Field::new("a", DataType::Int64, true),
3847            Field::new("b", DataType::Decimal128(10, 0), true),
3848        ]));
3849
3850        let int64_array = Arc::new(Int64Array::from(vec![
3851            Some(value),
3852            Some(value - 1),
3853            Some(value),
3854            Some(value + 1),
3855        ])) as ArrayRef;
3856
3857        // eq: int64array == decimal array
3858        apply_logic_op(
3859            &schema,
3860            &int64_array,
3861            &decimal_array,
3862            Operator::Eq,
3863            BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3864        )
3865        .unwrap();
3866        // neq: int64array != decimal array
3867        apply_logic_op(
3868            &schema,
3869            &int64_array,
3870            &decimal_array,
3871            Operator::NotEq,
3872            BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3873        )
3874        .unwrap();
3875
3876        let schema = Arc::new(Schema::new(vec![
3877            Field::new("a", DataType::Float64, true),
3878            Field::new("b", DataType::Decimal128(10, 2), true),
3879        ]));
3880
3881        let value: i128 = 123;
3882        let decimal_array = Arc::new(create_decimal_array(
3883            &[
3884                Some(value), // 1.23
3885                None,
3886                Some(value - 1), // 1.22
3887                Some(value + 1), // 1.24
3888            ],
3889            10,
3890            2,
3891        )) as ArrayRef;
3892        let float64_array = Arc::new(Float64Array::from(vec![
3893            Some(1.23),
3894            Some(1.22),
3895            Some(1.23),
3896            Some(1.24),
3897        ])) as ArrayRef;
3898        // lt: float64array < decimal array
3899        apply_logic_op(
3900            &schema,
3901            &float64_array,
3902            &decimal_array,
3903            Operator::Lt,
3904            BooleanArray::from(vec![Some(false), None, Some(false), Some(false)]),
3905        )
3906        .unwrap();
3907        // lt_eq: float64array <= decimal array
3908        apply_logic_op(
3909            &schema,
3910            &float64_array,
3911            &decimal_array,
3912            Operator::LtEq,
3913            BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
3914        )
3915        .unwrap();
3916        // gt: float64array > decimal array
3917        apply_logic_op(
3918            &schema,
3919            &float64_array,
3920            &decimal_array,
3921            Operator::Gt,
3922            BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
3923        )
3924        .unwrap();
3925        apply_logic_op(
3926            &schema,
3927            &float64_array,
3928            &decimal_array,
3929            Operator::GtEq,
3930            BooleanArray::from(vec![Some(true), None, Some(true), Some(true)]),
3931        )
3932        .unwrap();
3933        // is distinct: float64array is distinct decimal array
3934        // TODO: now we do not refactor the `is distinct or is not distinct` rule of coercion.
3935        // traced by https://github.com/apache/datafusion/issues/1590
3936        // the decimal array will be casted to float64array
3937        apply_logic_op(
3938            &schema,
3939            &float64_array,
3940            &decimal_array,
3941            Operator::IsDistinctFrom,
3942            BooleanArray::from(vec![Some(false), Some(true), Some(true), Some(false)]),
3943        )
3944        .unwrap();
3945        // is not distinct
3946        apply_logic_op(
3947            &schema,
3948            &float64_array,
3949            &decimal_array,
3950            Operator::IsNotDistinctFrom,
3951            BooleanArray::from(vec![Some(true), Some(false), Some(false), Some(true)]),
3952        )
3953        .unwrap();
3954
3955        Ok(())
3956    }
3957
3958    fn apply_decimal_arithmetic_op(
3959        schema: &SchemaRef,
3960        left: &ArrayRef,
3961        right: &ArrayRef,
3962        op: Operator,
3963        expected: ArrayRef,
3964    ) -> Result<()> {
3965        let arithmetic_op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?;
3966        let data: Vec<ArrayRef> = vec![Arc::clone(left), Arc::clone(right)];
3967        let batch = RecordBatch::try_new(Arc::clone(schema), data)?;
3968        let result = arithmetic_op
3969            .evaluate(&batch)?
3970            .into_array(batch.num_rows())
3971            .expect("Failed to convert to array");
3972
3973        assert_eq!(result.as_ref(), expected.as_ref());
3974        Ok(())
3975    }
3976
3977    #[test]
3978    fn arithmetic_decimal_expr_test() -> Result<()> {
3979        let schema = Arc::new(Schema::new(vec![
3980            Field::new("a", DataType::Int32, true),
3981            Field::new("b", DataType::Decimal128(10, 2), true),
3982        ]));
3983        let value: i128 = 123;
3984        let decimal_array = Arc::new(create_decimal_array(
3985            &[
3986                Some(value), // 1.23
3987                None,
3988                Some(value - 1), // 1.22
3989                Some(value + 1), // 1.24
3990            ],
3991            10,
3992            2,
3993        )) as ArrayRef;
3994        let int32_array = Arc::new(Int32Array::from(vec![
3995            Some(123),
3996            Some(122),
3997            Some(123),
3998            Some(124),
3999        ])) as ArrayRef;
4000
4001        // add: Int32array add decimal array
4002        let expect = Arc::new(create_decimal_array(
4003            &[Some(12423), None, Some(12422), Some(12524)],
4004            13,
4005            2,
4006        )) as ArrayRef;
4007        apply_decimal_arithmetic_op(
4008            &schema,
4009            &int32_array,
4010            &decimal_array,
4011            Operator::Plus,
4012            expect,
4013        )
4014        .unwrap();
4015
4016        // subtract: decimal array subtract int32 array
4017        let schema = Arc::new(Schema::new(vec![
4018            Field::new("a", DataType::Decimal128(10, 2), true),
4019            Field::new("b", DataType::Int32, true),
4020        ]));
4021        let expect = Arc::new(create_decimal_array(
4022            &[Some(-12177), None, Some(-12178), Some(-12276)],
4023            13,
4024            2,
4025        )) as ArrayRef;
4026        apply_decimal_arithmetic_op(
4027            &schema,
4028            &decimal_array,
4029            &int32_array,
4030            Operator::Minus,
4031            expect,
4032        )
4033        .unwrap();
4034
4035        // multiply: decimal array multiply int32 array
4036        let expect = Arc::new(create_decimal_array(
4037            &[Some(15129), None, Some(15006), Some(15376)],
4038            21,
4039            2,
4040        )) as ArrayRef;
4041        apply_decimal_arithmetic_op(
4042            &schema,
4043            &decimal_array,
4044            &int32_array,
4045            Operator::Multiply,
4046            expect,
4047        )
4048        .unwrap();
4049
4050        // divide: int32 array divide decimal array
4051        let schema = Arc::new(Schema::new(vec![
4052            Field::new("a", DataType::Int32, true),
4053            Field::new("b", DataType::Decimal128(10, 2), true),
4054        ]));
4055        let expect = Arc::new(create_decimal_array(
4056            &[Some(1000000), None, Some(1008196), Some(1000000)],
4057            16,
4058            4,
4059        )) as ArrayRef;
4060        apply_decimal_arithmetic_op(
4061            &schema,
4062            &int32_array,
4063            &decimal_array,
4064            Operator::Divide,
4065            expect,
4066        )
4067        .unwrap();
4068
4069        // modulus: int32 array modulus decimal array
4070        let schema = Arc::new(Schema::new(vec![
4071            Field::new("a", DataType::Int32, true),
4072            Field::new("b", DataType::Decimal128(10, 2), true),
4073        ]));
4074        let expect = Arc::new(create_decimal_array(
4075            &[Some(000), None, Some(100), Some(000)],
4076            10,
4077            2,
4078        )) as ArrayRef;
4079        apply_decimal_arithmetic_op(
4080            &schema,
4081            &int32_array,
4082            &decimal_array,
4083            Operator::Modulo,
4084            expect,
4085        )
4086        .unwrap();
4087
4088        Ok(())
4089    }
4090
4091    #[test]
4092    fn arithmetic_decimal_float_expr_test() -> Result<()> {
4093        let schema = Arc::new(Schema::new(vec![
4094            Field::new("a", DataType::Float64, true),
4095            Field::new("b", DataType::Decimal128(10, 2), true),
4096        ]));
4097        let value: i128 = 123;
4098        let decimal_array = Arc::new(create_decimal_array(
4099            &[
4100                Some(value), // 1.23
4101                None,
4102                Some(value - 1), // 1.22
4103                Some(value + 1), // 1.24
4104            ],
4105            10,
4106            2,
4107        )) as ArrayRef;
4108        let float64_array = Arc::new(Float64Array::from(vec![
4109            Some(123.0),
4110            Some(122.0),
4111            Some(123.0),
4112            Some(124.0),
4113        ])) as ArrayRef;
4114
4115        // add: float64 array add decimal array
4116        let expect = Arc::new(Float64Array::from(vec![
4117            Some(124.23),
4118            None,
4119            Some(124.22),
4120            Some(125.24),
4121        ])) as ArrayRef;
4122        apply_decimal_arithmetic_op(
4123            &schema,
4124            &float64_array,
4125            &decimal_array,
4126            Operator::Plus,
4127            expect,
4128        )
4129        .unwrap();
4130
4131        // subtract: decimal array subtract float64 array
4132        let schema = Arc::new(Schema::new(vec![
4133            Field::new("a", DataType::Float64, true),
4134            Field::new("b", DataType::Decimal128(10, 2), true),
4135        ]));
4136        let expect = Arc::new(Float64Array::from(vec![
4137            Some(121.77),
4138            None,
4139            Some(121.78),
4140            Some(122.76),
4141        ])) as ArrayRef;
4142        apply_decimal_arithmetic_op(
4143            &schema,
4144            &float64_array,
4145            &decimal_array,
4146            Operator::Minus,
4147            expect,
4148        )
4149        .unwrap();
4150
4151        // multiply: decimal array multiply float64 array
4152        let expect = Arc::new(Float64Array::from(vec![
4153            Some(151.29),
4154            None,
4155            Some(150.06),
4156            Some(153.76),
4157        ])) as ArrayRef;
4158        apply_decimal_arithmetic_op(
4159            &schema,
4160            &float64_array,
4161            &decimal_array,
4162            Operator::Multiply,
4163            expect,
4164        )
4165        .unwrap();
4166
4167        // divide: float64 array divide decimal array
4168        let schema = Arc::new(Schema::new(vec![
4169            Field::new("a", DataType::Float64, true),
4170            Field::new("b", DataType::Decimal128(10, 2), true),
4171        ]));
4172        let expect = Arc::new(Float64Array::from(vec![
4173            Some(100.0),
4174            None,
4175            Some(100.81967213114754),
4176            Some(100.0),
4177        ])) as ArrayRef;
4178        apply_decimal_arithmetic_op(
4179            &schema,
4180            &float64_array,
4181            &decimal_array,
4182            Operator::Divide,
4183            expect,
4184        )
4185        .unwrap();
4186
4187        // modulus: float64 array modulus decimal array
4188        let schema = Arc::new(Schema::new(vec![
4189            Field::new("a", DataType::Float64, true),
4190            Field::new("b", DataType::Decimal128(10, 2), true),
4191        ]));
4192        let expect = Arc::new(Float64Array::from(vec![
4193            Some(1.7763568394002505e-15),
4194            None,
4195            Some(1.0000000000000027),
4196            Some(8.881784197001252e-16),
4197        ])) as ArrayRef;
4198        apply_decimal_arithmetic_op(
4199            &schema,
4200            &float64_array,
4201            &decimal_array,
4202            Operator::Modulo,
4203            expect,
4204        )
4205        .unwrap();
4206
4207        Ok(())
4208    }
4209
4210    #[test]
4211    fn arithmetic_divide_zero() -> Result<()> {
4212        // other data type
4213        let schema = Arc::new(Schema::new(vec![
4214            Field::new("a", DataType::Int32, true),
4215            Field::new("b", DataType::Int32, true),
4216        ]));
4217        let a = Arc::new(Int32Array::from(vec![100]));
4218        let b = Arc::new(Int32Array::from(vec![0]));
4219
4220        let err = apply_arithmetic::<Int32Type>(
4221            schema,
4222            vec![a, b],
4223            Operator::Divide,
4224            Int32Array::from(vec![Some(4), Some(8), Some(16), Some(32), Some(64)]),
4225        )
4226        .unwrap_err();
4227
4228        let _expected = plan_datafusion_err!("Divide by zero");
4229
4230        assert!(matches!(err, ref _expected), "{err}");
4231
4232        // decimal
4233        let schema = Arc::new(Schema::new(vec![
4234            Field::new("a", DataType::Decimal128(25, 3), true),
4235            Field::new("b", DataType::Decimal128(25, 3), true),
4236        ]));
4237        let left_decimal_array = Arc::new(create_decimal_array(&[Some(1234567)], 25, 3));
4238        let right_decimal_array = Arc::new(create_decimal_array(&[Some(0)], 25, 3));
4239
4240        let err = apply_arithmetic::<Decimal128Type>(
4241            schema,
4242            vec![left_decimal_array, right_decimal_array],
4243            Operator::Divide,
4244            create_decimal_array(
4245                &[Some(12345670000000000000000000000000000), None],
4246                38,
4247                29,
4248            ),
4249        )
4250        .unwrap_err();
4251
4252        assert!(matches!(err, ref _expected), "{err}");
4253
4254        Ok(())
4255    }
4256
4257    #[test]
4258    fn bitwise_array_test() -> Result<()> {
4259        let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
4260        let right =
4261            Arc::new(Int32Array::from(vec![Some(1), Some(3), Some(7)])) as ArrayRef;
4262        let mut result = bitwise_and_dyn(Arc::clone(&left), Arc::clone(&right))?;
4263        let expected = Int32Array::from(vec![Some(0), None, Some(3)]);
4264        assert_eq!(result.as_ref(), &expected);
4265
4266        result = bitwise_or_dyn(Arc::clone(&left), Arc::clone(&right))?;
4267        let expected = Int32Array::from(vec![Some(13), None, Some(15)]);
4268        assert_eq!(result.as_ref(), &expected);
4269
4270        result = bitwise_xor_dyn(Arc::clone(&left), Arc::clone(&right))?;
4271        let expected = Int32Array::from(vec![Some(13), None, Some(12)]);
4272        assert_eq!(result.as_ref(), &expected);
4273
4274        let left =
4275            Arc::new(UInt32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
4276        let right =
4277            Arc::new(UInt32Array::from(vec![Some(1), Some(3), Some(7)])) as ArrayRef;
4278        let mut result = bitwise_and_dyn(Arc::clone(&left), Arc::clone(&right))?;
4279        let expected = UInt32Array::from(vec![Some(0), None, Some(3)]);
4280        assert_eq!(result.as_ref(), &expected);
4281
4282        result = bitwise_or_dyn(Arc::clone(&left), Arc::clone(&right))?;
4283        let expected = UInt32Array::from(vec![Some(13), None, Some(15)]);
4284        assert_eq!(result.as_ref(), &expected);
4285
4286        result = bitwise_xor_dyn(Arc::clone(&left), Arc::clone(&right))?;
4287        let expected = UInt32Array::from(vec![Some(13), None, Some(12)]);
4288        assert_eq!(result.as_ref(), &expected);
4289
4290        Ok(())
4291    }
4292
4293    #[test]
4294    fn bitwise_shift_array_test() -> Result<()> {
4295        let input = Arc::new(Int32Array::from(vec![Some(2), None, Some(10)])) as ArrayRef;
4296        let modules =
4297            Arc::new(Int32Array::from(vec![Some(2), Some(4), Some(8)])) as ArrayRef;
4298        let mut result =
4299            bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?;
4300
4301        let expected = Int32Array::from(vec![Some(8), None, Some(2560)]);
4302        assert_eq!(result.as_ref(), &expected);
4303
4304        result = bitwise_shift_right_dyn(Arc::clone(&result), Arc::clone(&modules))?;
4305        assert_eq!(result.as_ref(), &input);
4306
4307        let input =
4308            Arc::new(UInt32Array::from(vec![Some(2), None, Some(10)])) as ArrayRef;
4309        let modules =
4310            Arc::new(UInt32Array::from(vec![Some(2), Some(4), Some(8)])) as ArrayRef;
4311        let mut result =
4312            bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?;
4313
4314        let expected = UInt32Array::from(vec![Some(8), None, Some(2560)]);
4315        assert_eq!(result.as_ref(), &expected);
4316
4317        result = bitwise_shift_right_dyn(Arc::clone(&result), Arc::clone(&modules))?;
4318        assert_eq!(result.as_ref(), &input);
4319        Ok(())
4320    }
4321
4322    #[test]
4323    fn bitwise_shift_array_overflow_test() -> Result<()> {
4324        let input = Arc::new(Int32Array::from(vec![Some(2)])) as ArrayRef;
4325        let modules = Arc::new(Int32Array::from(vec![Some(100)])) as ArrayRef;
4326        let result = bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?;
4327
4328        let expected = Int32Array::from(vec![Some(32)]);
4329        assert_eq!(result.as_ref(), &expected);
4330
4331        let input = Arc::new(UInt32Array::from(vec![Some(2)])) as ArrayRef;
4332        let modules = Arc::new(UInt32Array::from(vec![Some(100)])) as ArrayRef;
4333        let result = bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?;
4334
4335        let expected = UInt32Array::from(vec![Some(32)]);
4336        assert_eq!(result.as_ref(), &expected);
4337        Ok(())
4338    }
4339
4340    #[test]
4341    fn bitwise_scalar_test() -> Result<()> {
4342        let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
4343        let right = ScalarValue::from(3i32);
4344        let mut result = bitwise_and_dyn_scalar(&left, right.clone()).unwrap()?;
4345        let expected = Int32Array::from(vec![Some(0), None, Some(3)]);
4346        assert_eq!(result.as_ref(), &expected);
4347
4348        result = bitwise_or_dyn_scalar(&left, right.clone()).unwrap()?;
4349        let expected = Int32Array::from(vec![Some(15), None, Some(11)]);
4350        assert_eq!(result.as_ref(), &expected);
4351
4352        result = bitwise_xor_dyn_scalar(&left, right).unwrap()?;
4353        let expected = Int32Array::from(vec![Some(15), None, Some(8)]);
4354        assert_eq!(result.as_ref(), &expected);
4355
4356        let left =
4357            Arc::new(UInt32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
4358        let right = ScalarValue::from(3u32);
4359        let mut result = bitwise_and_dyn_scalar(&left, right.clone()).unwrap()?;
4360        let expected = UInt32Array::from(vec![Some(0), None, Some(3)]);
4361        assert_eq!(result.as_ref(), &expected);
4362
4363        result = bitwise_or_dyn_scalar(&left, right.clone()).unwrap()?;
4364        let expected = UInt32Array::from(vec![Some(15), None, Some(11)]);
4365        assert_eq!(result.as_ref(), &expected);
4366
4367        result = bitwise_xor_dyn_scalar(&left, right).unwrap()?;
4368        let expected = UInt32Array::from(vec![Some(15), None, Some(8)]);
4369        assert_eq!(result.as_ref(), &expected);
4370        Ok(())
4371    }
4372
4373    #[test]
4374    fn bitwise_shift_scalar_test() -> Result<()> {
4375        let input = Arc::new(Int32Array::from(vec![Some(2), None, Some(4)])) as ArrayRef;
4376        let module = ScalarValue::from(10i32);
4377        let mut result =
4378            bitwise_shift_left_dyn_scalar(&input, module.clone()).unwrap()?;
4379
4380        let expected = Int32Array::from(vec![Some(2048), None, Some(4096)]);
4381        assert_eq!(result.as_ref(), &expected);
4382
4383        result = bitwise_shift_right_dyn_scalar(&result, module).unwrap()?;
4384        assert_eq!(result.as_ref(), &input);
4385
4386        let input = Arc::new(UInt32Array::from(vec![Some(2), None, Some(4)])) as ArrayRef;
4387        let module = ScalarValue::from(10u32);
4388        let mut result =
4389            bitwise_shift_left_dyn_scalar(&input, module.clone()).unwrap()?;
4390
4391        let expected = UInt32Array::from(vec![Some(2048), None, Some(4096)]);
4392        assert_eq!(result.as_ref(), &expected);
4393
4394        result = bitwise_shift_right_dyn_scalar(&result, module).unwrap()?;
4395        assert_eq!(result.as_ref(), &input);
4396        Ok(())
4397    }
4398
4399    #[test]
4400    fn test_display_and_or_combo() {
4401        let expr = BinaryExpr::new(
4402            Arc::new(BinaryExpr::new(
4403                lit(ScalarValue::from(1)),
4404                Operator::And,
4405                lit(ScalarValue::from(2)),
4406            )),
4407            Operator::And,
4408            Arc::new(BinaryExpr::new(
4409                lit(ScalarValue::from(3)),
4410                Operator::And,
4411                lit(ScalarValue::from(4)),
4412            )),
4413        );
4414        assert_eq!(expr.to_string(), "1 AND 2 AND 3 AND 4");
4415
4416        let expr = BinaryExpr::new(
4417            Arc::new(BinaryExpr::new(
4418                lit(ScalarValue::from(1)),
4419                Operator::Or,
4420                lit(ScalarValue::from(2)),
4421            )),
4422            Operator::Or,
4423            Arc::new(BinaryExpr::new(
4424                lit(ScalarValue::from(3)),
4425                Operator::Or,
4426                lit(ScalarValue::from(4)),
4427            )),
4428        );
4429        assert_eq!(expr.to_string(), "1 OR 2 OR 3 OR 4");
4430
4431        let expr = BinaryExpr::new(
4432            Arc::new(BinaryExpr::new(
4433                lit(ScalarValue::from(1)),
4434                Operator::And,
4435                lit(ScalarValue::from(2)),
4436            )),
4437            Operator::Or,
4438            Arc::new(BinaryExpr::new(
4439                lit(ScalarValue::from(3)),
4440                Operator::And,
4441                lit(ScalarValue::from(4)),
4442            )),
4443        );
4444        assert_eq!(expr.to_string(), "1 AND 2 OR 3 AND 4");
4445
4446        let expr = BinaryExpr::new(
4447            Arc::new(BinaryExpr::new(
4448                lit(ScalarValue::from(1)),
4449                Operator::Or,
4450                lit(ScalarValue::from(2)),
4451            )),
4452            Operator::And,
4453            Arc::new(BinaryExpr::new(
4454                lit(ScalarValue::from(3)),
4455                Operator::Or,
4456                lit(ScalarValue::from(4)),
4457            )),
4458        );
4459        assert_eq!(expr.to_string(), "(1 OR 2) AND (3 OR 4)");
4460    }
4461
4462    #[test]
4463    fn test_to_result_type_array() {
4464        let values = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
4465        let keys = Int8Array::from(vec![Some(0), None, Some(2), Some(3)]);
4466        let dictionary =
4467            Arc::new(DictionaryArray::try_new(keys, values).unwrap()) as ArrayRef;
4468
4469        // Casting Dictionary to Int32
4470        let casted = to_result_type_array(
4471            &Operator::Plus,
4472            Arc::clone(&dictionary),
4473            &DataType::Int32,
4474        )
4475        .unwrap();
4476        assert_eq!(
4477            &casted,
4478            &(Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(4)]))
4479                as ArrayRef)
4480        );
4481
4482        // Array has same datatype as result type, no casting
4483        let casted = to_result_type_array(
4484            &Operator::Plus,
4485            Arc::clone(&dictionary),
4486            dictionary.data_type(),
4487        )
4488        .unwrap();
4489        assert_eq!(&casted, &dictionary);
4490
4491        // Not numerical operator, no casting
4492        let casted = to_result_type_array(
4493            &Operator::Eq,
4494            Arc::clone(&dictionary),
4495            &DataType::Int32,
4496        )
4497        .unwrap();
4498        assert_eq!(&casted, &dictionary);
4499    }
4500
4501    #[test]
4502    fn test_add_with_overflow() -> Result<()> {
4503        // create test data
4504        let l = Arc::new(Int32Array::from(vec![1, i32::MAX]));
4505        let r = Arc::new(Int32Array::from(vec![2, 1]));
4506        let schema = Arc::new(Schema::new(vec![
4507            Field::new("l", DataType::Int32, false),
4508            Field::new("r", DataType::Int32, false),
4509        ]));
4510        let batch = RecordBatch::try_new(schema, vec![l, r])?;
4511
4512        // create expression
4513        let expr = BinaryExpr::new(
4514            Arc::new(Column::new("l", 0)),
4515            Operator::Plus,
4516            Arc::new(Column::new("r", 1)),
4517        )
4518        .with_fail_on_overflow(true);
4519
4520        // evaluate expression
4521        let result = expr.evaluate(&batch);
4522        assert!(
4523            result
4524                .err()
4525                .unwrap()
4526                .to_string()
4527                .contains("Overflow happened on: 2147483647 + 1")
4528        );
4529        Ok(())
4530    }
4531
4532    #[test]
4533    fn test_subtract_with_overflow() -> Result<()> {
4534        // create test data
4535        let l = Arc::new(Int32Array::from(vec![1, i32::MIN]));
4536        let r = Arc::new(Int32Array::from(vec![2, 1]));
4537        let schema = Arc::new(Schema::new(vec![
4538            Field::new("l", DataType::Int32, false),
4539            Field::new("r", DataType::Int32, false),
4540        ]));
4541        let batch = RecordBatch::try_new(schema, vec![l, r])?;
4542
4543        // create expression
4544        let expr = BinaryExpr::new(
4545            Arc::new(Column::new("l", 0)),
4546            Operator::Minus,
4547            Arc::new(Column::new("r", 1)),
4548        )
4549        .with_fail_on_overflow(true);
4550
4551        // evaluate expression
4552        let result = expr.evaluate(&batch);
4553        assert!(
4554            result
4555                .err()
4556                .unwrap()
4557                .to_string()
4558                .contains("Overflow happened on: -2147483648 - 1")
4559        );
4560        Ok(())
4561    }
4562
4563    #[test]
4564    fn test_mul_with_overflow() -> Result<()> {
4565        // create test data
4566        let l = Arc::new(Int32Array::from(vec![1, i32::MAX]));
4567        let r = Arc::new(Int32Array::from(vec![2, 2]));
4568        let schema = Arc::new(Schema::new(vec![
4569            Field::new("l", DataType::Int32, false),
4570            Field::new("r", DataType::Int32, false),
4571        ]));
4572        let batch = RecordBatch::try_new(schema, vec![l, r])?;
4573
4574        // create expression
4575        let expr = BinaryExpr::new(
4576            Arc::new(Column::new("l", 0)),
4577            Operator::Multiply,
4578            Arc::new(Column::new("r", 1)),
4579        )
4580        .with_fail_on_overflow(true);
4581
4582        // evaluate expression
4583        let result = expr.evaluate(&batch);
4584        assert!(
4585            result
4586                .err()
4587                .unwrap()
4588                .to_string()
4589                .contains("Overflow happened on: 2147483647 * 2")
4590        );
4591        Ok(())
4592    }
4593
4594    /// Test helper for SIMILAR TO binary operation
4595    fn apply_similar_to(
4596        schema: &SchemaRef,
4597        va: Vec<&str>,
4598        vb: Vec<&str>,
4599        negated: bool,
4600        case_insensitive: bool,
4601        expected: &BooleanArray,
4602    ) -> Result<()> {
4603        let a = StringArray::from(va);
4604        let b = StringArray::from(vb);
4605        let op = similar_to(
4606            negated,
4607            case_insensitive,
4608            col("a", schema)?,
4609            col("b", schema)?,
4610        )?;
4611        let batch =
4612            RecordBatch::try_new(Arc::clone(schema), vec![Arc::new(a), Arc::new(b)])?;
4613        let result = op
4614            .evaluate(&batch)?
4615            .into_array(batch.num_rows())
4616            .expect("Failed to convert to array");
4617        assert_eq!(result.as_ref(), expected);
4618
4619        Ok(())
4620    }
4621
4622    #[test]
4623    fn test_similar_to() {
4624        let schema = Arc::new(Schema::new(vec![
4625            Field::new("a", DataType::Utf8, false),
4626            Field::new("b", DataType::Utf8, false),
4627        ]));
4628
4629        let expected = [Some(true), Some(false)].iter().collect();
4630        // case-sensitive
4631        apply_similar_to(
4632            &schema,
4633            vec!["hello world", "Hello World"],
4634            vec!["hello.*", "hello.*"],
4635            false,
4636            false,
4637            &expected,
4638        )
4639        .unwrap();
4640        // case-insensitive
4641        apply_similar_to(
4642            &schema,
4643            vec!["hello world", "bye"],
4644            vec!["hello.*", "hello.*"],
4645            false,
4646            true,
4647            &expected,
4648        )
4649        .unwrap();
4650    }
4651
4652    pub fn binary_expr(
4653        left: Arc<dyn PhysicalExpr>,
4654        op: Operator,
4655        right: Arc<dyn PhysicalExpr>,
4656        schema: &Schema,
4657    ) -> Result<BinaryExpr> {
4658        Ok(binary_op(left, op, right, schema)?
4659            .as_any()
4660            .downcast_ref::<BinaryExpr>()
4661            .unwrap()
4662            .clone())
4663    }
4664
4665    /// Test for Uniform-Uniform, Unknown-Uniform, Uniform-Unknown and Unknown-Unknown evaluation.
4666    #[test]
4667    fn test_evaluate_statistics_combination_of_range_holders() -> Result<()> {
4668        let schema = &Schema::new(vec![Field::new("a", DataType::Float64, false)]);
4669        let a = Arc::new(Column::new("a", 0)) as _;
4670        let b = lit(ScalarValue::from(12.0));
4671
4672        let left_interval = Interval::make(Some(0.0), Some(12.0))?;
4673        let right_interval = Interval::make(Some(12.0), Some(36.0))?;
4674        let (left_mean, right_mean) = (ScalarValue::from(6.0), ScalarValue::from(24.0));
4675        let (left_med, right_med) = (ScalarValue::from(6.0), ScalarValue::from(24.0));
4676
4677        for children in [
4678            vec![
4679                &Distribution::new_uniform(left_interval.clone())?,
4680                &Distribution::new_uniform(right_interval.clone())?,
4681            ],
4682            vec![
4683                &Distribution::new_generic(
4684                    left_mean.clone(),
4685                    left_med.clone(),
4686                    ScalarValue::Float64(None),
4687                    left_interval.clone(),
4688                )?,
4689                &Distribution::new_uniform(right_interval.clone())?,
4690            ],
4691            vec![
4692                &Distribution::new_uniform(right_interval.clone())?,
4693                &Distribution::new_generic(
4694                    right_mean.clone(),
4695                    right_med.clone(),
4696                    ScalarValue::Float64(None),
4697                    right_interval.clone(),
4698                )?,
4699            ],
4700            vec![
4701                &Distribution::new_generic(
4702                    left_mean.clone(),
4703                    left_med.clone(),
4704                    ScalarValue::Float64(None),
4705                    left_interval.clone(),
4706                )?,
4707                &Distribution::new_generic(
4708                    right_mean.clone(),
4709                    right_med.clone(),
4710                    ScalarValue::Float64(None),
4711                    right_interval.clone(),
4712                )?,
4713            ],
4714        ] {
4715            let ops = vec![
4716                Operator::Plus,
4717                Operator::Minus,
4718                Operator::Multiply,
4719                Operator::Divide,
4720            ];
4721
4722            for op in ops {
4723                let expr = binary_expr(Arc::clone(&a), op, Arc::clone(&b), schema)?;
4724                assert_eq!(
4725                    expr.evaluate_statistics(&children)?,
4726                    new_generic_from_binary_op(&op, children[0], children[1])?
4727                );
4728            }
4729        }
4730        Ok(())
4731    }
4732
4733    #[test]
4734    fn test_evaluate_statistics_bernoulli() -> Result<()> {
4735        let schema = &Schema::new(vec![
4736            Field::new("a", DataType::Int64, false),
4737            Field::new("b", DataType::Int64, false),
4738        ]);
4739        let a = Arc::new(Column::new("a", 0)) as _;
4740        let b = Arc::new(Column::new("b", 1)) as _;
4741        let eq = Arc::new(binary_expr(
4742            Arc::clone(&a),
4743            Operator::Eq,
4744            Arc::clone(&b),
4745            schema,
4746        )?);
4747        let neq = Arc::new(binary_expr(a, Operator::NotEq, b, schema)?);
4748
4749        let left_stat = &Distribution::new_uniform(Interval::make(Some(0), Some(7))?)?;
4750        let right_stat = &Distribution::new_uniform(Interval::make(Some(4), Some(11))?)?;
4751
4752        // Intervals: [0, 7], [4, 11].
4753        // The intersection is [4, 7], so the probability of equality is 4 / 64 = 1 / 16.
4754        assert_eq!(
4755            eq.evaluate_statistics(&[left_stat, right_stat])?,
4756            Distribution::new_bernoulli(ScalarValue::from(1.0 / 16.0))?
4757        );
4758
4759        // The probability of being distinct is 1 - 1 / 16 = 15 / 16.
4760        assert_eq!(
4761            neq.evaluate_statistics(&[left_stat, right_stat])?,
4762            Distribution::new_bernoulli(ScalarValue::from(15.0 / 16.0))?
4763        );
4764
4765        Ok(())
4766    }
4767
4768    #[test]
4769    fn test_propagate_statistics_combination_of_range_holders_arithmetic() -> Result<()> {
4770        let schema = &Schema::new(vec![Field::new("a", DataType::Float64, false)]);
4771        let a = Arc::new(Column::new("a", 0)) as _;
4772        let b = lit(ScalarValue::from(12.0));
4773
4774        let left_interval = Interval::make(Some(0.0), Some(12.0))?;
4775        let right_interval = Interval::make(Some(12.0), Some(36.0))?;
4776
4777        let parent = Distribution::new_uniform(Interval::make(Some(-432.), Some(432.))?)?;
4778        let children = vec![
4779            vec![
4780                Distribution::new_uniform(left_interval.clone())?,
4781                Distribution::new_uniform(right_interval.clone())?,
4782            ],
4783            vec![
4784                Distribution::new_generic(
4785                    ScalarValue::from(6.),
4786                    ScalarValue::from(6.),
4787                    ScalarValue::Float64(None),
4788                    left_interval.clone(),
4789                )?,
4790                Distribution::new_uniform(right_interval.clone())?,
4791            ],
4792            vec![
4793                Distribution::new_uniform(left_interval.clone())?,
4794                Distribution::new_generic(
4795                    ScalarValue::from(12.),
4796                    ScalarValue::from(12.),
4797                    ScalarValue::Float64(None),
4798                    right_interval.clone(),
4799                )?,
4800            ],
4801            vec![
4802                Distribution::new_generic(
4803                    ScalarValue::from(6.),
4804                    ScalarValue::from(6.),
4805                    ScalarValue::Float64(None),
4806                    left_interval.clone(),
4807                )?,
4808                Distribution::new_generic(
4809                    ScalarValue::from(12.),
4810                    ScalarValue::from(12.),
4811                    ScalarValue::Float64(None),
4812                    right_interval.clone(),
4813                )?,
4814            ],
4815        ];
4816
4817        let ops = vec![
4818            Operator::Plus,
4819            Operator::Minus,
4820            Operator::Multiply,
4821            Operator::Divide,
4822        ];
4823
4824        for child_view in children {
4825            let child_refs = child_view.iter().collect::<Vec<_>>();
4826            for op in &ops {
4827                let expr = binary_expr(Arc::clone(&a), *op, Arc::clone(&b), schema)?;
4828                assert_eq!(
4829                    expr.propagate_statistics(&parent, child_refs.as_slice())?,
4830                    Some(child_view.clone())
4831                );
4832            }
4833        }
4834        Ok(())
4835    }
4836
4837    #[test]
4838    fn test_propagate_statistics_combination_of_range_holders_comparison() -> Result<()> {
4839        let schema = &Schema::new(vec![Field::new("a", DataType::Float64, false)]);
4840        let a = Arc::new(Column::new("a", 0)) as _;
4841        let b = lit(ScalarValue::from(12.0));
4842
4843        let left_interval = Interval::make(Some(0.0), Some(12.0))?;
4844        let right_interval = Interval::make(Some(6.0), Some(18.0))?;
4845
4846        let one = ScalarValue::from(1.0);
4847        let parent = Distribution::new_bernoulli(one)?;
4848        let children = vec![
4849            vec![
4850                Distribution::new_uniform(left_interval.clone())?,
4851                Distribution::new_uniform(right_interval.clone())?,
4852            ],
4853            vec![
4854                Distribution::new_generic(
4855                    ScalarValue::from(6.),
4856                    ScalarValue::from(6.),
4857                    ScalarValue::Float64(None),
4858                    left_interval.clone(),
4859                )?,
4860                Distribution::new_uniform(right_interval.clone())?,
4861            ],
4862            vec![
4863                Distribution::new_uniform(left_interval.clone())?,
4864                Distribution::new_generic(
4865                    ScalarValue::from(12.),
4866                    ScalarValue::from(12.),
4867                    ScalarValue::Float64(None),
4868                    right_interval.clone(),
4869                )?,
4870            ],
4871            vec![
4872                Distribution::new_generic(
4873                    ScalarValue::from(6.),
4874                    ScalarValue::from(6.),
4875                    ScalarValue::Float64(None),
4876                    left_interval.clone(),
4877                )?,
4878                Distribution::new_generic(
4879                    ScalarValue::from(12.),
4880                    ScalarValue::from(12.),
4881                    ScalarValue::Float64(None),
4882                    right_interval.clone(),
4883                )?,
4884            ],
4885        ];
4886
4887        let ops = vec![
4888            Operator::Eq,
4889            Operator::Gt,
4890            Operator::GtEq,
4891            Operator::Lt,
4892            Operator::LtEq,
4893        ];
4894
4895        for child_view in children {
4896            let child_refs = child_view.iter().collect::<Vec<_>>();
4897            for op in &ops {
4898                let expr = binary_expr(Arc::clone(&a), *op, Arc::clone(&b), schema)?;
4899                assert!(
4900                    expr.propagate_statistics(&parent, child_refs.as_slice())?
4901                        .is_some()
4902                );
4903            }
4904        }
4905
4906        Ok(())
4907    }
4908
4909    #[test]
4910    fn test_fmt_sql() -> Result<()> {
4911        let schema = Schema::new(vec![
4912            Field::new("a", DataType::Int32, false),
4913            Field::new("b", DataType::Int32, false),
4914        ]);
4915
4916        // Test basic binary expressions
4917        let simple_expr = binary_expr(
4918            col("a", &schema)?,
4919            Operator::Plus,
4920            col("b", &schema)?,
4921            &schema,
4922        )?;
4923        let display_string = simple_expr.to_string();
4924        assert_eq!(display_string, "a@0 + b@1");
4925        let sql_string = fmt_sql(&simple_expr).to_string();
4926        assert_eq!(sql_string, "a + b");
4927
4928        // Test nested expressions with different operator precedence
4929        let nested_expr = binary_expr(
4930            Arc::new(binary_expr(
4931                col("a", &schema)?,
4932                Operator::Plus,
4933                col("b", &schema)?,
4934                &schema,
4935            )?),
4936            Operator::Multiply,
4937            col("b", &schema)?,
4938            &schema,
4939        )?;
4940        let display_string = nested_expr.to_string();
4941        assert_eq!(display_string, "(a@0 + b@1) * b@1");
4942        let sql_string = fmt_sql(&nested_expr).to_string();
4943        assert_eq!(sql_string, "(a + b) * b");
4944
4945        // Test nested expressions with same operator precedence
4946        let nested_same_prec = binary_expr(
4947            Arc::new(binary_expr(
4948                col("a", &schema)?,
4949                Operator::Plus,
4950                col("b", &schema)?,
4951                &schema,
4952            )?),
4953            Operator::Plus,
4954            col("b", &schema)?,
4955            &schema,
4956        )?;
4957        let display_string = nested_same_prec.to_string();
4958        assert_eq!(display_string, "a@0 + b@1 + b@1");
4959        let sql_string = fmt_sql(&nested_same_prec).to_string();
4960        assert_eq!(sql_string, "a + b + b");
4961
4962        // Test with literals
4963        let lit_expr = binary_expr(
4964            col("a", &schema)?,
4965            Operator::Eq,
4966            lit(ScalarValue::Int32(Some(42))),
4967            &schema,
4968        )?;
4969        let display_string = lit_expr.to_string();
4970        assert_eq!(display_string, "a@0 = 42");
4971        let sql_string = fmt_sql(&lit_expr).to_string();
4972        assert_eq!(sql_string, "a = 42");
4973
4974        Ok(())
4975    }
4976
4977    #[test]
4978    fn test_check_short_circuit() {
4979        // Test with non-nullable arrays
4980        let schema = Arc::new(Schema::new(vec![
4981            Field::new("a", DataType::Int32, false),
4982            Field::new("b", DataType::Int32, false),
4983        ]));
4984        let a_array = Int32Array::from(vec![1, 3, 4, 5, 6]);
4985        let b_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
4986        let batch = RecordBatch::try_new(
4987            Arc::clone(&schema),
4988            vec![Arc::new(a_array), Arc::new(b_array)],
4989        )
4990        .unwrap();
4991
4992        // op: AND left: all false
4993        let left_expr = logical2physical(&logical_col("a").eq(expr_lit(2)), &schema);
4994        let left_value = left_expr.evaluate(&batch).unwrap();
4995        assert!(matches!(
4996            check_short_circuit(&left_value, &Operator::And),
4997            ShortCircuitStrategy::ReturnLeft
4998        ));
4999
5000        // op: AND left: not all false
5001        let left_expr = logical2physical(&logical_col("a").eq(expr_lit(3)), &schema);
5002        let left_value = left_expr.evaluate(&batch).unwrap();
5003        let ColumnarValue::Array(array) = &left_value else {
5004            panic!("Expected ColumnarValue::Array");
5005        };
5006        let ShortCircuitStrategy::PreSelection(value) =
5007            check_short_circuit(&left_value, &Operator::And)
5008        else {
5009            panic!("Expected ShortCircuitStrategy::PreSelection");
5010        };
5011        let expected_boolean_arr: Vec<_> =
5012            as_boolean_array(array).unwrap().iter().collect();
5013        let boolean_arr: Vec<_> = value.iter().collect();
5014        assert_eq!(expected_boolean_arr, boolean_arr);
5015
5016        // op: OR left: all true
5017        let left_expr = logical2physical(&logical_col("a").gt(expr_lit(0)), &schema);
5018        let left_value = left_expr.evaluate(&batch).unwrap();
5019        assert!(matches!(
5020            check_short_circuit(&left_value, &Operator::Or),
5021            ShortCircuitStrategy::ReturnLeft
5022        ));
5023
5024        // op: OR left: not all true
5025        let left_expr: Arc<dyn PhysicalExpr> =
5026            logical2physical(&logical_col("a").gt(expr_lit(2)), &schema);
5027        let left_value = left_expr.evaluate(&batch).unwrap();
5028        assert!(matches!(
5029            check_short_circuit(&left_value, &Operator::Or),
5030            ShortCircuitStrategy::None
5031        ));
5032
5033        // Test with nullable arrays and null values
5034        let schema_nullable = Arc::new(Schema::new(vec![
5035            Field::new("c", DataType::Boolean, true),
5036            Field::new("d", DataType::Boolean, true),
5037        ]));
5038
5039        // Create arrays with null values
5040        let c_array = Arc::new(BooleanArray::from(vec![
5041            Some(true),
5042            Some(false),
5043            None,
5044            Some(true),
5045            None,
5046        ])) as ArrayRef;
5047        let d_array = Arc::new(BooleanArray::from(vec![
5048            Some(false),
5049            Some(true),
5050            Some(false),
5051            None,
5052            Some(true),
5053        ])) as ArrayRef;
5054
5055        let batch_nullable = RecordBatch::try_new(
5056            Arc::clone(&schema_nullable),
5057            vec![Arc::clone(&c_array), Arc::clone(&d_array)],
5058        )
5059        .unwrap();
5060
5061        // Case: Mixed values with nulls - shouldn't short-circuit for AND
5062        let mixed_nulls = logical2physical(&logical_col("c"), &schema_nullable);
5063        let mixed_nulls_value = mixed_nulls.evaluate(&batch_nullable).unwrap();
5064        assert!(matches!(
5065            check_short_circuit(&mixed_nulls_value, &Operator::And),
5066            ShortCircuitStrategy::None
5067        ));
5068
5069        // Case: Mixed values with nulls - shouldn't short-circuit for OR
5070        assert!(matches!(
5071            check_short_circuit(&mixed_nulls_value, &Operator::Or),
5072            ShortCircuitStrategy::None
5073        ));
5074
5075        // Test with all nulls
5076        let all_nulls = Arc::new(BooleanArray::from(vec![None, None, None])) as ArrayRef;
5077        let null_batch = RecordBatch::try_new(
5078            Arc::new(Schema::new(vec![Field::new("e", DataType::Boolean, true)])),
5079            vec![all_nulls],
5080        )
5081        .unwrap();
5082
5083        let null_expr = logical2physical(&logical_col("e"), &null_batch.schema());
5084        let null_value = null_expr.evaluate(&null_batch).unwrap();
5085
5086        // All nulls shouldn't short-circuit for AND or OR
5087        assert!(matches!(
5088            check_short_circuit(&null_value, &Operator::And),
5089            ShortCircuitStrategy::None
5090        ));
5091        assert!(matches!(
5092            check_short_circuit(&null_value, &Operator::Or),
5093            ShortCircuitStrategy::None
5094        ));
5095
5096        // Test with scalar values
5097        // Scalar true
5098        let scalar_true = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true)));
5099        assert!(matches!(
5100            check_short_circuit(&scalar_true, &Operator::Or),
5101            ShortCircuitStrategy::ReturnLeft
5102        )); // Should short-circuit OR
5103        assert!(matches!(
5104            check_short_circuit(&scalar_true, &Operator::And),
5105            ShortCircuitStrategy::ReturnRight
5106        )); // Should return the RHS for AND
5107
5108        // Scalar false
5109        let scalar_false = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false)));
5110        assert!(matches!(
5111            check_short_circuit(&scalar_false, &Operator::And),
5112            ShortCircuitStrategy::ReturnLeft
5113        )); // Should short-circuit AND
5114        assert!(matches!(
5115            check_short_circuit(&scalar_false, &Operator::Or),
5116            ShortCircuitStrategy::ReturnRight
5117        )); // Should return the RHS for OR
5118
5119        // Scalar null
5120        let scalar_null = ColumnarValue::Scalar(ScalarValue::Boolean(None));
5121        assert!(matches!(
5122            check_short_circuit(&scalar_null, &Operator::And),
5123            ShortCircuitStrategy::None
5124        ));
5125        assert!(matches!(
5126            check_short_circuit(&scalar_null, &Operator::Or),
5127            ShortCircuitStrategy::None
5128        ));
5129    }
5130
5131    /// Test for [pre_selection_scatter]
5132    /// Since [check_short_circuit] ensures that the left side does not contain null and is neither all_true nor all_false, as well as not being empty,
5133    /// the following tests have been designed:
5134    /// 1. Test sparse left with interleaved true/false
5135    /// 2. Test multiple consecutive true blocks
5136    /// 3. Test multiple consecutive true blocks
5137    /// 4. Test single true at first position
5138    /// 5. Test single true at last position
5139    /// 6. Test nulls in right array
5140    #[test]
5141    fn test_pre_selection_scatter() {
5142        fn create_bool_array(bools: Vec<bool>) -> BooleanArray {
5143            BooleanArray::from(bools.into_iter().map(Some).collect::<Vec<_>>())
5144        }
5145        // Test sparse left with interleaved true/false
5146        {
5147            // Left: [T, F, T, F, T]
5148            // Right: [F, T, F] (values for 3 true positions)
5149            let left = create_bool_array(vec![true, false, true, false, true]);
5150            let right = create_bool_array(vec![false, true, false]);
5151
5152            let result = pre_selection_scatter(&left, Some(&right)).unwrap();
5153            let result_arr = result.into_array(left.len()).unwrap();
5154
5155            let expected = create_bool_array(vec![false, false, true, false, false]);
5156            assert_eq!(&expected, result_arr.as_boolean());
5157        }
5158        // Test multiple consecutive true blocks
5159        {
5160            // Left: [F, T, T, F, T, T, T]
5161            // Right: [T, F, F, T, F]
5162            let left =
5163                create_bool_array(vec![false, true, true, false, true, true, true]);
5164            let right = create_bool_array(vec![true, false, false, true, false]);
5165
5166            let result = pre_selection_scatter(&left, Some(&right)).unwrap();
5167            let result_arr = result.into_array(left.len()).unwrap();
5168
5169            let expected =
5170                create_bool_array(vec![false, true, false, false, false, true, false]);
5171            assert_eq!(&expected, result_arr.as_boolean());
5172        }
5173        // Test single true at first position
5174        {
5175            // Left: [T, F, F]
5176            // Right: [F]
5177            let left = create_bool_array(vec![true, false, false]);
5178            let right = create_bool_array(vec![false]);
5179
5180            let result = pre_selection_scatter(&left, Some(&right)).unwrap();
5181            let result_arr = result.into_array(left.len()).unwrap();
5182
5183            let expected = create_bool_array(vec![false, false, false]);
5184            assert_eq!(&expected, result_arr.as_boolean());
5185        }
5186        // Test single true at last position
5187        {
5188            // Left: [F, F, T]
5189            // Right: [F]
5190            let left = create_bool_array(vec![false, false, true]);
5191            let right = create_bool_array(vec![false]);
5192
5193            let result = pre_selection_scatter(&left, Some(&right)).unwrap();
5194            let result_arr = result.into_array(left.len()).unwrap();
5195
5196            let expected = create_bool_array(vec![false, false, false]);
5197            assert_eq!(&expected, result_arr.as_boolean());
5198        }
5199        // Test nulls in right array
5200        {
5201            // Left: [F, T, F, T]
5202            // Right: [None, Some(false)] (with null at first position)
5203            let left = create_bool_array(vec![false, true, false, true]);
5204            let right = BooleanArray::from(vec![None, Some(false)]);
5205
5206            let result = pre_selection_scatter(&left, Some(&right)).unwrap();
5207            let result_arr = result.into_array(left.len()).unwrap();
5208
5209            let expected = BooleanArray::from(vec![
5210                Some(false),
5211                None, // null from right
5212                Some(false),
5213                Some(false),
5214            ]);
5215            assert_eq!(&expected, result_arr.as_boolean());
5216        }
5217    }
5218
5219    #[test]
5220    fn test_and_true_preselection_returns_lhs() {
5221        let schema =
5222            Arc::new(Schema::new(vec![Field::new("c", DataType::Boolean, false)]));
5223        let c_array = Arc::new(BooleanArray::from(vec![false, true, false, false, false]))
5224            as ArrayRef;
5225        let batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::clone(&c_array)])
5226            .unwrap();
5227
5228        let expr = logical2physical(&logical_col("c").and(expr_lit(true)), &schema);
5229
5230        let result = expr.evaluate(&batch).unwrap();
5231        let ColumnarValue::Array(result_arr) = result else {
5232            panic!("Expected ColumnarValue::Array");
5233        };
5234
5235        let expected: Vec<_> = c_array.as_boolean().iter().collect();
5236        let actual: Vec<_> = result_arr.as_boolean().iter().collect();
5237        assert_eq!(
5238            expected, actual,
5239            "AND with TRUE must equal LHS even with PreSelection"
5240        );
5241    }
5242
5243    #[test]
5244    fn test_evaluate_bounds_int32() {
5245        let schema = Schema::new(vec![
5246            Field::new("a", DataType::Int32, false),
5247            Field::new("b", DataType::Int32, false),
5248        ]);
5249
5250        let a = Arc::new(Column::new("a", 0)) as _;
5251        let b = Arc::new(Column::new("b", 1)) as _;
5252
5253        // Test addition bounds
5254        let add_expr =
5255            binary_expr(Arc::clone(&a), Operator::Plus, Arc::clone(&b), &schema).unwrap();
5256        let add_bounds = add_expr
5257            .evaluate_bounds(&[
5258                &Interval::make(Some(1), Some(10)).unwrap(),
5259                &Interval::make(Some(5), Some(15)).unwrap(),
5260            ])
5261            .unwrap();
5262        assert_eq!(add_bounds, Interval::make(Some(6), Some(25)).unwrap());
5263
5264        // Test subtraction bounds
5265        let sub_expr =
5266            binary_expr(Arc::clone(&a), Operator::Minus, Arc::clone(&b), &schema)
5267                .unwrap();
5268        let sub_bounds = sub_expr
5269            .evaluate_bounds(&[
5270                &Interval::make(Some(1), Some(10)).unwrap(),
5271                &Interval::make(Some(5), Some(15)).unwrap(),
5272            ])
5273            .unwrap();
5274        assert_eq!(sub_bounds, Interval::make(Some(-14), Some(5)).unwrap());
5275
5276        // Test multiplication bounds
5277        let mul_expr =
5278            binary_expr(Arc::clone(&a), Operator::Multiply, Arc::clone(&b), &schema)
5279                .unwrap();
5280        let mul_bounds = mul_expr
5281            .evaluate_bounds(&[
5282                &Interval::make(Some(1), Some(10)).unwrap(),
5283                &Interval::make(Some(5), Some(15)).unwrap(),
5284            ])
5285            .unwrap();
5286        assert_eq!(mul_bounds, Interval::make(Some(5), Some(150)).unwrap());
5287
5288        // Test division bounds
5289        let div_expr =
5290            binary_expr(Arc::clone(&a), Operator::Divide, Arc::clone(&b), &schema)
5291                .unwrap();
5292        let div_bounds = div_expr
5293            .evaluate_bounds(&[
5294                &Interval::make(Some(10), Some(20)).unwrap(),
5295                &Interval::make(Some(2), Some(5)).unwrap(),
5296            ])
5297            .unwrap();
5298        assert_eq!(div_bounds, Interval::make(Some(2), Some(10)).unwrap());
5299    }
5300
5301    #[test]
5302    fn test_evaluate_bounds_bool() {
5303        let schema = Schema::new(vec![
5304            Field::new("a", DataType::Boolean, false),
5305            Field::new("b", DataType::Boolean, false),
5306        ]);
5307
5308        let a = Arc::new(Column::new("a", 0)) as _;
5309        let b = Arc::new(Column::new("b", 1)) as _;
5310
5311        // Test OR bounds
5312        let or_expr =
5313            binary_expr(Arc::clone(&a), Operator::Or, Arc::clone(&b), &schema).unwrap();
5314        let or_bounds = or_expr
5315            .evaluate_bounds(&[
5316                &Interval::make(Some(true), Some(true)).unwrap(),
5317                &Interval::make(Some(false), Some(false)).unwrap(),
5318            ])
5319            .unwrap();
5320        assert_eq!(or_bounds, Interval::make(Some(true), Some(true)).unwrap());
5321
5322        // Test AND bounds
5323        let and_expr =
5324            binary_expr(Arc::clone(&a), Operator::And, Arc::clone(&b), &schema).unwrap();
5325        let and_bounds = and_expr
5326            .evaluate_bounds(&[
5327                &Interval::make(Some(true), Some(true)).unwrap(),
5328                &Interval::make(Some(false), Some(false)).unwrap(),
5329            ])
5330            .unwrap();
5331        assert_eq!(
5332            and_bounds,
5333            Interval::make(Some(false), Some(false)).unwrap()
5334        );
5335    }
5336
5337    #[test]
5338    fn test_evaluate_nested_type() {
5339        let batch_schema = Arc::new(Schema::new(vec![
5340            Field::new(
5341                "a",
5342                DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
5343                true,
5344            ),
5345            Field::new(
5346                "b",
5347                DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
5348                true,
5349            ),
5350        ]));
5351
5352        let mut list_builder_a = ListBuilder::new(Int32Builder::new());
5353
5354        list_builder_a.append_value([Some(1)]);
5355        list_builder_a.append_value([Some(2)]);
5356        list_builder_a.append_value([]);
5357        list_builder_a.append_value([None]);
5358
5359        let list_array_a: ArrayRef = Arc::new(list_builder_a.finish());
5360
5361        let mut list_builder_b = ListBuilder::new(Int32Builder::new());
5362
5363        list_builder_b.append_value([Some(1)]);
5364        list_builder_b.append_value([Some(2)]);
5365        list_builder_b.append_value([]);
5366        list_builder_b.append_value([None]);
5367
5368        let list_array_b: ArrayRef = Arc::new(list_builder_b.finish());
5369
5370        let batch =
5371            RecordBatch::try_new(batch_schema, vec![list_array_a, list_array_b]).unwrap();
5372
5373        let schema = Arc::new(Schema::new(vec![
5374            Field::new(
5375                "a",
5376                DataType::List(Arc::new(Field::new("foo", DataType::Int32, true))),
5377                true,
5378            ),
5379            Field::new(
5380                "b",
5381                DataType::List(Arc::new(Field::new("bar", DataType::Int32, true))),
5382                true,
5383            ),
5384        ]));
5385
5386        let a = Arc::new(Column::new("a", 0)) as _;
5387        let b = Arc::new(Column::new("b", 1)) as _;
5388
5389        let eq_expr =
5390            binary_expr(Arc::clone(&a), Operator::Eq, Arc::clone(&b), &schema).unwrap();
5391
5392        let eq_result = eq_expr.evaluate(&batch).unwrap();
5393        let expected =
5394            BooleanArray::from_iter(vec![Some(true), Some(true), Some(true), Some(true)]);
5395        assert_eq!(eq_result.into_array(4).unwrap().as_boolean(), &expected);
5396    }
5397}