datafusion_functions/unicode/
character_length.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::utils::{make_scalar_function, utf8_to_int_type};
19use arrow::array::{
20    Array, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveArray,
21    StringArrayType,
22};
23use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};
24use datafusion_common::Result;
25use datafusion_expr::{
26    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
27};
28use datafusion_macros::user_doc;
29use std::any::Any;
30use std::sync::Arc;
31
32#[user_doc(
33    doc_section(label = "String Functions"),
34    description = "Returns the number of characters in a string.",
35    syntax_example = "character_length(str)",
36    sql_example = r#"```sql
37> select character_length('Ångström');
38+------------------------------------+
39| character_length(Utf8("Ångström")) |
40+------------------------------------+
41| 8                                  |
42+------------------------------------+
43```"#,
44    standard_argument(name = "str", prefix = "String"),
45    related_udf(name = "bit_length"),
46    related_udf(name = "octet_length")
47)]
48#[derive(Debug, PartialEq, Eq, Hash)]
49pub struct CharacterLengthFunc {
50    signature: Signature,
51    aliases: Vec<String>,
52}
53
54impl Default for CharacterLengthFunc {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60impl CharacterLengthFunc {
61    pub fn new() -> Self {
62        use DataType::*;
63        Self {
64            signature: Signature::uniform(
65                1,
66                vec![Utf8, LargeUtf8, Utf8View],
67                Volatility::Immutable,
68            ),
69            aliases: vec![String::from("length"), String::from("char_length")],
70        }
71    }
72}
73
74impl ScalarUDFImpl for CharacterLengthFunc {
75    fn as_any(&self) -> &dyn Any {
76        self
77    }
78
79    fn name(&self) -> &str {
80        "character_length"
81    }
82
83    fn signature(&self) -> &Signature {
84        &self.signature
85    }
86
87    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
88        utf8_to_int_type(&arg_types[0], "character_length")
89    }
90
91    fn invoke_with_args(
92        &self,
93        args: datafusion_expr::ScalarFunctionArgs,
94    ) -> Result<ColumnarValue> {
95        make_scalar_function(character_length, vec![])(&args.args)
96    }
97
98    fn aliases(&self) -> &[String] {
99        &self.aliases
100    }
101
102    fn documentation(&self) -> Option<&Documentation> {
103        self.doc()
104    }
105}
106
107/// Returns number of characters in the string.
108/// character_length('josé') = 4
109/// The implementation counts UTF-8 code points to count the number of characters
110fn character_length(args: &[ArrayRef]) -> Result<ArrayRef> {
111    match args[0].data_type() {
112        DataType::Utf8 => {
113            let string_array = args[0].as_string::<i32>();
114            character_length_general::<Int32Type, _>(string_array)
115        }
116        DataType::LargeUtf8 => {
117            let string_array = args[0].as_string::<i64>();
118            character_length_general::<Int64Type, _>(string_array)
119        }
120        DataType::Utf8View => {
121            let string_array = args[0].as_string_view();
122            character_length_general::<Int32Type, _>(string_array)
123        }
124        _ => unreachable!("CharacterLengthFunc"),
125    }
126}
127
128fn character_length_general<'a, T, V>(array: V) -> Result<ArrayRef>
129where
130    T: ArrowPrimitiveType,
131    T::Native: OffsetSizeTrait,
132    V: StringArrayType<'a>,
133{
134    // String characters are variable length encoded in UTF-8, counting the
135    // number of chars requires expensive decoding, however checking if the
136    // string is ASCII only is relatively cheap.
137    // If strings are ASCII only, count bytes instead.
138    let is_array_ascii_only = array.is_ascii();
139    let nulls = array.nulls().cloned();
140    let array = {
141        if is_array_ascii_only {
142            let values: Vec<_> = (0..array.len())
143                .map(|i| {
144                    // Safety: we are iterating with array.len() so the index is always valid
145                    let value = unsafe { array.value_unchecked(i) };
146                    T::Native::usize_as(value.len())
147                })
148                .collect();
149            PrimitiveArray::<T>::new(values.into(), nulls)
150        } else {
151            let values: Vec<_> = (0..array.len())
152                .map(|i| {
153                    // Safety: we are iterating with array.len() so the index is always valid
154                    if array.is_null(i) {
155                        T::default_value()
156                    } else {
157                        let value = unsafe { array.value_unchecked(i) };
158                        if value.is_empty() {
159                            T::default_value()
160                        } else if value.is_ascii() {
161                            T::Native::usize_as(value.len())
162                        } else {
163                            T::Native::usize_as(value.chars().count())
164                        }
165                    }
166                })
167                .collect();
168            PrimitiveArray::<T>::new(values.into(), nulls)
169        }
170    };
171
172    Ok(Arc::new(array))
173}
174
175#[cfg(test)]
176mod tests {
177    use crate::unicode::character_length::CharacterLengthFunc;
178    use crate::utils::test::test_function;
179    use arrow::array::{Array, Int32Array, Int64Array};
180    use arrow::datatypes::DataType::{Int32, Int64};
181    use datafusion_common::{Result, ScalarValue};
182    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
183
184    macro_rules! test_character_length {
185        ($INPUT:expr, $EXPECTED:expr) => {
186            test_function!(
187                CharacterLengthFunc::new(),
188                vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))],
189                $EXPECTED,
190                i32,
191                Int32,
192                Int32Array
193            );
194
195            test_function!(
196                CharacterLengthFunc::new(),
197                vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))],
198                $EXPECTED,
199                i64,
200                Int64,
201                Int64Array
202            );
203
204            test_function!(
205                CharacterLengthFunc::new(),
206                vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))],
207                $EXPECTED,
208                i32,
209                Int32,
210                Int32Array
211            );
212        };
213    }
214
215    #[test]
216    fn test_functions() -> Result<()> {
217        #[cfg(feature = "unicode_expressions")]
218        {
219            test_character_length!(Some(String::from("chars")), Ok(Some(5)));
220            test_character_length!(Some(String::from("josé")), Ok(Some(4)));
221            // test long strings (more than 12 bytes for StringView)
222            test_character_length!(Some(String::from("joséjoséjoséjosé")), Ok(Some(16)));
223            test_character_length!(Some(String::from("")), Ok(Some(0)));
224            test_character_length!(None, Ok(None));
225        }
226
227        #[cfg(not(feature = "unicode_expressions"))]
228        test_function!(
229            CharacterLengthFunc::new(),
230            &[ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé"))))],
231            internal_err!(
232                "function character_length requires compilation with feature flag: unicode_expressions."
233            ),
234            i32,
235            Int32,
236            Int32Array
237        );
238
239        Ok(())
240    }
241}