Skip to main content

datafusion_functions/unicode/
initcap.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22    Array, ArrayRef, GenericStringArray, GenericStringBuilder, OffsetSizeTrait,
23    StringViewBuilder,
24};
25use arrow::buffer::{Buffer, OffsetBuffer};
26use arrow::datatypes::DataType;
27
28use crate::utils::{make_scalar_function, utf8_to_str_type};
29use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
30use datafusion_common::types::logical_string;
31use datafusion_common::{Result, ScalarValue, exec_err};
32use datafusion_expr::{
33    Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignatureClass,
34    Volatility,
35};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39    doc_section(label = "String Functions"),
40    description = "Capitalizes the first character in each word in the input string. \
41            Words are delimited by non-alphanumeric characters.",
42    syntax_example = "initcap(str)",
43    sql_example = r#"```sql
44> select initcap('apache datafusion');
45+------------------------------------+
46| initcap(Utf8("apache datafusion")) |
47+------------------------------------+
48| Apache Datafusion                  |
49+------------------------------------+
50```"#,
51    standard_argument(name = "str", prefix = "String"),
52    related_udf(name = "lower"),
53    related_udf(name = "upper")
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct InitcapFunc {
57    signature: Signature,
58}
59
60impl Default for InitcapFunc {
61    fn default() -> Self {
62        InitcapFunc::new()
63    }
64}
65
66impl InitcapFunc {
67    pub fn new() -> Self {
68        Self {
69            signature: Signature::coercible(
70                vec![Coercion::new_exact(TypeSignatureClass::Native(
71                    logical_string(),
72                ))],
73                Volatility::Immutable,
74            ),
75        }
76    }
77}
78
79impl ScalarUDFImpl for InitcapFunc {
80    fn as_any(&self) -> &dyn Any {
81        self
82    }
83
84    fn name(&self) -> &str {
85        "initcap"
86    }
87
88    fn signature(&self) -> &Signature {
89        &self.signature
90    }
91
92    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
93        if let DataType::Utf8View = arg_types[0] {
94            Ok(DataType::Utf8View)
95        } else {
96            utf8_to_str_type(&arg_types[0], "initcap")
97        }
98    }
99
100    fn invoke_with_args(
101        &self,
102        args: datafusion_expr::ScalarFunctionArgs,
103    ) -> Result<ColumnarValue> {
104        let arg = &args.args[0];
105
106        // Scalar fast path - handle directly without array conversion
107        if let ColumnarValue::Scalar(scalar) = arg {
108            return match scalar {
109                ScalarValue::Utf8(None)
110                | ScalarValue::LargeUtf8(None)
111                | ScalarValue::Utf8View(None) => Ok(arg.clone()),
112                ScalarValue::Utf8(Some(s)) => {
113                    let mut result = String::new();
114                    initcap_string(s, &mut result);
115                    Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))))
116                }
117                ScalarValue::LargeUtf8(Some(s)) => {
118                    let mut result = String::new();
119                    initcap_string(s, &mut result);
120                    Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result))))
121                }
122                ScalarValue::Utf8View(Some(s)) => {
123                    let mut result = String::new();
124                    initcap_string(s, &mut result);
125                    Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result))))
126                }
127                other => {
128                    exec_err!(
129                        "Unsupported data type {:?} for function `initcap`",
130                        other.data_type()
131                    )
132                }
133            };
134        }
135
136        // Array path
137        let args = &args.args;
138        match args[0].data_type() {
139            DataType::Utf8 => make_scalar_function(initcap::<i32>, vec![])(args),
140            DataType::LargeUtf8 => make_scalar_function(initcap::<i64>, vec![])(args),
141            DataType::Utf8View => make_scalar_function(initcap_utf8view, vec![])(args),
142            other => {
143                exec_err!("Unsupported data type {other:?} for function `initcap`")
144            }
145        }
146    }
147
148    fn documentation(&self) -> Option<&Documentation> {
149        self.doc()
150    }
151}
152
153/// Converts the first letter of each word to uppercase and the rest to
154/// lowercase. Words are sequences of alphanumeric characters separated by
155/// non-alphanumeric characters.
156///
157/// Example:
158/// ```sql
159/// initcap('hi THOMAS') = 'Hi Thomas'
160/// ```
161fn initcap<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
162    let string_array = as_generic_string_array::<T>(&args[0])?;
163
164    if string_array.is_ascii() {
165        return Ok(initcap_ascii_array(string_array));
166    }
167
168    let mut builder = GenericStringBuilder::<T>::with_capacity(
169        string_array.len(),
170        string_array.value_data().len(),
171    );
172
173    let mut container = String::new();
174    string_array.iter().for_each(|str| match str {
175        Some(s) => {
176            initcap_string(s, &mut container);
177            builder.append_value(&container);
178        }
179        None => builder.append_null(),
180    });
181
182    Ok(Arc::new(builder.finish()) as ArrayRef)
183}
184
185/// Fast path for `Utf8` or `LargeUtf8` arrays that are ASCII-only. We can use a
186/// single pass over the buffer and operate directly on bytes.
187fn initcap_ascii_array<T: OffsetSizeTrait>(
188    string_array: &GenericStringArray<T>,
189) -> ArrayRef {
190    let offsets = string_array.offsets();
191    let src = string_array.value_data();
192    let first_offset = offsets.first().unwrap().as_usize();
193    let last_offset = offsets.last().unwrap().as_usize();
194
195    // For sliced arrays, only convert the visible bytes, not the entire input
196    // buffer.
197    let mut out = Vec::with_capacity(last_offset - first_offset);
198
199    for window in offsets.windows(2) {
200        let start = window[0].as_usize();
201        let end = window[1].as_usize();
202
203        let mut prev_is_alnum = false;
204        for &b in &src[start..end] {
205            let converted = if prev_is_alnum {
206                b.to_ascii_lowercase()
207            } else {
208                b.to_ascii_uppercase()
209            };
210            out.push(converted);
211            prev_is_alnum = b.is_ascii_alphanumeric();
212        }
213    }
214
215    let values = Buffer::from_vec(out);
216    let out_offsets = if first_offset == 0 {
217        offsets.clone()
218    } else {
219        // For sliced arrays, we need to rebase the offsets to reflect that the
220        // output only contains the bytes in the visible slice.
221        let rebased_offsets = offsets
222            .iter()
223            .map(|offset| T::usize_as(offset.as_usize() - first_offset))
224            .collect::<Vec<_>>();
225        OffsetBuffer::<T>::new(rebased_offsets.into())
226    };
227
228    // SAFETY: ASCII case conversion preserves byte length, so the original
229    // string boundaries are preserved. `out_offsets` is either identical to
230    // the input offsets or a rebased version relative to the compacted values
231    // buffer.
232    Arc::new(unsafe {
233        GenericStringArray::<T>::new_unchecked(
234            out_offsets,
235            values,
236            string_array.nulls().cloned(),
237        )
238    })
239}
240
241fn initcap_utf8view(args: &[ArrayRef]) -> Result<ArrayRef> {
242    let string_view_array = as_string_view_array(&args[0])?;
243    let mut builder = StringViewBuilder::with_capacity(string_view_array.len());
244    let mut container = String::new();
245
246    string_view_array.iter().for_each(|str| match str {
247        Some(s) => {
248            initcap_string(s, &mut container);
249            builder.append_value(&container);
250        }
251        None => builder.append_null(),
252    });
253
254    Ok(Arc::new(builder.finish()) as ArrayRef)
255}
256
257fn initcap_string(input: &str, container: &mut String) {
258    container.clear();
259    let mut prev_is_alphanumeric = false;
260
261    if input.is_ascii() {
262        container.reserve(input.len());
263        // SAFETY: each byte is ASCII, so the result is valid UTF-8.
264        let out = unsafe { container.as_mut_vec() };
265        for &b in input.as_bytes() {
266            if prev_is_alphanumeric {
267                out.push(b.to_ascii_lowercase());
268            } else {
269                out.push(b.to_ascii_uppercase());
270            }
271            prev_is_alphanumeric = b.is_ascii_alphanumeric();
272        }
273    } else {
274        for c in input.chars() {
275            if prev_is_alphanumeric {
276                container.extend(c.to_lowercase());
277            } else {
278                container.extend(c.to_uppercase());
279            }
280            prev_is_alphanumeric = c.is_alphanumeric();
281        }
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use crate::unicode::initcap::InitcapFunc;
288    use crate::utils::test::test_function;
289    use arrow::array::{Array, ArrayRef, LargeStringArray, StringArray, StringViewArray};
290    use arrow::datatypes::DataType::{Utf8, Utf8View};
291    use datafusion_common::{Result, ScalarValue};
292    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
293    use std::sync::Arc;
294
295    #[test]
296    fn test_functions() -> Result<()> {
297        test_function!(
298            InitcapFunc::new(),
299            vec![ColumnarValue::Scalar(ScalarValue::from("hi THOMAS"))],
300            Ok(Some("Hi Thomas")),
301            &str,
302            Utf8,
303            StringArray
304        );
305        test_function!(
306            InitcapFunc::new(),
307            vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some(
308                "êM ả ñAnDÚ ÁrBOL ОлЕГ ИвАНОВИч ÍslENsku ÞjóðaRiNNaR εΛλΗΝΙκΉ"
309                    .to_string()
310            )))],
311            Ok(Some(
312                "Êm Ả Ñandú Árbol Олег Иванович Íslensku Þjóðarinnar Ελληνική"
313            )),
314            &str,
315            Utf8,
316            StringArray
317        );
318        test_function!(
319            InitcapFunc::new(),
320            vec![ColumnarValue::Scalar(ScalarValue::from(""))],
321            Ok(Some("")),
322            &str,
323            Utf8,
324            StringArray
325        );
326        test_function!(
327            InitcapFunc::new(),
328            vec![ColumnarValue::Scalar(ScalarValue::from(""))],
329            Ok(Some("")),
330            &str,
331            Utf8,
332            StringArray
333        );
334        test_function!(
335            InitcapFunc::new(),
336            vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))],
337            Ok(None),
338            &str,
339            Utf8,
340            StringArray
341        );
342
343        test_function!(
344            InitcapFunc::new(),
345            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
346                "hi THOMAS".to_string()
347            )))],
348            Ok(Some("Hi Thomas")),
349            &str,
350            Utf8View,
351            StringViewArray
352        );
353        test_function!(
354            InitcapFunc::new(),
355            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
356                "hi THOMAS wIth M0re ThAN 12 ChaRs".to_string()
357            )))],
358            Ok(Some("Hi Thomas With M0re Than 12 Chars")),
359            &str,
360            Utf8View,
361            StringViewArray
362        );
363        test_function!(
364            InitcapFunc::new(),
365            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
366                "đẸp đẼ êM ả ñAnDÚ ÁrBOL ОлЕГ ИвАНОВИч ÍslENsku ÞjóðaRiNNaR εΛλΗΝΙκΉ"
367                    .to_string()
368            )))],
369            Ok(Some(
370                "Đẹp Đẽ Êm Ả Ñandú Árbol Олег Иванович Íslensku Þjóðarinnar Ελληνική"
371            )),
372            &str,
373            Utf8View,
374            StringViewArray
375        );
376        test_function!(
377            InitcapFunc::new(),
378            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
379                "".to_string()
380            )))],
381            Ok(Some("")),
382            &str,
383            Utf8View,
384            StringViewArray
385        );
386        test_function!(
387            InitcapFunc::new(),
388            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(None))],
389            Ok(None),
390            &str,
391            Utf8View,
392            StringViewArray
393        );
394
395        Ok(())
396    }
397
398    #[test]
399    fn test_initcap_ascii_array() -> Result<()> {
400        let array = StringArray::from(vec![
401            Some("hello world"),
402            None,
403            Some("foo-bar_baz/baX"),
404            Some(""),
405            Some("123 abc 456DEF"),
406            Some("ALL CAPS"),
407            Some("already correct"),
408        ]);
409        let args: Vec<ArrayRef> = vec![Arc::new(array)];
410        let result = super::initcap::<i32>(&args)?;
411        let result = result.as_any().downcast_ref::<StringArray>().unwrap();
412
413        assert_eq!(result.len(), 7);
414        assert_eq!(result.value(0), "Hello World");
415        assert!(result.is_null(1));
416        assert_eq!(result.value(2), "Foo-Bar_Baz/Bax");
417        assert_eq!(result.value(3), "");
418        assert_eq!(result.value(4), "123 Abc 456def");
419        assert_eq!(result.value(5), "All Caps");
420        assert_eq!(result.value(6), "Already Correct");
421        Ok(())
422    }
423
424    #[test]
425    fn test_initcap_ascii_large_array() -> Result<()> {
426        let array = LargeStringArray::from(vec![
427            Some("hello world"),
428            None,
429            Some("foo-bar_baz/baX"),
430            Some(""),
431            Some("123 abc 456DEF"),
432            Some("ALL CAPS"),
433            Some("already correct"),
434        ]);
435        let args: Vec<ArrayRef> = vec![Arc::new(array)];
436        let result = super::initcap::<i64>(&args)?;
437        let result = result.as_any().downcast_ref::<LargeStringArray>().unwrap();
438
439        assert_eq!(result.len(), 7);
440        assert_eq!(result.value(0), "Hello World");
441        assert!(result.is_null(1));
442        assert_eq!(result.value(2), "Foo-Bar_Baz/Bax");
443        assert_eq!(result.value(3), "");
444        assert_eq!(result.value(4), "123 Abc 456def");
445        assert_eq!(result.value(5), "All Caps");
446        assert_eq!(result.value(6), "Already Correct");
447        Ok(())
448    }
449
450    /// Test that initcap works correctly on a sliced ASCII StringArray.
451    #[test]
452    fn test_initcap_sliced_ascii_array() -> Result<()> {
453        let array = StringArray::from(vec![
454            Some("hello world"),
455            Some("foo bar"),
456            Some("baz qux"),
457        ]);
458        // Slice to get only the last two elements. The resulting array's
459        // offsets are [11, 18, 25] (non-zero start), but value_data still
460        // contains the full original buffer.
461        let sliced = array.slice(1, 2);
462        let args: Vec<ArrayRef> = vec![Arc::new(sliced)];
463        let result = super::initcap::<i32>(&args)?;
464        let result = result.as_any().downcast_ref::<StringArray>().unwrap();
465
466        assert_eq!(result.len(), 2);
467        assert_eq!(result.value(0), "Foo Bar");
468        assert_eq!(result.value(1), "Baz Qux");
469
470        // The output values buffer should be compact
471        assert_eq!(*result.offsets().first().unwrap(), 0);
472        assert_eq!(
473            result.value_data().len(),
474            *result.offsets().last().unwrap() as usize
475        );
476        Ok(())
477    }
478
479    /// Test that initcap works correctly on a sliced ASCII LargeStringArray.
480    #[test]
481    fn test_initcap_sliced_ascii_large_array() -> Result<()> {
482        let array = LargeStringArray::from(vec![
483            Some("hello world"),
484            Some("foo bar"),
485            Some("baz qux"),
486        ]);
487        // Slice to get only the last two elements. The resulting array's
488        // offsets are [11, 18, 25] (non-zero start), but value_data still
489        // contains the full original buffer.
490        let sliced = array.slice(1, 2);
491        let args: Vec<ArrayRef> = vec![Arc::new(sliced)];
492        let result = super::initcap::<i64>(&args)?;
493        let result = result.as_any().downcast_ref::<LargeStringArray>().unwrap();
494
495        assert_eq!(result.len(), 2);
496        assert_eq!(result.value(0), "Foo Bar");
497        assert_eq!(result.value(1), "Baz Qux");
498
499        // The output values buffer should be compact
500        assert_eq!(*result.offsets().first().unwrap(), 0);
501        assert_eq!(
502            result.value_data().len(),
503            *result.offsets().last().unwrap() as usize
504        );
505        Ok(())
506    }
507}