datafusion_functions/unicode/
initcap.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22    Array, ArrayRef, GenericStringBuilder, OffsetSizeTrait, StringViewBuilder,
23};
24use arrow::datatypes::DataType;
25
26use crate::utils::{make_scalar_function, utf8_to_str_type};
27use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
28use datafusion_common::{exec_err, Result};
29use datafusion_expr::{ColumnarValue, Documentation, Volatility};
30use datafusion_expr::{ScalarUDFImpl, Signature};
31use datafusion_macros::user_doc;
32
33#[user_doc(
34    doc_section(label = "String Functions"),
35    description = "Capitalizes the first character in each word in the input string. \
36            Words are delimited by non-alphanumeric characters.",
37    syntax_example = "initcap(str)",
38    sql_example = r#"```sql
39> select initcap('apache datafusion');
40+------------------------------------+
41| initcap(Utf8("apache datafusion")) |
42+------------------------------------+
43| Apache Datafusion                  |
44+------------------------------------+
45```"#,
46    standard_argument(name = "str", prefix = "String"),
47    related_udf(name = "lower"),
48    related_udf(name = "upper")
49)]
50#[derive(Debug)]
51pub struct InitcapFunc {
52    signature: Signature,
53}
54
55impl Default for InitcapFunc {
56    fn default() -> Self {
57        InitcapFunc::new()
58    }
59}
60
61impl InitcapFunc {
62    pub fn new() -> Self {
63        Self {
64            signature: Signature::string(1, Volatility::Immutable),
65        }
66    }
67}
68
69impl ScalarUDFImpl for InitcapFunc {
70    fn as_any(&self) -> &dyn Any {
71        self
72    }
73
74    fn name(&self) -> &str {
75        "initcap"
76    }
77
78    fn signature(&self) -> &Signature {
79        &self.signature
80    }
81
82    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
83        if let DataType::Utf8View = arg_types[0] {
84            Ok(DataType::Utf8View)
85        } else {
86            utf8_to_str_type(&arg_types[0], "initcap")
87        }
88    }
89
90    fn invoke_with_args(
91        &self,
92        args: datafusion_expr::ScalarFunctionArgs,
93    ) -> Result<ColumnarValue> {
94        let args = &args.args;
95        match args[0].data_type() {
96            DataType::Utf8 => make_scalar_function(initcap::<i32>, vec![])(args),
97            DataType::LargeUtf8 => make_scalar_function(initcap::<i64>, vec![])(args),
98            DataType::Utf8View => make_scalar_function(initcap_utf8view, vec![])(args),
99            other => {
100                exec_err!("Unsupported data type {other:?} for function `initcap`")
101            }
102        }
103    }
104
105    fn documentation(&self) -> Option<&Documentation> {
106        self.doc()
107    }
108}
109
110/// Converts the first letter of each word to upper case and the rest to lower
111/// case. Words are sequences of alphanumeric characters separated by
112/// non-alphanumeric characters.
113///
114/// Example:
115/// ```sql
116/// initcap('hi THOMAS') = 'Hi Thomas'
117/// ```
118fn initcap<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
119    let string_array = as_generic_string_array::<T>(&args[0])?;
120
121    let mut builder = GenericStringBuilder::<T>::with_capacity(
122        string_array.len(),
123        string_array.value_data().len(),
124    );
125
126    string_array.iter().for_each(|str| match str {
127        Some(s) => {
128            let initcap_str = initcap_string(s);
129            builder.append_value(initcap_str);
130        }
131        None => builder.append_null(),
132    });
133
134    Ok(Arc::new(builder.finish()) as ArrayRef)
135}
136
137fn initcap_utf8view(args: &[ArrayRef]) -> Result<ArrayRef> {
138    let string_view_array = as_string_view_array(&args[0])?;
139
140    let mut builder = StringViewBuilder::with_capacity(string_view_array.len());
141
142    string_view_array.iter().for_each(|str| match str {
143        Some(s) => {
144            let initcap_str = initcap_string(s);
145            builder.append_value(initcap_str);
146        }
147        None => builder.append_null(),
148    });
149
150    Ok(Arc::new(builder.finish()) as ArrayRef)
151}
152
153fn initcap_string(input: &str) -> String {
154    let mut result = String::with_capacity(input.len());
155    let mut prev_is_alphanumeric = false;
156
157    if input.is_ascii() {
158        for c in input.chars() {
159            if prev_is_alphanumeric {
160                result.push(c.to_ascii_lowercase());
161            } else {
162                result.push(c.to_ascii_uppercase());
163            };
164            prev_is_alphanumeric = c.is_ascii_alphanumeric();
165        }
166    } else {
167        for c in input.chars() {
168            if prev_is_alphanumeric {
169                result.extend(c.to_lowercase());
170            } else {
171                result.extend(c.to_uppercase());
172            }
173            prev_is_alphanumeric = c.is_alphanumeric();
174        }
175    }
176
177    result
178}
179
180#[cfg(test)]
181mod tests {
182    use crate::unicode::initcap::InitcapFunc;
183    use crate::utils::test::test_function;
184    use arrow::array::{Array, StringArray, StringViewArray};
185    use arrow::datatypes::DataType::{Utf8, Utf8View};
186    use datafusion_common::{Result, ScalarValue};
187    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
188
189    #[test]
190    fn test_functions() -> Result<()> {
191        test_function!(
192            InitcapFunc::new(),
193            vec![ColumnarValue::Scalar(ScalarValue::from("hi THOMAS"))],
194            Ok(Some("Hi Thomas")),
195            &str,
196            Utf8,
197            StringArray
198        );
199        test_function!(
200            InitcapFunc::new(),
201            vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some(
202                "êM ả ñAnDÚ ÁrBOL ОлЕГ ИвАНОВИч ÍslENsku ÞjóðaRiNNaR εΛλΗΝΙκΉ"
203                    .to_string()
204            )))],
205            Ok(Some(
206                "Êm Ả Ñandú Árbol Олег Иванович Íslensku Þjóðarinnar Ελληνική"
207            )),
208            &str,
209            Utf8,
210            StringArray
211        );
212        test_function!(
213            InitcapFunc::new(),
214            vec![ColumnarValue::Scalar(ScalarValue::from(""))],
215            Ok(Some("")),
216            &str,
217            Utf8,
218            StringArray
219        );
220        test_function!(
221            InitcapFunc::new(),
222            vec![ColumnarValue::Scalar(ScalarValue::from(""))],
223            Ok(Some("")),
224            &str,
225            Utf8,
226            StringArray
227        );
228        test_function!(
229            InitcapFunc::new(),
230            vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))],
231            Ok(None),
232            &str,
233            Utf8,
234            StringArray
235        );
236
237        test_function!(
238            InitcapFunc::new(),
239            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
240                "hi THOMAS".to_string()
241            )))],
242            Ok(Some("Hi Thomas")),
243            &str,
244            Utf8View,
245            StringViewArray
246        );
247        test_function!(
248            InitcapFunc::new(),
249            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
250                "hi THOMAS wIth M0re ThAN 12 ChaRs".to_string()
251            )))],
252            Ok(Some("Hi Thomas With M0re Than 12 Chars")),
253            &str,
254            Utf8View,
255            StringViewArray
256        );
257        test_function!(
258            InitcapFunc::new(),
259            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
260                "đẸp đẼ êM ả ñAnDÚ ÁrBOL ОлЕГ ИвАНОВИч ÍslENsku ÞjóðaRiNNaR εΛλΗΝΙκΉ"
261                    .to_string()
262            )))],
263            Ok(Some(
264                "Đẹp Đẽ Êm Ả Ñandú Árbol Олег Иванович Íslensku Þjóðarinnar Ελληνική"
265            )),
266            &str,
267            Utf8View,
268            StringViewArray
269        );
270        test_function!(
271            InitcapFunc::new(),
272            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
273                "".to_string()
274            )))],
275            Ok(Some("")),
276            &str,
277            Utf8View,
278            StringViewArray
279        );
280        test_function!(
281            InitcapFunc::new(),
282            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(None))],
283            Ok(None),
284            &str,
285            Utf8View,
286            StringViewArray
287        );
288
289        Ok(())
290    }
291}