Skip to main content

datafusion_spark/function/string/
length.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    Array, ArrayRef, AsArray, BinaryArrayType, PrimitiveArray, StringArrayType,
20};
21use arrow::datatypes::{DataType, Field, FieldRef, Int32Type};
22use datafusion_common::exec_err;
23use datafusion_expr::{
24    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
25    Volatility,
26};
27use datafusion_functions::utils::make_scalar_function;
28use std::sync::Arc;
29
30/// Spark-compatible `length` expression
31/// <https://spark.apache.org/docs/latest/api/sql/index.html#length>
32#[derive(Debug, PartialEq, Eq, Hash)]
33pub struct SparkLengthFunc {
34    signature: Signature,
35    aliases: Vec<String>,
36}
37
38impl Default for SparkLengthFunc {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl SparkLengthFunc {
45    pub fn new() -> Self {
46        Self {
47            signature: Signature::uniform(
48                1,
49                vec![
50                    DataType::Utf8View,
51                    DataType::Utf8,
52                    DataType::LargeUtf8,
53                    DataType::Binary,
54                    DataType::LargeBinary,
55                    DataType::BinaryView,
56                ],
57                Volatility::Immutable,
58            ),
59            aliases: vec![
60                String::from("character_length"),
61                String::from("char_length"),
62                String::from("len"),
63            ],
64        }
65    }
66}
67
68impl ScalarUDFImpl for SparkLengthFunc {
69    fn name(&self) -> &str {
70        "length"
71    }
72
73    fn signature(&self) -> &Signature {
74        &self.signature
75    }
76
77    fn return_type(&self, _args: &[DataType]) -> datafusion_common::Result<DataType> {
78        datafusion_common::internal_err!(
79            "return_type should not be called, use return_field_from_args instead"
80        )
81    }
82
83    fn invoke_with_args(
84        &self,
85        args: ScalarFunctionArgs,
86    ) -> datafusion_common::Result<ColumnarValue> {
87        make_scalar_function(spark_length, vec![])(&args.args)
88    }
89
90    fn aliases(&self) -> &[String] {
91        &self.aliases
92    }
93
94    fn return_field_from_args(
95        &self,
96        args: ReturnFieldArgs,
97    ) -> datafusion_common::Result<FieldRef> {
98        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
99        // spark length always returns Int32
100        Ok(Arc::new(Field::new(self.name(), DataType::Int32, nullable)))
101    }
102}
103
104fn spark_length(args: &[ArrayRef]) -> datafusion_common::Result<ArrayRef> {
105    match args[0].data_type() {
106        DataType::Utf8 => {
107            let string_array = args[0].as_string::<i32>();
108            character_length::<_>(&string_array)
109        }
110        DataType::LargeUtf8 => {
111            let string_array = args[0].as_string::<i64>();
112            character_length::<_>(&string_array)
113        }
114        DataType::Utf8View => {
115            let string_array = args[0].as_string_view();
116            character_length::<_>(&string_array)
117        }
118        DataType::Binary => {
119            let binary_array = args[0].as_binary::<i32>();
120            byte_length::<_>(&binary_array)
121        }
122        DataType::LargeBinary => {
123            let binary_array = args[0].as_binary::<i64>();
124            byte_length::<_>(&binary_array)
125        }
126        DataType::BinaryView => {
127            let binary_array = args[0].as_binary_view();
128            byte_length::<_>(&binary_array)
129        }
130        other => exec_err!("Unsupported data type {other:?} for function `length`"),
131    }
132}
133
134fn character_length<'a, V>(array: &V) -> datafusion_common::Result<ArrayRef>
135where
136    V: StringArrayType<'a>,
137{
138    // String characters are variable length encoded in UTF-8, counting the
139    // number of chars requires expensive decoding, however checking if the
140    // string is ASCII only is relatively cheap.
141    // If strings are ASCII only, count bytes instead.
142    let is_array_ascii_only = array.is_ascii();
143    let nulls = array.nulls().cloned();
144    let array = {
145        if is_array_ascii_only {
146            let values: Vec<_> = (0..array.len())
147                .map(|i| {
148                    // Safety: we are iterating with array.len() so the index is always valid
149                    let value = unsafe { array.value_unchecked(i) };
150                    value.len() as i32
151                })
152                .collect();
153            PrimitiveArray::<Int32Type>::new(values.into(), nulls)
154        } else {
155            let values: Vec<_> = (0..array.len())
156                .map(|i| {
157                    // Safety: we are iterating with array.len() so the index is always valid
158                    if array.is_null(i) {
159                        i32::default()
160                    } else {
161                        let value = unsafe { array.value_unchecked(i) };
162                        if value.is_empty() {
163                            i32::default()
164                        } else if value.is_ascii() {
165                            value.len() as i32
166                        } else {
167                            value.chars().count() as i32
168                        }
169                    }
170                })
171                .collect();
172            PrimitiveArray::<Int32Type>::new(values.into(), nulls)
173        }
174    };
175
176    Ok(Arc::new(array))
177}
178
179fn byte_length<'a, V>(array: &V) -> datafusion_common::Result<ArrayRef>
180where
181    V: BinaryArrayType<'a>,
182{
183    let nulls = array.nulls().cloned();
184    let values: Vec<_> = (0..array.len())
185        .map(|i| {
186            // Safety: we are iterating with array.len() so the index is always valid
187            let value = unsafe { array.value_unchecked(i) };
188            value.len() as i32
189        })
190        .collect();
191    Ok(Arc::new(PrimitiveArray::<Int32Type>::new(
192        values.into(),
193        nulls,
194    )))
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use crate::function::utils::test::test_scalar_function;
201    use arrow::array::Int32Array;
202    use arrow::datatypes::DataType::Int32;
203    use datafusion_common::{Result, ScalarValue};
204
205    macro_rules! test_spark_length_string {
206        ($INPUT:expr, $EXPECTED:expr) => {
207            test_scalar_function!(
208                SparkLengthFunc::new(),
209                vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))],
210                $EXPECTED,
211                i32,
212                Int32,
213                Int32Array
214            );
215
216            test_scalar_function!(
217                SparkLengthFunc::new(),
218                vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))],
219                $EXPECTED,
220                i32,
221                Int32,
222                Int32Array
223            );
224
225            test_scalar_function!(
226                SparkLengthFunc::new(),
227                vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))],
228                $EXPECTED,
229                i32,
230                Int32,
231                Int32Array
232            );
233        };
234    }
235
236    macro_rules! test_spark_length_binary {
237        ($INPUT:expr, $EXPECTED:expr) => {
238            test_scalar_function!(
239                SparkLengthFunc::new(),
240                vec![ColumnarValue::Scalar(ScalarValue::Binary($INPUT))],
241                $EXPECTED,
242                i32,
243                Int32,
244                Int32Array
245            );
246
247            test_scalar_function!(
248                SparkLengthFunc::new(),
249                vec![ColumnarValue::Scalar(ScalarValue::LargeBinary($INPUT))],
250                $EXPECTED,
251                i32,
252                Int32,
253                Int32Array
254            );
255
256            test_scalar_function!(
257                SparkLengthFunc::new(),
258                vec![ColumnarValue::Scalar(ScalarValue::BinaryView($INPUT))],
259                $EXPECTED,
260                i32,
261                Int32,
262                Int32Array
263            );
264        };
265    }
266
267    #[test]
268    fn test_functions() -> Result<()> {
269        test_spark_length_string!(Some(String::from("chars")), Ok(Some(5)));
270        test_spark_length_string!(Some(String::from("josé")), Ok(Some(4)));
271        // test long strings (more than 12 bytes for StringView)
272        test_spark_length_string!(Some(String::from("joséjoséjoséjosé")), Ok(Some(16)));
273        test_spark_length_string!(Some(String::from("")), Ok(Some(0)));
274        test_spark_length_string!(None, Ok(None));
275
276        test_spark_length_binary!(Some(String::from("chars").into_bytes()), Ok(Some(5)));
277        test_spark_length_binary!(Some(String::from("josé").into_bytes()), Ok(Some(5)));
278        // test long strings (more than 12 bytes for BinaryView)
279        test_spark_length_binary!(
280            Some(String::from("joséjoséjoséjosé").into_bytes()),
281            Ok(Some(20))
282        );
283        test_spark_length_binary!(Some(String::from("").into_bytes()), Ok(Some(0)));
284        test_spark_length_binary!(None, Ok(None));
285
286        Ok(())
287    }
288
289    #[test]
290    fn test_spark_length_nullability() -> Result<()> {
291        let func = SparkLengthFunc::new();
292
293        let nullable_field: FieldRef = Arc::new(Field::new("col", DataType::Utf8, true));
294
295        let out_nullable = func.return_field_from_args(ReturnFieldArgs {
296            arg_fields: &[nullable_field],
297            scalar_arguments: &[None],
298        })?;
299
300        assert!(
301            out_nullable.is_nullable(),
302            "length(col) should be nullable when child is nullable"
303        );
304
305        let non_nullable_field: FieldRef =
306            Arc::new(Field::new("col", DataType::Utf8, false));
307
308        let out_non_nullable = func.return_field_from_args(ReturnFieldArgs {
309            arg_fields: &[non_nullable_field],
310            scalar_arguments: &[None],
311        })?;
312
313        assert!(
314            !out_non_nullable.is_nullable(),
315            "length(col) should NOT be nullable when child is NOT nullable"
316        );
317
318        Ok(())
319    }
320}