datafusion_functions_nested/
flatten.rs1use crate::utils::make_scalar_function;
21use arrow::array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait};
22use arrow::buffer::OffsetBuffer;
23use arrow::datatypes::{
24 DataType,
25 DataType::{FixedSizeList, LargeList, List, Null},
26};
27use datafusion_common::cast::{as_large_list_array, as_list_array};
28use datafusion_common::utils::ListCoercion;
29use datafusion_common::{exec_err, utils::take_function_args, Result};
30use datafusion_expr::{
31 ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
32 ScalarUDFImpl, Signature, TypeSignature, Volatility,
33};
34use datafusion_macros::user_doc;
35use std::any::Any;
36use std::sync::Arc;
37
38make_udf_expr_and_func!(
39 Flatten,
40 flatten,
41 array,
42 "flattens an array of arrays into a single array.",
43 flatten_udf
44);
45
46#[user_doc(
47 doc_section(label = "Array Functions"),
48 description = "Converts an array of arrays to a flat array.\n\n- Applies to any depth of nested arrays\n- Does not change arrays that are already flat\n\nThe flattened array contains all the elements from all source arrays.",
49 syntax_example = "flatten(array)",
50 sql_example = r#"```sql
51> select flatten([[1, 2], [3, 4]]);
52+------------------------------+
53| flatten(List([1,2], [3,4])) |
54+------------------------------+
55| [1, 2, 3, 4] |
56+------------------------------+
57```"#,
58 argument(
59 name = "array",
60 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
61 )
62)]
63#[derive(Debug)]
64pub struct Flatten {
65 signature: Signature,
66 aliases: Vec<String>,
67}
68
69impl Default for Flatten {
70 fn default() -> Self {
71 Self::new()
72 }
73}
74
75impl Flatten {
76 pub fn new() -> Self {
77 Self {
78 signature: Signature {
79 type_signature: TypeSignature::ArraySignature(
80 ArrayFunctionSignature::Array {
81 arguments: vec![ArrayFunctionArgument::Array],
82 array_coercion: Some(ListCoercion::FixedSizedListToList),
83 },
84 ),
85 volatility: Volatility::Immutable,
86 },
87 aliases: vec![],
88 }
89 }
90}
91
92impl ScalarUDFImpl for Flatten {
93 fn as_any(&self) -> &dyn Any {
94 self
95 }
96
97 fn name(&self) -> &str {
98 "flatten"
99 }
100
101 fn signature(&self) -> &Signature {
102 &self.signature
103 }
104
105 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
106 let data_type = match &arg_types[0] {
107 List(field) | FixedSizeList(field, _) => match field.data_type() {
108 List(field) | FixedSizeList(field, _) => List(Arc::clone(field)),
109 _ => arg_types[0].clone(),
110 },
111 LargeList(field) => match field.data_type() {
112 List(field) | LargeList(field) | FixedSizeList(field, _) => {
113 LargeList(Arc::clone(field))
114 }
115 _ => arg_types[0].clone(),
116 },
117 Null => Null,
118 _ => exec_err!(
119 "Not reachable, data_type should be List, LargeList or FixedSizeList"
120 )?,
121 };
122
123 Ok(data_type)
124 }
125
126 fn invoke_with_args(
127 &self,
128 args: datafusion_expr::ScalarFunctionArgs,
129 ) -> Result<ColumnarValue> {
130 make_scalar_function(flatten_inner)(&args.args)
131 }
132
133 fn aliases(&self) -> &[String] {
134 &self.aliases
135 }
136
137 fn documentation(&self) -> Option<&Documentation> {
138 self.doc()
139 }
140}
141
142pub fn flatten_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
144 let [array] = take_function_args("flatten", args)?;
145
146 match array.data_type() {
147 List(_) => {
148 let (_field, offsets, values, nulls) =
149 as_list_array(&array)?.clone().into_parts();
150 let values = cast_fsl_to_list(values)?;
151
152 match values.data_type() {
153 List(_) => {
154 let (inner_field, inner_offsets, inner_values, _) =
155 as_list_array(&values)?.clone().into_parts();
156 let offsets = get_offsets_for_flatten::<i32>(inner_offsets, offsets);
157 let flattened_array = GenericListArray::<i32>::new(
158 inner_field,
159 offsets,
160 inner_values,
161 nulls,
162 );
163
164 Ok(Arc::new(flattened_array) as ArrayRef)
165 }
166 LargeList(_) => {
167 exec_err!("flatten does not support type '{:?}'", array.data_type())?
168 }
169 _ => Ok(Arc::clone(array) as ArrayRef),
170 }
171 }
172 LargeList(_) => {
173 let (_field, offsets, values, nulls) =
174 as_large_list_array(&array)?.clone().into_parts();
175 let values = cast_fsl_to_list(values)?;
176
177 match values.data_type() {
178 List(_) => {
179 let (inner_field, inner_offsets, inner_values, _) =
180 as_list_array(&values)?.clone().into_parts();
181 let offsets = get_large_offsets_for_flatten(inner_offsets, offsets);
182 let flattened_array = GenericListArray::<i64>::new(
183 inner_field,
184 offsets,
185 inner_values,
186 nulls,
187 );
188
189 Ok(Arc::new(flattened_array) as ArrayRef)
190 }
191 LargeList(_) => {
192 let (inner_field, inner_offsets, inner_values, nulls) =
193 as_large_list_array(&values)?.clone().into_parts();
194 let offsets = get_offsets_for_flatten::<i64>(inner_offsets, offsets);
195 let flattened_array = GenericListArray::<i64>::new(
196 inner_field,
197 offsets,
198 inner_values,
199 nulls,
200 );
201
202 Ok(Arc::new(flattened_array) as ArrayRef)
203 }
204 _ => Ok(Arc::clone(array) as ArrayRef),
205 }
206 }
207 Null => Ok(Arc::clone(array)),
208 _ => {
209 exec_err!("flatten does not support type '{:?}'", array.data_type())
210 }
211 }
212}
213
214fn get_offsets_for_flatten<O: OffsetSizeTrait>(
216 offsets: OffsetBuffer<O>,
217 indexes: OffsetBuffer<O>,
218) -> OffsetBuffer<O> {
219 let buffer = offsets.into_inner();
220 let offsets: Vec<O> = indexes
221 .iter()
222 .map(|i| buffer[i.to_usize().unwrap()])
223 .collect();
224 OffsetBuffer::new(offsets.into())
225}
226
227fn get_large_offsets_for_flatten<O: OffsetSizeTrait, P: OffsetSizeTrait>(
229 offsets: OffsetBuffer<O>,
230 indexes: OffsetBuffer<P>,
231) -> OffsetBuffer<i64> {
232 let buffer = offsets.into_inner();
233 let offsets: Vec<i64> = indexes
234 .iter()
235 .map(|i| buffer[i.to_usize().unwrap()].to_i64().unwrap())
236 .collect();
237 OffsetBuffer::new(offsets.into())
238}
239
240fn cast_fsl_to_list(array: ArrayRef) -> Result<ArrayRef> {
241 match array.data_type() {
242 FixedSizeList(field, _) => {
243 Ok(arrow::compute::cast(&array, &List(Arc::clone(field)))?)
244 }
245 _ => Ok(array),
246 }
247}