datafusion_functions_nested/
sort.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_sort function.
19
20use crate::utils::make_scalar_function;
21use arrow::array::{
22    Array, ArrayRef, GenericListArray, NullBufferBuilder, OffsetSizeTrait, new_null_array,
23};
24use arrow::buffer::OffsetBuffer;
25use arrow::compute::SortColumn;
26use arrow::datatypes::{DataType, FieldRef};
27use arrow::{compute, compute::SortOptions};
28use datafusion_common::cast::{as_large_list_array, as_list_array, as_string_array};
29use datafusion_common::utils::ListCoercion;
30use datafusion_common::{Result, exec_err, plan_err};
31use datafusion_expr::{
32    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
33    ScalarUDFImpl, Signature, TypeSignature, Volatility,
34};
35use datafusion_macros::user_doc;
36use std::any::Any;
37use std::sync::Arc;
38
39make_udf_expr_and_func!(
40    ArraySort,
41    array_sort,
42    array desc null_first,
43    "returns sorted array.",
44    array_sort_udf
45);
46
47/// Implementation of `array_sort` function
48///
49/// `array_sort` sorts the elements of an array
50///
51/// # Example
52///
53/// `array_sort([3, 1, 2])` returns `[1, 2, 3]`
54#[user_doc(
55    doc_section(label = "Array Functions"),
56    description = "Sort array.",
57    syntax_example = "array_sort(array, desc, nulls_first)",
58    sql_example = r#"```sql
59> select array_sort([3, 1, 2]);
60+-----------------------------+
61| array_sort(List([3,1,2]))   |
62+-----------------------------+
63| [1, 2, 3]                   |
64+-----------------------------+
65```"#,
66    argument(
67        name = "array",
68        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
69    ),
70    argument(
71        name = "desc",
72        description = "Whether to sort in descending order(`ASC` or `DESC`)."
73    ),
74    argument(
75        name = "nulls_first",
76        description = "Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`)."
77    )
78)]
79#[derive(Debug, PartialEq, Eq, Hash)]
80pub struct ArraySort {
81    signature: Signature,
82    aliases: Vec<String>,
83}
84
85impl Default for ArraySort {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91impl ArraySort {
92    pub fn new() -> Self {
93        Self {
94            signature: Signature::one_of(
95                vec![
96                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
97                        arguments: vec![ArrayFunctionArgument::Array],
98                        array_coercion: Some(ListCoercion::FixedSizedListToList),
99                    }),
100                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
101                        arguments: vec![
102                            ArrayFunctionArgument::Array,
103                            ArrayFunctionArgument::String,
104                        ],
105                        array_coercion: Some(ListCoercion::FixedSizedListToList),
106                    }),
107                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
108                        arguments: vec![
109                            ArrayFunctionArgument::Array,
110                            ArrayFunctionArgument::String,
111                            ArrayFunctionArgument::String,
112                        ],
113                        array_coercion: Some(ListCoercion::FixedSizedListToList),
114                    }),
115                ],
116                Volatility::Immutable,
117            ),
118            aliases: vec!["list_sort".to_string()],
119        }
120    }
121}
122
123impl ScalarUDFImpl for ArraySort {
124    fn as_any(&self) -> &dyn Any {
125        self
126    }
127
128    fn name(&self) -> &str {
129        "array_sort"
130    }
131
132    fn signature(&self) -> &Signature {
133        &self.signature
134    }
135
136    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
137        match &arg_types[0] {
138            DataType::Null => Ok(DataType::Null),
139            DataType::List(field) => {
140                Ok(DataType::new_list(field.data_type().clone(), true))
141            }
142            DataType::LargeList(field) => {
143                Ok(DataType::new_large_list(field.data_type().clone(), true))
144            }
145            arg_type => {
146                plan_err!("{} does not support type {arg_type}", self.name())
147            }
148        }
149    }
150
151    fn invoke_with_args(
152        &self,
153        args: datafusion_expr::ScalarFunctionArgs,
154    ) -> Result<ColumnarValue> {
155        make_scalar_function(array_sort_inner)(&args.args)
156    }
157
158    fn aliases(&self) -> &[String] {
159        &self.aliases
160    }
161
162    fn documentation(&self) -> Option<&Documentation> {
163        self.doc()
164    }
165}
166
167fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
168    if args.is_empty() || args.len() > 3 {
169        return exec_err!("array_sort expects one to three arguments");
170    }
171
172    if args[0].is_empty() || args[0].data_type().is_null() {
173        return Ok(Arc::clone(&args[0]));
174    }
175
176    if args[1..].iter().any(|array| array.is_null(0)) {
177        return Ok(new_null_array(args[0].data_type(), args[0].len()));
178    }
179
180    let sort_options = match args.len() {
181        1 => None,
182        2 => {
183            let sort = as_string_array(&args[1])?.value(0);
184            Some(SortOptions {
185                descending: order_desc(sort)?,
186                nulls_first: true,
187            })
188        }
189        3 => {
190            let sort = as_string_array(&args[1])?.value(0);
191            let nulls_first = as_string_array(&args[2])?.value(0);
192            Some(SortOptions {
193                descending: order_desc(sort)?,
194                nulls_first: order_nulls_first(nulls_first)?,
195            })
196        }
197        // We guard at the top
198        _ => unreachable!(),
199    };
200
201    match args[0].data_type() {
202        DataType::List(field) | DataType::LargeList(field)
203            if field.data_type().is_null() =>
204        {
205            Ok(Arc::clone(&args[0]))
206        }
207        DataType::List(field) => {
208            let array = as_list_array(&args[0])?;
209            array_sort_generic(array, field, sort_options)
210        }
211        DataType::LargeList(field) => {
212            let array = as_large_list_array(&args[0])?;
213            array_sort_generic(array, field, sort_options)
214        }
215        // Signature should prevent this arm ever occurring
216        _ => exec_err!("array_sort expects list for first argument"),
217    }
218}
219
220fn array_sort_generic<OffsetSize: OffsetSizeTrait>(
221    list_array: &GenericListArray<OffsetSize>,
222    field: &FieldRef,
223    sort_options: Option<SortOptions>,
224) -> Result<ArrayRef> {
225    let row_count = list_array.len();
226
227    let mut array_lengths = vec![];
228    let mut arrays = vec![];
229    let mut valid = NullBufferBuilder::new(row_count);
230    for i in 0..row_count {
231        if list_array.is_null(i) {
232            array_lengths.push(0);
233            valid.append_null();
234        } else {
235            let arr_ref = list_array.value(i);
236
237            // arrow sort kernel does not support Structs, so use
238            // lexsort_to_indices instead:
239            // https://github.com/apache/arrow-rs/issues/6911#issuecomment-2562928843
240            let sorted_array = match arr_ref.data_type() {
241                DataType::Struct(_) => {
242                    let sort_columns: Vec<SortColumn> = vec![SortColumn {
243                        values: Arc::clone(&arr_ref),
244                        options: sort_options,
245                    }];
246                    let indices = compute::lexsort_to_indices(&sort_columns, None)?;
247                    compute::take(arr_ref.as_ref(), &indices, None)?
248                }
249                _ => {
250                    let arr_ref = arr_ref.as_ref();
251                    compute::sort(arr_ref, sort_options)?
252                }
253            };
254            array_lengths.push(sorted_array.len());
255            arrays.push(sorted_array);
256            valid.append_non_null();
257        }
258    }
259
260    let buffer = valid.finish();
261
262    let elements = arrays
263        .iter()
264        .map(|a| a.as_ref())
265        .collect::<Vec<&dyn Array>>();
266
267    let list_arr = if elements.is_empty() {
268        GenericListArray::<OffsetSize>::new_null(Arc::clone(field), row_count)
269    } else {
270        GenericListArray::<OffsetSize>::new(
271            Arc::clone(field),
272            OffsetBuffer::from_lengths(array_lengths),
273            Arc::new(compute::concat(elements.as_slice())?),
274            buffer,
275        )
276    };
277    Ok(Arc::new(list_arr))
278}
279
280fn order_desc(modifier: &str) -> Result<bool> {
281    match modifier.to_uppercase().as_str() {
282        "DESC" => Ok(true),
283        "ASC" => Ok(false),
284        _ => exec_err!("the second parameter of array_sort expects DESC or ASC"),
285    }
286}
287
288fn order_nulls_first(modifier: &str) -> Result<bool> {
289    match modifier.to_uppercase().as_str() {
290        "NULLS FIRST" => Ok(true),
291        "NULLS LAST" => Ok(false),
292        _ => exec_err!(
293            "the third parameter of array_sort expects NULLS FIRST or NULLS LAST"
294        ),
295    }
296}