datafusion_spark/function/array/
spark_array.rs1use std::{any::Any, sync::Arc};
19
20use arrow::array::{Array, ArrayRef, new_null_array};
21use arrow::datatypes::{DataType, Field, FieldRef};
22use datafusion_common::utils::SingleRowListArrayBuilder;
23use datafusion_common::{Result, internal_err};
24use datafusion_expr::{
25 ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
26 TypeSignature, Volatility,
27};
28use datafusion_functions_nested::make_array::{array_array, coerce_types_inner};
29
30use crate::function::functions_nested_utils::make_scalar_function;
31
32const ARRAY_FIELD_DEFAULT_NAME: &str = "element";
33
34#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkArray {
36 signature: Signature,
37}
38
39impl Default for SparkArray {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl SparkArray {
46 pub fn new() -> Self {
47 Self {
48 signature: Signature::one_of(
49 vec![TypeSignature::UserDefined, TypeSignature::Nullary],
50 Volatility::Immutable,
51 ),
52 }
53 }
54}
55
56impl ScalarUDFImpl for SparkArray {
57 fn as_any(&self) -> &dyn Any {
58 self
59 }
60
61 fn name(&self) -> &str {
62 "array"
63 }
64
65 fn signature(&self) -> &Signature {
66 &self.signature
67 }
68
69 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
70 internal_err!("return_field_from_args should be used instead")
71 }
72
73 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
74 let data_types = args
75 .arg_fields
76 .iter()
77 .map(|f| f.data_type())
78 .cloned()
79 .collect::<Vec<_>>();
80
81 let mut expr_type = DataType::Null;
82 for arg_type in &data_types {
83 if !arg_type.equals_datatype(&DataType::Null) {
84 expr_type = arg_type.clone();
85 break;
86 }
87 }
88
89 let return_type = DataType::List(Arc::new(Field::new(
90 ARRAY_FIELD_DEFAULT_NAME,
91 expr_type,
92 true,
93 )));
94
95 Ok(Arc::new(Field::new(
96 "this_field_name_is_irrelevant",
97 return_type,
98 false,
99 )))
100 }
101
102 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
103 let ScalarFunctionArgs { args, .. } = args;
104 make_scalar_function(make_array_inner)(args.as_slice())
105 }
106
107 fn aliases(&self) -> &[String] {
108 &[]
109 }
110
111 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
112 coerce_types_inner(arg_types, self.name())
113 }
114}
115
116pub fn make_array_inner(arrays: &[ArrayRef]) -> Result<ArrayRef> {
120 let mut data_type = DataType::Null;
121 for arg in arrays {
122 let arg_data_type = arg.data_type();
123 if !arg_data_type.equals_datatype(&DataType::Null) {
124 data_type = arg_data_type.clone();
125 break;
126 }
127 }
128
129 match data_type {
130 DataType::Null => {
132 let length = arrays.iter().map(|a| a.len()).sum();
133 let array = new_null_array(&DataType::Null, length);
135 Ok(Arc::new(
136 SingleRowListArrayBuilder::new(array)
137 .with_nullable(true)
138 .with_field_name(Some(ARRAY_FIELD_DEFAULT_NAME.to_string()))
139 .build_list_array(),
140 ))
141 }
142 _ => array_array::<i32>(arrays, data_type, ARRAY_FIELD_DEFAULT_NAME),
143 }
144}