datafusion_functions/unicode/
initcap.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22 Array, ArrayRef, GenericStringArray, GenericStringBuilder, OffsetSizeTrait,
23 StringViewBuilder,
24};
25use arrow::buffer::{Buffer, OffsetBuffer};
26use arrow::datatypes::DataType;
27
28use crate::utils::{make_scalar_function, utf8_to_str_type};
29use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
30use datafusion_common::types::logical_string;
31use datafusion_common::{Result, ScalarValue, exec_err};
32use datafusion_expr::{
33 Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignatureClass,
34 Volatility,
35};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39 doc_section(label = "String Functions"),
40 description = "Capitalizes the first character in each word in the input string. \
41 Words are delimited by non-alphanumeric characters.",
42 syntax_example = "initcap(str)",
43 sql_example = r#"```sql
44> select initcap('apache datafusion');
45+------------------------------------+
46| initcap(Utf8("apache datafusion")) |
47+------------------------------------+
48| Apache Datafusion |
49+------------------------------------+
50```"#,
51 standard_argument(name = "str", prefix = "String"),
52 related_udf(name = "lower"),
53 related_udf(name = "upper")
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct InitcapFunc {
57 signature: Signature,
58}
59
60impl Default for InitcapFunc {
61 fn default() -> Self {
62 InitcapFunc::new()
63 }
64}
65
66impl InitcapFunc {
67 pub fn new() -> Self {
68 Self {
69 signature: Signature::coercible(
70 vec![Coercion::new_exact(TypeSignatureClass::Native(
71 logical_string(),
72 ))],
73 Volatility::Immutable,
74 ),
75 }
76 }
77}
78
79impl ScalarUDFImpl for InitcapFunc {
80 fn as_any(&self) -> &dyn Any {
81 self
82 }
83
84 fn name(&self) -> &str {
85 "initcap"
86 }
87
88 fn signature(&self) -> &Signature {
89 &self.signature
90 }
91
92 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
93 if let DataType::Utf8View = arg_types[0] {
94 Ok(DataType::Utf8View)
95 } else {
96 utf8_to_str_type(&arg_types[0], "initcap")
97 }
98 }
99
100 fn invoke_with_args(
101 &self,
102 args: datafusion_expr::ScalarFunctionArgs,
103 ) -> Result<ColumnarValue> {
104 let arg = &args.args[0];
105
106 if let ColumnarValue::Scalar(scalar) = arg {
108 return match scalar {
109 ScalarValue::Utf8(None)
110 | ScalarValue::LargeUtf8(None)
111 | ScalarValue::Utf8View(None) => Ok(arg.clone()),
112 ScalarValue::Utf8(Some(s)) => {
113 let mut result = String::new();
114 initcap_string(s, &mut result);
115 Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))))
116 }
117 ScalarValue::LargeUtf8(Some(s)) => {
118 let mut result = String::new();
119 initcap_string(s, &mut result);
120 Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result))))
121 }
122 ScalarValue::Utf8View(Some(s)) => {
123 let mut result = String::new();
124 initcap_string(s, &mut result);
125 Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result))))
126 }
127 other => {
128 exec_err!(
129 "Unsupported data type {:?} for function `initcap`",
130 other.data_type()
131 )
132 }
133 };
134 }
135
136 let args = &args.args;
138 match args[0].data_type() {
139 DataType::Utf8 => make_scalar_function(initcap::<i32>, vec![])(args),
140 DataType::LargeUtf8 => make_scalar_function(initcap::<i64>, vec![])(args),
141 DataType::Utf8View => make_scalar_function(initcap_utf8view, vec![])(args),
142 other => {
143 exec_err!("Unsupported data type {other:?} for function `initcap`")
144 }
145 }
146 }
147
148 fn documentation(&self) -> Option<&Documentation> {
149 self.doc()
150 }
151}
152
153fn initcap<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
162 let string_array = as_generic_string_array::<T>(&args[0])?;
163
164 if string_array.is_ascii() {
165 return Ok(initcap_ascii_array(string_array));
166 }
167
168 let mut builder = GenericStringBuilder::<T>::with_capacity(
169 string_array.len(),
170 string_array.value_data().len(),
171 );
172
173 let mut container = String::new();
174 string_array.iter().for_each(|str| match str {
175 Some(s) => {
176 initcap_string(s, &mut container);
177 builder.append_value(&container);
178 }
179 None => builder.append_null(),
180 });
181
182 Ok(Arc::new(builder.finish()) as ArrayRef)
183}
184
185fn initcap_ascii_array<T: OffsetSizeTrait>(
188 string_array: &GenericStringArray<T>,
189) -> ArrayRef {
190 let offsets = string_array.offsets();
191 let src = string_array.value_data();
192 let first_offset = offsets.first().unwrap().as_usize();
193 let last_offset = offsets.last().unwrap().as_usize();
194
195 let mut out = Vec::with_capacity(last_offset - first_offset);
198
199 for window in offsets.windows(2) {
200 let start = window[0].as_usize();
201 let end = window[1].as_usize();
202
203 let mut prev_is_alnum = false;
204 for &b in &src[start..end] {
205 let converted = if prev_is_alnum {
206 b.to_ascii_lowercase()
207 } else {
208 b.to_ascii_uppercase()
209 };
210 out.push(converted);
211 prev_is_alnum = b.is_ascii_alphanumeric();
212 }
213 }
214
215 let values = Buffer::from_vec(out);
216 let out_offsets = if first_offset == 0 {
217 offsets.clone()
218 } else {
219 let rebased_offsets = offsets
222 .iter()
223 .map(|offset| T::usize_as(offset.as_usize() - first_offset))
224 .collect::<Vec<_>>();
225 OffsetBuffer::<T>::new(rebased_offsets.into())
226 };
227
228 Arc::new(unsafe {
233 GenericStringArray::<T>::new_unchecked(
234 out_offsets,
235 values,
236 string_array.nulls().cloned(),
237 )
238 })
239}
240
241fn initcap_utf8view(args: &[ArrayRef]) -> Result<ArrayRef> {
242 let string_view_array = as_string_view_array(&args[0])?;
243 let mut builder = StringViewBuilder::with_capacity(string_view_array.len());
244 let mut container = String::new();
245
246 string_view_array.iter().for_each(|str| match str {
247 Some(s) => {
248 initcap_string(s, &mut container);
249 builder.append_value(&container);
250 }
251 None => builder.append_null(),
252 });
253
254 Ok(Arc::new(builder.finish()) as ArrayRef)
255}
256
257fn initcap_string(input: &str, container: &mut String) {
258 container.clear();
259 let mut prev_is_alphanumeric = false;
260
261 if input.is_ascii() {
262 container.reserve(input.len());
263 let out = unsafe { container.as_mut_vec() };
265 for &b in input.as_bytes() {
266 if prev_is_alphanumeric {
267 out.push(b.to_ascii_lowercase());
268 } else {
269 out.push(b.to_ascii_uppercase());
270 }
271 prev_is_alphanumeric = b.is_ascii_alphanumeric();
272 }
273 } else {
274 for c in input.chars() {
275 if prev_is_alphanumeric {
276 container.extend(c.to_lowercase());
277 } else {
278 container.extend(c.to_uppercase());
279 }
280 prev_is_alphanumeric = c.is_alphanumeric();
281 }
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use crate::unicode::initcap::InitcapFunc;
288 use crate::utils::test::test_function;
289 use arrow::array::{Array, ArrayRef, LargeStringArray, StringArray, StringViewArray};
290 use arrow::datatypes::DataType::{Utf8, Utf8View};
291 use datafusion_common::{Result, ScalarValue};
292 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
293 use std::sync::Arc;
294
295 #[test]
296 fn test_functions() -> Result<()> {
297 test_function!(
298 InitcapFunc::new(),
299 vec![ColumnarValue::Scalar(ScalarValue::from("hi THOMAS"))],
300 Ok(Some("Hi Thomas")),
301 &str,
302 Utf8,
303 StringArray
304 );
305 test_function!(
306 InitcapFunc::new(),
307 vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some(
308 "êM ả ñAnDÚ ÁrBOL ОлЕГ ИвАНОВИч ÍslENsku ÞjóðaRiNNaR εΛλΗΝΙκΉ"
309 .to_string()
310 )))],
311 Ok(Some(
312 "Êm Ả Ñandú Árbol Олег Иванович Íslensku Þjóðarinnar Ελληνική"
313 )),
314 &str,
315 Utf8,
316 StringArray
317 );
318 test_function!(
319 InitcapFunc::new(),
320 vec![ColumnarValue::Scalar(ScalarValue::from(""))],
321 Ok(Some("")),
322 &str,
323 Utf8,
324 StringArray
325 );
326 test_function!(
327 InitcapFunc::new(),
328 vec![ColumnarValue::Scalar(ScalarValue::from(""))],
329 Ok(Some("")),
330 &str,
331 Utf8,
332 StringArray
333 );
334 test_function!(
335 InitcapFunc::new(),
336 vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))],
337 Ok(None),
338 &str,
339 Utf8,
340 StringArray
341 );
342
343 test_function!(
344 InitcapFunc::new(),
345 vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
346 "hi THOMAS".to_string()
347 )))],
348 Ok(Some("Hi Thomas")),
349 &str,
350 Utf8View,
351 StringViewArray
352 );
353 test_function!(
354 InitcapFunc::new(),
355 vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
356 "hi THOMAS wIth M0re ThAN 12 ChaRs".to_string()
357 )))],
358 Ok(Some("Hi Thomas With M0re Than 12 Chars")),
359 &str,
360 Utf8View,
361 StringViewArray
362 );
363 test_function!(
364 InitcapFunc::new(),
365 vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
366 "đẸp đẼ êM ả ñAnDÚ ÁrBOL ОлЕГ ИвАНОВИч ÍslENsku ÞjóðaRiNNaR εΛλΗΝΙκΉ"
367 .to_string()
368 )))],
369 Ok(Some(
370 "Đẹp Đẽ Êm Ả Ñandú Árbol Олег Иванович Íslensku Þjóðarinnar Ελληνική"
371 )),
372 &str,
373 Utf8View,
374 StringViewArray
375 );
376 test_function!(
377 InitcapFunc::new(),
378 vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
379 "".to_string()
380 )))],
381 Ok(Some("")),
382 &str,
383 Utf8View,
384 StringViewArray
385 );
386 test_function!(
387 InitcapFunc::new(),
388 vec![ColumnarValue::Scalar(ScalarValue::Utf8View(None))],
389 Ok(None),
390 &str,
391 Utf8View,
392 StringViewArray
393 );
394
395 Ok(())
396 }
397
398 #[test]
399 fn test_initcap_ascii_array() -> Result<()> {
400 let array = StringArray::from(vec![
401 Some("hello world"),
402 None,
403 Some("foo-bar_baz/baX"),
404 Some(""),
405 Some("123 abc 456DEF"),
406 Some("ALL CAPS"),
407 Some("already correct"),
408 ]);
409 let args: Vec<ArrayRef> = vec![Arc::new(array)];
410 let result = super::initcap::<i32>(&args)?;
411 let result = result.as_any().downcast_ref::<StringArray>().unwrap();
412
413 assert_eq!(result.len(), 7);
414 assert_eq!(result.value(0), "Hello World");
415 assert!(result.is_null(1));
416 assert_eq!(result.value(2), "Foo-Bar_Baz/Bax");
417 assert_eq!(result.value(3), "");
418 assert_eq!(result.value(4), "123 Abc 456def");
419 assert_eq!(result.value(5), "All Caps");
420 assert_eq!(result.value(6), "Already Correct");
421 Ok(())
422 }
423
424 #[test]
425 fn test_initcap_ascii_large_array() -> Result<()> {
426 let array = LargeStringArray::from(vec![
427 Some("hello world"),
428 None,
429 Some("foo-bar_baz/baX"),
430 Some(""),
431 Some("123 abc 456DEF"),
432 Some("ALL CAPS"),
433 Some("already correct"),
434 ]);
435 let args: Vec<ArrayRef> = vec![Arc::new(array)];
436 let result = super::initcap::<i64>(&args)?;
437 let result = result.as_any().downcast_ref::<LargeStringArray>().unwrap();
438
439 assert_eq!(result.len(), 7);
440 assert_eq!(result.value(0), "Hello World");
441 assert!(result.is_null(1));
442 assert_eq!(result.value(2), "Foo-Bar_Baz/Bax");
443 assert_eq!(result.value(3), "");
444 assert_eq!(result.value(4), "123 Abc 456def");
445 assert_eq!(result.value(5), "All Caps");
446 assert_eq!(result.value(6), "Already Correct");
447 Ok(())
448 }
449
450 #[test]
452 fn test_initcap_sliced_ascii_array() -> Result<()> {
453 let array = StringArray::from(vec![
454 Some("hello world"),
455 Some("foo bar"),
456 Some("baz qux"),
457 ]);
458 let sliced = array.slice(1, 2);
462 let args: Vec<ArrayRef> = vec![Arc::new(sliced)];
463 let result = super::initcap::<i32>(&args)?;
464 let result = result.as_any().downcast_ref::<StringArray>().unwrap();
465
466 assert_eq!(result.len(), 2);
467 assert_eq!(result.value(0), "Foo Bar");
468 assert_eq!(result.value(1), "Baz Qux");
469
470 assert_eq!(*result.offsets().first().unwrap(), 0);
472 assert_eq!(
473 result.value_data().len(),
474 *result.offsets().last().unwrap() as usize
475 );
476 Ok(())
477 }
478
479 #[test]
481 fn test_initcap_sliced_ascii_large_array() -> Result<()> {
482 let array = LargeStringArray::from(vec![
483 Some("hello world"),
484 Some("foo bar"),
485 Some("baz qux"),
486 ]);
487 let sliced = array.slice(1, 2);
491 let args: Vec<ArrayRef> = vec![Arc::new(sliced)];
492 let result = super::initcap::<i64>(&args)?;
493 let result = result.as_any().downcast_ref::<LargeStringArray>().unwrap();
494
495 assert_eq!(result.len(), 2);
496 assert_eq!(result.value(0), "Foo Bar");
497 assert_eq!(result.value(1), "Baz Qux");
498
499 assert_eq!(*result.offsets().first().unwrap(), 0);
501 assert_eq!(
502 result.value_data().len(),
503 *result.offsets().last().unwrap() as usize
504 );
505 Ok(())
506 }
507}