Skip to main content

datafusion_functions_nested/
sort.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_sort function.
19
20use crate::utils::make_scalar_function;
21use arrow::array::BooleanBufferBuilder;
22use arrow::array::{
23    Array, ArrayRef, ArrowPrimitiveType, GenericListArray, OffsetSizeTrait,
24    PrimitiveArray, UInt32Array, UInt64Array, new_empty_array, new_null_array,
25};
26use arrow::buffer::{NullBuffer, OffsetBuffer};
27use arrow::datatypes::{ArrowNativeTypeOp, DataType, FieldRef};
28use arrow::row::{RowConverter, SortField};
29use arrow::{compute, compute::SortOptions, downcast_primitive_array};
30use datafusion_common::cast::{as_large_list_array, as_list_array, as_string_array};
31use datafusion_common::utils::ListCoercion;
32use datafusion_common::{Result, exec_err, internal_datafusion_err};
33use datafusion_expr::{
34    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
35    ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
36};
37use datafusion_macros::user_doc;
38use std::sync::Arc;
39
40make_udf_expr_and_func!(
41    ArraySort,
42    array_sort,
43    array desc null_first,
44    "returns sorted array.",
45    array_sort_udf
46);
47
48/// Implementation of `array_sort` function
49///
50/// `array_sort` sorts the elements of an array
51///
52/// # Example
53///
54/// `array_sort([3, 1, 2])` returns `[1, 2, 3]`
55#[user_doc(
56    doc_section(label = "Array Functions"),
57    description = "Sort array.",
58    syntax_example = "array_sort(array, desc, nulls_first)",
59    sql_example = r#"```sql
60> select array_sort([3, 1, 2]);
61+-----------------------------+
62| array_sort(List([3,1,2]))   |
63+-----------------------------+
64| [1, 2, 3]                   |
65+-----------------------------+
66```"#,
67    argument(
68        name = "array",
69        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
70    ),
71    argument(
72        name = "desc",
73        description = "Whether to sort in ascending (`ASC`) or descending (`DESC`) order. The default is `ASC`."
74    ),
75    argument(
76        name = "nulls_first",
77        description = "Whether to sort nulls first (`NULLS FIRST`) or last (`NULLS LAST`). The default is `NULLS FIRST`."
78    )
79)]
80#[derive(Debug, PartialEq, Eq, Hash)]
81pub struct ArraySort {
82    signature: Signature,
83    aliases: Vec<String>,
84}
85
86impl Default for ArraySort {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92impl ArraySort {
93    pub fn new() -> Self {
94        Self {
95            signature: Signature::one_of(
96                vec![
97                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
98                        arguments: vec![ArrayFunctionArgument::Array],
99                        array_coercion: Some(ListCoercion::FixedSizedListToList),
100                    }),
101                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
102                        arguments: vec![
103                            ArrayFunctionArgument::Array,
104                            ArrayFunctionArgument::String,
105                        ],
106                        array_coercion: Some(ListCoercion::FixedSizedListToList),
107                    }),
108                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
109                        arguments: vec![
110                            ArrayFunctionArgument::Array,
111                            ArrayFunctionArgument::String,
112                            ArrayFunctionArgument::String,
113                        ],
114                        array_coercion: Some(ListCoercion::FixedSizedListToList),
115                    }),
116                ],
117                Volatility::Immutable,
118            ),
119            aliases: vec!["list_sort".to_string()],
120        }
121    }
122}
123
124impl ScalarUDFImpl for ArraySort {
125    fn name(&self) -> &str {
126        "array_sort"
127    }
128
129    fn signature(&self) -> &Signature {
130        &self.signature
131    }
132
133    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
134        Ok(arg_types[0].clone())
135    }
136
137    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
138        make_scalar_function(array_sort_inner)(&args.args)
139    }
140
141    fn aliases(&self) -> &[String] {
142        &self.aliases
143    }
144
145    fn documentation(&self) -> Option<&Documentation> {
146        self.doc()
147    }
148}
149
150fn array_sort_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
151    if args.is_empty() || args.len() > 3 {
152        return exec_err!("array_sort expects one to three arguments");
153    }
154
155    if args[0].is_empty() || args[0].data_type().is_null() {
156        return Ok(Arc::clone(&args[0]));
157    }
158
159    if args[1..].iter().any(|array| array.is_null(0)) {
160        return Ok(new_null_array(args[0].data_type(), args[0].len()));
161    }
162
163    let sort_options = if args.len() >= 2 {
164        let order = as_string_array(&args[1])?.value(0);
165        let descending = order_desc(order)?;
166        let nulls_first = if args.len() >= 3 {
167            order_nulls_first(as_string_array(&args[2])?.value(0))?
168        } else {
169            true
170        };
171        Some(SortOptions {
172            descending,
173            nulls_first,
174        })
175    } else {
176        None
177    };
178
179    match args[0].data_type() {
180        DataType::List(field) | DataType::LargeList(field)
181            if field.data_type().is_null() =>
182        {
183            Ok(Arc::clone(&args[0]))
184        }
185        DataType::List(field) => {
186            let array = as_list_array(&args[0])?;
187            array_sort_generic(array, Arc::clone(field), sort_options)
188        }
189        DataType::LargeList(field) => {
190            let array = as_large_list_array(&args[0])?;
191            array_sort_generic(array, Arc::clone(field), sort_options)
192        }
193        // Signature should prevent this arm ever occurring
194        _ => exec_err!("array_sort expects list for first argument"),
195    }
196}
197
198fn array_sort_generic<OffsetSize: OffsetSizeTrait>(
199    list_array: &GenericListArray<OffsetSize>,
200    field: FieldRef,
201    sort_options: Option<SortOptions>,
202) -> Result<ArrayRef> {
203    let values = list_array.values();
204
205    if values.data_type().is_primitive() {
206        array_sort_primitive(list_array, field, sort_options)
207    } else {
208        array_sort_non_primitive(list_array, field, sort_options)
209    }
210}
211
212/// Sort each row of a primitive-typed ListArray using a custom in-place sort
213/// kernel.
214fn array_sort_primitive<OffsetSize: OffsetSizeTrait>(
215    list_array: &GenericListArray<OffsetSize>,
216    field: FieldRef,
217    sort_options: Option<SortOptions>,
218) -> Result<ArrayRef> {
219    let values = list_array.values().as_ref();
220    downcast_primitive_array! {
221        values => sort_primitive_list(values, list_array, field, sort_options),
222        _ => exec_err!("array_sort: unsupported primitive type")
223    }
224}
225
226fn sort_primitive_list<T: ArrowPrimitiveType, OffsetSize: OffsetSizeTrait>(
227    prim_values: &PrimitiveArray<T>,
228    list_array: &GenericListArray<OffsetSize>,
229    field: FieldRef,
230    sort_options: Option<SortOptions>,
231) -> Result<ArrayRef>
232where
233    T::Native: ArrowNativeTypeOp,
234{
235    if prim_values.null_count() > 0 {
236        sort_list_with_nulls(prim_values, list_array, field, sort_options)
237    } else {
238        sort_list_no_nulls(prim_values, list_array, field, sort_options)
239    }
240}
241
242/// Fast path for primitive values with no element-level nulls. Copies all
243/// values into a single `Vec` and sorts each row's slice in-place.
244fn sort_list_no_nulls<T: ArrowPrimitiveType, OffsetSize: OffsetSizeTrait>(
245    prim_values: &PrimitiveArray<T>,
246    list_array: &GenericListArray<OffsetSize>,
247    field: FieldRef,
248    sort_options: Option<SortOptions>,
249) -> Result<ArrayRef>
250where
251    T::Native: ArrowNativeTypeOp,
252{
253    let row_count = list_array.len();
254    let offsets = list_array.offsets();
255    let values_start = offsets[0].as_usize();
256    let values_end = offsets[row_count].as_usize();
257
258    let descending = sort_options.is_some_and(|o| o.descending);
259
260    // Copy all values into a mutable buffer
261    let mut values: Vec<T::Native> =
262        prim_values.values()[values_start..values_end].to_vec();
263
264    for (row_index, window) in offsets.windows(2).enumerate() {
265        if list_array.is_null(row_index) {
266            continue;
267        }
268        let start = window[0].as_usize() - values_start;
269        let end = window[1].as_usize() - values_start;
270        let slice = &mut values[start..end];
271        if descending {
272            slice.sort_unstable_by(|a, b| b.compare(*a));
273        } else {
274            slice.sort_unstable_by(|a, b| a.compare(*b));
275        }
276    }
277
278    let new_offsets = rebase_offsets(offsets);
279    let sorted_values = Arc::new(
280        PrimitiveArray::<T>::new(values.into(), None)
281            .with_data_type(prim_values.data_type().clone()),
282    );
283
284    Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
285        field,
286        new_offsets,
287        sorted_values,
288        list_array.nulls().cloned(),
289    )?))
290}
291
292/// Slow path for primitive values with element-level nulls.
293fn sort_list_with_nulls<T: ArrowPrimitiveType, OffsetSize: OffsetSizeTrait>(
294    prim_values: &PrimitiveArray<T>,
295    list_array: &GenericListArray<OffsetSize>,
296    field: FieldRef,
297    sort_options: Option<SortOptions>,
298) -> Result<ArrayRef>
299where
300    T::Native: ArrowNativeTypeOp,
301{
302    let row_count = list_array.len();
303    let offsets = list_array.offsets();
304    let values_start = offsets[0].as_usize();
305    let values_end = offsets[row_count].as_usize();
306    let total_values = values_end - values_start;
307
308    let descending = sort_options.is_some_and(|o| o.descending);
309    let nulls_first = sort_options.is_none_or(|o| o.nulls_first);
310
311    let mut out_values: Vec<T::Native> = vec![T::Native::default(); total_values];
312    let mut validity = BooleanBufferBuilder::new(total_values);
313
314    let src_nulls = prim_values.nulls().ok_or_else(|| {
315        internal_datafusion_err!(
316            "sort_list_with_nulls called but values have no null buffer"
317        )
318    })?;
319    let src_values = prim_values.values();
320
321    for (row_index, window) in offsets.windows(2).enumerate() {
322        let start = window[0].as_usize();
323        let end = window[1].as_usize();
324        let row_len = end - start;
325        let out_start = start - values_start;
326
327        if list_array.is_null(row_index) || row_len == 0 {
328            validity.append_n(row_len, false);
329            continue;
330        }
331
332        let null_count = src_nulls.slice(start, row_len).null_count();
333        let valid_count = row_len - null_count;
334
335        // Compact valid values directly into the target region of the output
336        // buffer: after nulls (if nulls_first) or at the start (if nulls_last).
337        let valid_offset = if nulls_first { null_count } else { 0 };
338        let mut write_pos = out_start + valid_offset;
339        for i in start..end {
340            if src_nulls.is_valid(i) {
341                out_values[write_pos] = src_values[i];
342                write_pos += 1;
343            }
344        }
345
346        let valid_slice = &mut out_values
347            [out_start + valid_offset..out_start + valid_offset + valid_count];
348        if descending {
349            valid_slice.sort_unstable_by(|a, b| b.compare(*a));
350        } else {
351            valid_slice.sort_unstable_by(|a, b| a.compare(*b));
352        }
353
354        // Build validity bits
355        if nulls_first {
356            validity.append_n(null_count, false);
357            validity.append_n(valid_count, true);
358        } else {
359            validity.append_n(valid_count, true);
360            validity.append_n(null_count, false);
361        }
362    }
363
364    let new_offsets = rebase_offsets(offsets);
365
366    let null_buffer = NullBuffer::from(validity.finish());
367    let sorted_values = Arc::new(
368        PrimitiveArray::<T>::new(out_values.into(), Some(null_buffer))
369            .with_data_type(prim_values.data_type().clone()),
370    );
371
372    Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
373        field,
374        new_offsets,
375        sorted_values,
376        list_array.nulls().cloned(),
377    )?))
378}
379
380/// Sort a non-pritive-typed ListArray by converting all rows at once using
381/// `RowConverter`, and then sort row indices by comparing encoded bytes (sort
382/// direction and null ordering are baked into the encoding), and materialize
383/// the result with a single `take()`.
384fn array_sort_non_primitive<OffsetSize: OffsetSizeTrait>(
385    list_array: &GenericListArray<OffsetSize>,
386    field: FieldRef,
387    sort_options: Option<SortOptions>,
388) -> Result<ArrayRef> {
389    let row_count = list_array.len();
390    let values = list_array.values();
391    let offsets = list_array.offsets();
392    let values_start = offsets[0].as_usize();
393    let total_values = offsets[row_count].as_usize() - values_start;
394
395    let converter = RowConverter::new(vec![SortField::new_with_options(
396        values.data_type().clone(),
397        sort_options.unwrap_or_default(),
398    )])?;
399    let values_sliced = values.slice(values_start, total_values);
400    let rows = converter.convert_columns(&[Arc::clone(&values_sliced)])?;
401
402    let mut indices: Vec<OffsetSize> = Vec::with_capacity(total_values);
403    let mut new_offsets = Vec::with_capacity(row_count + 1);
404    new_offsets.push(OffsetSize::usize_as(0));
405
406    let mut sort_scratch: Vec<usize> = Vec::new();
407
408    for (row_index, window) in offsets.windows(2).enumerate() {
409        let start = window[0];
410        let end = window[1];
411
412        if list_array.is_null(row_index) {
413            new_offsets.push(new_offsets[row_index]);
414            continue;
415        }
416
417        let len = (end - start).as_usize();
418        let local_start = start.as_usize() - values_start;
419
420        if len <= 1 {
421            indices.extend((local_start..local_start + len).map(OffsetSize::usize_as));
422        } else {
423            sort_scratch.clear();
424            sort_scratch.extend(local_start..local_start + len);
425            sort_scratch.sort_unstable_by(|&a, &b| rows.row(a).cmp(&rows.row(b)));
426            indices.extend(sort_scratch.iter().map(|&i| OffsetSize::usize_as(i)));
427        }
428
429        new_offsets.push(new_offsets[row_index] + (end - start));
430    }
431
432    let sorted_values = if indices.is_empty() {
433        new_empty_array(values.data_type())
434    } else {
435        take_by_indices(&values_sliced, indices)?
436    };
437
438    Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
439        field,
440        OffsetBuffer::<OffsetSize>::new(new_offsets.into()),
441        sorted_values,
442        list_array.nulls().cloned(),
443    )?))
444}
445
446/// Select elements from `values` at the given `indices` using `compute::take`.
447/// We consume `indices` in order to avoid an intermediate copy.
448fn take_by_indices<OffsetSize: OffsetSizeTrait>(
449    values: &ArrayRef,
450    indices: Vec<OffsetSize>,
451) -> Result<ArrayRef> {
452    let len = indices.len();
453    let buffer = arrow::buffer::Buffer::from_vec(indices);
454    let indices_array: ArrayRef = if OffsetSize::IS_LARGE {
455        Arc::new(UInt64Array::new(
456            arrow::buffer::ScalarBuffer::new(buffer, 0, len),
457            None,
458        ))
459    } else {
460        Arc::new(UInt32Array::new(
461            arrow::buffer::ScalarBuffer::new(buffer, 0, len),
462            None,
463        ))
464    };
465    Ok(compute::take(values.as_ref(), &indices_array, None)?)
466}
467
468/// Rebase offsets so they start at 0. For non-sliced ListArrays (the common
469/// case) offsets already start at 0 and we can clone the Arc-backed buffer
470/// cheaply instead of allocating a new Vec.
471fn rebase_offsets<OffsetSize: OffsetSizeTrait>(
472    offsets: &OffsetBuffer<OffsetSize>,
473) -> OffsetBuffer<OffsetSize> {
474    if offsets[0].as_usize() == 0 {
475        offsets.clone()
476    } else {
477        let rebased: Vec<OffsetSize> = offsets.iter().map(|o| *o - offsets[0]).collect();
478        OffsetBuffer::new(rebased.into())
479    }
480}
481
482fn order_desc(modifier: &str) -> Result<bool> {
483    match modifier.to_uppercase().as_str() {
484        "DESC" => Ok(true),
485        "ASC" => Ok(false),
486        _ => exec_err!("the second parameter of array_sort expects DESC or ASC"),
487    }
488}
489
490fn order_nulls_first(modifier: &str) -> Result<bool> {
491    match modifier.to_uppercase().as_str() {
492        "NULLS FIRST" => Ok(true),
493        "NULLS LAST" => Ok(false),
494        _ => exec_err!(
495            "the third parameter of array_sort expects NULLS FIRST or NULLS LAST"
496        ),
497    }
498}