datafusion_functions/string/
common.rs1use std::sync::Arc;
21
22use crate::strings::make_and_append_view;
23use arrow::array::{
24 Array, ArrayRef, GenericStringArray, GenericStringBuilder, NullBufferBuilder,
25 OffsetSizeTrait, StringBuilder, StringViewArray, new_null_array,
26};
27use arrow::buffer::{Buffer, ScalarBuffer};
28use arrow::datatypes::DataType;
29use datafusion_common::Result;
30use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
31use datafusion_common::{ScalarValue, exec_err};
32use datafusion_expr::ColumnarValue;
33
34pub(crate) trait Trimmer {
40 fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32);
41}
42
43pub(crate) struct TrimLeft;
45
46impl Trimmer for TrimLeft {
47 #[inline]
48 fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) {
49 let trimmed = input.trim_start_matches(pattern);
50 let offset = (input.len() - trimmed.len()) as u32;
51 (trimmed, offset)
52 }
53}
54
55pub(crate) struct TrimRight;
57
58impl Trimmer for TrimRight {
59 #[inline]
60 fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) {
61 let trimmed = input.trim_end_matches(pattern);
62 (trimmed, 0)
63 }
64}
65
66pub(crate) struct TrimBoth;
68
69impl Trimmer for TrimBoth {
70 #[inline]
71 fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) {
72 let left_trimmed = input.trim_start_matches(pattern);
73 let offset = (input.len() - left_trimmed.len()) as u32;
74 let trimmed = left_trimmed.trim_end_matches(pattern);
75 (trimmed, offset)
76 }
77}
78
79pub(crate) fn general_trim<T: OffsetSizeTrait, Tr: Trimmer>(
80 args: &[ArrayRef],
81 use_string_view: bool,
82) -> Result<ArrayRef> {
83 if use_string_view {
84 string_view_trim::<Tr>(args)
85 } else {
86 string_trim::<T, Tr>(args)
87 }
88}
89
90fn string_view_trim<Tr: Trimmer>(args: &[ArrayRef]) -> Result<ArrayRef> {
96 let string_view_array = as_string_view_array(&args[0])?;
97 let mut views_buf = Vec::with_capacity(string_view_array.len());
98 let mut null_builder = NullBufferBuilder::new(string_view_array.len());
99
100 match args.len() {
101 1 => {
102 let pattern = [' '];
104 for (src_str_opt, raw_view) in string_view_array
105 .iter()
106 .zip(string_view_array.views().iter())
107 {
108 trim_and_append_view::<Tr>(
109 src_str_opt,
110 &pattern,
111 &mut views_buf,
112 &mut null_builder,
113 raw_view,
114 );
115 }
116 }
117 2 => {
118 let characters_array = as_string_view_array(&args[1])?;
119
120 if characters_array.len() == 1 {
121 if characters_array.is_null(0) {
123 return Ok(new_null_array(
124 &DataType::Utf8View,
125 string_view_array.len(),
126 ));
127 }
128
129 let pattern: Vec<char> = characters_array.value(0).chars().collect();
130 for (src_str_opt, raw_view) in string_view_array
131 .iter()
132 .zip(string_view_array.views().iter())
133 {
134 trim_and_append_view::<Tr>(
135 src_str_opt,
136 &pattern,
137 &mut views_buf,
138 &mut null_builder,
139 raw_view,
140 );
141 }
142 } else {
143 for ((src_str_opt, raw_view), characters_opt) in string_view_array
145 .iter()
146 .zip(string_view_array.views().iter())
147 .zip(characters_array.iter())
148 {
149 if let (Some(src_str), Some(characters)) =
150 (src_str_opt, characters_opt)
151 {
152 let pattern: Vec<char> = characters.chars().collect();
153 let (trimmed, offset) = Tr::trim(src_str, &pattern);
154 make_and_append_view(
155 &mut views_buf,
156 &mut null_builder,
157 raw_view,
158 trimmed,
159 offset,
160 );
161 } else {
162 null_builder.append_null();
163 views_buf.push(0);
164 }
165 }
166 }
167 }
168 other => {
169 return exec_err!(
170 "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2."
171 );
172 }
173 }
174
175 let views_buf = ScalarBuffer::from(views_buf);
176 let nulls_buf = null_builder.finish();
177
178 unsafe {
183 let array = StringViewArray::new_unchecked(
184 views_buf,
185 string_view_array.data_buffers().to_vec(),
186 nulls_buf,
187 );
188 Ok(Arc::new(array) as ArrayRef)
189 }
190}
191
192#[inline]
202fn trim_and_append_view<Tr: Trimmer>(
203 src_str_opt: Option<&str>,
204 pattern: &[char],
205 views_buf: &mut Vec<u128>,
206 null_builder: &mut NullBufferBuilder,
207 original_view: &u128,
208) {
209 if let Some(src_str) = src_str_opt {
210 let (trimmed, offset) = Tr::trim(src_str, pattern);
211 make_and_append_view(views_buf, null_builder, original_view, trimmed, offset);
212 } else {
213 null_builder.append_null();
214 views_buf.push(0);
215 }
216}
217
218fn string_trim<T: OffsetSizeTrait, Tr: Trimmer>(args: &[ArrayRef]) -> Result<ArrayRef> {
224 let string_array = as_generic_string_array::<T>(&args[0])?;
225
226 match args.len() {
227 1 => {
228 let pattern = [' '];
230 let result = string_array
231 .iter()
232 .map(|string| string.map(|s| Tr::trim(s, &pattern).0))
233 .collect::<GenericStringArray<T>>();
234
235 Ok(Arc::new(result) as ArrayRef)
236 }
237 2 => {
238 let characters_array = as_generic_string_array::<T>(&args[1])?;
239
240 if characters_array.len() == 1 {
241 if characters_array.is_null(0) {
243 return Ok(new_null_array(
244 string_array.data_type(),
245 string_array.len(),
246 ));
247 }
248
249 let pattern: Vec<char> = characters_array.value(0).chars().collect();
250 let result = string_array
251 .iter()
252 .map(|item| item.map(|s| Tr::trim(s, &pattern).0))
253 .collect::<GenericStringArray<T>>();
254 return Ok(Arc::new(result) as ArrayRef);
255 }
256
257 let result = string_array
259 .iter()
260 .zip(characters_array.iter())
261 .map(|(string, characters)| match (string, characters) {
262 (Some(s), Some(c)) => {
263 let pattern: Vec<char> = c.chars().collect();
264 Some(Tr::trim(s, &pattern).0)
265 }
266 _ => None,
267 })
268 .collect::<GenericStringArray<T>>();
269
270 Ok(Arc::new(result) as ArrayRef)
271 }
272 other => {
273 exec_err!(
274 "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2."
275 )
276 }
277 }
278}
279
280pub(crate) fn to_lower(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
281 case_conversion(args, |string| string.to_lowercase(), name)
282}
283
284pub(crate) fn to_upper(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
285 case_conversion(args, |string| string.to_uppercase(), name)
286}
287
288fn case_conversion<'a, F>(
289 args: &'a [ColumnarValue],
290 op: F,
291 name: &str,
292) -> Result<ColumnarValue>
293where
294 F: Fn(&'a str) -> String,
295{
296 match &args[0] {
297 ColumnarValue::Array(array) => match array.data_type() {
298 DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::<i32, _>(
299 array, op,
300 )?)),
301 DataType::LargeUtf8 => Ok(ColumnarValue::Array(case_conversion_array::<
302 i64,
303 _,
304 >(array, op)?)),
305 DataType::Utf8View => {
306 let string_array = as_string_view_array(array)?;
307 let mut string_builder = StringBuilder::with_capacity(
308 string_array.len(),
309 string_array.get_array_memory_size(),
310 );
311
312 for str in string_array.iter() {
313 if let Some(str) = str {
314 string_builder.append_value(op(str));
315 } else {
316 string_builder.append_null();
317 }
318 }
319
320 Ok(ColumnarValue::Array(Arc::new(string_builder.finish())))
321 }
322 other => exec_err!("Unsupported data type {other:?} for function {name}"),
323 },
324 ColumnarValue::Scalar(scalar) => match scalar {
325 ScalarValue::Utf8(a) => {
326 let result = a.as_ref().map(|x| op(x));
327 Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
328 }
329 ScalarValue::LargeUtf8(a) => {
330 let result = a.as_ref().map(|x| op(x));
331 Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result)))
332 }
333 ScalarValue::Utf8View(a) => {
334 let result = a.as_ref().map(|x| op(x));
335 Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
336 }
337 other => exec_err!("Unsupported data type {other:?} for function {name}"),
338 },
339 }
340}
341
342fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result<ArrayRef>
343where
344 O: OffsetSizeTrait,
345 F: Fn(&'a str) -> String,
346{
347 const PRE_ALLOC_BYTES: usize = 8;
348
349 let string_array = as_generic_string_array::<O>(array)?;
350 let value_data = string_array.value_data();
351
352 if value_data.is_ascii() {
354 return case_conversion_ascii_array::<O, _>(string_array, op);
355 }
356
357 let item_len = string_array.len();
359 let capacity = string_array.value_data().len() + PRE_ALLOC_BYTES;
360 let mut builder = GenericStringBuilder::<O>::with_capacity(item_len, capacity);
361
362 if string_array.null_count() == 0 {
363 let iter =
364 (0..item_len).map(|i| Some(op(unsafe { string_array.value_unchecked(i) })));
365 builder.extend(iter);
366 } else {
367 let iter = string_array.iter().map(|string| string.map(&op));
368 builder.extend(iter);
369 }
370 Ok(Arc::new(builder.finish()))
371}
372
373fn case_conversion_ascii_array<'a, O, F>(
377 string_array: &'a GenericStringArray<O>,
378 op: F,
379) -> Result<ArrayRef>
380where
381 O: OffsetSizeTrait,
382 F: Fn(&'a str) -> String,
383{
384 let value_data = string_array.value_data();
385 let str_values = unsafe { std::str::from_utf8_unchecked(value_data) };
388
389 let converted_values = op(str_values);
391 assert_eq!(converted_values.len(), str_values.len());
392 let bytes = converted_values.into_bytes();
393
394 let values = Buffer::from_vec(bytes);
396 let offsets = string_array.offsets().clone();
397 let nulls = string_array.nulls().cloned();
398 Ok(Arc::new(unsafe {
400 GenericStringArray::<O>::new_unchecked(offsets, values, nulls)
401 }))
402}