Skip to main content

datafusion_functions/string/
common.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Common utilities for implementing string functions
19
20use std::sync::Arc;
21
22use crate::strings::make_and_append_view;
23use arrow::array::{
24    Array, ArrayRef, GenericStringArray, GenericStringBuilder, NullBufferBuilder,
25    OffsetSizeTrait, StringBuilder, StringViewArray, new_null_array,
26};
27use arrow::buffer::{Buffer, ScalarBuffer};
28use arrow::datatypes::DataType;
29use datafusion_common::Result;
30use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
31use datafusion_common::{ScalarValue, exec_err};
32use datafusion_expr::ColumnarValue;
33
34/// Trait for trim operations, allowing compile-time dispatch instead of runtime matching.
35///
36/// Each implementation performs its specific trim operation and returns
37/// (trimmed_str, start_offset) where start_offset is the byte offset
38/// from the beginning of the input string where the trimmed result starts.
39pub(crate) trait Trimmer {
40    fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32);
41
42    /// Optimized trim for a single ASCII byte.
43    /// Uses byte-level scanning instead of char-level iteration.
44    fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32);
45}
46
47/// Returns the number of leading bytes matching `byte`
48#[inline]
49fn leading_bytes(bytes: &[u8], byte: u8) -> usize {
50    bytes.iter().take_while(|&&b| b == byte).count()
51}
52
53/// Returns the number of trailing bytes matching `byte`
54#[inline]
55fn trailing_bytes(bytes: &[u8], byte: u8) -> usize {
56    bytes.iter().rev().take_while(|&&b| b == byte).count()
57}
58
59/// Left trim - removes leading characters
60pub(crate) struct TrimLeft;
61
62impl Trimmer for TrimLeft {
63    #[inline]
64    fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) {
65        if pattern.len() == 1 && pattern[0].is_ascii() {
66            return Self::trim_ascii_char(input, pattern[0] as u8);
67        }
68        let trimmed = input.trim_start_matches(pattern);
69        let offset = (input.len() - trimmed.len()) as u32;
70        (trimmed, offset)
71    }
72
73    #[inline]
74    fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32) {
75        let start = leading_bytes(input.as_bytes(), byte);
76        (&input[start..], start as u32)
77    }
78}
79
80/// Right trim - removes trailing characters
81pub(crate) struct TrimRight;
82
83impl Trimmer for TrimRight {
84    #[inline]
85    fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) {
86        if pattern.len() == 1 && pattern[0].is_ascii() {
87            return Self::trim_ascii_char(input, pattern[0] as u8);
88        }
89        let trimmed = input.trim_end_matches(pattern);
90        (trimmed, 0)
91    }
92
93    #[inline]
94    fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32) {
95        let bytes = input.as_bytes();
96        let end = bytes.len() - trailing_bytes(bytes, byte);
97        (&input[..end], 0)
98    }
99}
100
101/// Both trim - removes both leading and trailing characters
102pub(crate) struct TrimBoth;
103
104impl Trimmer for TrimBoth {
105    #[inline]
106    fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) {
107        if pattern.len() == 1 && pattern[0].is_ascii() {
108            return Self::trim_ascii_char(input, pattern[0] as u8);
109        }
110        let left_trimmed = input.trim_start_matches(pattern);
111        let offset = (input.len() - left_trimmed.len()) as u32;
112        let trimmed = left_trimmed.trim_end_matches(pattern);
113        (trimmed, offset)
114    }
115
116    #[inline]
117    fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32) {
118        let bytes = input.as_bytes();
119        let start = leading_bytes(bytes, byte);
120        let end = bytes.len() - trailing_bytes(&bytes[start..], byte);
121        (&input[start..end], start as u32)
122    }
123}
124
125pub(crate) fn general_trim<T: OffsetSizeTrait, Tr: Trimmer>(
126    args: &[ArrayRef],
127    use_string_view: bool,
128) -> Result<ArrayRef> {
129    if use_string_view {
130        string_view_trim::<Tr>(args)
131    } else {
132        string_trim::<T, Tr>(args)
133    }
134}
135
136/// Applies the trim function to the given string view array(s)
137/// and returns a new string view array with the trimmed values.
138///
139/// Pre-computes the pattern characters once for scalar patterns to avoid
140/// repeated allocations per row.
141fn string_view_trim<Tr: Trimmer>(args: &[ArrayRef]) -> Result<ArrayRef> {
142    let string_view_array = as_string_view_array(&args[0])?;
143    let mut views_buf = Vec::with_capacity(string_view_array.len());
144    let mut null_builder = NullBufferBuilder::new(string_view_array.len());
145
146    match args.len() {
147        1 => {
148            // Trim spaces by default
149            for (src_str_opt, raw_view) in string_view_array
150                .iter()
151                .zip(string_view_array.views().iter())
152            {
153                if let Some(src_str) = src_str_opt {
154                    let (trimmed, offset) = Tr::trim_ascii_char(src_str, b' ');
155                    make_and_append_view(
156                        &mut views_buf,
157                        &mut null_builder,
158                        raw_view,
159                        trimmed,
160                        offset,
161                    );
162                } else {
163                    null_builder.append_null();
164                    views_buf.push(0);
165                }
166            }
167        }
168        2 => {
169            let characters_array = as_string_view_array(&args[1])?;
170
171            if characters_array.len() == 1 {
172                // Scalar pattern - pre-compute pattern chars once
173                if characters_array.is_null(0) {
174                    return Ok(new_null_array(
175                        &DataType::Utf8View,
176                        string_view_array.len(),
177                    ));
178                }
179
180                let pattern: Vec<char> = characters_array.value(0).chars().collect();
181                for (src_str_opt, raw_view) in string_view_array
182                    .iter()
183                    .zip(string_view_array.views().iter())
184                {
185                    trim_and_append_view::<Tr>(
186                        src_str_opt,
187                        &pattern,
188                        &mut views_buf,
189                        &mut null_builder,
190                        raw_view,
191                    );
192                }
193            } else {
194                // Per-row pattern - must compute pattern chars for each row
195                let mut pattern: Vec<char> = Vec::new();
196                for ((src_str_opt, raw_view), characters_opt) in string_view_array
197                    .iter()
198                    .zip(string_view_array.views().iter())
199                    .zip(characters_array.iter())
200                {
201                    if let (Some(src_str), Some(characters)) =
202                        (src_str_opt, characters_opt)
203                    {
204                        pattern.clear();
205                        pattern.extend(characters.chars());
206                        let (trimmed, offset) = Tr::trim(src_str, &pattern);
207                        make_and_append_view(
208                            &mut views_buf,
209                            &mut null_builder,
210                            raw_view,
211                            trimmed,
212                            offset,
213                        );
214                    } else {
215                        null_builder.append_null();
216                        views_buf.push(0);
217                    }
218                }
219            }
220        }
221        other => {
222            return exec_err!(
223                "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2."
224            );
225        }
226    }
227
228    let views_buf = ScalarBuffer::from(views_buf);
229    let nulls_buf = null_builder.finish();
230
231    // Safety:
232    // (1) The blocks of the given views are all provided
233    // (2) Each of the range `view.offset+start..end` of view in views_buf is within
234    // the bounds of each of the blocks
235    unsafe {
236        let array = StringViewArray::new_unchecked(
237            views_buf,
238            string_view_array.data_buffers().to_vec(),
239            nulls_buf,
240        );
241        Ok(Arc::new(array) as ArrayRef)
242    }
243}
244
245/// Trims the given string and appends the trimmed string to the views buffer
246/// and the null buffer.
247///
248/// Arguments
249/// - `src_str_opt`: The original string value (represented by the view)
250/// - `pattern`: Pre-computed character pattern to trim
251/// - `views_buf`: The buffer to append the updated views to
252/// - `null_builder`: The buffer to append the null values to
253/// - `original_view`: The original view value (that contains src_str_opt)
254#[inline]
255fn trim_and_append_view<Tr: Trimmer>(
256    src_str_opt: Option<&str>,
257    pattern: &[char],
258    views_buf: &mut Vec<u128>,
259    null_builder: &mut NullBufferBuilder,
260    original_view: &u128,
261) {
262    if let Some(src_str) = src_str_opt {
263        let (trimmed, offset) = Tr::trim(src_str, pattern);
264        make_and_append_view(views_buf, null_builder, original_view, trimmed, offset);
265    } else {
266        null_builder.append_null();
267        views_buf.push(0);
268    }
269}
270
271/// Applies the trim function to the given string array(s)
272/// and returns a new string array with the trimmed values.
273///
274/// Pre-computes the pattern characters once for scalar patterns to avoid
275/// repeated allocations per row.
276fn string_trim<T: OffsetSizeTrait, Tr: Trimmer>(args: &[ArrayRef]) -> Result<ArrayRef> {
277    let string_array = as_generic_string_array::<T>(&args[0])?;
278
279    match args.len() {
280        1 => {
281            // Trim spaces by default
282            let result = string_array
283                .iter()
284                .map(|string| string.map(|s| Tr::trim_ascii_char(s, b' ').0))
285                .collect::<GenericStringArray<T>>();
286
287            Ok(Arc::new(result) as ArrayRef)
288        }
289        2 => {
290            let characters_array = as_generic_string_array::<T>(&args[1])?;
291
292            if characters_array.len() == 1 {
293                // Scalar pattern - pre-compute pattern chars once
294                if characters_array.is_null(0) {
295                    return Ok(new_null_array(
296                        string_array.data_type(),
297                        string_array.len(),
298                    ));
299                }
300
301                let pattern: Vec<char> = characters_array.value(0).chars().collect();
302                let result = string_array
303                    .iter()
304                    .map(|item| item.map(|s| Tr::trim(s, &pattern).0))
305                    .collect::<GenericStringArray<T>>();
306                return Ok(Arc::new(result) as ArrayRef);
307            }
308
309            // Per-row pattern - must compute pattern chars for each row
310            let mut pattern: Vec<char> = Vec::new();
311            let result = string_array
312                .iter()
313                .zip(characters_array.iter())
314                .map(|(string, characters)| match (string, characters) {
315                    (Some(s), Some(c)) => {
316                        pattern.clear();
317                        pattern.extend(c.chars());
318                        Some(Tr::trim(s, &pattern).0)
319                    }
320                    _ => None,
321                })
322                .collect::<GenericStringArray<T>>();
323
324            Ok(Arc::new(result) as ArrayRef)
325        }
326        other => {
327            exec_err!(
328                "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2."
329            )
330        }
331    }
332}
333
334pub(crate) fn to_lower(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
335    case_conversion(args, |string| string.to_lowercase(), name)
336}
337
338pub(crate) fn to_upper(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
339    case_conversion(args, |string| string.to_uppercase(), name)
340}
341
342fn case_conversion<'a, F>(
343    args: &'a [ColumnarValue],
344    op: F,
345    name: &str,
346) -> Result<ColumnarValue>
347where
348    F: Fn(&'a str) -> String,
349{
350    match &args[0] {
351        ColumnarValue::Array(array) => match array.data_type() {
352            DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::<i32, _>(
353                array, op,
354            )?)),
355            DataType::LargeUtf8 => Ok(ColumnarValue::Array(case_conversion_array::<
356                i64,
357                _,
358            >(array, op)?)),
359            DataType::Utf8View => {
360                let string_array = as_string_view_array(array)?;
361                let mut string_builder = StringBuilder::with_capacity(
362                    string_array.len(),
363                    string_array.get_array_memory_size(),
364                );
365
366                for str in string_array.iter() {
367                    if let Some(str) = str {
368                        string_builder.append_value(op(str));
369                    } else {
370                        string_builder.append_null();
371                    }
372                }
373
374                Ok(ColumnarValue::Array(Arc::new(string_builder.finish())))
375            }
376            other => exec_err!("Unsupported data type {other:?} for function {name}"),
377        },
378        ColumnarValue::Scalar(scalar) => match scalar {
379            ScalarValue::Utf8(a) => {
380                let result = a.as_ref().map(|x| op(x));
381                Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
382            }
383            ScalarValue::LargeUtf8(a) => {
384                let result = a.as_ref().map(|x| op(x));
385                Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result)))
386            }
387            ScalarValue::Utf8View(a) => {
388                let result = a.as_ref().map(|x| op(x));
389                Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
390            }
391            other => exec_err!("Unsupported data type {other:?} for function {name}"),
392        },
393    }
394}
395
396fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result<ArrayRef>
397where
398    O: OffsetSizeTrait,
399    F: Fn(&'a str) -> String,
400{
401    const PRE_ALLOC_BYTES: usize = 8;
402
403    let string_array = as_generic_string_array::<O>(array)?;
404    let value_data = string_array.value_data();
405
406    // All values are ASCII.
407    if value_data.is_ascii() {
408        return case_conversion_ascii_array::<O, _>(string_array, op);
409    }
410
411    // Values contain non-ASCII.
412    let item_len = string_array.len();
413    let capacity = string_array.value_data().len() + PRE_ALLOC_BYTES;
414    let mut builder = GenericStringBuilder::<O>::with_capacity(item_len, capacity);
415
416    if string_array.null_count() == 0 {
417        let iter =
418            (0..item_len).map(|i| Some(op(unsafe { string_array.value_unchecked(i) })));
419        builder.extend(iter);
420    } else {
421        let iter = string_array.iter().map(|string| string.map(&op));
422        builder.extend(iter);
423    }
424    Ok(Arc::new(builder.finish()))
425}
426
427/// All values of string_array are ASCII, and when converting case, there is no changes in the byte
428/// array length. Therefore, the StringArray can be treated as a complete ASCII string for
429/// case conversion, and we can reuse the offsets buffer and the nulls buffer.
430fn case_conversion_ascii_array<'a, O, F>(
431    string_array: &'a GenericStringArray<O>,
432    op: F,
433) -> Result<ArrayRef>
434where
435    O: OffsetSizeTrait,
436    F: Fn(&'a str) -> String,
437{
438    let value_data = string_array.value_data();
439    // SAFETY: all items stored in value_data satisfy UTF8.
440    // ref: impl ByteArrayNativeType for str {...}
441    let str_values = unsafe { std::str::from_utf8_unchecked(value_data) };
442
443    // conversion
444    let converted_values = op(str_values);
445    assert_eq!(converted_values.len(), str_values.len());
446    let bytes = converted_values.into_bytes();
447
448    // build result
449    let values = Buffer::from_vec(bytes);
450    let offsets = string_array.offsets().clone();
451    let nulls = string_array.nulls().cloned();
452    // SAFETY: offsets and nulls are consistent with the input array.
453    Ok(Arc::new(unsafe {
454        GenericStringArray::<O>::new_unchecked(offsets, values, nulls)
455    }))
456}