datafusion_functions_nested/
sort.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_sort function.
19
20use crate::utils::make_scalar_function;
21use arrow::array::{new_null_array, Array, ArrayRef, ListArray, NullBufferBuilder};
22use arrow::buffer::OffsetBuffer;
23use arrow::compute::SortColumn;
24use arrow::datatypes::{DataType, Field};
25use arrow::{compute, compute::SortOptions};
26use datafusion_common::cast::{as_list_array, as_string_array};
27use datafusion_common::utils::ListCoercion;
28use datafusion_common::{exec_err, plan_err, Result};
29use datafusion_expr::{
30    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
31    ScalarUDFImpl, Signature, TypeSignature, Volatility,
32};
33use datafusion_macros::user_doc;
34use std::any::Any;
35use std::sync::Arc;
36
37make_udf_expr_and_func!(
38    ArraySort,
39    array_sort,
40    array desc null_first,
41    "returns sorted array.",
42    array_sort_udf
43);
44
45/// Implementation of `array_sort` function
46///
47/// `array_sort` sorts the elements of an array
48///
49/// # Example
50///
51/// `array_sort([3, 1, 2])` returns `[1, 2, 3]`
52#[user_doc(
53    doc_section(label = "Array Functions"),
54    description = "Sort array.",
55    syntax_example = "array_sort(array, desc, nulls_first)",
56    sql_example = r#"```sql
57> select array_sort([3, 1, 2]);
58+-----------------------------+
59| array_sort(List([3,1,2]))   |
60+-----------------------------+
61| [1, 2, 3]                   |
62+-----------------------------+
63```"#,
64    argument(
65        name = "array",
66        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
67    ),
68    argument(
69        name = "desc",
70        description = "Whether to sort in descending order(`ASC` or `DESC`)."
71    ),
72    argument(
73        name = "nulls_first",
74        description = "Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`)."
75    )
76)]
77#[derive(Debug)]
78pub struct ArraySort {
79    signature: Signature,
80    aliases: Vec<String>,
81}
82
83impl Default for ArraySort {
84    fn default() -> Self {
85        Self::new()
86    }
87}
88
89impl ArraySort {
90    pub fn new() -> Self {
91        Self {
92            signature: Signature::one_of(
93                vec![
94                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
95                        arguments: vec![ArrayFunctionArgument::Array],
96                        array_coercion: Some(ListCoercion::FixedSizedListToList),
97                    }),
98                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
99                        arguments: vec![
100                            ArrayFunctionArgument::Array,
101                            ArrayFunctionArgument::String,
102                        ],
103                        array_coercion: Some(ListCoercion::FixedSizedListToList),
104                    }),
105                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
106                        arguments: vec![
107                            ArrayFunctionArgument::Array,
108                            ArrayFunctionArgument::String,
109                            ArrayFunctionArgument::String,
110                        ],
111                        array_coercion: Some(ListCoercion::FixedSizedListToList),
112                    }),
113                ],
114                Volatility::Immutable,
115            ),
116            aliases: vec!["list_sort".to_string()],
117        }
118    }
119}
120
121impl ScalarUDFImpl for ArraySort {
122    fn as_any(&self) -> &dyn Any {
123        self
124    }
125
126    fn name(&self) -> &str {
127        "array_sort"
128    }
129
130    fn signature(&self) -> &Signature {
131        &self.signature
132    }
133
134    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
135        match &arg_types[0] {
136            DataType::Null => Ok(DataType::Null),
137            DataType::List(field) => {
138                Ok(DataType::new_list(field.data_type().clone(), true))
139            }
140            arg_type => {
141                plan_err!("{} does not support type {arg_type}", self.name())
142            }
143        }
144    }
145
146    fn invoke_with_args(
147        &self,
148        args: datafusion_expr::ScalarFunctionArgs,
149    ) -> Result<ColumnarValue> {
150        make_scalar_function(array_sort_inner)(&args.args)
151    }
152
153    fn aliases(&self) -> &[String] {
154        &self.aliases
155    }
156
157    fn documentation(&self) -> Option<&Documentation> {
158        self.doc()
159    }
160}
161
162/// Array_sort SQL function
163pub fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
164    if args.is_empty() || args.len() > 3 {
165        return exec_err!("array_sort expects one to three arguments");
166    }
167
168    if args[0].data_type().is_null() {
169        return Ok(Arc::clone(&args[0]));
170    }
171
172    let list_array = as_list_array(&args[0])?;
173    let row_count = list_array.len();
174    if row_count == 0 || list_array.value_type().is_null() {
175        return Ok(Arc::clone(&args[0]));
176    }
177
178    if args[1..].iter().any(|array| array.is_null(0)) {
179        return Ok(new_null_array(args[0].data_type(), args[0].len()));
180    }
181
182    let sort_option = match args.len() {
183        1 => None,
184        2 => {
185            let sort = as_string_array(&args[1])?.value(0);
186            Some(SortOptions {
187                descending: order_desc(sort)?,
188                nulls_first: true,
189            })
190        }
191        3 => {
192            let sort = as_string_array(&args[1])?.value(0);
193            let nulls_first = as_string_array(&args[2])?.value(0);
194            Some(SortOptions {
195                descending: order_desc(sort)?,
196                nulls_first: order_nulls_first(nulls_first)?,
197            })
198        }
199        _ => return exec_err!("array_sort expects 1 to 3 arguments"),
200    };
201
202    let mut array_lengths = vec![];
203    let mut arrays = vec![];
204    let mut valid = NullBufferBuilder::new(row_count);
205    for i in 0..row_count {
206        if list_array.is_null(i) {
207            array_lengths.push(0);
208            valid.append_null();
209        } else {
210            let arr_ref = list_array.value(i);
211
212            // arrow sort kernel does not support Structs, so use
213            // lexsort_to_indices instead:
214            // https://github.com/apache/arrow-rs/issues/6911#issuecomment-2562928843
215            let sorted_array = match arr_ref.data_type() {
216                DataType::Struct(_) => {
217                    let sort_columns: Vec<SortColumn> = vec![SortColumn {
218                        values: Arc::clone(&arr_ref),
219                        options: sort_option,
220                    }];
221                    let indices = compute::lexsort_to_indices(&sort_columns, None)?;
222                    compute::take(arr_ref.as_ref(), &indices, None)?
223                }
224                _ => {
225                    let arr_ref = arr_ref.as_ref();
226                    compute::sort(arr_ref, sort_option)?
227                }
228            };
229            array_lengths.push(sorted_array.len());
230            arrays.push(sorted_array);
231            valid.append_non_null();
232        }
233    }
234
235    // Assume all arrays have the same data type
236    let data_type = list_array.value_type();
237    let buffer = valid.finish();
238
239    let elements = arrays
240        .iter()
241        .map(|a| a.as_ref())
242        .collect::<Vec<&dyn Array>>();
243
244    let list_arr = if elements.is_empty() {
245        ListArray::new_null(Arc::new(Field::new_list_field(data_type, true)), row_count)
246    } else {
247        ListArray::new(
248            Arc::new(Field::new_list_field(data_type, true)),
249            OffsetBuffer::from_lengths(array_lengths),
250            Arc::new(compute::concat(elements.as_slice())?),
251            buffer,
252        )
253    };
254    Ok(Arc::new(list_arr))
255}
256
257fn order_desc(modifier: &str) -> Result<bool> {
258    match modifier.to_uppercase().as_str() {
259        "DESC" => Ok(true),
260        "ASC" => Ok(false),
261        _ => exec_err!("the second parameter of array_sort expects DESC or ASC"),
262    }
263}
264
265fn order_nulls_first(modifier: &str) -> Result<bool> {
266    match modifier.to_uppercase().as_str() {
267        "NULLS FIRST" => Ok(true),
268        "NULLS LAST" => Ok(false),
269        _ => exec_err!(
270            "the third parameter of array_sort expects NULLS FIRST or NULLS LAST"
271        ),
272    }
273}