datafusion_physical_expr_common/
datum.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::BooleanArray;
19use arrow::array::{ArrayRef, Datum, make_comparator};
20use arrow::buffer::NullBuffer;
21use arrow::compute::kernels::cmp::{
22    distinct, eq, gt, gt_eq, lt, lt_eq, neq, not_distinct,
23};
24use arrow::compute::{SortOptions, ilike, like, nilike, nlike};
25use arrow::error::ArrowError;
26use datafusion_common::{Result, ScalarValue};
27use datafusion_common::{arrow_datafusion_err, assert_or_internal_err, internal_err};
28use datafusion_expr_common::columnar_value::ColumnarValue;
29use datafusion_expr_common::operator::Operator;
30use std::sync::Arc;
31
32/// Applies a binary [`Datum`] kernel `f` to `lhs` and `rhs`
33///
34/// This maps arrow-rs' [`Datum`] kernels to DataFusion's [`ColumnarValue`] abstraction
35pub fn apply(
36    lhs: &ColumnarValue,
37    rhs: &ColumnarValue,
38    f: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
39) -> Result<ColumnarValue> {
40    match (&lhs, &rhs) {
41        (ColumnarValue::Array(left), ColumnarValue::Array(right)) => {
42            Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?))
43        }
44        (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok(
45            ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?),
46        ),
47        (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok(
48            ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?),
49        ),
50        (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => {
51            let array = f(&left.to_scalar()?, &right.to_scalar()?)?;
52            let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?;
53            Ok(ColumnarValue::Scalar(scalar))
54        }
55    }
56}
57
58/// Applies a binary [`Datum`] comparison operator `op` to `lhs` and `rhs`
59pub fn apply_cmp(
60    op: Operator,
61    lhs: &ColumnarValue,
62    rhs: &ColumnarValue,
63) -> Result<ColumnarValue> {
64    if lhs.data_type().is_nested() {
65        apply_cmp_for_nested(op, lhs, rhs)
66    } else {
67        let f = match op {
68            Operator::Eq => eq,
69            Operator::NotEq => neq,
70            Operator::Lt => lt,
71            Operator::LtEq => lt_eq,
72            Operator::Gt => gt,
73            Operator::GtEq => gt_eq,
74            Operator::IsDistinctFrom => distinct,
75            Operator::IsNotDistinctFrom => not_distinct,
76
77            Operator::LikeMatch => like,
78            Operator::ILikeMatch => ilike,
79            Operator::NotLikeMatch => nlike,
80            Operator::NotILikeMatch => nilike,
81
82            _ => {
83                return internal_err!("Invalid compare operator: {}", op);
84            }
85        };
86
87        apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?)))
88    }
89}
90
91/// Applies a binary [`Datum`] comparison operator `op` to `lhs` and `rhs` for nested type like
92/// List, FixedSizeList, LargeList, Struct, Union, Map, or a dictionary of a nested type
93pub fn apply_cmp_for_nested(
94    op: Operator,
95    lhs: &ColumnarValue,
96    rhs: &ColumnarValue,
97) -> Result<ColumnarValue> {
98    let left_data_type = lhs.data_type();
99    let right_data_type = rhs.data_type();
100
101    assert_or_internal_err!(
102        matches!(
103            op,
104            Operator::Eq
105                | Operator::NotEq
106                | Operator::Lt
107                | Operator::Gt
108                | Operator::LtEq
109                | Operator::GtEq
110                | Operator::IsDistinctFrom
111                | Operator::IsNotDistinctFrom
112        ) && left_data_type.equals_datatype(&right_data_type),
113        "invalid operator or data type mismatch for nested data, op {op} left {left_data_type}, right {right_data_type}",
114    );
115
116    apply(lhs, rhs, |l, r| {
117        Ok(Arc::new(compare_op_for_nested(op, l, r)?))
118    })
119}
120
121/// Compare with eq with either nested or non-nested
122pub fn compare_with_eq(
123    lhs: &dyn Datum,
124    rhs: &dyn Datum,
125    is_nested: bool,
126) -> Result<BooleanArray> {
127    if is_nested {
128        compare_op_for_nested(Operator::Eq, lhs, rhs)
129    } else {
130        eq(lhs, rhs).map_err(|e| arrow_datafusion_err!(e))
131    }
132}
133
134/// Compare on nested type List, Struct, and so on
135pub fn compare_op_for_nested(
136    op: Operator,
137    lhs: &dyn Datum,
138    rhs: &dyn Datum,
139) -> Result<BooleanArray> {
140    let (l, is_l_scalar) = lhs.get();
141    let (r, is_r_scalar) = rhs.get();
142    let l_len = l.len();
143    let r_len = r.len();
144
145    assert_or_internal_err!(l_len == r_len || is_l_scalar || is_r_scalar, "len mismatch");
146
147    let len = match is_l_scalar {
148        true => r_len,
149        false => l_len,
150    };
151
152    // fast path, if compare with one null and operator is not 'distinct', then we can return null array directly
153    if !matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom)
154        && (is_l_scalar && l.null_count() == 1 || is_r_scalar && r.null_count() == 1)
155    {
156        return Ok(BooleanArray::new_null(len));
157    }
158
159    // TODO: make SortOptions configurable
160    // we choose the default behaviour from arrow-rs which has null-first that follow spark's behaviour
161    let cmp = make_comparator(l, r, SortOptions::default())?;
162
163    let cmp_with_op = |i, j| match op {
164        Operator::Eq | Operator::IsNotDistinctFrom => cmp(i, j).is_eq(),
165        Operator::Lt => cmp(i, j).is_lt(),
166        Operator::Gt => cmp(i, j).is_gt(),
167        Operator::LtEq => !cmp(i, j).is_gt(),
168        Operator::GtEq => !cmp(i, j).is_lt(),
169        Operator::NotEq | Operator::IsDistinctFrom => !cmp(i, j).is_eq(),
170        _ => unreachable!("unexpected operator found"),
171    };
172
173    let values = match (is_l_scalar, is_r_scalar) {
174        (false, false) => (0..len).map(|i| cmp_with_op(i, i)).collect(),
175        (true, false) => (0..len).map(|i| cmp_with_op(0, i)).collect(),
176        (false, true) => (0..len).map(|i| cmp_with_op(i, 0)).collect(),
177        (true, true) => std::iter::once(cmp_with_op(0, 0)).collect(),
178    };
179
180    // Distinct understand how to compare with NULL
181    // i.e NULL is distinct from NULL -> false
182    if matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) {
183        Ok(BooleanArray::new(values, None))
184    } else {
185        // If one of the side is NULL, we return NULL
186        // i.e. NULL eq NULL -> NULL
187        // For nested comparisons, we need to ensure the null buffer matches the result length
188        let nulls = match (is_l_scalar, is_r_scalar) {
189            (false, false) | (true, true) => NullBuffer::union(l.nulls(), r.nulls()),
190            (true, false) => {
191                // When left is null-scalar and right is array, expand left nulls to match result length
192                match l.nulls().filter(|nulls| !nulls.is_valid(0)) {
193                    Some(_) => Some(NullBuffer::new_null(len)), // Left scalar is null
194                    None => r.nulls().cloned(),                 // Left scalar is non-null
195                }
196            }
197            (false, true) => {
198                // When right is null-scalar and left is array, expand right nulls to match result length
199                match r.nulls().filter(|nulls| !nulls.is_valid(0)) {
200                    Some(_) => Some(NullBuffer::new_null(len)), // Right scalar is null
201                    None => l.nulls().cloned(), // Right scalar is non-null
202                }
203            }
204        };
205        Ok(BooleanArray::new(values, nulls))
206    }
207}