datafusion_functions_nested/
flatten.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for flatten function.
19
20use crate::utils::make_scalar_function;
21use arrow::array::{ArrayRef, GenericListArray, OffsetSizeTrait};
22use arrow::buffer::OffsetBuffer;
23use arrow::datatypes::{
24    DataType,
25    DataType::{FixedSizeList, LargeList, List, Null},
26};
27use datafusion_common::cast::{
28    as_generic_list_array, as_large_list_array, as_list_array,
29};
30use datafusion_common::{exec_err, utils::take_function_args, Result};
31use datafusion_expr::{
32    ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature,
33    TypeSignature, Volatility,
34};
35use datafusion_macros::user_doc;
36use std::any::Any;
37use std::sync::Arc;
38
39make_udf_expr_and_func!(
40    Flatten,
41    flatten,
42    array,
43    "flattens an array of arrays into a single array.",
44    flatten_udf
45);
46
47#[user_doc(
48    doc_section(label = "Array Functions"),
49    description = "Converts an array of arrays to a flat array.\n\n- Applies to any depth of nested arrays\n- Does not change arrays that are already flat\n\nThe flattened array contains all the elements from all source arrays.",
50    syntax_example = "flatten(array)",
51    sql_example = r#"```sql
52> select flatten([[1, 2], [3, 4]]);
53+------------------------------+
54| flatten(List([1,2], [3,4]))  |
55+------------------------------+
56| [1, 2, 3, 4]                 |
57+------------------------------+
58```"#,
59    argument(
60        name = "array",
61        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
62    )
63)]
64#[derive(Debug)]
65pub struct Flatten {
66    signature: Signature,
67    aliases: Vec<String>,
68}
69
70impl Default for Flatten {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76impl Flatten {
77    pub fn new() -> Self {
78        Self {
79            signature: Signature {
80                // TODO (https://github.com/apache/datafusion/issues/13757) flatten should be single-step, not recursive
81                type_signature: TypeSignature::ArraySignature(
82                    ArrayFunctionSignature::RecursiveArray,
83                ),
84                volatility: Volatility::Immutable,
85            },
86            aliases: vec![],
87        }
88    }
89}
90
91impl ScalarUDFImpl for Flatten {
92    fn as_any(&self) -> &dyn Any {
93        self
94    }
95
96    fn name(&self) -> &str {
97        "flatten"
98    }
99
100    fn signature(&self) -> &Signature {
101        &self.signature
102    }
103
104    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
105        fn get_base_type(data_type: &DataType) -> Result<DataType> {
106            match data_type {
107                List(field) | FixedSizeList(field, _)
108                    if matches!(field.data_type(), List(_) | FixedSizeList(_, _)) =>
109                {
110                    get_base_type(field.data_type())
111                }
112                LargeList(field) if matches!(field.data_type(), LargeList(_)) => {
113                    get_base_type(field.data_type())
114                }
115                Null | List(_) | LargeList(_) => Ok(data_type.to_owned()),
116                FixedSizeList(field, _) => Ok(List(Arc::clone(field))),
117                _ => exec_err!(
118                    "Not reachable, data_type should be List, LargeList or FixedSizeList"
119                ),
120            }
121        }
122
123        let data_type = get_base_type(&arg_types[0])?;
124        Ok(data_type)
125    }
126
127    fn invoke_with_args(
128        &self,
129        args: datafusion_expr::ScalarFunctionArgs,
130    ) -> Result<ColumnarValue> {
131        make_scalar_function(flatten_inner)(&args.args)
132    }
133
134    fn aliases(&self) -> &[String] {
135        &self.aliases
136    }
137
138    fn documentation(&self) -> Option<&Documentation> {
139        self.doc()
140    }
141}
142
143/// Flatten SQL function
144pub fn flatten_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
145    let [array] = take_function_args("flatten", args)?;
146
147    match array.data_type() {
148        List(_) => {
149            let list_arr = as_list_array(&array)?;
150            let flattened_array = flatten_internal::<i32>(list_arr.clone(), None)?;
151            Ok(Arc::new(flattened_array) as ArrayRef)
152        }
153        LargeList(_) => {
154            let list_arr = as_large_list_array(&array)?;
155            let flattened_array = flatten_internal::<i64>(list_arr.clone(), None)?;
156            Ok(Arc::new(flattened_array) as ArrayRef)
157        }
158        Null => Ok(Arc::clone(array)),
159        _ => {
160            exec_err!("flatten does not support type '{:?}'", array.data_type())
161        }
162    }
163}
164
165fn flatten_internal<O: OffsetSizeTrait>(
166    list_arr: GenericListArray<O>,
167    indexes: Option<OffsetBuffer<O>>,
168) -> Result<GenericListArray<O>> {
169    let (field, offsets, values, _) = list_arr.clone().into_parts();
170    let data_type = field.data_type();
171
172    match data_type {
173        // Recursively get the base offsets for flattened array
174        List(_) | LargeList(_) => {
175            let sub_list = as_generic_list_array::<O>(&values)?;
176            if let Some(indexes) = indexes {
177                let offsets = get_offsets_for_flatten(offsets, indexes);
178                flatten_internal::<O>(sub_list.clone(), Some(offsets))
179            } else {
180                flatten_internal::<O>(sub_list.clone(), Some(offsets))
181            }
182        }
183        // Reach the base level, create a new list array
184        _ => {
185            if let Some(indexes) = indexes {
186                let offsets = get_offsets_for_flatten(offsets, indexes);
187                let list_arr = GenericListArray::<O>::new(field, offsets, values, None);
188                Ok(list_arr)
189            } else {
190                Ok(list_arr)
191            }
192        }
193    }
194}
195
196// Create new offsets that are equivalent to `flatten` the array.
197fn get_offsets_for_flatten<O: OffsetSizeTrait>(
198    offsets: OffsetBuffer<O>,
199    indexes: OffsetBuffer<O>,
200) -> OffsetBuffer<O> {
201    let buffer = offsets.into_inner();
202    let offsets: Vec<O> = indexes
203        .iter()
204        .map(|i| buffer[i.to_usize().unwrap()])
205        .collect();
206    OffsetBuffer::new(offsets.into())
207}