datafusion_spark/function/string/
length.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    Array, ArrayRef, AsArray, BinaryArrayType, PrimitiveArray, StringArrayType,
20};
21use arrow::datatypes::{DataType, Int32Type};
22use datafusion_common::exec_err;
23use datafusion_expr::{
24    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
25};
26use datafusion_functions::utils::make_scalar_function;
27use std::sync::Arc;
28
29/// Spark-compatible `length` expression
30/// <https://spark.apache.org/docs/latest/api/sql/index.html#length>
31#[derive(Debug, PartialEq, Eq, Hash)]
32pub struct SparkLengthFunc {
33    signature: Signature,
34    aliases: Vec<String>,
35}
36
37impl Default for SparkLengthFunc {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl SparkLengthFunc {
44    pub fn new() -> Self {
45        Self {
46            signature: Signature::uniform(
47                1,
48                vec![
49                    DataType::Utf8View,
50                    DataType::Utf8,
51                    DataType::LargeUtf8,
52                    DataType::Binary,
53                    DataType::LargeBinary,
54                    DataType::BinaryView,
55                ],
56                Volatility::Immutable,
57            ),
58            aliases: vec![
59                String::from("character_length"),
60                String::from("char_length"),
61                String::from("len"),
62            ],
63        }
64    }
65}
66
67impl ScalarUDFImpl for SparkLengthFunc {
68    fn as_any(&self) -> &dyn std::any::Any {
69        self
70    }
71
72    fn name(&self) -> &str {
73        "length"
74    }
75
76    fn signature(&self) -> &Signature {
77        &self.signature
78    }
79
80    fn return_type(&self, _args: &[DataType]) -> datafusion_common::Result<DataType> {
81        // spark length always returns Int32
82        Ok(DataType::Int32)
83    }
84
85    fn invoke_with_args(
86        &self,
87        args: ScalarFunctionArgs,
88    ) -> datafusion_common::Result<ColumnarValue> {
89        make_scalar_function(spark_length, vec![])(&args.args)
90    }
91
92    fn aliases(&self) -> &[String] {
93        &self.aliases
94    }
95}
96
97fn spark_length(args: &[ArrayRef]) -> datafusion_common::Result<ArrayRef> {
98    match args[0].data_type() {
99        DataType::Utf8 => {
100            let string_array = args[0].as_string::<i32>();
101            character_length::<_>(string_array)
102        }
103        DataType::LargeUtf8 => {
104            let string_array = args[0].as_string::<i64>();
105            character_length::<_>(string_array)
106        }
107        DataType::Utf8View => {
108            let string_array = args[0].as_string_view();
109            character_length::<_>(string_array)
110        }
111        DataType::Binary => {
112            let binary_array = args[0].as_binary::<i32>();
113            byte_length::<_>(binary_array)
114        }
115        DataType::LargeBinary => {
116            let binary_array = args[0].as_binary::<i64>();
117            byte_length::<_>(binary_array)
118        }
119        DataType::BinaryView => {
120            let binary_array = args[0].as_binary_view();
121            byte_length::<_>(binary_array)
122        }
123        other => exec_err!("Unsupported data type {other:?} for function `length`"),
124    }
125}
126
127fn character_length<'a, V>(array: V) -> datafusion_common::Result<ArrayRef>
128where
129    V: StringArrayType<'a>,
130{
131    // String characters are variable length encoded in UTF-8, counting the
132    // number of chars requires expensive decoding, however checking if the
133    // string is ASCII only is relatively cheap.
134    // If strings are ASCII only, count bytes instead.
135    let is_array_ascii_only = array.is_ascii();
136    let nulls = array.nulls().cloned();
137    let array = {
138        if is_array_ascii_only {
139            let values: Vec<_> = (0..array.len())
140                .map(|i| {
141                    // Safety: we are iterating with array.len() so the index is always valid
142                    let value = unsafe { array.value_unchecked(i) };
143                    value.len() as i32
144                })
145                .collect();
146            PrimitiveArray::<Int32Type>::new(values.into(), nulls)
147        } else {
148            let values: Vec<_> = (0..array.len())
149                .map(|i| {
150                    // Safety: we are iterating with array.len() so the index is always valid
151                    if array.is_null(i) {
152                        i32::default()
153                    } else {
154                        let value = unsafe { array.value_unchecked(i) };
155                        if value.is_empty() {
156                            i32::default()
157                        } else if value.is_ascii() {
158                            value.len() as i32
159                        } else {
160                            value.chars().count() as i32
161                        }
162                    }
163                })
164                .collect();
165            PrimitiveArray::<Int32Type>::new(values.into(), nulls)
166        }
167    };
168
169    Ok(Arc::new(array))
170}
171
172fn byte_length<'a, V>(array: V) -> datafusion_common::Result<ArrayRef>
173where
174    V: BinaryArrayType<'a>,
175{
176    let nulls = array.nulls().cloned();
177    let values: Vec<_> = (0..array.len())
178        .map(|i| {
179            // Safety: we are iterating with array.len() so the index is always valid
180            let value = unsafe { array.value_unchecked(i) };
181            value.len() as i32
182        })
183        .collect();
184    Ok(Arc::new(PrimitiveArray::<Int32Type>::new(
185        values.into(),
186        nulls,
187    )))
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use crate::function::utils::test::test_scalar_function;
194    use arrow::array::{Array, Int32Array};
195    use arrow::datatypes::DataType::Int32;
196    use datafusion_common::{Result, ScalarValue};
197    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
198
199    macro_rules! test_spark_length_string {
200        ($INPUT:expr, $EXPECTED:expr) => {
201            test_scalar_function!(
202                SparkLengthFunc::new(),
203                vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))],
204                $EXPECTED,
205                i32,
206                Int32,
207                Int32Array
208            );
209
210            test_scalar_function!(
211                SparkLengthFunc::new(),
212                vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))],
213                $EXPECTED,
214                i32,
215                Int32,
216                Int32Array
217            );
218
219            test_scalar_function!(
220                SparkLengthFunc::new(),
221                vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))],
222                $EXPECTED,
223                i32,
224                Int32,
225                Int32Array
226            );
227        };
228    }
229
230    macro_rules! test_spark_length_binary {
231        ($INPUT:expr, $EXPECTED:expr) => {
232            test_scalar_function!(
233                SparkLengthFunc::new(),
234                vec![ColumnarValue::Scalar(ScalarValue::Binary($INPUT))],
235                $EXPECTED,
236                i32,
237                Int32,
238                Int32Array
239            );
240
241            test_scalar_function!(
242                SparkLengthFunc::new(),
243                vec![ColumnarValue::Scalar(ScalarValue::LargeBinary($INPUT))],
244                $EXPECTED,
245                i32,
246                Int32,
247                Int32Array
248            );
249
250            test_scalar_function!(
251                SparkLengthFunc::new(),
252                vec![ColumnarValue::Scalar(ScalarValue::BinaryView($INPUT))],
253                $EXPECTED,
254                i32,
255                Int32,
256                Int32Array
257            );
258        };
259    }
260
261    #[test]
262    fn test_functions() -> Result<()> {
263        test_spark_length_string!(Some(String::from("chars")), Ok(Some(5)));
264        test_spark_length_string!(Some(String::from("josé")), Ok(Some(4)));
265        // test long strings (more than 12 bytes for StringView)
266        test_spark_length_string!(Some(String::from("joséjoséjoséjosé")), Ok(Some(16)));
267        test_spark_length_string!(Some(String::from("")), Ok(Some(0)));
268        test_spark_length_string!(None, Ok(None));
269
270        test_spark_length_binary!(Some(String::from("chars").into_bytes()), Ok(Some(5)));
271        test_spark_length_binary!(Some(String::from("josé").into_bytes()), Ok(Some(5)));
272        // test long strings (more than 12 bytes for BinaryView)
273        test_spark_length_binary!(
274            Some(String::from("joséjoséjoséjosé").into_bytes()),
275            Ok(Some(20))
276        );
277        test_spark_length_binary!(Some(String::from("").into_bytes()), Ok(Some(0)));
278        test_spark_length_binary!(None, Ok(None));
279
280        Ok(())
281    }
282}