1use crate::utils::make_scalar_function;
21use arrow::array::BooleanBufferBuilder;
22use arrow::array::{
23 Array, ArrayRef, ArrowPrimitiveType, GenericListArray, OffsetSizeTrait,
24 PrimitiveArray, UInt32Array, UInt64Array, new_empty_array, new_null_array,
25};
26use arrow::buffer::{NullBuffer, OffsetBuffer};
27use arrow::datatypes::{ArrowNativeTypeOp, DataType, FieldRef};
28use arrow::row::{RowConverter, SortField};
29use arrow::{compute, compute::SortOptions, downcast_primitive_array};
30use datafusion_common::cast::{as_large_list_array, as_list_array, as_string_array};
31use datafusion_common::utils::ListCoercion;
32use datafusion_common::{Result, exec_err, internal_datafusion_err};
33use datafusion_expr::{
34 ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
35 ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
36};
37use datafusion_macros::user_doc;
38use std::sync::Arc;
39
40make_udf_expr_and_func!(
41 ArraySort,
42 array_sort,
43 array desc null_first,
44 "returns sorted array.",
45 array_sort_udf
46);
47
48#[user_doc(
56 doc_section(label = "Array Functions"),
57 description = "Sort array.",
58 syntax_example = "array_sort(array, desc, nulls_first)",
59 sql_example = r#"```sql
60> select array_sort([3, 1, 2]);
61+-----------------------------+
62| array_sort(List([3,1,2])) |
63+-----------------------------+
64| [1, 2, 3] |
65+-----------------------------+
66```"#,
67 argument(
68 name = "array",
69 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
70 ),
71 argument(
72 name = "desc",
73 description = "Whether to sort in ascending (`ASC`) or descending (`DESC`) order. The default is `ASC`."
74 ),
75 argument(
76 name = "nulls_first",
77 description = "Whether to sort nulls first (`NULLS FIRST`) or last (`NULLS LAST`). The default is `NULLS FIRST`."
78 )
79)]
80#[derive(Debug, PartialEq, Eq, Hash)]
81pub struct ArraySort {
82 signature: Signature,
83 aliases: Vec<String>,
84}
85
86impl Default for ArraySort {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92impl ArraySort {
93 pub fn new() -> Self {
94 Self {
95 signature: Signature::one_of(
96 vec![
97 TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
98 arguments: vec![ArrayFunctionArgument::Array],
99 array_coercion: Some(ListCoercion::FixedSizedListToList),
100 }),
101 TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
102 arguments: vec![
103 ArrayFunctionArgument::Array,
104 ArrayFunctionArgument::String,
105 ],
106 array_coercion: Some(ListCoercion::FixedSizedListToList),
107 }),
108 TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
109 arguments: vec![
110 ArrayFunctionArgument::Array,
111 ArrayFunctionArgument::String,
112 ArrayFunctionArgument::String,
113 ],
114 array_coercion: Some(ListCoercion::FixedSizedListToList),
115 }),
116 ],
117 Volatility::Immutable,
118 ),
119 aliases: vec!["list_sort".to_string()],
120 }
121 }
122}
123
124impl ScalarUDFImpl for ArraySort {
125 fn name(&self) -> &str {
126 "array_sort"
127 }
128
129 fn signature(&self) -> &Signature {
130 &self.signature
131 }
132
133 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
134 Ok(arg_types[0].clone())
135 }
136
137 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
138 make_scalar_function(array_sort_inner)(&args.args)
139 }
140
141 fn aliases(&self) -> &[String] {
142 &self.aliases
143 }
144
145 fn documentation(&self) -> Option<&Documentation> {
146 self.doc()
147 }
148}
149
150fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
151 if args.is_empty() || args.len() > 3 {
152 return exec_err!("array_sort expects one to three arguments");
153 }
154
155 if args[0].is_empty() || args[0].data_type().is_null() {
156 return Ok(Arc::clone(&args[0]));
157 }
158
159 if args[1..].iter().any(|array| array.is_null(0)) {
160 return Ok(new_null_array(args[0].data_type(), args[0].len()));
161 }
162
163 let sort_options = if args.len() >= 2 {
164 let order = as_string_array(&args[1])?.value(0);
165 let descending = order_desc(order)?;
166 let nulls_first = if args.len() >= 3 {
167 order_nulls_first(as_string_array(&args[2])?.value(0))?
168 } else {
169 true
170 };
171 Some(SortOptions {
172 descending,
173 nulls_first,
174 })
175 } else {
176 None
177 };
178
179 match args[0].data_type() {
180 DataType::List(field) | DataType::LargeList(field)
181 if field.data_type().is_null() =>
182 {
183 Ok(Arc::clone(&args[0]))
184 }
185 DataType::List(field) => {
186 let array = as_list_array(&args[0])?;
187 array_sort_generic(array, Arc::clone(field), sort_options)
188 }
189 DataType::LargeList(field) => {
190 let array = as_large_list_array(&args[0])?;
191 array_sort_generic(array, Arc::clone(field), sort_options)
192 }
193 _ => exec_err!("array_sort expects list for first argument"),
195 }
196}
197
198fn array_sort_generic<OffsetSize: OffsetSizeTrait>(
199 list_array: &GenericListArray<OffsetSize>,
200 field: FieldRef,
201 sort_options: Option<SortOptions>,
202) -> Result<ArrayRef> {
203 let values = list_array.values();
204
205 if values.data_type().is_primitive() {
206 array_sort_primitive(list_array, field, sort_options)
207 } else {
208 array_sort_non_primitive(list_array, field, sort_options)
209 }
210}
211
212fn array_sort_primitive<OffsetSize: OffsetSizeTrait>(
215 list_array: &GenericListArray<OffsetSize>,
216 field: FieldRef,
217 sort_options: Option<SortOptions>,
218) -> Result<ArrayRef> {
219 let values = list_array.values().as_ref();
220 downcast_primitive_array! {
221 values => sort_primitive_list(values, list_array, field, sort_options),
222 _ => exec_err!("array_sort: unsupported primitive type")
223 }
224}
225
226fn sort_primitive_list<T: ArrowPrimitiveType, OffsetSize: OffsetSizeTrait>(
227 prim_values: &PrimitiveArray<T>,
228 list_array: &GenericListArray<OffsetSize>,
229 field: FieldRef,
230 sort_options: Option<SortOptions>,
231) -> Result<ArrayRef>
232where
233 T::Native: ArrowNativeTypeOp,
234{
235 if prim_values.null_count() > 0 {
236 sort_list_with_nulls(prim_values, list_array, field, sort_options)
237 } else {
238 sort_list_no_nulls(prim_values, list_array, field, sort_options)
239 }
240}
241
242fn sort_list_no_nulls<T: ArrowPrimitiveType, OffsetSize: OffsetSizeTrait>(
245 prim_values: &PrimitiveArray<T>,
246 list_array: &GenericListArray<OffsetSize>,
247 field: FieldRef,
248 sort_options: Option<SortOptions>,
249) -> Result<ArrayRef>
250where
251 T::Native: ArrowNativeTypeOp,
252{
253 let row_count = list_array.len();
254 let offsets = list_array.offsets();
255 let values_start = offsets[0].as_usize();
256 let values_end = offsets[row_count].as_usize();
257
258 let descending = sort_options.is_some_and(|o| o.descending);
259
260 let mut values: Vec<T::Native> =
262 prim_values.values()[values_start..values_end].to_vec();
263
264 for (row_index, window) in offsets.windows(2).enumerate() {
265 if list_array.is_null(row_index) {
266 continue;
267 }
268 let start = window[0].as_usize() - values_start;
269 let end = window[1].as_usize() - values_start;
270 let slice = &mut values[start..end];
271 if descending {
272 slice.sort_unstable_by(|a, b| b.compare(*a));
273 } else {
274 slice.sort_unstable_by(|a, b| a.compare(*b));
275 }
276 }
277
278 let new_offsets = rebase_offsets(offsets);
279 let sorted_values = Arc::new(
280 PrimitiveArray::<T>::new(values.into(), None)
281 .with_data_type(prim_values.data_type().clone()),
282 );
283
284 Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
285 field,
286 new_offsets,
287 sorted_values,
288 list_array.nulls().cloned(),
289 )?))
290}
291
292fn sort_list_with_nulls<T: ArrowPrimitiveType, OffsetSize: OffsetSizeTrait>(
294 prim_values: &PrimitiveArray<T>,
295 list_array: &GenericListArray<OffsetSize>,
296 field: FieldRef,
297 sort_options: Option<SortOptions>,
298) -> Result<ArrayRef>
299where
300 T::Native: ArrowNativeTypeOp,
301{
302 let row_count = list_array.len();
303 let offsets = list_array.offsets();
304 let values_start = offsets[0].as_usize();
305 let values_end = offsets[row_count].as_usize();
306 let total_values = values_end - values_start;
307
308 let descending = sort_options.is_some_and(|o| o.descending);
309 let nulls_first = sort_options.is_none_or(|o| o.nulls_first);
310
311 let mut out_values: Vec<T::Native> = vec![T::Native::default(); total_values];
312 let mut validity = BooleanBufferBuilder::new(total_values);
313
314 let src_nulls = prim_values.nulls().ok_or_else(|| {
315 internal_datafusion_err!(
316 "sort_list_with_nulls called but values have no null buffer"
317 )
318 })?;
319 let src_values = prim_values.values();
320
321 for (row_index, window) in offsets.windows(2).enumerate() {
322 let start = window[0].as_usize();
323 let end = window[1].as_usize();
324 let row_len = end - start;
325 let out_start = start - values_start;
326
327 if list_array.is_null(row_index) || row_len == 0 {
328 validity.append_n(row_len, false);
329 continue;
330 }
331
332 let null_count = src_nulls.slice(start, row_len).null_count();
333 let valid_count = row_len - null_count;
334
335 let valid_offset = if nulls_first { null_count } else { 0 };
338 let mut write_pos = out_start + valid_offset;
339 for i in start..end {
340 if src_nulls.is_valid(i) {
341 out_values[write_pos] = src_values[i];
342 write_pos += 1;
343 }
344 }
345
346 let valid_slice = &mut out_values
347 [out_start + valid_offset..out_start + valid_offset + valid_count];
348 if descending {
349 valid_slice.sort_unstable_by(|a, b| b.compare(*a));
350 } else {
351 valid_slice.sort_unstable_by(|a, b| a.compare(*b));
352 }
353
354 if nulls_first {
356 validity.append_n(null_count, false);
357 validity.append_n(valid_count, true);
358 } else {
359 validity.append_n(valid_count, true);
360 validity.append_n(null_count, false);
361 }
362 }
363
364 let new_offsets = rebase_offsets(offsets);
365
366 let null_buffer = NullBuffer::from(validity.finish());
367 let sorted_values = Arc::new(
368 PrimitiveArray::<T>::new(out_values.into(), Some(null_buffer))
369 .with_data_type(prim_values.data_type().clone()),
370 );
371
372 Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
373 field,
374 new_offsets,
375 sorted_values,
376 list_array.nulls().cloned(),
377 )?))
378}
379
380fn array_sort_non_primitive<OffsetSize: OffsetSizeTrait>(
385 list_array: &GenericListArray<OffsetSize>,
386 field: FieldRef,
387 sort_options: Option<SortOptions>,
388) -> Result<ArrayRef> {
389 let row_count = list_array.len();
390 let values = list_array.values();
391 let offsets = list_array.offsets();
392 let values_start = offsets[0].as_usize();
393 let total_values = offsets[row_count].as_usize() - values_start;
394
395 let converter = RowConverter::new(vec![SortField::new_with_options(
396 values.data_type().clone(),
397 sort_options.unwrap_or_default(),
398 )])?;
399 let values_sliced = values.slice(values_start, total_values);
400 let rows = converter.convert_columns(&[Arc::clone(&values_sliced)])?;
401
402 let mut indices: Vec<OffsetSize> = Vec::with_capacity(total_values);
403 let mut new_offsets = Vec::with_capacity(row_count + 1);
404 new_offsets.push(OffsetSize::usize_as(0));
405
406 let mut sort_scratch: Vec<usize> = Vec::new();
407
408 for (row_index, window) in offsets.windows(2).enumerate() {
409 let start = window[0];
410 let end = window[1];
411
412 if list_array.is_null(row_index) {
413 new_offsets.push(new_offsets[row_index]);
414 continue;
415 }
416
417 let len = (end - start).as_usize();
418 let local_start = start.as_usize() - values_start;
419
420 if len <= 1 {
421 indices.extend((local_start..local_start + len).map(OffsetSize::usize_as));
422 } else {
423 sort_scratch.clear();
424 sort_scratch.extend(local_start..local_start + len);
425 sort_scratch.sort_unstable_by(|&a, &b| rows.row(a).cmp(&rows.row(b)));
426 indices.extend(sort_scratch.iter().map(|&i| OffsetSize::usize_as(i)));
427 }
428
429 new_offsets.push(new_offsets[row_index] + (end - start));
430 }
431
432 let sorted_values = if indices.is_empty() {
433 new_empty_array(values.data_type())
434 } else {
435 take_by_indices(&values_sliced, indices)?
436 };
437
438 Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
439 field,
440 OffsetBuffer::<OffsetSize>::new(new_offsets.into()),
441 sorted_values,
442 list_array.nulls().cloned(),
443 )?))
444}
445
446fn take_by_indices<OffsetSize: OffsetSizeTrait>(
449 values: &ArrayRef,
450 indices: Vec<OffsetSize>,
451) -> Result<ArrayRef> {
452 let len = indices.len();
453 let buffer = arrow::buffer::Buffer::from_vec(indices);
454 let indices_array: ArrayRef = if OffsetSize::IS_LARGE {
455 Arc::new(UInt64Array::new(
456 arrow::buffer::ScalarBuffer::new(buffer, 0, len),
457 None,
458 ))
459 } else {
460 Arc::new(UInt32Array::new(
461 arrow::buffer::ScalarBuffer::new(buffer, 0, len),
462 None,
463 ))
464 };
465 Ok(compute::take(values.as_ref(), &indices_array, None)?)
466}
467
468fn rebase_offsets<OffsetSize: OffsetSizeTrait>(
472 offsets: &OffsetBuffer<OffsetSize>,
473) -> OffsetBuffer<OffsetSize> {
474 if offsets[0].as_usize() == 0 {
475 offsets.clone()
476 } else {
477 let rebased: Vec<OffsetSize> = offsets.iter().map(|o| *o - offsets[0]).collect();
478 OffsetBuffer::new(rebased.into())
479 }
480}
481
482fn order_desc(modifier: &str) -> Result<bool> {
483 match modifier.to_uppercase().as_str() {
484 "DESC" => Ok(true),
485 "ASC" => Ok(false),
486 _ => exec_err!("the second parameter of array_sort expects DESC or ASC"),
487 }
488}
489
490fn order_nulls_first(modifier: &str) -> Result<bool> {
491 match modifier.to_uppercase().as_str() {
492 "NULLS FIRST" => Ok(true),
493 "NULLS LAST" => Ok(false),
494 _ => exec_err!(
495 "the third parameter of array_sort expects NULLS FIRST or NULLS LAST"
496 ),
497 }
498}