datafusion_functions/core/
greatest_least_utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{Array, ArrayRef, BooleanArray};
19use arrow::compute::kernels::zip::zip;
20use arrow::datatypes::DataType;
21use datafusion_common::{internal_err, plan_err, Result, ScalarValue};
22use datafusion_expr_common::columnar_value::ColumnarValue;
23use datafusion_expr_common::type_coercion::binary::type_union_resolution;
24use std::sync::Arc;
25
26pub(super) trait GreatestLeastOperator {
27    const NAME: &'static str;
28
29    fn keep_scalar<'a>(
30        lhs: &'a ScalarValue,
31        rhs: &'a ScalarValue,
32    ) -> Result<&'a ScalarValue>;
33
34    /// Return array with true for values that we should keep from the lhs array
35    fn get_indexes_to_keep(lhs: &dyn Array, rhs: &dyn Array) -> Result<BooleanArray>;
36}
37
38fn keep_array<Op: GreatestLeastOperator>(
39    lhs: ArrayRef,
40    rhs: ArrayRef,
41) -> Result<ArrayRef> {
42    // True for values that we should keep from the left array
43    let keep_lhs = Op::get_indexes_to_keep(lhs.as_ref(), rhs.as_ref())?;
44
45    let result = zip(&keep_lhs, &lhs, &rhs)?;
46
47    Ok(result)
48}
49
50pub(super) fn execute_conditional<Op: GreatestLeastOperator>(
51    args: &[ColumnarValue],
52) -> Result<ColumnarValue> {
53    if args.is_empty() {
54        return internal_err!(
55            "{} was called with no arguments. It requires at least 1.",
56            Op::NAME
57        );
58    }
59
60    // Some engines (e.g. SQL Server) allow greatest/least with single arg, it's a noop
61    if args.len() == 1 {
62        return Ok(args[0].clone());
63    }
64
65    // Split to scalars and arrays for later optimization
66    let (scalars, arrays): (Vec<_>, Vec<_>) = args.iter().partition(|x| match x {
67        ColumnarValue::Scalar(_) => true,
68        ColumnarValue::Array(_) => false,
69    });
70
71    let mut arrays_iter = arrays.iter().map(|x| match x {
72        ColumnarValue::Array(a) => a,
73        _ => unreachable!(),
74    });
75
76    let first_array = arrays_iter.next();
77
78    let mut result: ArrayRef;
79
80    // Optimization: merge all scalars into one to avoid recomputing (constant folding)
81    if !scalars.is_empty() {
82        let mut scalars_iter = scalars.iter().map(|x| match x {
83            ColumnarValue::Scalar(s) => s,
84            _ => unreachable!(),
85        });
86
87        // We have at least one scalar
88        let mut result_scalar = scalars_iter.next().unwrap();
89
90        for scalar in scalars_iter {
91            result_scalar = Op::keep_scalar(result_scalar, scalar)?;
92        }
93
94        // If we only have scalars, return the one that we should keep (largest/least)
95        if arrays.is_empty() {
96            return Ok(ColumnarValue::Scalar(result_scalar.clone()));
97        }
98
99        // We have at least one array
100        let first_array = first_array.unwrap();
101
102        // Start with the result value
103        result = keep_array::<Op>(
104            Arc::clone(first_array),
105            result_scalar.to_array_of_size(first_array.len())?,
106        )?;
107    } else {
108        // If we only have arrays, start with the first array
109        // (We must have at least one array)
110        result = Arc::clone(first_array.unwrap());
111    }
112
113    for array in arrays_iter {
114        result = keep_array::<Op>(Arc::clone(array), result)?;
115    }
116
117    Ok(ColumnarValue::Array(result))
118}
119
120pub(super) fn find_coerced_type<Op: GreatestLeastOperator>(
121    data_types: &[DataType],
122) -> Result<DataType> {
123    if data_types.is_empty() {
124        plan_err!(
125            "{} was called without any arguments. It requires at least 1.",
126            Op::NAME
127        )
128    } else if let Some(coerced_type) = type_union_resolution(data_types) {
129        Ok(coerced_type)
130    } else {
131        plan_err!("Cannot find a common type for arguments")
132    }
133}