Skip to main content

datafusion_functions/unicode/
initcap.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait};
21use arrow::buffer::{Buffer, OffsetBuffer};
22use arrow::datatypes::DataType;
23
24use crate::strings::{GenericStringArrayBuilder, StringViewArrayBuilder};
25use crate::utils::{make_scalar_function, utf8_to_str_type};
26use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
27use datafusion_common::types::logical_string;
28use datafusion_common::{Result, ScalarValue, exec_err};
29use datafusion_expr::{
30    Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
31    TypeSignatureClass, Volatility,
32};
33use datafusion_macros::user_doc;
34
35#[user_doc(
36    doc_section(label = "String Functions"),
37    description = "Capitalizes the first character in each word in the input string. \
38            Words are delimited by non-alphanumeric characters.",
39    syntax_example = "initcap(str)",
40    sql_example = r#"```sql
41> select initcap('apache datafusion');
42+------------------------------------+
43| initcap(Utf8("apache datafusion")) |
44+------------------------------------+
45| Apache Datafusion                  |
46+------------------------------------+
47```"#,
48    standard_argument(name = "str", prefix = "String"),
49    related_udf(name = "lower"),
50    related_udf(name = "upper")
51)]
52#[derive(Debug, PartialEq, Eq, Hash)]
53pub struct InitcapFunc {
54    signature: Signature,
55}
56
57impl Default for InitcapFunc {
58    fn default() -> Self {
59        InitcapFunc::new()
60    }
61}
62
63impl InitcapFunc {
64    pub fn new() -> Self {
65        Self {
66            signature: Signature::coercible(
67                vec![Coercion::new_exact(TypeSignatureClass::Native(
68                    logical_string(),
69                ))],
70                Volatility::Immutable,
71            ),
72        }
73    }
74}
75
76impl ScalarUDFImpl for InitcapFunc {
77    fn name(&self) -> &str {
78        "initcap"
79    }
80
81    fn signature(&self) -> &Signature {
82        &self.signature
83    }
84
85    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
86        if let DataType::Utf8View = arg_types[0] {
87            Ok(DataType::Utf8View)
88        } else {
89            utf8_to_str_type(&arg_types[0], "initcap")
90        }
91    }
92
93    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
94        let arg = &args.args[0];
95
96        // Scalar fast path - handle directly without array conversion
97        if let ColumnarValue::Scalar(scalar) = arg {
98            return match scalar {
99                ScalarValue::Utf8(None)
100                | ScalarValue::LargeUtf8(None)
101                | ScalarValue::Utf8View(None) => Ok(arg.clone()),
102                ScalarValue::Utf8(Some(s)) => {
103                    let mut result = String::new();
104                    initcap_string(s, &mut result);
105                    Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))))
106                }
107                ScalarValue::LargeUtf8(Some(s)) => {
108                    let mut result = String::new();
109                    initcap_string(s, &mut result);
110                    Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result))))
111                }
112                ScalarValue::Utf8View(Some(s)) => {
113                    let mut result = String::new();
114                    initcap_string(s, &mut result);
115                    Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result))))
116                }
117                other => {
118                    exec_err!(
119                        "Unsupported data type {:?} for function `initcap`",
120                        other.data_type()
121                    )
122                }
123            };
124        }
125
126        // Array path
127        let args = &args.args;
128        match args[0].data_type() {
129            DataType::Utf8 => make_scalar_function(initcap::<i32>, vec![])(args),
130            DataType::LargeUtf8 => make_scalar_function(initcap::<i64>, vec![])(args),
131            DataType::Utf8View => make_scalar_function(initcap_utf8view, vec![])(args),
132            other => {
133                exec_err!("Unsupported data type {other:?} for function `initcap`")
134            }
135        }
136    }
137
138    fn documentation(&self) -> Option<&Documentation> {
139        self.doc()
140    }
141}
142
143/// Converts the first letter of each word to uppercase and the rest to
144/// lowercase. Words are sequences of alphanumeric characters separated by
145/// non-alphanumeric characters.
146///
147/// Example:
148/// ```sql
149/// initcap('hi THOMAS') = 'Hi Thomas'
150/// ```
151fn initcap<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
152    let string_array = as_generic_string_array::<T>(&args[0])?;
153
154    if string_array.is_ascii() {
155        return Ok(initcap_ascii_array(string_array));
156    }
157
158    let len = string_array.len();
159    let mut builder = GenericStringArrayBuilder::<T>::with_capacity(
160        len,
161        string_array.value_data().len(),
162    );
163
164    let mut container = String::new();
165    let nulls = string_array.nulls().cloned();
166    if let Some(ref n) = nulls {
167        for i in 0..len {
168            if n.is_null(i) {
169                builder.append_placeholder();
170            } else {
171                // SAFETY: not null per check above.
172                let s = unsafe { string_array.value_unchecked(i) };
173                initcap_string(s, &mut container);
174                builder.append_value(&container);
175            }
176        }
177    } else {
178        for i in 0..len {
179            // SAFETY: no null buffer means every index is valid.
180            let s = unsafe { string_array.value_unchecked(i) };
181            initcap_string(s, &mut container);
182            builder.append_value(&container);
183        }
184    }
185
186    Ok(Arc::new(builder.finish(nulls)?) as ArrayRef)
187}
188
189/// Fast path for `Utf8` or `LargeUtf8` arrays that are ASCII-only. We can use a
190/// single pass over the buffer and operate directly on bytes.
191fn initcap_ascii_array<T: OffsetSizeTrait>(
192    string_array: &GenericStringArray<T>,
193) -> ArrayRef {
194    let offsets = string_array.offsets();
195    let src = string_array.value_data();
196    let first_offset = offsets.first().unwrap().as_usize();
197    let last_offset = offsets.last().unwrap().as_usize();
198
199    // For sliced arrays, only convert the visible bytes, not the entire input
200    // buffer.
201    let mut out = Vec::with_capacity(last_offset - first_offset);
202
203    for window in offsets.windows(2) {
204        let start = window[0].as_usize();
205        let end = window[1].as_usize();
206
207        let mut prev_is_alnum = false;
208        for &b in &src[start..end] {
209            let converted = if prev_is_alnum {
210                b.to_ascii_lowercase()
211            } else {
212                b.to_ascii_uppercase()
213            };
214            out.push(converted);
215            prev_is_alnum = b.is_ascii_alphanumeric();
216        }
217    }
218
219    let values = Buffer::from_vec(out);
220    let out_offsets = if first_offset == 0 {
221        offsets.clone()
222    } else {
223        // For sliced arrays, we need to rebase the offsets to reflect that the
224        // output only contains the bytes in the visible slice.
225        let rebased_offsets = offsets
226            .iter()
227            .map(|offset| T::usize_as(offset.as_usize() - first_offset))
228            .collect::<Vec<_>>();
229        OffsetBuffer::<T>::new(rebased_offsets.into())
230    };
231
232    // SAFETY: ASCII case conversion preserves byte length, so the original
233    // string boundaries are preserved. `out_offsets` is either identical to
234    // the input offsets or a rebased version relative to the compacted values
235    // buffer.
236    Arc::new(unsafe {
237        GenericStringArray::<T>::new_unchecked(
238            out_offsets,
239            values,
240            string_array.nulls().cloned(),
241        )
242    })
243}
244
245fn initcap_utf8view(args: &[ArrayRef]) -> Result<ArrayRef> {
246    let string_view_array = as_string_view_array(&args[0])?;
247    let len = string_view_array.len();
248    let mut builder = StringViewArrayBuilder::with_capacity(len);
249    let mut container = String::new();
250
251    let nulls = string_view_array.nulls().cloned();
252    if let Some(ref n) = nulls {
253        for i in 0..len {
254            if n.is_null(i) {
255                builder.append_placeholder();
256            } else {
257                // SAFETY: not null per check above.
258                let s = unsafe { string_view_array.value_unchecked(i) };
259                initcap_string(s, &mut container);
260                builder.append_value(&container);
261            }
262        }
263    } else {
264        for i in 0..len {
265            // SAFETY: no null buffer means every index is valid.
266            let s = unsafe { string_view_array.value_unchecked(i) };
267            initcap_string(s, &mut container);
268            builder.append_value(&container);
269        }
270    }
271
272    Ok(Arc::new(builder.finish(nulls)?) as ArrayRef)
273}
274
275fn initcap_string(input: &str, container: &mut String) {
276    container.clear();
277    let mut prev_is_alphanumeric = false;
278
279    if input.is_ascii() {
280        container.reserve(input.len());
281        // SAFETY: each byte is ASCII, so the result is valid UTF-8.
282        let out = unsafe { container.as_mut_vec() };
283        for &b in input.as_bytes() {
284            if prev_is_alphanumeric {
285                out.push(b.to_ascii_lowercase());
286            } else {
287                out.push(b.to_ascii_uppercase());
288            }
289            prev_is_alphanumeric = b.is_ascii_alphanumeric();
290        }
291    } else {
292        for c in input.chars() {
293            if prev_is_alphanumeric {
294                container.extend(c.to_lowercase());
295            } else {
296                container.extend(c.to_uppercase());
297            }
298            prev_is_alphanumeric = c.is_alphanumeric();
299        }
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use crate::unicode::initcap::InitcapFunc;
306    use crate::utils::test::test_function;
307    use arrow::array::{Array, ArrayRef, LargeStringArray, StringArray, StringViewArray};
308    use arrow::datatypes::DataType::{Utf8, Utf8View};
309    use datafusion_common::{Result, ScalarValue};
310    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
311    use std::sync::Arc;
312
313    #[test]
314    fn test_functions() -> Result<()> {
315        test_function!(
316            InitcapFunc::new(),
317            vec![ColumnarValue::Scalar(ScalarValue::from("hi THOMAS"))],
318            Ok(Some("Hi Thomas")),
319            &str,
320            Utf8,
321            StringArray
322        );
323        test_function!(
324            InitcapFunc::new(),
325            vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some(
326                "êM ả ñAnDÚ ÁrBOL ОлЕГ ИвАНОВИч ÍslENsku ÞjóðaRiNNaR εΛλΗΝΙκΉ"
327                    .to_string()
328            )))],
329            Ok(Some(
330                "Êm Ả Ñandú Árbol Олег Иванович Íslensku Þjóðarinnar Ελληνική"
331            )),
332            &str,
333            Utf8,
334            StringArray
335        );
336        test_function!(
337            InitcapFunc::new(),
338            vec![ColumnarValue::Scalar(ScalarValue::from(""))],
339            Ok(Some("")),
340            &str,
341            Utf8,
342            StringArray
343        );
344        test_function!(
345            InitcapFunc::new(),
346            vec![ColumnarValue::Scalar(ScalarValue::from(""))],
347            Ok(Some("")),
348            &str,
349            Utf8,
350            StringArray
351        );
352        test_function!(
353            InitcapFunc::new(),
354            vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))],
355            Ok(None),
356            &str,
357            Utf8,
358            StringArray
359        );
360
361        test_function!(
362            InitcapFunc::new(),
363            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
364                "hi THOMAS".to_string()
365            )))],
366            Ok(Some("Hi Thomas")),
367            &str,
368            Utf8View,
369            StringViewArray
370        );
371        test_function!(
372            InitcapFunc::new(),
373            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
374                "hi THOMAS wIth M0re ThAN 12 ChaRs".to_string()
375            )))],
376            Ok(Some("Hi Thomas With M0re Than 12 Chars")),
377            &str,
378            Utf8View,
379            StringViewArray
380        );
381        test_function!(
382            InitcapFunc::new(),
383            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
384                "đẸp đẼ êM ả ñAnDÚ ÁrBOL ОлЕГ ИвАНОВИч ÍslENsku ÞjóðaRiNNaR εΛλΗΝΙκΉ"
385                    .to_string()
386            )))],
387            Ok(Some(
388                "Đẹp Đẽ Êm Ả Ñandú Árbol Олег Иванович Íslensku Þjóðarinnar Ελληνική"
389            )),
390            &str,
391            Utf8View,
392            StringViewArray
393        );
394        test_function!(
395            InitcapFunc::new(),
396            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
397                "".to_string()
398            )))],
399            Ok(Some("")),
400            &str,
401            Utf8View,
402            StringViewArray
403        );
404        test_function!(
405            InitcapFunc::new(),
406            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(None))],
407            Ok(None),
408            &str,
409            Utf8View,
410            StringViewArray
411        );
412
413        Ok(())
414    }
415
416    #[test]
417    fn test_initcap_ascii_array() -> Result<()> {
418        let array = StringArray::from(vec![
419            Some("hello world"),
420            None,
421            Some("foo-bar_baz/baX"),
422            Some(""),
423            Some("123 abc 456DEF"),
424            Some("ALL CAPS"),
425            Some("already correct"),
426        ]);
427        let args: Vec<ArrayRef> = vec![Arc::new(array)];
428        let result = super::initcap::<i32>(&args)?;
429        let result = result.as_any().downcast_ref::<StringArray>().unwrap();
430
431        assert_eq!(result.len(), 7);
432        assert_eq!(result.value(0), "Hello World");
433        assert!(result.is_null(1));
434        assert_eq!(result.value(2), "Foo-Bar_Baz/Bax");
435        assert_eq!(result.value(3), "");
436        assert_eq!(result.value(4), "123 Abc 456def");
437        assert_eq!(result.value(5), "All Caps");
438        assert_eq!(result.value(6), "Already Correct");
439        Ok(())
440    }
441
442    #[test]
443    fn test_initcap_ascii_large_array() -> Result<()> {
444        let array = LargeStringArray::from(vec![
445            Some("hello world"),
446            None,
447            Some("foo-bar_baz/baX"),
448            Some(""),
449            Some("123 abc 456DEF"),
450            Some("ALL CAPS"),
451            Some("already correct"),
452        ]);
453        let args: Vec<ArrayRef> = vec![Arc::new(array)];
454        let result = super::initcap::<i64>(&args)?;
455        let result = result.as_any().downcast_ref::<LargeStringArray>().unwrap();
456
457        assert_eq!(result.len(), 7);
458        assert_eq!(result.value(0), "Hello World");
459        assert!(result.is_null(1));
460        assert_eq!(result.value(2), "Foo-Bar_Baz/Bax");
461        assert_eq!(result.value(3), "");
462        assert_eq!(result.value(4), "123 Abc 456def");
463        assert_eq!(result.value(5), "All Caps");
464        assert_eq!(result.value(6), "Already Correct");
465        Ok(())
466    }
467
468    /// Test that initcap works correctly on a sliced ASCII StringArray.
469    #[test]
470    fn test_initcap_sliced_ascii_array() -> Result<()> {
471        let array = StringArray::from(vec![
472            Some("hello world"),
473            Some("foo bar"),
474            Some("baz qux"),
475        ]);
476        // Slice to get only the last two elements. The resulting array's
477        // offsets are [11, 18, 25] (non-zero start), but value_data still
478        // contains the full original buffer.
479        let sliced = array.slice(1, 2);
480        let args: Vec<ArrayRef> = vec![Arc::new(sliced)];
481        let result = super::initcap::<i32>(&args)?;
482        let result = result.as_any().downcast_ref::<StringArray>().unwrap();
483
484        assert_eq!(result.len(), 2);
485        assert_eq!(result.value(0), "Foo Bar");
486        assert_eq!(result.value(1), "Baz Qux");
487
488        // The output values buffer should be compact
489        assert_eq!(*result.offsets().first().unwrap(), 0);
490        assert_eq!(
491            result.value_data().len(),
492            *result.offsets().last().unwrap() as usize
493        );
494        Ok(())
495    }
496
497    /// Test that initcap works correctly on a sliced ASCII LargeStringArray.
498    #[test]
499    fn test_initcap_sliced_ascii_large_array() -> Result<()> {
500        let array = LargeStringArray::from(vec![
501            Some("hello world"),
502            Some("foo bar"),
503            Some("baz qux"),
504        ]);
505        // Slice to get only the last two elements. The resulting array's
506        // offsets are [11, 18, 25] (non-zero start), but value_data still
507        // contains the full original buffer.
508        let sliced = array.slice(1, 2);
509        let args: Vec<ArrayRef> = vec![Arc::new(sliced)];
510        let result = super::initcap::<i64>(&args)?;
511        let result = result.as_any().downcast_ref::<LargeStringArray>().unwrap();
512
513        assert_eq!(result.len(), 2);
514        assert_eq!(result.value(0), "Foo Bar");
515        assert_eq!(result.value(1), "Baz Qux");
516
517        // The output values buffer should be compact
518        assert_eq!(*result.offsets().first().unwrap(), 0);
519        assert_eq!(
520            result.value_data().len(),
521            *result.offsets().last().unwrap() as usize
522        );
523        Ok(())
524    }
525}