datafusion_functions/string/
common.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Common utilities for implementing string functions
19
20use std::sync::Arc;
21
22use crate::strings::make_and_append_view;
23use arrow::array::{
24    Array, ArrayRef, GenericStringArray, GenericStringBuilder, NullBufferBuilder,
25    OffsetSizeTrait, StringBuilder, StringViewArray, new_null_array,
26};
27use arrow::buffer::{Buffer, ScalarBuffer};
28use arrow::datatypes::DataType;
29use datafusion_common::Result;
30use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
31use datafusion_common::{ScalarValue, exec_err};
32use datafusion_expr::ColumnarValue;
33
34/// Trait for trim operations, allowing compile-time dispatch instead of runtime matching.
35///
36/// Each implementation performs its specific trim operation and returns
37/// (trimmed_str, start_offset) where start_offset is the byte offset
38/// from the beginning of the input string where the trimmed result starts.
39pub(crate) trait Trimmer {
40    fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32);
41}
42
43/// Left trim - removes leading characters
44pub(crate) struct TrimLeft;
45
46impl Trimmer for TrimLeft {
47    #[inline]
48    fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) {
49        let trimmed = input.trim_start_matches(pattern);
50        let offset = (input.len() - trimmed.len()) as u32;
51        (trimmed, offset)
52    }
53}
54
55/// Right trim - removes trailing characters
56pub(crate) struct TrimRight;
57
58impl Trimmer for TrimRight {
59    #[inline]
60    fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) {
61        let trimmed = input.trim_end_matches(pattern);
62        (trimmed, 0)
63    }
64}
65
66/// Both trim - removes both leading and trailing characters
67pub(crate) struct TrimBoth;
68
69impl Trimmer for TrimBoth {
70    #[inline]
71    fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) {
72        let left_trimmed = input.trim_start_matches(pattern);
73        let offset = (input.len() - left_trimmed.len()) as u32;
74        let trimmed = left_trimmed.trim_end_matches(pattern);
75        (trimmed, offset)
76    }
77}
78
79pub(crate) fn general_trim<T: OffsetSizeTrait, Tr: Trimmer>(
80    args: &[ArrayRef],
81    use_string_view: bool,
82) -> Result<ArrayRef> {
83    if use_string_view {
84        string_view_trim::<Tr>(args)
85    } else {
86        string_trim::<T, Tr>(args)
87    }
88}
89
90/// Applies the trim function to the given string view array(s)
91/// and returns a new string view array with the trimmed values.
92///
93/// Pre-computes the pattern characters once for scalar patterns to avoid
94/// repeated allocations per row.
95fn string_view_trim<Tr: Trimmer>(args: &[ArrayRef]) -> Result<ArrayRef> {
96    let string_view_array = as_string_view_array(&args[0])?;
97    let mut views_buf = Vec::with_capacity(string_view_array.len());
98    let mut null_builder = NullBufferBuilder::new(string_view_array.len());
99
100    match args.len() {
101        1 => {
102            // Default whitespace trim - pattern is just space
103            let pattern = [' '];
104            for (src_str_opt, raw_view) in string_view_array
105                .iter()
106                .zip(string_view_array.views().iter())
107            {
108                trim_and_append_view::<Tr>(
109                    src_str_opt,
110                    &pattern,
111                    &mut views_buf,
112                    &mut null_builder,
113                    raw_view,
114                );
115            }
116        }
117        2 => {
118            let characters_array = as_string_view_array(&args[1])?;
119
120            if characters_array.len() == 1 {
121                // Scalar pattern - pre-compute pattern chars once
122                if characters_array.is_null(0) {
123                    return Ok(new_null_array(
124                        &DataType::Utf8View,
125                        string_view_array.len(),
126                    ));
127                }
128
129                let pattern: Vec<char> = characters_array.value(0).chars().collect();
130                for (src_str_opt, raw_view) in string_view_array
131                    .iter()
132                    .zip(string_view_array.views().iter())
133                {
134                    trim_and_append_view::<Tr>(
135                        src_str_opt,
136                        &pattern,
137                        &mut views_buf,
138                        &mut null_builder,
139                        raw_view,
140                    );
141                }
142            } else {
143                // Per-row pattern - must compute pattern chars for each row
144                for ((src_str_opt, raw_view), characters_opt) in string_view_array
145                    .iter()
146                    .zip(string_view_array.views().iter())
147                    .zip(characters_array.iter())
148                {
149                    if let (Some(src_str), Some(characters)) =
150                        (src_str_opt, characters_opt)
151                    {
152                        let pattern: Vec<char> = characters.chars().collect();
153                        let (trimmed, offset) = Tr::trim(src_str, &pattern);
154                        make_and_append_view(
155                            &mut views_buf,
156                            &mut null_builder,
157                            raw_view,
158                            trimmed,
159                            offset,
160                        );
161                    } else {
162                        null_builder.append_null();
163                        views_buf.push(0);
164                    }
165                }
166            }
167        }
168        other => {
169            return exec_err!(
170                "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2."
171            );
172        }
173    }
174
175    let views_buf = ScalarBuffer::from(views_buf);
176    let nulls_buf = null_builder.finish();
177
178    // Safety:
179    // (1) The blocks of the given views are all provided
180    // (2) Each of the range `view.offset+start..end` of view in views_buf is within
181    // the bounds of each of the blocks
182    unsafe {
183        let array = StringViewArray::new_unchecked(
184            views_buf,
185            string_view_array.data_buffers().to_vec(),
186            nulls_buf,
187        );
188        Ok(Arc::new(array) as ArrayRef)
189    }
190}
191
192/// Trims the given string and appends the trimmed string to the views buffer
193/// and the null buffer.
194///
195/// Arguments
196/// - `src_str_opt`: The original string value (represented by the view)
197/// - `pattern`: Pre-computed character pattern to trim
198/// - `views_buf`: The buffer to append the updated views to
199/// - `null_builder`: The buffer to append the null values to
200/// - `original_view`: The original view value (that contains src_str_opt)
201#[inline]
202fn trim_and_append_view<Tr: Trimmer>(
203    src_str_opt: Option<&str>,
204    pattern: &[char],
205    views_buf: &mut Vec<u128>,
206    null_builder: &mut NullBufferBuilder,
207    original_view: &u128,
208) {
209    if let Some(src_str) = src_str_opt {
210        let (trimmed, offset) = Tr::trim(src_str, pattern);
211        make_and_append_view(views_buf, null_builder, original_view, trimmed, offset);
212    } else {
213        null_builder.append_null();
214        views_buf.push(0);
215    }
216}
217
218/// Applies the trim function to the given string array(s)
219/// and returns a new string array with the trimmed values.
220///
221/// Pre-computes the pattern characters once for scalar patterns to avoid
222/// repeated allocations per row.
223fn string_trim<T: OffsetSizeTrait, Tr: Trimmer>(args: &[ArrayRef]) -> Result<ArrayRef> {
224    let string_array = as_generic_string_array::<T>(&args[0])?;
225
226    match args.len() {
227        1 => {
228            // Default whitespace trim - pattern is just space
229            let pattern = [' '];
230            let result = string_array
231                .iter()
232                .map(|string| string.map(|s| Tr::trim(s, &pattern).0))
233                .collect::<GenericStringArray<T>>();
234
235            Ok(Arc::new(result) as ArrayRef)
236        }
237        2 => {
238            let characters_array = as_generic_string_array::<T>(&args[1])?;
239
240            if characters_array.len() == 1 {
241                // Scalar pattern - pre-compute pattern chars once
242                if characters_array.is_null(0) {
243                    return Ok(new_null_array(
244                        string_array.data_type(),
245                        string_array.len(),
246                    ));
247                }
248
249                let pattern: Vec<char> = characters_array.value(0).chars().collect();
250                let result = string_array
251                    .iter()
252                    .map(|item| item.map(|s| Tr::trim(s, &pattern).0))
253                    .collect::<GenericStringArray<T>>();
254                return Ok(Arc::new(result) as ArrayRef);
255            }
256
257            // Per-row pattern - must compute pattern chars for each row
258            let result = string_array
259                .iter()
260                .zip(characters_array.iter())
261                .map(|(string, characters)| match (string, characters) {
262                    (Some(s), Some(c)) => {
263                        let pattern: Vec<char> = c.chars().collect();
264                        Some(Tr::trim(s, &pattern).0)
265                    }
266                    _ => None,
267                })
268                .collect::<GenericStringArray<T>>();
269
270            Ok(Arc::new(result) as ArrayRef)
271        }
272        other => {
273            exec_err!(
274                "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2."
275            )
276        }
277    }
278}
279
280pub(crate) fn to_lower(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
281    case_conversion(args, |string| string.to_lowercase(), name)
282}
283
284pub(crate) fn to_upper(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
285    case_conversion(args, |string| string.to_uppercase(), name)
286}
287
288fn case_conversion<'a, F>(
289    args: &'a [ColumnarValue],
290    op: F,
291    name: &str,
292) -> Result<ColumnarValue>
293where
294    F: Fn(&'a str) -> String,
295{
296    match &args[0] {
297        ColumnarValue::Array(array) => match array.data_type() {
298            DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::<i32, _>(
299                array, op,
300            )?)),
301            DataType::LargeUtf8 => Ok(ColumnarValue::Array(case_conversion_array::<
302                i64,
303                _,
304            >(array, op)?)),
305            DataType::Utf8View => {
306                let string_array = as_string_view_array(array)?;
307                let mut string_builder = StringBuilder::with_capacity(
308                    string_array.len(),
309                    string_array.get_array_memory_size(),
310                );
311
312                for str in string_array.iter() {
313                    if let Some(str) = str {
314                        string_builder.append_value(op(str));
315                    } else {
316                        string_builder.append_null();
317                    }
318                }
319
320                Ok(ColumnarValue::Array(Arc::new(string_builder.finish())))
321            }
322            other => exec_err!("Unsupported data type {other:?} for function {name}"),
323        },
324        ColumnarValue::Scalar(scalar) => match scalar {
325            ScalarValue::Utf8(a) => {
326                let result = a.as_ref().map(|x| op(x));
327                Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
328            }
329            ScalarValue::LargeUtf8(a) => {
330                let result = a.as_ref().map(|x| op(x));
331                Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result)))
332            }
333            ScalarValue::Utf8View(a) => {
334                let result = a.as_ref().map(|x| op(x));
335                Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
336            }
337            other => exec_err!("Unsupported data type {other:?} for function {name}"),
338        },
339    }
340}
341
342fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result<ArrayRef>
343where
344    O: OffsetSizeTrait,
345    F: Fn(&'a str) -> String,
346{
347    const PRE_ALLOC_BYTES: usize = 8;
348
349    let string_array = as_generic_string_array::<O>(array)?;
350    let value_data = string_array.value_data();
351
352    // All values are ASCII.
353    if value_data.is_ascii() {
354        return case_conversion_ascii_array::<O, _>(string_array, op);
355    }
356
357    // Values contain non-ASCII.
358    let item_len = string_array.len();
359    let capacity = string_array.value_data().len() + PRE_ALLOC_BYTES;
360    let mut builder = GenericStringBuilder::<O>::with_capacity(item_len, capacity);
361
362    if string_array.null_count() == 0 {
363        let iter =
364            (0..item_len).map(|i| Some(op(unsafe { string_array.value_unchecked(i) })));
365        builder.extend(iter);
366    } else {
367        let iter = string_array.iter().map(|string| string.map(&op));
368        builder.extend(iter);
369    }
370    Ok(Arc::new(builder.finish()))
371}
372
373/// All values of string_array are ASCII, and when converting case, there is no changes in the byte
374/// array length. Therefore, the StringArray can be treated as a complete ASCII string for
375/// case conversion, and we can reuse the offsets buffer and the nulls buffer.
376fn case_conversion_ascii_array<'a, O, F>(
377    string_array: &'a GenericStringArray<O>,
378    op: F,
379) -> Result<ArrayRef>
380where
381    O: OffsetSizeTrait,
382    F: Fn(&'a str) -> String,
383{
384    let value_data = string_array.value_data();
385    // SAFETY: all items stored in value_data satisfy UTF8.
386    // ref: impl ByteArrayNativeType for str {...}
387    let str_values = unsafe { std::str::from_utf8_unchecked(value_data) };
388
389    // conversion
390    let converted_values = op(str_values);
391    assert_eq!(converted_values.len(), str_values.len());
392    let bytes = converted_values.into_bytes();
393
394    // build result
395    let values = Buffer::from_vec(bytes);
396    let offsets = string_array.offsets().clone();
397    let nulls = string_array.nulls().cloned();
398    // SAFETY: offsets and nulls are consistent with the input array.
399    Ok(Arc::new(unsafe {
400        GenericStringArray::<O>::new_unchecked(offsets, values, nulls)
401    }))
402}