datafusion_physical_expr_common/
datum.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::BooleanArray;
19use arrow::array::{make_comparator, ArrayRef, Datum};
20use arrow::buffer::NullBuffer;
21use arrow::compute::SortOptions;
22use arrow::error::ArrowError;
23use datafusion_common::DataFusionError;
24use datafusion_common::{arrow_datafusion_err, internal_err};
25use datafusion_common::{Result, ScalarValue};
26use datafusion_expr_common::columnar_value::ColumnarValue;
27use datafusion_expr_common::operator::Operator;
28use std::sync::Arc;
29
30/// Applies a binary [`Datum`] kernel `f` to `lhs` and `rhs`
31///
32/// This maps arrow-rs' [`Datum`] kernels to DataFusion's [`ColumnarValue`] abstraction
33pub fn apply(
34    lhs: &ColumnarValue,
35    rhs: &ColumnarValue,
36    f: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
37) -> Result<ColumnarValue> {
38    match (&lhs, &rhs) {
39        (ColumnarValue::Array(left), ColumnarValue::Array(right)) => {
40            Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?))
41        }
42        (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok(
43            ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?),
44        ),
45        (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok(
46            ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?),
47        ),
48        (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => {
49            let array = f(&left.to_scalar()?, &right.to_scalar()?)?;
50            let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?;
51            Ok(ColumnarValue::Scalar(scalar))
52        }
53    }
54}
55
56/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs`
57pub fn apply_cmp(
58    lhs: &ColumnarValue,
59    rhs: &ColumnarValue,
60    f: impl Fn(&dyn Datum, &dyn Datum) -> Result<BooleanArray, ArrowError>,
61) -> Result<ColumnarValue> {
62    apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?)))
63}
64
65/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` for nested type like
66/// List, FixedSizeList, LargeList, Struct, Union, Map, or a dictionary of a nested type
67pub fn apply_cmp_for_nested(
68    op: Operator,
69    lhs: &ColumnarValue,
70    rhs: &ColumnarValue,
71) -> Result<ColumnarValue> {
72    if matches!(
73        op,
74        Operator::Eq
75            | Operator::NotEq
76            | Operator::Lt
77            | Operator::Gt
78            | Operator::LtEq
79            | Operator::GtEq
80            | Operator::IsDistinctFrom
81            | Operator::IsNotDistinctFrom
82    ) {
83        apply(lhs, rhs, |l, r| {
84            Ok(Arc::new(compare_op_for_nested(op, l, r)?))
85        })
86    } else {
87        internal_err!("invalid operator for nested")
88    }
89}
90
91/// Compare with eq with either nested or non-nested
92pub fn compare_with_eq(
93    lhs: &dyn Datum,
94    rhs: &dyn Datum,
95    is_nested: bool,
96) -> Result<BooleanArray> {
97    if is_nested {
98        compare_op_for_nested(Operator::Eq, lhs, rhs)
99    } else {
100        arrow::compute::kernels::cmp::eq(lhs, rhs).map_err(|e| arrow_datafusion_err!(e))
101    }
102}
103
104/// Compare on nested type List, Struct, and so on
105pub fn compare_op_for_nested(
106    op: Operator,
107    lhs: &dyn Datum,
108    rhs: &dyn Datum,
109) -> Result<BooleanArray> {
110    let (l, is_l_scalar) = lhs.get();
111    let (r, is_r_scalar) = rhs.get();
112    let l_len = l.len();
113    let r_len = r.len();
114
115    if l_len != r_len && !is_l_scalar && !is_r_scalar {
116        return internal_err!("len mismatch");
117    }
118
119    let len = match is_l_scalar {
120        true => r_len,
121        false => l_len,
122    };
123
124    // fast path, if compare with one null and operator is not 'distinct', then we can return null array directly
125    if !matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom)
126        && (is_l_scalar && l.null_count() == 1 || is_r_scalar && r.null_count() == 1)
127    {
128        return Ok(BooleanArray::new_null(len));
129    }
130
131    // TODO: make SortOptions configurable
132    // we choose the default behaviour from arrow-rs which has null-first that follow spark's behaviour
133    let cmp = make_comparator(l, r, SortOptions::default())?;
134
135    let cmp_with_op = |i, j| match op {
136        Operator::Eq | Operator::IsNotDistinctFrom => cmp(i, j).is_eq(),
137        Operator::Lt => cmp(i, j).is_lt(),
138        Operator::Gt => cmp(i, j).is_gt(),
139        Operator::LtEq => !cmp(i, j).is_gt(),
140        Operator::GtEq => !cmp(i, j).is_lt(),
141        Operator::NotEq | Operator::IsDistinctFrom => !cmp(i, j).is_eq(),
142        _ => unreachable!("unexpected operator found"),
143    };
144
145    let values = match (is_l_scalar, is_r_scalar) {
146        (false, false) => (0..len).map(|i| cmp_with_op(i, i)).collect(),
147        (true, false) => (0..len).map(|i| cmp_with_op(0, i)).collect(),
148        (false, true) => (0..len).map(|i| cmp_with_op(i, 0)).collect(),
149        (true, true) => std::iter::once(cmp_with_op(0, 0)).collect(),
150    };
151
152    // Distinct understand how to compare with NULL
153    // i.e NULL is distinct from NULL -> false
154    if matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) {
155        Ok(BooleanArray::new(values, None))
156    } else {
157        // If one of the side is NULL, we returns NULL
158        // i.e. NULL eq NULL -> NULL
159        let nulls = NullBuffer::union(l.nulls(), r.nulls());
160        Ok(BooleanArray::new(values, nulls))
161    }
162}