datafusion_functions_nested/
sort.rs1use crate::utils::make_scalar_function;
21use arrow::array::{Array, ArrayRef, ListArray, NullBufferBuilder};
22use arrow::buffer::OffsetBuffer;
23use arrow::datatypes::DataType::{FixedSizeList, LargeList, List};
24use arrow::datatypes::{DataType, Field};
25use arrow::{compute, compute::SortOptions};
26use datafusion_common::cast::{as_list_array, as_string_array};
27use datafusion_common::{exec_err, Result};
28use datafusion_expr::{
29 ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
30};
31use datafusion_macros::user_doc;
32use std::any::Any;
33use std::sync::Arc;
34
35make_udf_expr_and_func!(
36 ArraySort,
37 array_sort,
38 array desc null_first,
39 "returns sorted array.",
40 array_sort_udf
41);
42
43#[user_doc(
51 doc_section(label = "Array Functions"),
52 description = "Sort array.",
53 syntax_example = "array_sort(array, desc, nulls_first)",
54 sql_example = r#"```sql
55> select array_sort([3, 1, 2]);
56+-----------------------------+
57| array_sort(List([3,1,2])) |
58+-----------------------------+
59| [1, 2, 3] |
60+-----------------------------+
61```"#,
62 argument(
63 name = "array",
64 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
65 ),
66 argument(
67 name = "desc",
68 description = "Whether to sort in descending order(`ASC` or `DESC`)."
69 ),
70 argument(
71 name = "nulls_first",
72 description = "Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`)."
73 )
74)]
75#[derive(Debug)]
76pub struct ArraySort {
77 signature: Signature,
78 aliases: Vec<String>,
79}
80
81impl Default for ArraySort {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87impl ArraySort {
88 pub fn new() -> Self {
89 Self {
90 signature: Signature::variadic_any(Volatility::Immutable),
91 aliases: vec!["list_sort".to_string()],
92 }
93 }
94}
95
96impl ScalarUDFImpl for ArraySort {
97 fn as_any(&self) -> &dyn Any {
98 self
99 }
100
101 fn name(&self) -> &str {
102 "array_sort"
103 }
104
105 fn signature(&self) -> &Signature {
106 &self.signature
107 }
108
109 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
110 match &arg_types[0] {
111 List(field) | FixedSizeList(field, _) => Ok(List(Arc::new(
112 Field::new_list_field(field.data_type().clone(), true),
113 ))),
114 LargeList(field) => Ok(LargeList(Arc::new(Field::new_list_field(
115 field.data_type().clone(),
116 true,
117 )))),
118 _ => exec_err!(
119 "Not reachable, data_type should be List, LargeList or FixedSizeList"
120 ),
121 }
122 }
123
124 fn invoke_with_args(
125 &self,
126 args: datafusion_expr::ScalarFunctionArgs,
127 ) -> Result<ColumnarValue> {
128 make_scalar_function(array_sort_inner)(&args.args)
129 }
130
131 fn aliases(&self) -> &[String] {
132 &self.aliases
133 }
134
135 fn documentation(&self) -> Option<&Documentation> {
136 self.doc()
137 }
138}
139
140pub fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
142 if args.is_empty() || args.len() > 3 {
143 return exec_err!("array_sort expects one to three arguments");
144 }
145
146 let sort_option = match args.len() {
147 1 => None,
148 2 => {
149 let sort = as_string_array(&args[1])?.value(0);
150 Some(SortOptions {
151 descending: order_desc(sort)?,
152 nulls_first: true,
153 })
154 }
155 3 => {
156 let sort = as_string_array(&args[1])?.value(0);
157 let nulls_first = as_string_array(&args[2])?.value(0);
158 Some(SortOptions {
159 descending: order_desc(sort)?,
160 nulls_first: order_nulls_first(nulls_first)?,
161 })
162 }
163 _ => return exec_err!("array_sort expects 1 to 3 arguments"),
164 };
165
166 let list_array = as_list_array(&args[0])?;
167 let row_count = list_array.len();
168 if row_count == 0 {
169 return Ok(Arc::clone(&args[0]));
170 }
171
172 let mut array_lengths = vec![];
173 let mut arrays = vec![];
174 let mut valid = NullBufferBuilder::new(row_count);
175 for i in 0..row_count {
176 if list_array.is_null(i) {
177 array_lengths.push(0);
178 valid.append_null();
179 } else {
180 let arr_ref = list_array.value(i);
181 let arr_ref = arr_ref.as_ref();
182
183 let sorted_array = compute::sort(arr_ref, sort_option)?;
184 array_lengths.push(sorted_array.len());
185 arrays.push(sorted_array);
186 valid.append_non_null();
187 }
188 }
189
190 let data_type = list_array.value_type();
192 let buffer = valid.finish();
193
194 let elements = arrays
195 .iter()
196 .map(|a| a.as_ref())
197 .collect::<Vec<&dyn Array>>();
198
199 let list_arr = ListArray::new(
200 Arc::new(Field::new_list_field(data_type, true)),
201 OffsetBuffer::from_lengths(array_lengths),
202 Arc::new(compute::concat(elements.as_slice())?),
203 buffer,
204 );
205 Ok(Arc::new(list_arr))
206}
207
208fn order_desc(modifier: &str) -> Result<bool> {
209 match modifier.to_uppercase().as_str() {
210 "DESC" => Ok(true),
211 "ASC" => Ok(false),
212 _ => exec_err!("the second parameter of array_sort expects DESC or ASC"),
213 }
214}
215
216fn order_nulls_first(modifier: &str) -> Result<bool> {
217 match modifier.to_uppercase().as_str() {
218 "NULLS FIRST" => Ok(true),
219 "NULLS LAST" => Ok(false),
220 _ => exec_err!(
221 "the third parameter of array_sort expects NULLS FIRST or NULLS LAST"
222 ),
223 }
224}