datafusion_functions/core/
least.rs1use crate::core::greatest_least_utils::GreatestLeastOperator;
19use arrow::array::{Array, BooleanArray, make_comparator};
20use arrow::buffer::BooleanBuffer;
21use arrow::compute::SortOptions;
22use arrow::compute::kernels::cmp;
23use arrow::datatypes::DataType;
24use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err};
25use datafusion_doc::Documentation;
26use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
27use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
28use datafusion_macros::user_doc;
29
30const SORT_OPTIONS: SortOptions = SortOptions {
31 descending: false,
33
34 nulls_first: false,
36};
37
38#[user_doc(
39 doc_section(label = "Conditional Functions"),
40 description = "Returns the smallest value in a list of expressions. Returns _null_ if all expressions are _null_.",
41 syntax_example = "least(expression1[, ..., expression_n])",
42 sql_example = r#"```sql
43> select least(4, 7, 5);
44+---------------------------+
45| least(4,7,5) |
46+---------------------------+
47| 4 |
48+---------------------------+
49```"#,
50 argument(
51 name = "expression1, expression_n",
52 description = "Expressions to compare and return the smallest value. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary."
53 )
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct LeastFunc {
57 signature: Signature,
58}
59
60impl Default for LeastFunc {
61 fn default() -> Self {
62 LeastFunc::new()
63 }
64}
65
66impl LeastFunc {
67 pub fn new() -> Self {
68 Self {
69 signature: Signature::user_defined(Volatility::Immutable),
70 }
71 }
72}
73
74impl GreatestLeastOperator for LeastFunc {
75 const NAME: &'static str = "least";
76
77 fn keep_scalar<'a>(
78 lhs: &'a ScalarValue,
79 rhs: &'a ScalarValue,
80 ) -> Result<&'a ScalarValue> {
81 if lhs.is_null() {
85 return Ok(rhs);
86 }
87
88 if rhs.is_null() {
89 return Ok(lhs);
90 }
91
92 if !lhs.data_type().is_nested() {
93 return if lhs <= rhs { Ok(lhs) } else { Ok(rhs) };
94 }
95
96 let cmp = make_comparator(
100 lhs.to_array()?.as_ref(),
101 rhs.to_array()?.as_ref(),
102 SORT_OPTIONS,
103 )?;
104
105 if cmp(0, 0).is_le() { Ok(lhs) } else { Ok(rhs) }
106 }
107
108 fn get_indexes_to_keep(lhs: &dyn Array, rhs: &dyn Array) -> Result<BooleanArray> {
111 if !lhs.data_type().is_nested()
116 && lhs.logical_null_count() == 0
117 && rhs.logical_null_count() == 0
118 {
119 return cmp::lt_eq(&lhs, &rhs).map_err(|e| e.into());
120 }
121
122 let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?;
123
124 assert_eq_or_internal_err!(
125 lhs.len(),
126 rhs.len(),
127 "All arrays should have the same length for least comparison"
128 );
129
130 let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_le());
131
132 Ok(BooleanArray::new(values, None))
134 }
135}
136
137impl ScalarUDFImpl for LeastFunc {
138 fn name(&self) -> &str {
139 "least"
140 }
141
142 fn signature(&self) -> &Signature {
143 &self.signature
144 }
145
146 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
147 Ok(arg_types[0].clone())
148 }
149
150 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
151 super::greatest_least_utils::execute_conditional::<Self>(&args.args)
152 }
153
154 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
155 let coerced_type =
156 super::greatest_least_utils::find_coerced_type::<Self>(arg_types)?;
157
158 Ok(vec![coerced_type; arg_types.len()])
159 }
160
161 fn documentation(&self) -> Option<&Documentation> {
162 self.doc()
163 }
164}
165
166#[cfg(test)]
167mod test {
168 use crate::core::least::LeastFunc;
169 use arrow::datatypes::DataType;
170 use datafusion_expr::ScalarUDFImpl;
171
172 #[test]
173 fn test_least_return_types_without_common_supertype_in_arg_type() {
174 let least = LeastFunc::new();
175 let return_type = least
176 .coerce_types(&[DataType::Decimal128(10, 3), DataType::Decimal128(10, 4)])
177 .unwrap();
178 assert_eq!(
179 return_type,
180 vec![DataType::Decimal128(11, 4), DataType::Decimal128(11, 4)]
181 );
182 }
183}