datafusion_functions/core/
least.rs1use crate::core::greatest_least_utils::GreatestLeastOperator;
19use arrow::array::{Array, BooleanArray, make_comparator};
20use arrow::buffer::BooleanBuffer;
21use arrow::compute::SortOptions;
22use arrow::compute::kernels::cmp;
23use arrow::datatypes::DataType;
24use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err};
25use datafusion_doc::Documentation;
26use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
27use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
28use datafusion_macros::user_doc;
29use std::any::Any;
30
31const SORT_OPTIONS: SortOptions = SortOptions {
32 descending: false,
34
35 nulls_first: false,
37};
38
39#[user_doc(
40 doc_section(label = "Conditional Functions"),
41 description = "Returns the smallest value in a list of expressions. Returns _null_ if all expressions are _null_.",
42 syntax_example = "least(expression1[, ..., expression_n])",
43 sql_example = r#"```sql
44> select least(4, 7, 5);
45+---------------------------+
46| least(4,7,5) |
47+---------------------------+
48| 4 |
49+---------------------------+
50```"#,
51 argument(
52 name = "expression1, expression_n",
53 description = "Expressions to compare and return the smallest value. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary."
54 )
55)]
56#[derive(Debug, PartialEq, Eq, Hash)]
57pub struct LeastFunc {
58 signature: Signature,
59}
60
61impl Default for LeastFunc {
62 fn default() -> Self {
63 LeastFunc::new()
64 }
65}
66
67impl LeastFunc {
68 pub fn new() -> Self {
69 Self {
70 signature: Signature::user_defined(Volatility::Immutable),
71 }
72 }
73}
74
75impl GreatestLeastOperator for LeastFunc {
76 const NAME: &'static str = "least";
77
78 fn keep_scalar<'a>(
79 lhs: &'a ScalarValue,
80 rhs: &'a ScalarValue,
81 ) -> Result<&'a ScalarValue> {
82 if lhs.is_null() {
86 return Ok(rhs);
87 }
88
89 if rhs.is_null() {
90 return Ok(lhs);
91 }
92
93 if !lhs.data_type().is_nested() {
94 return if lhs <= rhs { Ok(lhs) } else { Ok(rhs) };
95 }
96
97 let cmp = make_comparator(
101 lhs.to_array()?.as_ref(),
102 rhs.to_array()?.as_ref(),
103 SORT_OPTIONS,
104 )?;
105
106 if cmp(0, 0).is_le() { Ok(lhs) } else { Ok(rhs) }
107 }
108
109 fn get_indexes_to_keep(lhs: &dyn Array, rhs: &dyn Array) -> Result<BooleanArray> {
112 if !lhs.data_type().is_nested()
117 && lhs.logical_null_count() == 0
118 && rhs.logical_null_count() == 0
119 {
120 return cmp::lt_eq(&lhs, &rhs).map_err(|e| e.into());
121 }
122
123 let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?;
124
125 assert_eq_or_internal_err!(
126 lhs.len(),
127 rhs.len(),
128 "All arrays should have the same length for least comparison"
129 );
130
131 let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_le());
132
133 Ok(BooleanArray::new(values, None))
135 }
136}
137
138impl ScalarUDFImpl for LeastFunc {
139 fn as_any(&self) -> &dyn Any {
140 self
141 }
142
143 fn name(&self) -> &str {
144 "least"
145 }
146
147 fn signature(&self) -> &Signature {
148 &self.signature
149 }
150
151 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
152 Ok(arg_types[0].clone())
153 }
154
155 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
156 super::greatest_least_utils::execute_conditional::<Self>(&args.args)
157 }
158
159 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
160 let coerced_type =
161 super::greatest_least_utils::find_coerced_type::<Self>(arg_types)?;
162
163 Ok(vec![coerced_type; arg_types.len()])
164 }
165
166 fn documentation(&self) -> Option<&Documentation> {
167 self.doc()
168 }
169}
170
171#[cfg(test)]
172mod test {
173 use crate::core::least::LeastFunc;
174 use arrow::datatypes::DataType;
175 use datafusion_expr::ScalarUDFImpl;
176
177 #[test]
178 fn test_least_return_types_without_common_supertype_in_arg_type() {
179 let least = LeastFunc::new();
180 let return_type = least
181 .coerce_types(&[DataType::Decimal128(10, 3), DataType::Decimal128(10, 4)])
182 .unwrap();
183 assert_eq!(
184 return_type,
185 vec![DataType::Decimal128(11, 4), DataType::Decimal128(11, 4)]
186 );
187 }
188}