Skip to main content

datafusion_functions/core/
greatest.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::core::greatest_least_utils::GreatestLeastOperator;
19use arrow::array::{Array, BooleanArray, make_comparator};
20use arrow::buffer::BooleanBuffer;
21use arrow::compute::SortOptions;
22use arrow::compute::kernels::cmp;
23use arrow::datatypes::DataType;
24use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err};
25use datafusion_doc::Documentation;
26use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
27use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
28use datafusion_macros::user_doc;
29
30const SORT_OPTIONS: SortOptions = SortOptions {
31    // We want greatest first
32    descending: false,
33
34    // NULL will be less than any other value
35    nulls_first: true,
36};
37
38#[user_doc(
39    doc_section(label = "Conditional Functions"),
40    description = "Returns the greatest value in a list of expressions. Returns _null_ if all expressions are _null_.",
41    syntax_example = "greatest(expression1[, ..., expression_n])",
42    sql_example = r#"```sql
43> select greatest(4, 7, 5);
44+---------------------------+
45| greatest(4,7,5)           |
46+---------------------------+
47| 7                         |
48+---------------------------+
49```"#,
50    argument(
51        name = "expression1, expression_n",
52        description = "Expressions to compare and return the greatest value.. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary."
53    )
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct GreatestFunc {
57    signature: Signature,
58}
59
60impl Default for GreatestFunc {
61    fn default() -> Self {
62        GreatestFunc::new()
63    }
64}
65
66impl GreatestFunc {
67    pub fn new() -> Self {
68        Self {
69            signature: Signature::user_defined(Volatility::Immutable),
70        }
71    }
72}
73
74impl GreatestLeastOperator for GreatestFunc {
75    const NAME: &'static str = "greatest";
76
77    fn keep_scalar<'a>(
78        lhs: &'a ScalarValue,
79        rhs: &'a ScalarValue,
80    ) -> Result<&'a ScalarValue> {
81        if !lhs.data_type().is_nested() {
82            return if lhs >= rhs { Ok(lhs) } else { Ok(rhs) };
83        }
84
85        // If complex type we can't compare directly as we want null values to be smaller
86        let cmp = make_comparator(
87            lhs.to_array()?.as_ref(),
88            rhs.to_array()?.as_ref(),
89            SORT_OPTIONS,
90        )?;
91
92        if cmp(0, 0).is_ge() { Ok(lhs) } else { Ok(rhs) }
93    }
94
95    /// Return boolean array where `arr[i] = lhs[i] >= rhs[i]` for all i, where `arr` is the result array
96    /// Nulls are always considered smaller than any other value
97    fn get_indexes_to_keep(lhs: &dyn Array, rhs: &dyn Array) -> Result<BooleanArray> {
98        // Fast path:
99        // If both arrays are not nested, have the same length and no nulls, we can use the faster vectorized kernel
100        // - If both arrays are not nested: Nested types, such as lists, are not supported as the null semantics are not well-defined.
101        // - both array does not have any nulls: cmp::gt_eq will return null if any of the input is null while we want to return false in that case
102        if !lhs.data_type().is_nested()
103            && lhs.logical_null_count() == 0
104            && rhs.logical_null_count() == 0
105        {
106            return cmp::gt_eq(&lhs, &rhs).map_err(|e| e.into());
107        }
108
109        let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?;
110
111        assert_eq_or_internal_err!(
112            lhs.len(),
113            rhs.len(),
114            "All arrays should have the same length for greatest comparison"
115        );
116
117        let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_ge());
118
119        // No nulls as we only want to keep the values that are larger, its either true or false
120        Ok(BooleanArray::new(values, None))
121    }
122}
123
124impl ScalarUDFImpl for GreatestFunc {
125    fn name(&self) -> &str {
126        "greatest"
127    }
128
129    fn signature(&self) -> &Signature {
130        &self.signature
131    }
132
133    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
134        Ok(arg_types[0].clone())
135    }
136
137    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
138        super::greatest_least_utils::execute_conditional::<Self>(&args.args)
139    }
140
141    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
142        let coerced_type =
143            super::greatest_least_utils::find_coerced_type::<Self>(arg_types)?;
144
145        Ok(vec![coerced_type; arg_types.len()])
146    }
147
148    fn documentation(&self) -> Option<&Documentation> {
149        self.doc()
150    }
151}
152
153#[cfg(test)]
154mod test {
155    use crate::core;
156    use arrow::datatypes::DataType;
157    use datafusion_expr::ScalarUDFImpl;
158
159    #[test]
160    fn test_greatest_return_types_without_common_supertype_in_arg_type() {
161        let greatest = core::greatest::GreatestFunc::new();
162        let return_type = greatest
163            .coerce_types(&[DataType::Decimal128(10, 3), DataType::Decimal128(10, 4)])
164            .unwrap();
165        assert_eq!(
166            return_type,
167            vec![DataType::Decimal128(11, 4), DataType::Decimal128(11, 4)]
168        );
169    }
170}