datafusion_functions/unicode/
initcap.rs1use std::sync::Arc;
19
20use arrow::array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait};
21use arrow::buffer::{Buffer, OffsetBuffer};
22use arrow::datatypes::DataType;
23
24use crate::strings::{GenericStringArrayBuilder, StringViewArrayBuilder};
25use crate::utils::{make_scalar_function, utf8_to_str_type};
26use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
27use datafusion_common::types::logical_string;
28use datafusion_common::{Result, ScalarValue, exec_err};
29use datafusion_expr::{
30 Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
31 TypeSignatureClass, Volatility,
32};
33use datafusion_macros::user_doc;
34
35#[user_doc(
36 doc_section(label = "String Functions"),
37 description = "Capitalizes the first character in each word in the input string. \
38 Words are delimited by non-alphanumeric characters.",
39 syntax_example = "initcap(str)",
40 sql_example = r#"```sql
41> select initcap('apache datafusion');
42+------------------------------------+
43| initcap(Utf8("apache datafusion")) |
44+------------------------------------+
45| Apache Datafusion |
46+------------------------------------+
47```"#,
48 standard_argument(name = "str", prefix = "String"),
49 related_udf(name = "lower"),
50 related_udf(name = "upper")
51)]
52#[derive(Debug, PartialEq, Eq, Hash)]
53pub struct InitcapFunc {
54 signature: Signature,
55}
56
57impl Default for InitcapFunc {
58 fn default() -> Self {
59 InitcapFunc::new()
60 }
61}
62
63impl InitcapFunc {
64 pub fn new() -> Self {
65 Self {
66 signature: Signature::coercible(
67 vec![Coercion::new_exact(TypeSignatureClass::Native(
68 logical_string(),
69 ))],
70 Volatility::Immutable,
71 ),
72 }
73 }
74}
75
76impl ScalarUDFImpl for InitcapFunc {
77 fn name(&self) -> &str {
78 "initcap"
79 }
80
81 fn signature(&self) -> &Signature {
82 &self.signature
83 }
84
85 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
86 if let DataType::Utf8View = arg_types[0] {
87 Ok(DataType::Utf8View)
88 } else {
89 utf8_to_str_type(&arg_types[0], "initcap")
90 }
91 }
92
93 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
94 let arg = &args.args[0];
95
96 if let ColumnarValue::Scalar(scalar) = arg {
98 return match scalar {
99 ScalarValue::Utf8(None)
100 | ScalarValue::LargeUtf8(None)
101 | ScalarValue::Utf8View(None) => Ok(arg.clone()),
102 ScalarValue::Utf8(Some(s)) => {
103 let mut result = String::new();
104 initcap_string(s, &mut result);
105 Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))))
106 }
107 ScalarValue::LargeUtf8(Some(s)) => {
108 let mut result = String::new();
109 initcap_string(s, &mut result);
110 Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result))))
111 }
112 ScalarValue::Utf8View(Some(s)) => {
113 let mut result = String::new();
114 initcap_string(s, &mut result);
115 Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result))))
116 }
117 other => {
118 exec_err!(
119 "Unsupported data type {:?} for function `initcap`",
120 other.data_type()
121 )
122 }
123 };
124 }
125
126 let args = &args.args;
128 match args[0].data_type() {
129 DataType::Utf8 => make_scalar_function(initcap::<i32>, vec![])(args),
130 DataType::LargeUtf8 => make_scalar_function(initcap::<i64>, vec![])(args),
131 DataType::Utf8View => make_scalar_function(initcap_utf8view, vec![])(args),
132 other => {
133 exec_err!("Unsupported data type {other:?} for function `initcap`")
134 }
135 }
136 }
137
138 fn documentation(&self) -> Option<&Documentation> {
139 self.doc()
140 }
141}
142
143fn initcap<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
152 let string_array = as_generic_string_array::<T>(&args[0])?;
153
154 if string_array.is_ascii() {
155 return Ok(initcap_ascii_array(string_array));
156 }
157
158 let len = string_array.len();
159 let mut builder = GenericStringArrayBuilder::<T>::with_capacity(
160 len,
161 string_array.value_data().len(),
162 );
163
164 let mut container = String::new();
165 let nulls = string_array.nulls().cloned();
166 if let Some(ref n) = nulls {
167 for i in 0..len {
168 if n.is_null(i) {
169 builder.append_placeholder();
170 } else {
171 let s = unsafe { string_array.value_unchecked(i) };
173 initcap_string(s, &mut container);
174 builder.append_value(&container);
175 }
176 }
177 } else {
178 for i in 0..len {
179 let s = unsafe { string_array.value_unchecked(i) };
181 initcap_string(s, &mut container);
182 builder.append_value(&container);
183 }
184 }
185
186 Ok(Arc::new(builder.finish(nulls)?) as ArrayRef)
187}
188
189fn initcap_ascii_array<T: OffsetSizeTrait>(
192 string_array: &GenericStringArray<T>,
193) -> ArrayRef {
194 let offsets = string_array.offsets();
195 let src = string_array.value_data();
196 let first_offset = offsets.first().unwrap().as_usize();
197 let last_offset = offsets.last().unwrap().as_usize();
198
199 let mut out = Vec::with_capacity(last_offset - first_offset);
202
203 for window in offsets.windows(2) {
204 let start = window[0].as_usize();
205 let end = window[1].as_usize();
206
207 let mut prev_is_alnum = false;
208 for &b in &src[start..end] {
209 let converted = if prev_is_alnum {
210 b.to_ascii_lowercase()
211 } else {
212 b.to_ascii_uppercase()
213 };
214 out.push(converted);
215 prev_is_alnum = b.is_ascii_alphanumeric();
216 }
217 }
218
219 let values = Buffer::from_vec(out);
220 let out_offsets = if first_offset == 0 {
221 offsets.clone()
222 } else {
223 let rebased_offsets = offsets
226 .iter()
227 .map(|offset| T::usize_as(offset.as_usize() - first_offset))
228 .collect::<Vec<_>>();
229 OffsetBuffer::<T>::new(rebased_offsets.into())
230 };
231
232 Arc::new(unsafe {
237 GenericStringArray::<T>::new_unchecked(
238 out_offsets,
239 values,
240 string_array.nulls().cloned(),
241 )
242 })
243}
244
245fn initcap_utf8view(args: &[ArrayRef]) -> Result<ArrayRef> {
246 let string_view_array = as_string_view_array(&args[0])?;
247 let len = string_view_array.len();
248 let mut builder = StringViewArrayBuilder::with_capacity(len);
249 let mut container = String::new();
250
251 let nulls = string_view_array.nulls().cloned();
252 if let Some(ref n) = nulls {
253 for i in 0..len {
254 if n.is_null(i) {
255 builder.append_placeholder();
256 } else {
257 let s = unsafe { string_view_array.value_unchecked(i) };
259 initcap_string(s, &mut container);
260 builder.append_value(&container);
261 }
262 }
263 } else {
264 for i in 0..len {
265 let s = unsafe { string_view_array.value_unchecked(i) };
267 initcap_string(s, &mut container);
268 builder.append_value(&container);
269 }
270 }
271
272 Ok(Arc::new(builder.finish(nulls)?) as ArrayRef)
273}
274
275fn initcap_string(input: &str, container: &mut String) {
276 container.clear();
277 let mut prev_is_alphanumeric = false;
278
279 if input.is_ascii() {
280 container.reserve(input.len());
281 let out = unsafe { container.as_mut_vec() };
283 for &b in input.as_bytes() {
284 if prev_is_alphanumeric {
285 out.push(b.to_ascii_lowercase());
286 } else {
287 out.push(b.to_ascii_uppercase());
288 }
289 prev_is_alphanumeric = b.is_ascii_alphanumeric();
290 }
291 } else {
292 for c in input.chars() {
293 if prev_is_alphanumeric {
294 container.extend(c.to_lowercase());
295 } else {
296 container.extend(c.to_uppercase());
297 }
298 prev_is_alphanumeric = c.is_alphanumeric();
299 }
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use crate::unicode::initcap::InitcapFunc;
306 use crate::utils::test::test_function;
307 use arrow::array::{Array, ArrayRef, LargeStringArray, StringArray, StringViewArray};
308 use arrow::datatypes::DataType::{Utf8, Utf8View};
309 use datafusion_common::{Result, ScalarValue};
310 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
311 use std::sync::Arc;
312
313 #[test]
314 fn test_functions() -> Result<()> {
315 test_function!(
316 InitcapFunc::new(),
317 vec![ColumnarValue::Scalar(ScalarValue::from("hi THOMAS"))],
318 Ok(Some("Hi Thomas")),
319 &str,
320 Utf8,
321 StringArray
322 );
323 test_function!(
324 InitcapFunc::new(),
325 vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some(
326 "êM ả ñAnDÚ ÁrBOL ОлЕГ ИвАНОВИч ÍslENsku ÞjóðaRiNNaR εΛλΗΝΙκΉ"
327 .to_string()
328 )))],
329 Ok(Some(
330 "Êm Ả Ñandú Árbol Олег Иванович Íslensku Þjóðarinnar Ελληνική"
331 )),
332 &str,
333 Utf8,
334 StringArray
335 );
336 test_function!(
337 InitcapFunc::new(),
338 vec![ColumnarValue::Scalar(ScalarValue::from(""))],
339 Ok(Some("")),
340 &str,
341 Utf8,
342 StringArray
343 );
344 test_function!(
345 InitcapFunc::new(),
346 vec![ColumnarValue::Scalar(ScalarValue::from(""))],
347 Ok(Some("")),
348 &str,
349 Utf8,
350 StringArray
351 );
352 test_function!(
353 InitcapFunc::new(),
354 vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))],
355 Ok(None),
356 &str,
357 Utf8,
358 StringArray
359 );
360
361 test_function!(
362 InitcapFunc::new(),
363 vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
364 "hi THOMAS".to_string()
365 )))],
366 Ok(Some("Hi Thomas")),
367 &str,
368 Utf8View,
369 StringViewArray
370 );
371 test_function!(
372 InitcapFunc::new(),
373 vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
374 "hi THOMAS wIth M0re ThAN 12 ChaRs".to_string()
375 )))],
376 Ok(Some("Hi Thomas With M0re Than 12 Chars")),
377 &str,
378 Utf8View,
379 StringViewArray
380 );
381 test_function!(
382 InitcapFunc::new(),
383 vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
384 "đẸp đẼ êM ả ñAnDÚ ÁrBOL ОлЕГ ИвАНОВИч ÍslENsku ÞjóðaRiNNaR εΛλΗΝΙκΉ"
385 .to_string()
386 )))],
387 Ok(Some(
388 "Đẹp Đẽ Êm Ả Ñandú Árbol Олег Иванович Íslensku Þjóðarinnar Ελληνική"
389 )),
390 &str,
391 Utf8View,
392 StringViewArray
393 );
394 test_function!(
395 InitcapFunc::new(),
396 vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
397 "".to_string()
398 )))],
399 Ok(Some("")),
400 &str,
401 Utf8View,
402 StringViewArray
403 );
404 test_function!(
405 InitcapFunc::new(),
406 vec![ColumnarValue::Scalar(ScalarValue::Utf8View(None))],
407 Ok(None),
408 &str,
409 Utf8View,
410 StringViewArray
411 );
412
413 Ok(())
414 }
415
416 #[test]
417 fn test_initcap_ascii_array() -> Result<()> {
418 let array = StringArray::from(vec![
419 Some("hello world"),
420 None,
421 Some("foo-bar_baz/baX"),
422 Some(""),
423 Some("123 abc 456DEF"),
424 Some("ALL CAPS"),
425 Some("already correct"),
426 ]);
427 let args: Vec<ArrayRef> = vec![Arc::new(array)];
428 let result = super::initcap::<i32>(&args)?;
429 let result = result.as_any().downcast_ref::<StringArray>().unwrap();
430
431 assert_eq!(result.len(), 7);
432 assert_eq!(result.value(0), "Hello World");
433 assert!(result.is_null(1));
434 assert_eq!(result.value(2), "Foo-Bar_Baz/Bax");
435 assert_eq!(result.value(3), "");
436 assert_eq!(result.value(4), "123 Abc 456def");
437 assert_eq!(result.value(5), "All Caps");
438 assert_eq!(result.value(6), "Already Correct");
439 Ok(())
440 }
441
442 #[test]
443 fn test_initcap_ascii_large_array() -> Result<()> {
444 let array = LargeStringArray::from(vec![
445 Some("hello world"),
446 None,
447 Some("foo-bar_baz/baX"),
448 Some(""),
449 Some("123 abc 456DEF"),
450 Some("ALL CAPS"),
451 Some("already correct"),
452 ]);
453 let args: Vec<ArrayRef> = vec![Arc::new(array)];
454 let result = super::initcap::<i64>(&args)?;
455 let result = result.as_any().downcast_ref::<LargeStringArray>().unwrap();
456
457 assert_eq!(result.len(), 7);
458 assert_eq!(result.value(0), "Hello World");
459 assert!(result.is_null(1));
460 assert_eq!(result.value(2), "Foo-Bar_Baz/Bax");
461 assert_eq!(result.value(3), "");
462 assert_eq!(result.value(4), "123 Abc 456def");
463 assert_eq!(result.value(5), "All Caps");
464 assert_eq!(result.value(6), "Already Correct");
465 Ok(())
466 }
467
468 #[test]
470 fn test_initcap_sliced_ascii_array() -> Result<()> {
471 let array = StringArray::from(vec![
472 Some("hello world"),
473 Some("foo bar"),
474 Some("baz qux"),
475 ]);
476 let sliced = array.slice(1, 2);
480 let args: Vec<ArrayRef> = vec![Arc::new(sliced)];
481 let result = super::initcap::<i32>(&args)?;
482 let result = result.as_any().downcast_ref::<StringArray>().unwrap();
483
484 assert_eq!(result.len(), 2);
485 assert_eq!(result.value(0), "Foo Bar");
486 assert_eq!(result.value(1), "Baz Qux");
487
488 assert_eq!(*result.offsets().first().unwrap(), 0);
490 assert_eq!(
491 result.value_data().len(),
492 *result.offsets().last().unwrap() as usize
493 );
494 Ok(())
495 }
496
497 #[test]
499 fn test_initcap_sliced_ascii_large_array() -> Result<()> {
500 let array = LargeStringArray::from(vec![
501 Some("hello world"),
502 Some("foo bar"),
503 Some("baz qux"),
504 ]);
505 let sliced = array.slice(1, 2);
509 let args: Vec<ArrayRef> = vec![Arc::new(sliced)];
510 let result = super::initcap::<i64>(&args)?;
511 let result = result.as_any().downcast_ref::<LargeStringArray>().unwrap();
512
513 assert_eq!(result.len(), 2);
514 assert_eq!(result.value(0), "Foo Bar");
515 assert_eq!(result.value(1), "Baz Qux");
516
517 assert_eq!(*result.offsets().first().unwrap(), 0);
519 assert_eq!(
520 result.value_data().len(),
521 *result.offsets().last().unwrap() as usize
522 );
523 Ok(())
524 }
525}