datafusion_spark/function/string/
length.rs1use arrow::array::{
19 Array, ArrayRef, AsArray, BinaryArrayType, PrimitiveArray, StringArrayType,
20};
21use arrow::datatypes::{DataType, Field, FieldRef, Int32Type};
22use datafusion_common::exec_err;
23use datafusion_expr::{
24 ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
25 Volatility,
26};
27use datafusion_functions::utils::make_scalar_function;
28use std::sync::Arc;
29
30#[derive(Debug, PartialEq, Eq, Hash)]
33pub struct SparkLengthFunc {
34 signature: Signature,
35 aliases: Vec<String>,
36}
37
38impl Default for SparkLengthFunc {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl SparkLengthFunc {
45 pub fn new() -> Self {
46 Self {
47 signature: Signature::uniform(
48 1,
49 vec![
50 DataType::Utf8View,
51 DataType::Utf8,
52 DataType::LargeUtf8,
53 DataType::Binary,
54 DataType::LargeBinary,
55 DataType::BinaryView,
56 ],
57 Volatility::Immutable,
58 ),
59 aliases: vec![
60 String::from("character_length"),
61 String::from("char_length"),
62 String::from("len"),
63 ],
64 }
65 }
66}
67
68impl ScalarUDFImpl for SparkLengthFunc {
69 fn name(&self) -> &str {
70 "length"
71 }
72
73 fn signature(&self) -> &Signature {
74 &self.signature
75 }
76
77 fn return_type(&self, _args: &[DataType]) -> datafusion_common::Result<DataType> {
78 datafusion_common::internal_err!(
79 "return_type should not be called, use return_field_from_args instead"
80 )
81 }
82
83 fn invoke_with_args(
84 &self,
85 args: ScalarFunctionArgs,
86 ) -> datafusion_common::Result<ColumnarValue> {
87 make_scalar_function(spark_length, vec![])(&args.args)
88 }
89
90 fn aliases(&self) -> &[String] {
91 &self.aliases
92 }
93
94 fn return_field_from_args(
95 &self,
96 args: ReturnFieldArgs,
97 ) -> datafusion_common::Result<FieldRef> {
98 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
99 Ok(Arc::new(Field::new(self.name(), DataType::Int32, nullable)))
101 }
102}
103
104fn spark_length(args: &[ArrayRef]) -> datafusion_common::Result<ArrayRef> {
105 match args[0].data_type() {
106 DataType::Utf8 => {
107 let string_array = args[0].as_string::<i32>();
108 character_length::<_>(&string_array)
109 }
110 DataType::LargeUtf8 => {
111 let string_array = args[0].as_string::<i64>();
112 character_length::<_>(&string_array)
113 }
114 DataType::Utf8View => {
115 let string_array = args[0].as_string_view();
116 character_length::<_>(&string_array)
117 }
118 DataType::Binary => {
119 let binary_array = args[0].as_binary::<i32>();
120 byte_length::<_>(&binary_array)
121 }
122 DataType::LargeBinary => {
123 let binary_array = args[0].as_binary::<i64>();
124 byte_length::<_>(&binary_array)
125 }
126 DataType::BinaryView => {
127 let binary_array = args[0].as_binary_view();
128 byte_length::<_>(&binary_array)
129 }
130 other => exec_err!("Unsupported data type {other:?} for function `length`"),
131 }
132}
133
134fn character_length<'a, V>(array: &V) -> datafusion_common::Result<ArrayRef>
135where
136 V: StringArrayType<'a>,
137{
138 let is_array_ascii_only = array.is_ascii();
143 let nulls = array.nulls().cloned();
144 let array = {
145 if is_array_ascii_only {
146 let values: Vec<_> = (0..array.len())
147 .map(|i| {
148 let value = unsafe { array.value_unchecked(i) };
150 value.len() as i32
151 })
152 .collect();
153 PrimitiveArray::<Int32Type>::new(values.into(), nulls)
154 } else {
155 let values: Vec<_> = (0..array.len())
156 .map(|i| {
157 if array.is_null(i) {
159 i32::default()
160 } else {
161 let value = unsafe { array.value_unchecked(i) };
162 if value.is_empty() {
163 i32::default()
164 } else if value.is_ascii() {
165 value.len() as i32
166 } else {
167 value.chars().count() as i32
168 }
169 }
170 })
171 .collect();
172 PrimitiveArray::<Int32Type>::new(values.into(), nulls)
173 }
174 };
175
176 Ok(Arc::new(array))
177}
178
179fn byte_length<'a, V>(array: &V) -> datafusion_common::Result<ArrayRef>
180where
181 V: BinaryArrayType<'a>,
182{
183 let nulls = array.nulls().cloned();
184 let values: Vec<_> = (0..array.len())
185 .map(|i| {
186 let value = unsafe { array.value_unchecked(i) };
188 value.len() as i32
189 })
190 .collect();
191 Ok(Arc::new(PrimitiveArray::<Int32Type>::new(
192 values.into(),
193 nulls,
194 )))
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use crate::function::utils::test::test_scalar_function;
201 use arrow::array::Int32Array;
202 use arrow::datatypes::DataType::Int32;
203 use datafusion_common::{Result, ScalarValue};
204
205 macro_rules! test_spark_length_string {
206 ($INPUT:expr, $EXPECTED:expr) => {
207 test_scalar_function!(
208 SparkLengthFunc::new(),
209 vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))],
210 $EXPECTED,
211 i32,
212 Int32,
213 Int32Array
214 );
215
216 test_scalar_function!(
217 SparkLengthFunc::new(),
218 vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))],
219 $EXPECTED,
220 i32,
221 Int32,
222 Int32Array
223 );
224
225 test_scalar_function!(
226 SparkLengthFunc::new(),
227 vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))],
228 $EXPECTED,
229 i32,
230 Int32,
231 Int32Array
232 );
233 };
234 }
235
236 macro_rules! test_spark_length_binary {
237 ($INPUT:expr, $EXPECTED:expr) => {
238 test_scalar_function!(
239 SparkLengthFunc::new(),
240 vec![ColumnarValue::Scalar(ScalarValue::Binary($INPUT))],
241 $EXPECTED,
242 i32,
243 Int32,
244 Int32Array
245 );
246
247 test_scalar_function!(
248 SparkLengthFunc::new(),
249 vec![ColumnarValue::Scalar(ScalarValue::LargeBinary($INPUT))],
250 $EXPECTED,
251 i32,
252 Int32,
253 Int32Array
254 );
255
256 test_scalar_function!(
257 SparkLengthFunc::new(),
258 vec![ColumnarValue::Scalar(ScalarValue::BinaryView($INPUT))],
259 $EXPECTED,
260 i32,
261 Int32,
262 Int32Array
263 );
264 };
265 }
266
267 #[test]
268 fn test_functions() -> Result<()> {
269 test_spark_length_string!(Some(String::from("chars")), Ok(Some(5)));
270 test_spark_length_string!(Some(String::from("josé")), Ok(Some(4)));
271 test_spark_length_string!(Some(String::from("joséjoséjoséjosé")), Ok(Some(16)));
273 test_spark_length_string!(Some(String::from("")), Ok(Some(0)));
274 test_spark_length_string!(None, Ok(None));
275
276 test_spark_length_binary!(Some(String::from("chars").into_bytes()), Ok(Some(5)));
277 test_spark_length_binary!(Some(String::from("josé").into_bytes()), Ok(Some(5)));
278 test_spark_length_binary!(
280 Some(String::from("joséjoséjoséjosé").into_bytes()),
281 Ok(Some(20))
282 );
283 test_spark_length_binary!(Some(String::from("").into_bytes()), Ok(Some(0)));
284 test_spark_length_binary!(None, Ok(None));
285
286 Ok(())
287 }
288
289 #[test]
290 fn test_spark_length_nullability() -> Result<()> {
291 let func = SparkLengthFunc::new();
292
293 let nullable_field: FieldRef = Arc::new(Field::new("col", DataType::Utf8, true));
294
295 let out_nullable = func.return_field_from_args(ReturnFieldArgs {
296 arg_fields: &[nullable_field],
297 scalar_arguments: &[None],
298 })?;
299
300 assert!(
301 out_nullable.is_nullable(),
302 "length(col) should be nullable when child is nullable"
303 );
304
305 let non_nullable_field: FieldRef =
306 Arc::new(Field::new("col", DataType::Utf8, false));
307
308 let out_non_nullable = func.return_field_from_args(ReturnFieldArgs {
309 arg_fields: &[non_nullable_field],
310 scalar_arguments: &[None],
311 })?;
312
313 assert!(
314 !out_non_nullable.is_nullable(),
315 "length(col) should NOT be nullable when child is NOT nullable"
316 );
317
318 Ok(())
319 }
320}