datafusion_functions/string/
common.rs1use std::sync::Arc;
21
22use crate::strings::make_and_append_view;
23use arrow::array::{
24 Array, ArrayRef, GenericStringArray, GenericStringBuilder, NullBufferBuilder,
25 OffsetSizeTrait, StringBuilder, StringViewArray, new_null_array,
26};
27use arrow::buffer::{Buffer, ScalarBuffer};
28use arrow::datatypes::DataType;
29use datafusion_common::Result;
30use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
31use datafusion_common::{ScalarValue, exec_err};
32use datafusion_expr::ColumnarValue;
33
34pub(crate) trait Trimmer {
40 fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32);
41
42 fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32);
45}
46
47#[inline]
49fn leading_bytes(bytes: &[u8], byte: u8) -> usize {
50 bytes.iter().take_while(|&&b| b == byte).count()
51}
52
53#[inline]
55fn trailing_bytes(bytes: &[u8], byte: u8) -> usize {
56 bytes.iter().rev().take_while(|&&b| b == byte).count()
57}
58
59pub(crate) struct TrimLeft;
61
62impl Trimmer for TrimLeft {
63 #[inline]
64 fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) {
65 if pattern.len() == 1 && pattern[0].is_ascii() {
66 return Self::trim_ascii_char(input, pattern[0] as u8);
67 }
68 let trimmed = input.trim_start_matches(pattern);
69 let offset = (input.len() - trimmed.len()) as u32;
70 (trimmed, offset)
71 }
72
73 #[inline]
74 fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32) {
75 let start = leading_bytes(input.as_bytes(), byte);
76 (&input[start..], start as u32)
77 }
78}
79
80pub(crate) struct TrimRight;
82
83impl Trimmer for TrimRight {
84 #[inline]
85 fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) {
86 if pattern.len() == 1 && pattern[0].is_ascii() {
87 return Self::trim_ascii_char(input, pattern[0] as u8);
88 }
89 let trimmed = input.trim_end_matches(pattern);
90 (trimmed, 0)
91 }
92
93 #[inline]
94 fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32) {
95 let bytes = input.as_bytes();
96 let end = bytes.len() - trailing_bytes(bytes, byte);
97 (&input[..end], 0)
98 }
99}
100
101pub(crate) struct TrimBoth;
103
104impl Trimmer for TrimBoth {
105 #[inline]
106 fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) {
107 if pattern.len() == 1 && pattern[0].is_ascii() {
108 return Self::trim_ascii_char(input, pattern[0] as u8);
109 }
110 let left_trimmed = input.trim_start_matches(pattern);
111 let offset = (input.len() - left_trimmed.len()) as u32;
112 let trimmed = left_trimmed.trim_end_matches(pattern);
113 (trimmed, offset)
114 }
115
116 #[inline]
117 fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32) {
118 let bytes = input.as_bytes();
119 let start = leading_bytes(bytes, byte);
120 let end = bytes.len() - trailing_bytes(&bytes[start..], byte);
121 (&input[start..end], start as u32)
122 }
123}
124
125pub(crate) fn general_trim<T: OffsetSizeTrait, Tr: Trimmer>(
126 args: &[ArrayRef],
127 use_string_view: bool,
128) -> Result<ArrayRef> {
129 if use_string_view {
130 string_view_trim::<Tr>(args)
131 } else {
132 string_trim::<T, Tr>(args)
133 }
134}
135
136fn string_view_trim<Tr: Trimmer>(args: &[ArrayRef]) -> Result<ArrayRef> {
142 let string_view_array = as_string_view_array(&args[0])?;
143 let mut views_buf = Vec::with_capacity(string_view_array.len());
144 let mut null_builder = NullBufferBuilder::new(string_view_array.len());
145
146 match args.len() {
147 1 => {
148 for (src_str_opt, raw_view) in string_view_array
150 .iter()
151 .zip(string_view_array.views().iter())
152 {
153 if let Some(src_str) = src_str_opt {
154 let (trimmed, offset) = Tr::trim_ascii_char(src_str, b' ');
155 make_and_append_view(
156 &mut views_buf,
157 &mut null_builder,
158 raw_view,
159 trimmed,
160 offset,
161 );
162 } else {
163 null_builder.append_null();
164 views_buf.push(0);
165 }
166 }
167 }
168 2 => {
169 let characters_array = as_string_view_array(&args[1])?;
170
171 if characters_array.len() == 1 {
172 if characters_array.is_null(0) {
174 return Ok(new_null_array(
175 &DataType::Utf8View,
176 string_view_array.len(),
177 ));
178 }
179
180 let pattern: Vec<char> = characters_array.value(0).chars().collect();
181 for (src_str_opt, raw_view) in string_view_array
182 .iter()
183 .zip(string_view_array.views().iter())
184 {
185 trim_and_append_view::<Tr>(
186 src_str_opt,
187 &pattern,
188 &mut views_buf,
189 &mut null_builder,
190 raw_view,
191 );
192 }
193 } else {
194 let mut pattern: Vec<char> = Vec::new();
196 for ((src_str_opt, raw_view), characters_opt) in string_view_array
197 .iter()
198 .zip(string_view_array.views().iter())
199 .zip(characters_array.iter())
200 {
201 if let (Some(src_str), Some(characters)) =
202 (src_str_opt, characters_opt)
203 {
204 pattern.clear();
205 pattern.extend(characters.chars());
206 let (trimmed, offset) = Tr::trim(src_str, &pattern);
207 make_and_append_view(
208 &mut views_buf,
209 &mut null_builder,
210 raw_view,
211 trimmed,
212 offset,
213 );
214 } else {
215 null_builder.append_null();
216 views_buf.push(0);
217 }
218 }
219 }
220 }
221 other => {
222 return exec_err!(
223 "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2."
224 );
225 }
226 }
227
228 let views_buf = ScalarBuffer::from(views_buf);
229 let nulls_buf = null_builder.finish();
230
231 unsafe {
236 let array = StringViewArray::new_unchecked(
237 views_buf,
238 string_view_array.data_buffers().to_vec(),
239 nulls_buf,
240 );
241 Ok(Arc::new(array) as ArrayRef)
242 }
243}
244
245#[inline]
255fn trim_and_append_view<Tr: Trimmer>(
256 src_str_opt: Option<&str>,
257 pattern: &[char],
258 views_buf: &mut Vec<u128>,
259 null_builder: &mut NullBufferBuilder,
260 original_view: &u128,
261) {
262 if let Some(src_str) = src_str_opt {
263 let (trimmed, offset) = Tr::trim(src_str, pattern);
264 make_and_append_view(views_buf, null_builder, original_view, trimmed, offset);
265 } else {
266 null_builder.append_null();
267 views_buf.push(0);
268 }
269}
270
271fn string_trim<T: OffsetSizeTrait, Tr: Trimmer>(args: &[ArrayRef]) -> Result<ArrayRef> {
277 let string_array = as_generic_string_array::<T>(&args[0])?;
278
279 match args.len() {
280 1 => {
281 let result = string_array
283 .iter()
284 .map(|string| string.map(|s| Tr::trim_ascii_char(s, b' ').0))
285 .collect::<GenericStringArray<T>>();
286
287 Ok(Arc::new(result) as ArrayRef)
288 }
289 2 => {
290 let characters_array = as_generic_string_array::<T>(&args[1])?;
291
292 if characters_array.len() == 1 {
293 if characters_array.is_null(0) {
295 return Ok(new_null_array(
296 string_array.data_type(),
297 string_array.len(),
298 ));
299 }
300
301 let pattern: Vec<char> = characters_array.value(0).chars().collect();
302 let result = string_array
303 .iter()
304 .map(|item| item.map(|s| Tr::trim(s, &pattern).0))
305 .collect::<GenericStringArray<T>>();
306 return Ok(Arc::new(result) as ArrayRef);
307 }
308
309 let mut pattern: Vec<char> = Vec::new();
311 let result = string_array
312 .iter()
313 .zip(characters_array.iter())
314 .map(|(string, characters)| match (string, characters) {
315 (Some(s), Some(c)) => {
316 pattern.clear();
317 pattern.extend(c.chars());
318 Some(Tr::trim(s, &pattern).0)
319 }
320 _ => None,
321 })
322 .collect::<GenericStringArray<T>>();
323
324 Ok(Arc::new(result) as ArrayRef)
325 }
326 other => {
327 exec_err!(
328 "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2."
329 )
330 }
331 }
332}
333
334pub(crate) fn to_lower(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
335 case_conversion(args, |string| string.to_lowercase(), name)
336}
337
338pub(crate) fn to_upper(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
339 case_conversion(args, |string| string.to_uppercase(), name)
340}
341
342fn case_conversion<'a, F>(
343 args: &'a [ColumnarValue],
344 op: F,
345 name: &str,
346) -> Result<ColumnarValue>
347where
348 F: Fn(&'a str) -> String,
349{
350 match &args[0] {
351 ColumnarValue::Array(array) => match array.data_type() {
352 DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::<i32, _>(
353 array, op,
354 )?)),
355 DataType::LargeUtf8 => Ok(ColumnarValue::Array(case_conversion_array::<
356 i64,
357 _,
358 >(array, op)?)),
359 DataType::Utf8View => {
360 let string_array = as_string_view_array(array)?;
361 let mut string_builder = StringBuilder::with_capacity(
362 string_array.len(),
363 string_array.get_array_memory_size(),
364 );
365
366 for str in string_array.iter() {
367 if let Some(str) = str {
368 string_builder.append_value(op(str));
369 } else {
370 string_builder.append_null();
371 }
372 }
373
374 Ok(ColumnarValue::Array(Arc::new(string_builder.finish())))
375 }
376 other => exec_err!("Unsupported data type {other:?} for function {name}"),
377 },
378 ColumnarValue::Scalar(scalar) => match scalar {
379 ScalarValue::Utf8(a) => {
380 let result = a.as_ref().map(|x| op(x));
381 Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
382 }
383 ScalarValue::LargeUtf8(a) => {
384 let result = a.as_ref().map(|x| op(x));
385 Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result)))
386 }
387 ScalarValue::Utf8View(a) => {
388 let result = a.as_ref().map(|x| op(x));
389 Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
390 }
391 other => exec_err!("Unsupported data type {other:?} for function {name}"),
392 },
393 }
394}
395
396fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result<ArrayRef>
397where
398 O: OffsetSizeTrait,
399 F: Fn(&'a str) -> String,
400{
401 const PRE_ALLOC_BYTES: usize = 8;
402
403 let string_array = as_generic_string_array::<O>(array)?;
404 let value_data = string_array.value_data();
405
406 if value_data.is_ascii() {
408 return case_conversion_ascii_array::<O, _>(string_array, op);
409 }
410
411 let item_len = string_array.len();
413 let capacity = string_array.value_data().len() + PRE_ALLOC_BYTES;
414 let mut builder = GenericStringBuilder::<O>::with_capacity(item_len, capacity);
415
416 if string_array.null_count() == 0 {
417 let iter =
418 (0..item_len).map(|i| Some(op(unsafe { string_array.value_unchecked(i) })));
419 builder.extend(iter);
420 } else {
421 let iter = string_array.iter().map(|string| string.map(&op));
422 builder.extend(iter);
423 }
424 Ok(Arc::new(builder.finish()))
425}
426
427fn case_conversion_ascii_array<'a, O, F>(
431 string_array: &'a GenericStringArray<O>,
432 op: F,
433) -> Result<ArrayRef>
434where
435 O: OffsetSizeTrait,
436 F: Fn(&'a str) -> String,
437{
438 let value_data = string_array.value_data();
439 let str_values = unsafe { std::str::from_utf8_unchecked(value_data) };
442
443 let converted_values = op(str_values);
445 assert_eq!(converted_values.len(), str_values.len());
446 let bytes = converted_values.into_bytes();
447
448 let values = Buffer::from_vec(bytes);
450 let offsets = string_array.offsets().clone();
451 let nulls = string_array.nulls().cloned();
452 Ok(Arc::new(unsafe {
454 GenericStringArray::<O>::new_unchecked(offsets, values, nulls)
455 }))
456}