datafusion_functions_nested/
sort.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_sort function.
19
20use crate::utils::make_scalar_function;
21use arrow::array::{
22    new_null_array, Array, ArrayRef, GenericListArray, NullBufferBuilder, OffsetSizeTrait,
23};
24use arrow::buffer::OffsetBuffer;
25use arrow::compute::SortColumn;
26use arrow::datatypes::{DataType, FieldRef};
27use arrow::{compute, compute::SortOptions};
28use datafusion_common::cast::{as_large_list_array, as_list_array, as_string_array};
29use datafusion_common::utils::ListCoercion;
30use datafusion_common::{exec_err, plan_err, Result};
31use datafusion_expr::{
32    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
33    ScalarUDFImpl, Signature, TypeSignature, Volatility,
34};
35use datafusion_macros::user_doc;
36use std::any::Any;
37use std::sync::Arc;
38
39make_udf_expr_and_func!(
40    ArraySort,
41    array_sort,
42    array desc null_first,
43    "returns sorted array.",
44    array_sort_udf
45);
46
47/// Implementation of `array_sort` function
48///
49/// `array_sort` sorts the elements of an array
50///
51/// # Example
52///
53/// `array_sort([3, 1, 2])` returns `[1, 2, 3]`
54#[user_doc(
55    doc_section(label = "Array Functions"),
56    description = "Sort array.",
57    syntax_example = "array_sort(array, desc, nulls_first)",
58    sql_example = r#"```sql
59> select array_sort([3, 1, 2]);
60+-----------------------------+
61| array_sort(List([3,1,2]))   |
62+-----------------------------+
63| [1, 2, 3]                   |
64+-----------------------------+
65```"#,
66    argument(
67        name = "array",
68        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
69    ),
70    argument(
71        name = "desc",
72        description = "Whether to sort in descending order(`ASC` or `DESC`)."
73    ),
74    argument(
75        name = "nulls_first",
76        description = "Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`)."
77    )
78)]
79#[derive(Debug, PartialEq, Eq, Hash)]
80pub struct ArraySort {
81    signature: Signature,
82    aliases: Vec<String>,
83}
84
85impl Default for ArraySort {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91impl ArraySort {
92    pub fn new() -> Self {
93        Self {
94            signature: Signature::one_of(
95                vec![
96                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
97                        arguments: vec![ArrayFunctionArgument::Array],
98                        array_coercion: Some(ListCoercion::FixedSizedListToList),
99                    }),
100                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
101                        arguments: vec![
102                            ArrayFunctionArgument::Array,
103                            ArrayFunctionArgument::String,
104                        ],
105                        array_coercion: Some(ListCoercion::FixedSizedListToList),
106                    }),
107                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
108                        arguments: vec![
109                            ArrayFunctionArgument::Array,
110                            ArrayFunctionArgument::String,
111                            ArrayFunctionArgument::String,
112                        ],
113                        array_coercion: Some(ListCoercion::FixedSizedListToList),
114                    }),
115                ],
116                Volatility::Immutable,
117            ),
118            aliases: vec!["list_sort".to_string()],
119        }
120    }
121}
122
123impl ScalarUDFImpl for ArraySort {
124    fn as_any(&self) -> &dyn Any {
125        self
126    }
127
128    fn name(&self) -> &str {
129        "array_sort"
130    }
131
132    fn signature(&self) -> &Signature {
133        &self.signature
134    }
135
136    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
137        match &arg_types[0] {
138            DataType::Null => Ok(DataType::Null),
139            DataType::List(field) => {
140                Ok(DataType::new_list(field.data_type().clone(), true))
141            }
142            DataType::LargeList(field) => {
143                Ok(DataType::new_large_list(field.data_type().clone(), true))
144            }
145            arg_type => {
146                plan_err!("{} does not support type {arg_type}", self.name())
147            }
148        }
149    }
150
151    fn invoke_with_args(
152        &self,
153        args: datafusion_expr::ScalarFunctionArgs,
154    ) -> Result<ColumnarValue> {
155        make_scalar_function(array_sort_inner)(&args.args)
156    }
157
158    fn aliases(&self) -> &[String] {
159        &self.aliases
160    }
161
162    fn documentation(&self) -> Option<&Documentation> {
163        self.doc()
164    }
165}
166
167/// Array_sort SQL function
168pub fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
169    if args.is_empty() || args.len() > 3 {
170        return exec_err!("array_sort expects one to three arguments");
171    }
172
173    if args[0].is_empty() || args[0].data_type().is_null() {
174        return Ok(Arc::clone(&args[0]));
175    }
176
177    if args[1..].iter().any(|array| array.is_null(0)) {
178        return Ok(new_null_array(args[0].data_type(), args[0].len()));
179    }
180
181    let sort_options = match args.len() {
182        1 => None,
183        2 => {
184            let sort = as_string_array(&args[1])?.value(0);
185            Some(SortOptions {
186                descending: order_desc(sort)?,
187                nulls_first: true,
188            })
189        }
190        3 => {
191            let sort = as_string_array(&args[1])?.value(0);
192            let nulls_first = as_string_array(&args[2])?.value(0);
193            Some(SortOptions {
194                descending: order_desc(sort)?,
195                nulls_first: order_nulls_first(nulls_first)?,
196            })
197        }
198        // We guard at the top
199        _ => unreachable!(),
200    };
201
202    match args[0].data_type() {
203        DataType::List(field) | DataType::LargeList(field)
204            if field.data_type().is_null() =>
205        {
206            Ok(Arc::clone(&args[0]))
207        }
208        DataType::List(field) => {
209            let array = as_list_array(&args[0])?;
210            array_sort_generic(array, field, sort_options)
211        }
212        DataType::LargeList(field) => {
213            let array = as_large_list_array(&args[0])?;
214            array_sort_generic(array, field, sort_options)
215        }
216        // Signature should prevent this arm ever occurring
217        _ => exec_err!("array_sort expects list for first argument"),
218    }
219}
220
221/// Array_sort SQL function
222pub fn array_sort_generic<OffsetSize: OffsetSizeTrait>(
223    list_array: &GenericListArray<OffsetSize>,
224    field: &FieldRef,
225    sort_options: Option<SortOptions>,
226) -> Result<ArrayRef> {
227    let row_count = list_array.len();
228
229    let mut array_lengths = vec![];
230    let mut arrays = vec![];
231    let mut valid = NullBufferBuilder::new(row_count);
232    for i in 0..row_count {
233        if list_array.is_null(i) {
234            array_lengths.push(0);
235            valid.append_null();
236        } else {
237            let arr_ref = list_array.value(i);
238
239            // arrow sort kernel does not support Structs, so use
240            // lexsort_to_indices instead:
241            // https://github.com/apache/arrow-rs/issues/6911#issuecomment-2562928843
242            let sorted_array = match arr_ref.data_type() {
243                DataType::Struct(_) => {
244                    let sort_columns: Vec<SortColumn> = vec![SortColumn {
245                        values: Arc::clone(&arr_ref),
246                        options: sort_options,
247                    }];
248                    let indices = compute::lexsort_to_indices(&sort_columns, None)?;
249                    compute::take(arr_ref.as_ref(), &indices, None)?
250                }
251                _ => {
252                    let arr_ref = arr_ref.as_ref();
253                    compute::sort(arr_ref, sort_options)?
254                }
255            };
256            array_lengths.push(sorted_array.len());
257            arrays.push(sorted_array);
258            valid.append_non_null();
259        }
260    }
261
262    let buffer = valid.finish();
263
264    let elements = arrays
265        .iter()
266        .map(|a| a.as_ref())
267        .collect::<Vec<&dyn Array>>();
268
269    let list_arr = if elements.is_empty() {
270        GenericListArray::<OffsetSize>::new_null(Arc::clone(field), row_count)
271    } else {
272        GenericListArray::<OffsetSize>::new(
273            Arc::clone(field),
274            OffsetBuffer::from_lengths(array_lengths),
275            Arc::new(compute::concat(elements.as_slice())?),
276            buffer,
277        )
278    };
279    Ok(Arc::new(list_arr))
280}
281
282fn order_desc(modifier: &str) -> Result<bool> {
283    match modifier.to_uppercase().as_str() {
284        "DESC" => Ok(true),
285        "ASC" => Ok(false),
286        _ => exec_err!("the second parameter of array_sort expects DESC or ASC"),
287    }
288}
289
290fn order_nulls_first(modifier: &str) -> Result<bool> {
291    match modifier.to_uppercase().as_str() {
292        "NULLS FIRST" => Ok(true),
293        "NULLS LAST" => Ok(false),
294        _ => exec_err!(
295            "the third parameter of array_sort expects NULLS FIRST or NULLS LAST"
296        ),
297    }
298}