datafusion_functions_nested/
flatten.rs
1use crate::utils::make_scalar_function;
21use arrow::array::{ArrayRef, GenericListArray, OffsetSizeTrait};
22use arrow::buffer::OffsetBuffer;
23use arrow::datatypes::{
24 DataType,
25 DataType::{FixedSizeList, LargeList, List, Null},
26};
27use datafusion_common::cast::{
28 as_generic_list_array, as_large_list_array, as_list_array,
29};
30use datafusion_common::{exec_err, utils::take_function_args, Result};
31use datafusion_expr::{
32 ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature,
33 TypeSignature, Volatility,
34};
35use datafusion_macros::user_doc;
36use std::any::Any;
37use std::sync::Arc;
38
39make_udf_expr_and_func!(
40 Flatten,
41 flatten,
42 array,
43 "flattens an array of arrays into a single array.",
44 flatten_udf
45);
46
47#[user_doc(
48 doc_section(label = "Array Functions"),
49 description = "Converts an array of arrays to a flat array.\n\n- Applies to any depth of nested arrays\n- Does not change arrays that are already flat\n\nThe flattened array contains all the elements from all source arrays.",
50 syntax_example = "flatten(array)",
51 sql_example = r#"```sql
52> select flatten([[1, 2], [3, 4]]);
53+------------------------------+
54| flatten(List([1,2], [3,4])) |
55+------------------------------+
56| [1, 2, 3, 4] |
57+------------------------------+
58```"#,
59 argument(
60 name = "array",
61 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
62 )
63)]
64#[derive(Debug)]
65pub struct Flatten {
66 signature: Signature,
67 aliases: Vec<String>,
68}
69
70impl Default for Flatten {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76impl Flatten {
77 pub fn new() -> Self {
78 Self {
79 signature: Signature {
80 type_signature: TypeSignature::ArraySignature(
82 ArrayFunctionSignature::RecursiveArray,
83 ),
84 volatility: Volatility::Immutable,
85 },
86 aliases: vec![],
87 }
88 }
89}
90
91impl ScalarUDFImpl for Flatten {
92 fn as_any(&self) -> &dyn Any {
93 self
94 }
95
96 fn name(&self) -> &str {
97 "flatten"
98 }
99
100 fn signature(&self) -> &Signature {
101 &self.signature
102 }
103
104 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
105 fn get_base_type(data_type: &DataType) -> Result<DataType> {
106 match data_type {
107 List(field) | FixedSizeList(field, _)
108 if matches!(field.data_type(), List(_) | FixedSizeList(_, _)) =>
109 {
110 get_base_type(field.data_type())
111 }
112 LargeList(field) if matches!(field.data_type(), LargeList(_)) => {
113 get_base_type(field.data_type())
114 }
115 Null | List(_) | LargeList(_) => Ok(data_type.to_owned()),
116 FixedSizeList(field, _) => Ok(List(Arc::clone(field))),
117 _ => exec_err!(
118 "Not reachable, data_type should be List, LargeList or FixedSizeList"
119 ),
120 }
121 }
122
123 let data_type = get_base_type(&arg_types[0])?;
124 Ok(data_type)
125 }
126
127 fn invoke_with_args(
128 &self,
129 args: datafusion_expr::ScalarFunctionArgs,
130 ) -> Result<ColumnarValue> {
131 make_scalar_function(flatten_inner)(&args.args)
132 }
133
134 fn aliases(&self) -> &[String] {
135 &self.aliases
136 }
137
138 fn documentation(&self) -> Option<&Documentation> {
139 self.doc()
140 }
141}
142
143pub fn flatten_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
145 let [array] = take_function_args("flatten", args)?;
146
147 match array.data_type() {
148 List(_) => {
149 let list_arr = as_list_array(&array)?;
150 let flattened_array = flatten_internal::<i32>(list_arr.clone(), None)?;
151 Ok(Arc::new(flattened_array) as ArrayRef)
152 }
153 LargeList(_) => {
154 let list_arr = as_large_list_array(&array)?;
155 let flattened_array = flatten_internal::<i64>(list_arr.clone(), None)?;
156 Ok(Arc::new(flattened_array) as ArrayRef)
157 }
158 Null => Ok(Arc::clone(array)),
159 _ => {
160 exec_err!("flatten does not support type '{:?}'", array.data_type())
161 }
162 }
163}
164
165fn flatten_internal<O: OffsetSizeTrait>(
166 list_arr: GenericListArray<O>,
167 indexes: Option<OffsetBuffer<O>>,
168) -> Result<GenericListArray<O>> {
169 let (field, offsets, values, _) = list_arr.clone().into_parts();
170 let data_type = field.data_type();
171
172 match data_type {
173 List(_) | LargeList(_) => {
175 let sub_list = as_generic_list_array::<O>(&values)?;
176 if let Some(indexes) = indexes {
177 let offsets = get_offsets_for_flatten(offsets, indexes);
178 flatten_internal::<O>(sub_list.clone(), Some(offsets))
179 } else {
180 flatten_internal::<O>(sub_list.clone(), Some(offsets))
181 }
182 }
183 _ => {
185 if let Some(indexes) = indexes {
186 let offsets = get_offsets_for_flatten(offsets, indexes);
187 let list_arr = GenericListArray::<O>::new(field, offsets, values, None);
188 Ok(list_arr)
189 } else {
190 Ok(list_arr)
191 }
192 }
193 }
194}
195
196fn get_offsets_for_flatten<O: OffsetSizeTrait>(
198 offsets: OffsetBuffer<O>,
199 indexes: OffsetBuffer<O>,
200) -> OffsetBuffer<O> {
201 let buffer = offsets.into_inner();
202 let offsets: Vec<O> = indexes
203 .iter()
204 .map(|i| buffer[i.to_usize().unwrap()])
205 .collect();
206 OffsetBuffer::new(offsets.into())
207}