datafusion_spark/function/string/
length.rs1use arrow::array::{
19 Array, ArrayRef, AsArray, BinaryArrayType, PrimitiveArray, StringArrayType,
20};
21use arrow::datatypes::{DataType, Field, FieldRef, Int32Type};
22use datafusion_common::exec_err;
23use datafusion_expr::{
24 ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
25 Volatility,
26};
27use datafusion_functions::utils::make_scalar_function;
28use std::sync::Arc;
29
30#[derive(Debug, PartialEq, Eq, Hash)]
33pub struct SparkLengthFunc {
34 signature: Signature,
35 aliases: Vec<String>,
36}
37
38impl Default for SparkLengthFunc {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl SparkLengthFunc {
45 pub fn new() -> Self {
46 Self {
47 signature: Signature::uniform(
48 1,
49 vec![
50 DataType::Utf8View,
51 DataType::Utf8,
52 DataType::LargeUtf8,
53 DataType::Binary,
54 DataType::LargeBinary,
55 DataType::BinaryView,
56 ],
57 Volatility::Immutable,
58 ),
59 aliases: vec![
60 String::from("character_length"),
61 String::from("char_length"),
62 String::from("len"),
63 ],
64 }
65 }
66}
67
68impl ScalarUDFImpl for SparkLengthFunc {
69 fn as_any(&self) -> &dyn std::any::Any {
70 self
71 }
72
73 fn name(&self) -> &str {
74 "length"
75 }
76
77 fn signature(&self) -> &Signature {
78 &self.signature
79 }
80
81 fn return_type(&self, _args: &[DataType]) -> datafusion_common::Result<DataType> {
82 datafusion_common::internal_err!(
83 "return_type should not be called, use return_field_from_args instead"
84 )
85 }
86
87 fn invoke_with_args(
88 &self,
89 args: ScalarFunctionArgs,
90 ) -> datafusion_common::Result<ColumnarValue> {
91 make_scalar_function(spark_length, vec![])(&args.args)
92 }
93
94 fn aliases(&self) -> &[String] {
95 &self.aliases
96 }
97
98 fn return_field_from_args(
99 &self,
100 args: ReturnFieldArgs,
101 ) -> datafusion_common::Result<FieldRef> {
102 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
103 Ok(Arc::new(Field::new(self.name(), DataType::Int32, nullable)))
105 }
106}
107
108fn spark_length(args: &[ArrayRef]) -> datafusion_common::Result<ArrayRef> {
109 match args[0].data_type() {
110 DataType::Utf8 => {
111 let string_array = args[0].as_string::<i32>();
112 character_length::<_>(&string_array)
113 }
114 DataType::LargeUtf8 => {
115 let string_array = args[0].as_string::<i64>();
116 character_length::<_>(&string_array)
117 }
118 DataType::Utf8View => {
119 let string_array = args[0].as_string_view();
120 character_length::<_>(&string_array)
121 }
122 DataType::Binary => {
123 let binary_array = args[0].as_binary::<i32>();
124 byte_length::<_>(&binary_array)
125 }
126 DataType::LargeBinary => {
127 let binary_array = args[0].as_binary::<i64>();
128 byte_length::<_>(&binary_array)
129 }
130 DataType::BinaryView => {
131 let binary_array = args[0].as_binary_view();
132 byte_length::<_>(&binary_array)
133 }
134 other => exec_err!("Unsupported data type {other:?} for function `length`"),
135 }
136}
137
138fn character_length<'a, V>(array: &V) -> datafusion_common::Result<ArrayRef>
139where
140 V: StringArrayType<'a>,
141{
142 let is_array_ascii_only = array.is_ascii();
147 let nulls = array.nulls().cloned();
148 let array = {
149 if is_array_ascii_only {
150 let values: Vec<_> = (0..array.len())
151 .map(|i| {
152 let value = unsafe { array.value_unchecked(i) };
154 value.len() as i32
155 })
156 .collect();
157 PrimitiveArray::<Int32Type>::new(values.into(), nulls)
158 } else {
159 let values: Vec<_> = (0..array.len())
160 .map(|i| {
161 if array.is_null(i) {
163 i32::default()
164 } else {
165 let value = unsafe { array.value_unchecked(i) };
166 if value.is_empty() {
167 i32::default()
168 } else if value.is_ascii() {
169 value.len() as i32
170 } else {
171 value.chars().count() as i32
172 }
173 }
174 })
175 .collect();
176 PrimitiveArray::<Int32Type>::new(values.into(), nulls)
177 }
178 };
179
180 Ok(Arc::new(array))
181}
182
183fn byte_length<'a, V>(array: &V) -> datafusion_common::Result<ArrayRef>
184where
185 V: BinaryArrayType<'a>,
186{
187 let nulls = array.nulls().cloned();
188 let values: Vec<_> = (0..array.len())
189 .map(|i| {
190 let value = unsafe { array.value_unchecked(i) };
192 value.len() as i32
193 })
194 .collect();
195 Ok(Arc::new(PrimitiveArray::<Int32Type>::new(
196 values.into(),
197 nulls,
198 )))
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use crate::function::utils::test::test_scalar_function;
205 use arrow::array::{Array, Int32Array};
206 use arrow::datatypes::DataType::Int32;
207 use arrow::datatypes::{Field, FieldRef};
208 use datafusion_common::{Result, ScalarValue};
209 use datafusion_expr::{ColumnarValue, ReturnFieldArgs, ScalarUDFImpl};
210
211 macro_rules! test_spark_length_string {
212 ($INPUT:expr, $EXPECTED:expr) => {
213 test_scalar_function!(
214 SparkLengthFunc::new(),
215 vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))],
216 $EXPECTED,
217 i32,
218 Int32,
219 Int32Array
220 );
221
222 test_scalar_function!(
223 SparkLengthFunc::new(),
224 vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))],
225 $EXPECTED,
226 i32,
227 Int32,
228 Int32Array
229 );
230
231 test_scalar_function!(
232 SparkLengthFunc::new(),
233 vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))],
234 $EXPECTED,
235 i32,
236 Int32,
237 Int32Array
238 );
239 };
240 }
241
242 macro_rules! test_spark_length_binary {
243 ($INPUT:expr, $EXPECTED:expr) => {
244 test_scalar_function!(
245 SparkLengthFunc::new(),
246 vec![ColumnarValue::Scalar(ScalarValue::Binary($INPUT))],
247 $EXPECTED,
248 i32,
249 Int32,
250 Int32Array
251 );
252
253 test_scalar_function!(
254 SparkLengthFunc::new(),
255 vec![ColumnarValue::Scalar(ScalarValue::LargeBinary($INPUT))],
256 $EXPECTED,
257 i32,
258 Int32,
259 Int32Array
260 );
261
262 test_scalar_function!(
263 SparkLengthFunc::new(),
264 vec![ColumnarValue::Scalar(ScalarValue::BinaryView($INPUT))],
265 $EXPECTED,
266 i32,
267 Int32,
268 Int32Array
269 );
270 };
271 }
272
273 #[test]
274 fn test_functions() -> Result<()> {
275 test_spark_length_string!(Some(String::from("chars")), Ok(Some(5)));
276 test_spark_length_string!(Some(String::from("josé")), Ok(Some(4)));
277 test_spark_length_string!(Some(String::from("joséjoséjoséjosé")), Ok(Some(16)));
279 test_spark_length_string!(Some(String::from("")), Ok(Some(0)));
280 test_spark_length_string!(None, Ok(None));
281
282 test_spark_length_binary!(Some(String::from("chars").into_bytes()), Ok(Some(5)));
283 test_spark_length_binary!(Some(String::from("josé").into_bytes()), Ok(Some(5)));
284 test_spark_length_binary!(
286 Some(String::from("joséjoséjoséjosé").into_bytes()),
287 Ok(Some(20))
288 );
289 test_spark_length_binary!(Some(String::from("").into_bytes()), Ok(Some(0)));
290 test_spark_length_binary!(None, Ok(None));
291
292 Ok(())
293 }
294
295 #[test]
296 fn test_spark_length_nullability() -> Result<()> {
297 let func = SparkLengthFunc::new();
298
299 let nullable_field: FieldRef = Arc::new(Field::new("col", DataType::Utf8, true));
300
301 let out_nullable = func.return_field_from_args(ReturnFieldArgs {
302 arg_fields: &[nullable_field],
303 scalar_arguments: &[None],
304 })?;
305
306 assert!(
307 out_nullable.is_nullable(),
308 "length(col) should be nullable when child is nullable"
309 );
310
311 let non_nullable_field: FieldRef =
312 Arc::new(Field::new("col", DataType::Utf8, false));
313
314 let out_non_nullable = func.return_field_from_args(ReturnFieldArgs {
315 arg_fields: &[non_nullable_field],
316 scalar_arguments: &[None],
317 })?;
318
319 assert!(
320 !out_non_nullable.is_nullable(),
321 "length(col) should NOT be nullable when child is NOT nullable"
322 );
323
324 Ok(())
325 }
326}