datafusion_functions_nested/
dimension.rs1use arrow::array::{
21 Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait, UInt64Array,
22};
23use arrow::datatypes::{
24 DataType,
25 DataType::{FixedSizeList, LargeList, List, UInt64},
26 Field, UInt64Type,
27};
28use std::any::Any;
29
30use datafusion_common::cast::{as_large_list_array, as_list_array};
31use datafusion_common::{exec_err, plan_err, utils::take_function_args, Result};
32
33use crate::utils::{compute_array_dims, make_scalar_function};
34use datafusion_expr::{
35 ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
36};
37use datafusion_macros::user_doc;
38use std::sync::Arc;
39
40make_udf_expr_and_func!(
41 ArrayDims,
42 array_dims,
43 array,
44 "returns an array of the array's dimensions.",
45 array_dims_udf
46);
47
48#[user_doc(
49 doc_section(label = "Array Functions"),
50 description = "Returns an array of the array's dimensions.",
51 syntax_example = "array_dims(array)",
52 sql_example = r#"```sql
53> select array_dims([[1, 2, 3], [4, 5, 6]]);
54+---------------------------------+
55| array_dims(List([1,2,3,4,5,6])) |
56+---------------------------------+
57| [2, 3] |
58+---------------------------------+
59```"#,
60 argument(
61 name = "array",
62 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
63 )
64)]
65#[derive(Debug)]
66pub struct ArrayDims {
67 signature: Signature,
68 aliases: Vec<String>,
69}
70
71impl Default for ArrayDims {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77impl ArrayDims {
78 pub fn new() -> Self {
79 Self {
80 signature: Signature::array(Volatility::Immutable),
81 aliases: vec!["list_dims".to_string()],
82 }
83 }
84}
85
86impl ScalarUDFImpl for ArrayDims {
87 fn as_any(&self) -> &dyn Any {
88 self
89 }
90 fn name(&self) -> &str {
91 "array_dims"
92 }
93
94 fn signature(&self) -> &Signature {
95 &self.signature
96 }
97
98 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
99 Ok(match arg_types[0] {
100 List(_) | LargeList(_) | FixedSizeList(_, _) => {
101 List(Arc::new(Field::new_list_field(UInt64, true)))
102 }
103 _ => {
104 return plan_err!("The array_dims function can only accept List/LargeList/FixedSizeList.");
105 }
106 })
107 }
108
109 fn invoke_with_args(
110 &self,
111 args: datafusion_expr::ScalarFunctionArgs,
112 ) -> Result<ColumnarValue> {
113 make_scalar_function(array_dims_inner)(&args.args)
114 }
115
116 fn aliases(&self) -> &[String] {
117 &self.aliases
118 }
119
120 fn documentation(&self) -> Option<&Documentation> {
121 self.doc()
122 }
123}
124
125make_udf_expr_and_func!(
126 ArrayNdims,
127 array_ndims,
128 array,
129 "returns the number of dimensions of the array.",
130 array_ndims_udf
131);
132
133#[user_doc(
134 doc_section(label = "Array Functions"),
135 description = "Returns the number of dimensions of the array.",
136 syntax_example = "array_ndims(array, element)",
137 sql_example = r#"```sql
138> select array_ndims([[1, 2, 3], [4, 5, 6]]);
139+----------------------------------+
140| array_ndims(List([1,2,3,4,5,6])) |
141+----------------------------------+
142| 2 |
143+----------------------------------+
144```"#,
145 argument(
146 name = "array",
147 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
148 ),
149 argument(name = "element", description = "Array element.")
150)]
151#[derive(Debug)]
152pub(super) struct ArrayNdims {
153 signature: Signature,
154 aliases: Vec<String>,
155}
156impl ArrayNdims {
157 pub fn new() -> Self {
158 Self {
159 signature: Signature::array(Volatility::Immutable),
160 aliases: vec![String::from("list_ndims")],
161 }
162 }
163}
164
165impl ScalarUDFImpl for ArrayNdims {
166 fn as_any(&self) -> &dyn Any {
167 self
168 }
169 fn name(&self) -> &str {
170 "array_ndims"
171 }
172
173 fn signature(&self) -> &Signature {
174 &self.signature
175 }
176
177 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
178 Ok(match arg_types[0] {
179 List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64,
180 _ => {
181 return plan_err!("The array_ndims function can only accept List/LargeList/FixedSizeList.");
182 }
183 })
184 }
185
186 fn invoke_with_args(
187 &self,
188 args: datafusion_expr::ScalarFunctionArgs,
189 ) -> Result<ColumnarValue> {
190 make_scalar_function(array_ndims_inner)(&args.args)
191 }
192
193 fn aliases(&self) -> &[String] {
194 &self.aliases
195 }
196
197 fn documentation(&self) -> Option<&Documentation> {
198 self.doc()
199 }
200}
201
202pub fn array_dims_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
204 let [array] = take_function_args("array_dims", args)?;
205
206 let data = match array.data_type() {
207 List(_) => {
208 let array = as_list_array(&array)?;
209 array
210 .iter()
211 .map(compute_array_dims)
212 .collect::<Result<Vec<_>>>()?
213 }
214 LargeList(_) => {
215 let array = as_large_list_array(&array)?;
216 array
217 .iter()
218 .map(compute_array_dims)
219 .collect::<Result<Vec<_>>>()?
220 }
221 array_type => {
222 return exec_err!("array_dims does not support type '{array_type:?}'");
223 }
224 };
225
226 let result = ListArray::from_iter_primitive::<UInt64Type, _, _>(data);
227
228 Ok(Arc::new(result) as ArrayRef)
229}
230
231pub fn array_ndims_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
233 let [array_dim] = take_function_args("array_ndims", args)?;
234
235 fn general_list_ndims<O: OffsetSizeTrait>(
236 array: &GenericListArray<O>,
237 ) -> Result<ArrayRef> {
238 let mut data = Vec::new();
239 let ndims = datafusion_common::utils::list_ndims(array.data_type());
240
241 for arr in array.iter() {
242 if arr.is_some() {
243 data.push(Some(ndims))
244 } else {
245 data.push(None)
246 }
247 }
248
249 Ok(Arc::new(UInt64Array::from(data)) as ArrayRef)
250 }
251 match array_dim.data_type() {
252 List(_) => {
253 let array = as_list_array(&array_dim)?;
254 general_list_ndims::<i32>(array)
255 }
256 LargeList(_) => {
257 let array = as_large_list_array(&array_dim)?;
258 general_list_ndims::<i64>(array)
259 }
260 array_type => exec_err!("array_ndims does not support type {array_type:?}"),
261 }
262}