datafusion_spark/function/string/
length.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    Array, ArrayRef, AsArray, BinaryArrayType, PrimitiveArray, StringArrayType,
20};
21use arrow::datatypes::{DataType, Field, FieldRef, Int32Type};
22use datafusion_common::exec_err;
23use datafusion_expr::{
24    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
25    Volatility,
26};
27use datafusion_functions::utils::make_scalar_function;
28use std::sync::Arc;
29
30/// Spark-compatible `length` expression
31/// <https://spark.apache.org/docs/latest/api/sql/index.html#length>
32#[derive(Debug, PartialEq, Eq, Hash)]
33pub struct SparkLengthFunc {
34    signature: Signature,
35    aliases: Vec<String>,
36}
37
38impl Default for SparkLengthFunc {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl SparkLengthFunc {
45    pub fn new() -> Self {
46        Self {
47            signature: Signature::uniform(
48                1,
49                vec![
50                    DataType::Utf8View,
51                    DataType::Utf8,
52                    DataType::LargeUtf8,
53                    DataType::Binary,
54                    DataType::LargeBinary,
55                    DataType::BinaryView,
56                ],
57                Volatility::Immutable,
58            ),
59            aliases: vec![
60                String::from("character_length"),
61                String::from("char_length"),
62                String::from("len"),
63            ],
64        }
65    }
66}
67
68impl ScalarUDFImpl for SparkLengthFunc {
69    fn as_any(&self) -> &dyn std::any::Any {
70        self
71    }
72
73    fn name(&self) -> &str {
74        "length"
75    }
76
77    fn signature(&self) -> &Signature {
78        &self.signature
79    }
80
81    fn return_type(&self, _args: &[DataType]) -> datafusion_common::Result<DataType> {
82        datafusion_common::internal_err!(
83            "return_type should not be called, use return_field_from_args instead"
84        )
85    }
86
87    fn invoke_with_args(
88        &self,
89        args: ScalarFunctionArgs,
90    ) -> datafusion_common::Result<ColumnarValue> {
91        make_scalar_function(spark_length, vec![])(&args.args)
92    }
93
94    fn aliases(&self) -> &[String] {
95        &self.aliases
96    }
97
98    fn return_field_from_args(
99        &self,
100        args: ReturnFieldArgs,
101    ) -> datafusion_common::Result<FieldRef> {
102        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
103        // spark length always returns Int32
104        Ok(Arc::new(Field::new(self.name(), DataType::Int32, nullable)))
105    }
106}
107
108fn spark_length(args: &[ArrayRef]) -> datafusion_common::Result<ArrayRef> {
109    match args[0].data_type() {
110        DataType::Utf8 => {
111            let string_array = args[0].as_string::<i32>();
112            character_length::<_>(&string_array)
113        }
114        DataType::LargeUtf8 => {
115            let string_array = args[0].as_string::<i64>();
116            character_length::<_>(&string_array)
117        }
118        DataType::Utf8View => {
119            let string_array = args[0].as_string_view();
120            character_length::<_>(&string_array)
121        }
122        DataType::Binary => {
123            let binary_array = args[0].as_binary::<i32>();
124            byte_length::<_>(&binary_array)
125        }
126        DataType::LargeBinary => {
127            let binary_array = args[0].as_binary::<i64>();
128            byte_length::<_>(&binary_array)
129        }
130        DataType::BinaryView => {
131            let binary_array = args[0].as_binary_view();
132            byte_length::<_>(&binary_array)
133        }
134        other => exec_err!("Unsupported data type {other:?} for function `length`"),
135    }
136}
137
138fn character_length<'a, V>(array: &V) -> datafusion_common::Result<ArrayRef>
139where
140    V: StringArrayType<'a>,
141{
142    // String characters are variable length encoded in UTF-8, counting the
143    // number of chars requires expensive decoding, however checking if the
144    // string is ASCII only is relatively cheap.
145    // If strings are ASCII only, count bytes instead.
146    let is_array_ascii_only = array.is_ascii();
147    let nulls = array.nulls().cloned();
148    let array = {
149        if is_array_ascii_only {
150            let values: Vec<_> = (0..array.len())
151                .map(|i| {
152                    // Safety: we are iterating with array.len() so the index is always valid
153                    let value = unsafe { array.value_unchecked(i) };
154                    value.len() as i32
155                })
156                .collect();
157            PrimitiveArray::<Int32Type>::new(values.into(), nulls)
158        } else {
159            let values: Vec<_> = (0..array.len())
160                .map(|i| {
161                    // Safety: we are iterating with array.len() so the index is always valid
162                    if array.is_null(i) {
163                        i32::default()
164                    } else {
165                        let value = unsafe { array.value_unchecked(i) };
166                        if value.is_empty() {
167                            i32::default()
168                        } else if value.is_ascii() {
169                            value.len() as i32
170                        } else {
171                            value.chars().count() as i32
172                        }
173                    }
174                })
175                .collect();
176            PrimitiveArray::<Int32Type>::new(values.into(), nulls)
177        }
178    };
179
180    Ok(Arc::new(array))
181}
182
183fn byte_length<'a, V>(array: &V) -> datafusion_common::Result<ArrayRef>
184where
185    V: BinaryArrayType<'a>,
186{
187    let nulls = array.nulls().cloned();
188    let values: Vec<_> = (0..array.len())
189        .map(|i| {
190            // Safety: we are iterating with array.len() so the index is always valid
191            let value = unsafe { array.value_unchecked(i) };
192            value.len() as i32
193        })
194        .collect();
195    Ok(Arc::new(PrimitiveArray::<Int32Type>::new(
196        values.into(),
197        nulls,
198    )))
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use crate::function::utils::test::test_scalar_function;
205    use arrow::array::{Array, Int32Array};
206    use arrow::datatypes::DataType::Int32;
207    use arrow::datatypes::{Field, FieldRef};
208    use datafusion_common::{Result, ScalarValue};
209    use datafusion_expr::{ColumnarValue, ReturnFieldArgs, ScalarUDFImpl};
210
211    macro_rules! test_spark_length_string {
212        ($INPUT:expr, $EXPECTED:expr) => {
213            test_scalar_function!(
214                SparkLengthFunc::new(),
215                vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))],
216                $EXPECTED,
217                i32,
218                Int32,
219                Int32Array
220            );
221
222            test_scalar_function!(
223                SparkLengthFunc::new(),
224                vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))],
225                $EXPECTED,
226                i32,
227                Int32,
228                Int32Array
229            );
230
231            test_scalar_function!(
232                SparkLengthFunc::new(),
233                vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))],
234                $EXPECTED,
235                i32,
236                Int32,
237                Int32Array
238            );
239        };
240    }
241
242    macro_rules! test_spark_length_binary {
243        ($INPUT:expr, $EXPECTED:expr) => {
244            test_scalar_function!(
245                SparkLengthFunc::new(),
246                vec![ColumnarValue::Scalar(ScalarValue::Binary($INPUT))],
247                $EXPECTED,
248                i32,
249                Int32,
250                Int32Array
251            );
252
253            test_scalar_function!(
254                SparkLengthFunc::new(),
255                vec![ColumnarValue::Scalar(ScalarValue::LargeBinary($INPUT))],
256                $EXPECTED,
257                i32,
258                Int32,
259                Int32Array
260            );
261
262            test_scalar_function!(
263                SparkLengthFunc::new(),
264                vec![ColumnarValue::Scalar(ScalarValue::BinaryView($INPUT))],
265                $EXPECTED,
266                i32,
267                Int32,
268                Int32Array
269            );
270        };
271    }
272
273    #[test]
274    fn test_functions() -> Result<()> {
275        test_spark_length_string!(Some(String::from("chars")), Ok(Some(5)));
276        test_spark_length_string!(Some(String::from("josé")), Ok(Some(4)));
277        // test long strings (more than 12 bytes for StringView)
278        test_spark_length_string!(Some(String::from("joséjoséjoséjosé")), Ok(Some(16)));
279        test_spark_length_string!(Some(String::from("")), Ok(Some(0)));
280        test_spark_length_string!(None, Ok(None));
281
282        test_spark_length_binary!(Some(String::from("chars").into_bytes()), Ok(Some(5)));
283        test_spark_length_binary!(Some(String::from("josé").into_bytes()), Ok(Some(5)));
284        // test long strings (more than 12 bytes for BinaryView)
285        test_spark_length_binary!(
286            Some(String::from("joséjoséjoséjosé").into_bytes()),
287            Ok(Some(20))
288        );
289        test_spark_length_binary!(Some(String::from("").into_bytes()), Ok(Some(0)));
290        test_spark_length_binary!(None, Ok(None));
291
292        Ok(())
293    }
294
295    #[test]
296    fn test_spark_length_nullability() -> Result<()> {
297        let func = SparkLengthFunc::new();
298
299        let nullable_field: FieldRef = Arc::new(Field::new("col", DataType::Utf8, true));
300
301        let out_nullable = func.return_field_from_args(ReturnFieldArgs {
302            arg_fields: &[nullable_field],
303            scalar_arguments: &[None],
304        })?;
305
306        assert!(
307            out_nullable.is_nullable(),
308            "length(col) should be nullable when child is nullable"
309        );
310
311        let non_nullable_field: FieldRef =
312            Arc::new(Field::new("col", DataType::Utf8, false));
313
314        let out_non_nullable = func.return_field_from_args(ReturnFieldArgs {
315            arg_fields: &[non_nullable_field],
316            scalar_arguments: &[None],
317        })?;
318
319        assert!(
320            !out_non_nullable.is_nullable(),
321            "length(col) should NOT be nullable when child is NOT nullable"
322        );
323
324        Ok(())
325    }
326}