datafusion_functions/string/
common.rs1use std::sync::Arc;
21
22use crate::strings::{
23 GenericStringArrayBuilder, STRING_VIEW_INIT_BLOCK_SIZE, STRING_VIEW_MAX_BLOCK_SIZE,
24 StringViewArrayBuilder, append_view,
25};
26use arrow::array::{
27 Array, ArrayRef, GenericStringArray, NullBufferBuilder, OffsetSizeTrait,
28 StringViewArray, new_null_array,
29};
30use arrow::buffer::{Buffer, OffsetBuffer, ScalarBuffer};
31use arrow::datatypes::DataType;
32use datafusion_common::Result;
33use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
34use datafusion_common::{ScalarValue, exec_err};
35use datafusion_expr::ColumnarValue;
36
37pub(crate) trait Trimmer {
43 fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32);
44
45 fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32);
48}
49
50#[inline]
52fn leading_bytes(bytes: &[u8], byte: u8) -> usize {
53 bytes.iter().take_while(|&&b| b == byte).count()
54}
55
56#[inline]
58fn trailing_bytes(bytes: &[u8], byte: u8) -> usize {
59 bytes.iter().rev().take_while(|&&b| b == byte).count()
60}
61
62pub(crate) struct TrimLeft;
64
65impl Trimmer for TrimLeft {
66 #[inline]
67 fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) {
68 if pattern.len() == 1 && pattern[0].is_ascii() {
69 return Self::trim_ascii_char(input, pattern[0] as u8);
70 }
71 let trimmed = input.trim_start_matches(pattern);
72 let offset = (input.len() - trimmed.len()) as u32;
73 (trimmed, offset)
74 }
75
76 #[inline]
77 fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32) {
78 let start = leading_bytes(input.as_bytes(), byte);
79 (&input[start..], start as u32)
80 }
81}
82
83pub(crate) struct TrimRight;
85
86impl Trimmer for TrimRight {
87 #[inline]
88 fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) {
89 if pattern.len() == 1 && pattern[0].is_ascii() {
90 return Self::trim_ascii_char(input, pattern[0] as u8);
91 }
92 let trimmed = input.trim_end_matches(pattern);
93 (trimmed, 0)
94 }
95
96 #[inline]
97 fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32) {
98 let bytes = input.as_bytes();
99 let end = bytes.len() - trailing_bytes(bytes, byte);
100 (&input[..end], 0)
101 }
102}
103
104pub(crate) struct TrimBoth;
106
107impl Trimmer for TrimBoth {
108 #[inline]
109 fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) {
110 if pattern.len() == 1 && pattern[0].is_ascii() {
111 return Self::trim_ascii_char(input, pattern[0] as u8);
112 }
113 let left_trimmed = input.trim_start_matches(pattern);
114 let offset = (input.len() - left_trimmed.len()) as u32;
115 let trimmed = left_trimmed.trim_end_matches(pattern);
116 (trimmed, offset)
117 }
118
119 #[inline]
120 fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32) {
121 let bytes = input.as_bytes();
122 let start = leading_bytes(bytes, byte);
123 let end = bytes.len() - trailing_bytes(&bytes[start..], byte);
124 (&input[start..end], start as u32)
125 }
126}
127
128pub(crate) fn general_trim<T: OffsetSizeTrait, Tr: Trimmer>(
129 args: &[ArrayRef],
130 use_string_view: bool,
131) -> Result<ArrayRef> {
132 if use_string_view {
133 string_view_trim::<Tr>(args)
134 } else {
135 string_trim::<T, Tr>(args)
136 }
137}
138
139fn string_view_trim<Tr: Trimmer>(args: &[ArrayRef]) -> Result<ArrayRef> {
145 let string_view_array = as_string_view_array(&args[0])?;
146 let mut views_buf = Vec::with_capacity(string_view_array.len());
147 let mut null_builder = NullBufferBuilder::new(string_view_array.len());
148
149 match args.len() {
150 1 => {
151 for (src_str_opt, raw_view) in string_view_array
153 .iter()
154 .zip(string_view_array.views().iter())
155 {
156 if let Some(src_str) = src_str_opt {
157 let (trimmed, offset) = Tr::trim_ascii_char(src_str, b' ');
158 append_view(&mut views_buf, raw_view, trimmed, offset);
159 null_builder.append_non_null();
160 } else {
161 null_builder.append_null();
162 views_buf.push(0);
163 }
164 }
165 }
166 2 => {
167 let characters_array = as_string_view_array(&args[1])?;
168
169 if characters_array.len() == 1 {
170 if characters_array.is_null(0) {
172 return Ok(new_null_array(
173 &DataType::Utf8View,
174 string_view_array.len(),
175 ));
176 }
177
178 let pattern: Vec<char> = characters_array.value(0).chars().collect();
179 for (src_str_opt, raw_view) in string_view_array
180 .iter()
181 .zip(string_view_array.views().iter())
182 {
183 trim_and_append_view::<Tr>(
184 src_str_opt,
185 &pattern,
186 &mut views_buf,
187 &mut null_builder,
188 raw_view,
189 );
190 }
191 } else {
192 let mut pattern: Vec<char> = Vec::new();
194 for ((src_str_opt, raw_view), characters_opt) in string_view_array
195 .iter()
196 .zip(string_view_array.views().iter())
197 .zip(characters_array.iter())
198 {
199 if let (Some(src_str), Some(characters)) =
200 (src_str_opt, characters_opt)
201 {
202 pattern.clear();
203 pattern.extend(characters.chars());
204 let (trimmed, offset) = Tr::trim(src_str, &pattern);
205 append_view(&mut views_buf, raw_view, trimmed, offset);
206 null_builder.append_non_null();
207 } else {
208 null_builder.append_null();
209 views_buf.push(0);
210 }
211 }
212 }
213 }
214 other => {
215 return exec_err!(
216 "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2."
217 );
218 }
219 }
220
221 let views_buf = ScalarBuffer::from(views_buf);
222 let nulls_buf = null_builder.finish();
223
224 unsafe {
229 let array = StringViewArray::new_unchecked(
230 views_buf,
231 string_view_array.data_buffers().to_vec(),
232 nulls_buf,
233 );
234 Ok(Arc::new(array) as ArrayRef)
235 }
236}
237
238#[inline]
248fn trim_and_append_view<Tr: Trimmer>(
249 src_str_opt: Option<&str>,
250 pattern: &[char],
251 views_buf: &mut Vec<u128>,
252 null_builder: &mut NullBufferBuilder,
253 original_view: &u128,
254) {
255 if let Some(src_str) = src_str_opt {
256 let (trimmed, offset) = Tr::trim(src_str, pattern);
257 append_view(views_buf, original_view, trimmed, offset);
258 null_builder.append_non_null();
259 } else {
260 null_builder.append_null();
261 views_buf.push(0);
262 }
263}
264
265fn string_trim<T: OffsetSizeTrait, Tr: Trimmer>(args: &[ArrayRef]) -> Result<ArrayRef> {
271 let string_array = as_generic_string_array::<T>(&args[0])?;
272
273 match args.len() {
274 1 => {
275 let result = string_array
277 .iter()
278 .map(|string| string.map(|s| Tr::trim_ascii_char(s, b' ').0))
279 .collect::<GenericStringArray<T>>();
280
281 Ok(Arc::new(result) as ArrayRef)
282 }
283 2 => {
284 let characters_array = as_generic_string_array::<T>(&args[1])?;
285
286 if characters_array.len() == 1 {
287 if characters_array.is_null(0) {
289 return Ok(new_null_array(
290 string_array.data_type(),
291 string_array.len(),
292 ));
293 }
294
295 let pattern: Vec<char> = characters_array.value(0).chars().collect();
296 let result = string_array
297 .iter()
298 .map(|item| item.map(|s| Tr::trim(s, &pattern).0))
299 .collect::<GenericStringArray<T>>();
300 return Ok(Arc::new(result) as ArrayRef);
301 }
302
303 let mut pattern: Vec<char> = Vec::new();
305 let result = string_array
306 .iter()
307 .zip(characters_array.iter())
308 .map(|(string, characters)| match (string, characters) {
309 (Some(s), Some(c)) => {
310 pattern.clear();
311 pattern.extend(c.chars());
312 Some(Tr::trim(s, &pattern).0)
313 }
314 _ => None,
315 })
316 .collect::<GenericStringArray<T>>();
317
318 Ok(Arc::new(result) as ArrayRef)
319 }
320 other => {
321 exec_err!(
322 "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2."
323 )
324 }
325 }
326}
327
328pub(crate) fn to_lower(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
329 case_conversion(args, true, name)
330}
331
332pub(crate) fn to_upper(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
333 case_conversion(args, false, name)
334}
335
336#[inline]
337fn unicode_case(s: &str, lower: bool) -> String {
338 if lower {
339 s.to_lowercase()
340 } else {
341 s.to_uppercase()
342 }
343}
344
345fn case_conversion(
346 args: &[ColumnarValue],
347 lower: bool,
348 name: &str,
349) -> Result<ColumnarValue> {
350 match &args[0] {
351 ColumnarValue::Array(array) => match array.data_type() {
352 DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::<i32>(
353 array, lower,
354 )?)),
355 DataType::LargeUtf8 => Ok(ColumnarValue::Array(
356 case_conversion_array::<i64>(array, lower)?,
357 )),
358 DataType::Utf8View => {
359 let string_array = as_string_view_array(array)?;
360 if string_array.is_ascii() {
361 return Ok(ColumnarValue::Array(Arc::new(
362 case_conversion_utf8view_ascii(string_array, lower),
363 )));
364 }
365 let item_len = string_array.len();
366 let nulls = string_array.nulls().cloned();
368 let mut builder = StringViewArrayBuilder::with_capacity(item_len);
369
370 if let Some(ref n) = nulls {
371 for i in 0..item_len {
372 if n.is_null(i) {
373 builder.append_placeholder();
374 } else {
375 let s = unsafe { string_array.value_unchecked(i) };
377 builder.append_value(&unicode_case(s, lower));
378 }
379 }
380 } else {
381 for i in 0..item_len {
382 let s = unsafe { string_array.value_unchecked(i) };
384 builder.append_value(&unicode_case(s, lower));
385 }
386 }
387
388 Ok(ColumnarValue::Array(Arc::new(builder.finish(nulls)?)))
389 }
390 other => exec_err!("Unsupported data type {other:?} for function {name}"),
391 },
392 ColumnarValue::Scalar(scalar) => match scalar {
393 ScalarValue::Utf8(a) => {
394 let result = a.as_ref().map(|x| unicode_case(x, lower));
395 Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
396 }
397 ScalarValue::LargeUtf8(a) => {
398 let result = a.as_ref().map(|x| unicode_case(x, lower));
399 Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result)))
400 }
401 ScalarValue::Utf8View(a) => {
402 let result = a.as_ref().map(|x| unicode_case(x, lower));
403 Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(result)))
404 }
405 other => exec_err!("Unsupported data type {other:?} for function {name}"),
406 },
407 }
408}
409
410fn case_conversion_array<O: OffsetSizeTrait>(
411 array: &ArrayRef,
412 lower: bool,
413) -> Result<ArrayRef> {
414 const PRE_ALLOC_BYTES: usize = 8;
415
416 let string_array = as_generic_string_array::<O>(array)?;
417 if string_array.is_ascii() {
418 return case_conversion_ascii_array::<O>(string_array, lower);
419 }
420
421 let item_len = string_array.len();
423 let offsets = string_array.value_offsets();
424 let start = offsets.first().unwrap().as_usize();
425 let end = offsets.last().unwrap().as_usize();
426 let capacity = (end - start) + PRE_ALLOC_BYTES;
427 let nulls = string_array.nulls().cloned();
429 let mut builder = GenericStringArrayBuilder::<O>::with_capacity(item_len, capacity);
430
431 if let Some(ref n) = nulls {
432 for i in 0..item_len {
433 if n.is_null(i) {
434 builder.append_placeholder();
435 } else {
436 let s = unsafe { string_array.value_unchecked(i) };
438 builder.append_value(&unicode_case(s, lower));
439 }
440 }
441 } else {
442 for i in 0..item_len {
443 let s = unsafe { string_array.value_unchecked(i) };
445 builder.append_value(&unicode_case(s, lower));
446 }
447 }
448 Ok(Arc::new(builder.finish(nulls)?))
449}
450
451fn case_conversion_utf8view_ascii(
453 array: &StringViewArray,
454 lower: bool,
455) -> StringViewArray {
456 if lower {
458 case_conversion_utf8view_ascii_inner(array, u8::to_ascii_lowercase)
459 } else {
460 case_conversion_utf8view_ascii_inner(array, u8::to_ascii_uppercase)
461 }
462}
463
464fn case_conversion_utf8view_ascii_inner<F: Fn(&u8) -> u8>(
470 array: &StringViewArray,
471 convert: F,
472) -> StringViewArray {
473 let item_len = array.len();
474 let views = array.views();
475 let data_buffers = array.data_buffers();
476 let nulls = array.nulls();
477
478 let mut new_views: Vec<u128> = Vec::with_capacity(item_len);
479 let mut in_progress: Vec<u8> = Vec::new();
483 let mut completed: Vec<Buffer> = Vec::new();
484 let mut block_size: u32 = STRING_VIEW_INIT_BLOCK_SIZE;
485
486 for i in 0..item_len {
487 if nulls.is_some_and(|n| n.is_null(i)) {
488 new_views.push(0);
491 continue;
492 }
493 let view = views[i];
494 let len = view as u32 as usize;
496 if len == 0 {
497 new_views.push(0);
498 continue;
499 }
500 let mut bytes = view.to_le_bytes();
501 if len <= 12 {
502 for b in &mut bytes[4..4 + len] {
505 *b = convert(b);
506 }
507 new_views.push(u128::from_le_bytes(bytes));
508 } else {
509 let required_cap = in_progress.len() + len;
515 if in_progress.capacity() < required_cap {
516 if !in_progress.is_empty() {
517 completed.push(Buffer::from_vec(std::mem::take(&mut in_progress)));
518 }
519 if block_size < STRING_VIEW_MAX_BLOCK_SIZE {
520 block_size = block_size.saturating_mul(2);
521 }
522 let to_reserve = len.max(block_size as usize);
523 in_progress.reserve(to_reserve);
524 }
525
526 let buffer_index: u32 = i32::try_from(completed.len())
529 .expect("buffer count exceeds i32::MAX")
530 as u32;
531 let new_offset: u32 =
532 i32::try_from(in_progress.len()).expect("offset exceeds i32::MAX") as u32;
533
534 let src_buffer_index =
537 u32::from_le_bytes(bytes[8..12].try_into().unwrap()) as usize;
538 let src_offset =
539 u32::from_le_bytes(bytes[12..16].try_into().unwrap()) as usize;
540 let src =
541 &data_buffers[src_buffer_index].as_slice()[src_offset..src_offset + len];
542
543 let prefix_start = in_progress.len();
544 in_progress.extend(src.iter().map(&convert));
545
546 let prefix: [u8; 4] = in_progress[prefix_start..prefix_start + 4]
550 .try_into()
551 .unwrap();
552 bytes[4..8].copy_from_slice(&prefix);
553 bytes[8..12].copy_from_slice(&buffer_index.to_le_bytes());
554 bytes[12..16].copy_from_slice(&new_offset.to_le_bytes());
555 new_views.push(u128::from_le_bytes(bytes));
556 }
557 }
558
559 if !in_progress.is_empty() {
560 completed.push(Buffer::from_vec(in_progress));
561 }
562
563 unsafe {
569 StringViewArray::new_unchecked(
570 ScalarBuffer::from(new_views),
571 completed,
572 array.nulls().cloned(),
573 )
574 }
575}
576
577fn case_conversion_ascii_array<O: OffsetSizeTrait>(
582 string_array: &GenericStringArray<O>,
583 lower: bool,
584) -> Result<ArrayRef> {
585 let value_offsets = string_array.value_offsets();
586 let start = value_offsets.first().unwrap().as_usize();
587 let end = value_offsets.last().unwrap().as_usize();
588 let relevant = &string_array.value_data()[start..end];
589
590 let converted: Vec<u8> = if lower {
591 relevant.iter().map(u8::to_ascii_lowercase).collect()
592 } else {
593 relevant.iter().map(u8::to_ascii_uppercase).collect()
594 };
595 let values = Buffer::from_vec(converted);
596
597 let offsets = if start == 0 {
599 string_array.offsets().clone()
600 } else {
601 let s = O::usize_as(start);
602 let rebased: Vec<O> = value_offsets.iter().map(|&o| o - s).collect();
603 unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(rebased)) }
606 };
607
608 let nulls = string_array.nulls().cloned();
609 Ok(Arc::new(unsafe {
612 GenericStringArray::<O>::new_unchecked(offsets, values, nulls)
613 }))
614}