datafusion_functions_nested/
sort.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_sort function.
19
20use crate::utils::make_scalar_function;
21use arrow::array::{Array, ArrayRef, ListArray, NullBufferBuilder};
22use arrow::buffer::OffsetBuffer;
23use arrow::datatypes::DataType::{FixedSizeList, LargeList, List};
24use arrow::datatypes::{DataType, Field};
25use arrow::{compute, compute::SortOptions};
26use datafusion_common::cast::{as_list_array, as_string_array};
27use datafusion_common::{exec_err, Result};
28use datafusion_expr::{
29    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
30};
31use datafusion_macros::user_doc;
32use std::any::Any;
33use std::sync::Arc;
34
35make_udf_expr_and_func!(
36    ArraySort,
37    array_sort,
38    array desc null_first,
39    "returns sorted array.",
40    array_sort_udf
41);
42
43/// Implementation of `array_sort` function
44///
45/// `array_sort` sorts the elements of an array
46///
47/// # Example
48///
49/// `array_sort([3, 1, 2])` returns `[1, 2, 3]`
50#[user_doc(
51    doc_section(label = "Array Functions"),
52    description = "Sort array.",
53    syntax_example = "array_sort(array, desc, nulls_first)",
54    sql_example = r#"```sql
55> select array_sort([3, 1, 2]);
56+-----------------------------+
57| array_sort(List([3,1,2]))   |
58+-----------------------------+
59| [1, 2, 3]                   |
60+-----------------------------+
61```"#,
62    argument(
63        name = "array",
64        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
65    ),
66    argument(
67        name = "desc",
68        description = "Whether to sort in descending order(`ASC` or `DESC`)."
69    ),
70    argument(
71        name = "nulls_first",
72        description = "Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`)."
73    )
74)]
75#[derive(Debug)]
76pub struct ArraySort {
77    signature: Signature,
78    aliases: Vec<String>,
79}
80
81impl Default for ArraySort {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87impl ArraySort {
88    pub fn new() -> Self {
89        Self {
90            signature: Signature::variadic_any(Volatility::Immutable),
91            aliases: vec!["list_sort".to_string()],
92        }
93    }
94}
95
96impl ScalarUDFImpl for ArraySort {
97    fn as_any(&self) -> &dyn Any {
98        self
99    }
100
101    fn name(&self) -> &str {
102        "array_sort"
103    }
104
105    fn signature(&self) -> &Signature {
106        &self.signature
107    }
108
109    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
110        match &arg_types[0] {
111            List(field) | FixedSizeList(field, _) => Ok(List(Arc::new(
112                Field::new_list_field(field.data_type().clone(), true),
113            ))),
114            LargeList(field) => Ok(LargeList(Arc::new(Field::new_list_field(
115                field.data_type().clone(),
116                true,
117            )))),
118            _ => exec_err!(
119                "Not reachable, data_type should be List, LargeList or FixedSizeList"
120            ),
121        }
122    }
123
124    fn invoke_with_args(
125        &self,
126        args: datafusion_expr::ScalarFunctionArgs,
127    ) -> Result<ColumnarValue> {
128        make_scalar_function(array_sort_inner)(&args.args)
129    }
130
131    fn aliases(&self) -> &[String] {
132        &self.aliases
133    }
134
135    fn documentation(&self) -> Option<&Documentation> {
136        self.doc()
137    }
138}
139
140/// Array_sort SQL function
141pub fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
142    if args.is_empty() || args.len() > 3 {
143        return exec_err!("array_sort expects one to three arguments");
144    }
145
146    let sort_option = match args.len() {
147        1 => None,
148        2 => {
149            let sort = as_string_array(&args[1])?.value(0);
150            Some(SortOptions {
151                descending: order_desc(sort)?,
152                nulls_first: true,
153            })
154        }
155        3 => {
156            let sort = as_string_array(&args[1])?.value(0);
157            let nulls_first = as_string_array(&args[2])?.value(0);
158            Some(SortOptions {
159                descending: order_desc(sort)?,
160                nulls_first: order_nulls_first(nulls_first)?,
161            })
162        }
163        _ => return exec_err!("array_sort expects 1 to 3 arguments"),
164    };
165
166    let list_array = as_list_array(&args[0])?;
167    let row_count = list_array.len();
168    if row_count == 0 {
169        return Ok(Arc::clone(&args[0]));
170    }
171
172    let mut array_lengths = vec![];
173    let mut arrays = vec![];
174    let mut valid = NullBufferBuilder::new(row_count);
175    for i in 0..row_count {
176        if list_array.is_null(i) {
177            array_lengths.push(0);
178            valid.append_null();
179        } else {
180            let arr_ref = list_array.value(i);
181            let arr_ref = arr_ref.as_ref();
182
183            let sorted_array = compute::sort(arr_ref, sort_option)?;
184            array_lengths.push(sorted_array.len());
185            arrays.push(sorted_array);
186            valid.append_non_null();
187        }
188    }
189
190    // Assume all arrays have the same data type
191    let data_type = list_array.value_type();
192    let buffer = valid.finish();
193
194    let elements = arrays
195        .iter()
196        .map(|a| a.as_ref())
197        .collect::<Vec<&dyn Array>>();
198
199    let list_arr = ListArray::new(
200        Arc::new(Field::new_list_field(data_type, true)),
201        OffsetBuffer::from_lengths(array_lengths),
202        Arc::new(compute::concat(elements.as_slice())?),
203        buffer,
204    );
205    Ok(Arc::new(list_arr))
206}
207
208fn order_desc(modifier: &str) -> Result<bool> {
209    match modifier.to_uppercase().as_str() {
210        "DESC" => Ok(true),
211        "ASC" => Ok(false),
212        _ => exec_err!("the second parameter of array_sort expects DESC or ASC"),
213    }
214}
215
216fn order_nulls_first(modifier: &str) -> Result<bool> {
217    match modifier.to_uppercase().as_str() {
218        "NULLS FIRST" => Ok(true),
219        "NULLS LAST" => Ok(false),
220        _ => exec_err!(
221            "the third parameter of array_sort expects NULLS FIRST or NULLS LAST"
222        ),
223    }
224}