Skip to main content

datafusion_functions/string/
common.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Common utilities for implementing string functions
19
20use std::sync::Arc;
21
22use crate::strings::{
23    GenericStringArrayBuilder, STRING_VIEW_INIT_BLOCK_SIZE, STRING_VIEW_MAX_BLOCK_SIZE,
24    StringViewArrayBuilder, append_view,
25};
26use arrow::array::{
27    Array, ArrayRef, GenericStringArray, NullBufferBuilder, OffsetSizeTrait,
28    StringViewArray, new_null_array,
29};
30use arrow::buffer::{Buffer, OffsetBuffer, ScalarBuffer};
31use arrow::datatypes::DataType;
32use datafusion_common::Result;
33use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
34use datafusion_common::{ScalarValue, exec_err};
35use datafusion_expr::ColumnarValue;
36
37/// Trait for trim operations, allowing compile-time dispatch instead of runtime matching.
38///
39/// Each implementation performs its specific trim operation and returns
40/// (trimmed_str, start_offset) where start_offset is the byte offset
41/// from the beginning of the input string where the trimmed result starts.
42pub(crate) trait Trimmer {
43    fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32);
44
45    /// Optimized trim for a single ASCII byte.
46    /// Uses byte-level scanning instead of char-level iteration.
47    fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32);
48}
49
50/// Returns the number of leading bytes matching `byte`
51#[inline]
52fn leading_bytes(bytes: &[u8], byte: u8) -> usize {
53    bytes.iter().take_while(|&&b| b == byte).count()
54}
55
56/// Returns the number of trailing bytes matching `byte`
57#[inline]
58fn trailing_bytes(bytes: &[u8], byte: u8) -> usize {
59    bytes.iter().rev().take_while(|&&b| b == byte).count()
60}
61
62/// Left trim - removes leading characters
63pub(crate) struct TrimLeft;
64
65impl Trimmer for TrimLeft {
66    #[inline]
67    fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) {
68        if pattern.len() == 1 && pattern[0].is_ascii() {
69            return Self::trim_ascii_char(input, pattern[0] as u8);
70        }
71        let trimmed = input.trim_start_matches(pattern);
72        let offset = (input.len() - trimmed.len()) as u32;
73        (trimmed, offset)
74    }
75
76    #[inline]
77    fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32) {
78        let start = leading_bytes(input.as_bytes(), byte);
79        (&input[start..], start as u32)
80    }
81}
82
83/// Right trim - removes trailing characters
84pub(crate) struct TrimRight;
85
86impl Trimmer for TrimRight {
87    #[inline]
88    fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) {
89        if pattern.len() == 1 && pattern[0].is_ascii() {
90            return Self::trim_ascii_char(input, pattern[0] as u8);
91        }
92        let trimmed = input.trim_end_matches(pattern);
93        (trimmed, 0)
94    }
95
96    #[inline]
97    fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32) {
98        let bytes = input.as_bytes();
99        let end = bytes.len() - trailing_bytes(bytes, byte);
100        (&input[..end], 0)
101    }
102}
103
104/// Both trim - removes both leading and trailing characters
105pub(crate) struct TrimBoth;
106
107impl Trimmer for TrimBoth {
108    #[inline]
109    fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) {
110        if pattern.len() == 1 && pattern[0].is_ascii() {
111            return Self::trim_ascii_char(input, pattern[0] as u8);
112        }
113        let left_trimmed = input.trim_start_matches(pattern);
114        let offset = (input.len() - left_trimmed.len()) as u32;
115        let trimmed = left_trimmed.trim_end_matches(pattern);
116        (trimmed, offset)
117    }
118
119    #[inline]
120    fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32) {
121        let bytes = input.as_bytes();
122        let start = leading_bytes(bytes, byte);
123        let end = bytes.len() - trailing_bytes(&bytes[start..], byte);
124        (&input[start..end], start as u32)
125    }
126}
127
128pub(crate) fn general_trim<T: OffsetSizeTrait, Tr: Trimmer>(
129    args: &[ArrayRef],
130    use_string_view: bool,
131) -> Result<ArrayRef> {
132    if use_string_view {
133        string_view_trim::<Tr>(args)
134    } else {
135        string_trim::<T, Tr>(args)
136    }
137}
138
139/// Applies the trim function to the given string view array(s)
140/// and returns a new string view array with the trimmed values.
141///
142/// Pre-computes the pattern characters once for scalar patterns to avoid
143/// repeated allocations per row.
144fn string_view_trim<Tr: Trimmer>(args: &[ArrayRef]) -> Result<ArrayRef> {
145    let string_view_array = as_string_view_array(&args[0])?;
146    let mut views_buf = Vec::with_capacity(string_view_array.len());
147    let mut null_builder = NullBufferBuilder::new(string_view_array.len());
148
149    match args.len() {
150        1 => {
151            // Trim spaces by default
152            for (src_str_opt, raw_view) in string_view_array
153                .iter()
154                .zip(string_view_array.views().iter())
155            {
156                if let Some(src_str) = src_str_opt {
157                    let (trimmed, offset) = Tr::trim_ascii_char(src_str, b' ');
158                    append_view(&mut views_buf, raw_view, trimmed, offset);
159                    null_builder.append_non_null();
160                } else {
161                    null_builder.append_null();
162                    views_buf.push(0);
163                }
164            }
165        }
166        2 => {
167            let characters_array = as_string_view_array(&args[1])?;
168
169            if characters_array.len() == 1 {
170                // Scalar pattern - pre-compute pattern chars once
171                if characters_array.is_null(0) {
172                    return Ok(new_null_array(
173                        &DataType::Utf8View,
174                        string_view_array.len(),
175                    ));
176                }
177
178                let pattern: Vec<char> = characters_array.value(0).chars().collect();
179                for (src_str_opt, raw_view) in string_view_array
180                    .iter()
181                    .zip(string_view_array.views().iter())
182                {
183                    trim_and_append_view::<Tr>(
184                        src_str_opt,
185                        &pattern,
186                        &mut views_buf,
187                        &mut null_builder,
188                        raw_view,
189                    );
190                }
191            } else {
192                // Per-row pattern - must compute pattern chars for each row
193                let mut pattern: Vec<char> = Vec::new();
194                for ((src_str_opt, raw_view), characters_opt) in string_view_array
195                    .iter()
196                    .zip(string_view_array.views().iter())
197                    .zip(characters_array.iter())
198                {
199                    if let (Some(src_str), Some(characters)) =
200                        (src_str_opt, characters_opt)
201                    {
202                        pattern.clear();
203                        pattern.extend(characters.chars());
204                        let (trimmed, offset) = Tr::trim(src_str, &pattern);
205                        append_view(&mut views_buf, raw_view, trimmed, offset);
206                        null_builder.append_non_null();
207                    } else {
208                        null_builder.append_null();
209                        views_buf.push(0);
210                    }
211                }
212            }
213        }
214        other => {
215            return exec_err!(
216                "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2."
217            );
218        }
219    }
220
221    let views_buf = ScalarBuffer::from(views_buf);
222    let nulls_buf = null_builder.finish();
223
224    // Safety:
225    // (1) The blocks of the given views are all provided
226    // (2) Each of the range `view.offset+start..end` of view in views_buf is within
227    // the bounds of each of the blocks
228    unsafe {
229        let array = StringViewArray::new_unchecked(
230            views_buf,
231            string_view_array.data_buffers().to_vec(),
232            nulls_buf,
233        );
234        Ok(Arc::new(array) as ArrayRef)
235    }
236}
237
238/// Trims the given string and appends the trimmed string to the views buffer
239/// and the null buffer.
240///
241/// Arguments
242/// - `src_str_opt`: The original string value (represented by the view)
243/// - `pattern`: Pre-computed character pattern to trim
244/// - `views_buf`: The buffer to append the updated views to
245/// - `null_builder`: The buffer to append the null values to
246/// - `original_view`: The original view value (that contains src_str_opt)
247#[inline]
248fn trim_and_append_view<Tr: Trimmer>(
249    src_str_opt: Option<&str>,
250    pattern: &[char],
251    views_buf: &mut Vec<u128>,
252    null_builder: &mut NullBufferBuilder,
253    original_view: &u128,
254) {
255    if let Some(src_str) = src_str_opt {
256        let (trimmed, offset) = Tr::trim(src_str, pattern);
257        append_view(views_buf, original_view, trimmed, offset);
258        null_builder.append_non_null();
259    } else {
260        null_builder.append_null();
261        views_buf.push(0);
262    }
263}
264
265/// Applies the trim function to the given string array(s)
266/// and returns a new string array with the trimmed values.
267///
268/// Pre-computes the pattern characters once for scalar patterns to avoid
269/// repeated allocations per row.
270fn string_trim<T: OffsetSizeTrait, Tr: Trimmer>(args: &[ArrayRef]) -> Result<ArrayRef> {
271    let string_array = as_generic_string_array::<T>(&args[0])?;
272
273    match args.len() {
274        1 => {
275            // Trim spaces by default
276            let result = string_array
277                .iter()
278                .map(|string| string.map(|s| Tr::trim_ascii_char(s, b' ').0))
279                .collect::<GenericStringArray<T>>();
280
281            Ok(Arc::new(result) as ArrayRef)
282        }
283        2 => {
284            let characters_array = as_generic_string_array::<T>(&args[1])?;
285
286            if characters_array.len() == 1 {
287                // Scalar pattern - pre-compute pattern chars once
288                if characters_array.is_null(0) {
289                    return Ok(new_null_array(
290                        string_array.data_type(),
291                        string_array.len(),
292                    ));
293                }
294
295                let pattern: Vec<char> = characters_array.value(0).chars().collect();
296                let result = string_array
297                    .iter()
298                    .map(|item| item.map(|s| Tr::trim(s, &pattern).0))
299                    .collect::<GenericStringArray<T>>();
300                return Ok(Arc::new(result) as ArrayRef);
301            }
302
303            // Per-row pattern - must compute pattern chars for each row
304            let mut pattern: Vec<char> = Vec::new();
305            let result = string_array
306                .iter()
307                .zip(characters_array.iter())
308                .map(|(string, characters)| match (string, characters) {
309                    (Some(s), Some(c)) => {
310                        pattern.clear();
311                        pattern.extend(c.chars());
312                        Some(Tr::trim(s, &pattern).0)
313                    }
314                    _ => None,
315                })
316                .collect::<GenericStringArray<T>>();
317
318            Ok(Arc::new(result) as ArrayRef)
319        }
320        other => {
321            exec_err!(
322                "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2."
323            )
324        }
325    }
326}
327
328pub(crate) fn to_lower(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
329    case_conversion(args, true, name)
330}
331
332pub(crate) fn to_upper(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
333    case_conversion(args, false, name)
334}
335
336#[inline]
337fn unicode_case(s: &str, lower: bool) -> String {
338    if lower {
339        s.to_lowercase()
340    } else {
341        s.to_uppercase()
342    }
343}
344
345fn case_conversion(
346    args: &[ColumnarValue],
347    lower: bool,
348    name: &str,
349) -> Result<ColumnarValue> {
350    match &args[0] {
351        ColumnarValue::Array(array) => match array.data_type() {
352            DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::<i32>(
353                array, lower,
354            )?)),
355            DataType::LargeUtf8 => Ok(ColumnarValue::Array(
356                case_conversion_array::<i64>(array, lower)?,
357            )),
358            DataType::Utf8View => {
359                let string_array = as_string_view_array(array)?;
360                if string_array.is_ascii() {
361                    return Ok(ColumnarValue::Array(Arc::new(
362                        case_conversion_utf8view_ascii(string_array, lower),
363                    )));
364                }
365                let item_len = string_array.len();
366                // Null-preserving: reuse the input null buffer as the output null buffer.
367                let nulls = string_array.nulls().cloned();
368                let mut builder = StringViewArrayBuilder::with_capacity(item_len);
369
370                if let Some(ref n) = nulls {
371                    for i in 0..item_len {
372                        if n.is_null(i) {
373                            builder.append_placeholder();
374                        } else {
375                            // SAFETY: `n.is_null(i)` was false in the branch above.
376                            let s = unsafe { string_array.value_unchecked(i) };
377                            builder.append_value(&unicode_case(s, lower));
378                        }
379                    }
380                } else {
381                    for i in 0..item_len {
382                        // SAFETY: no null buffer means every index is valid.
383                        let s = unsafe { string_array.value_unchecked(i) };
384                        builder.append_value(&unicode_case(s, lower));
385                    }
386                }
387
388                Ok(ColumnarValue::Array(Arc::new(builder.finish(nulls)?)))
389            }
390            other => exec_err!("Unsupported data type {other:?} for function {name}"),
391        },
392        ColumnarValue::Scalar(scalar) => match scalar {
393            ScalarValue::Utf8(a) => {
394                let result = a.as_ref().map(|x| unicode_case(x, lower));
395                Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
396            }
397            ScalarValue::LargeUtf8(a) => {
398                let result = a.as_ref().map(|x| unicode_case(x, lower));
399                Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result)))
400            }
401            ScalarValue::Utf8View(a) => {
402                let result = a.as_ref().map(|x| unicode_case(x, lower));
403                Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(result)))
404            }
405            other => exec_err!("Unsupported data type {other:?} for function {name}"),
406        },
407    }
408}
409
410fn case_conversion_array<O: OffsetSizeTrait>(
411    array: &ArrayRef,
412    lower: bool,
413) -> Result<ArrayRef> {
414    const PRE_ALLOC_BYTES: usize = 8;
415
416    let string_array = as_generic_string_array::<O>(array)?;
417    if string_array.is_ascii() {
418        return case_conversion_ascii_array::<O>(string_array, lower);
419    }
420
421    // Values contain non-ASCII.
422    let item_len = string_array.len();
423    let offsets = string_array.value_offsets();
424    let start = offsets.first().unwrap().as_usize();
425    let end = offsets.last().unwrap().as_usize();
426    let capacity = (end - start) + PRE_ALLOC_BYTES;
427    // Null-preserving: reuse the input null buffer as the output null buffer.
428    let nulls = string_array.nulls().cloned();
429    let mut builder = GenericStringArrayBuilder::<O>::with_capacity(item_len, capacity);
430
431    if let Some(ref n) = nulls {
432        for i in 0..item_len {
433            if n.is_null(i) {
434                builder.append_placeholder();
435            } else {
436                // SAFETY: `n.is_null(i)` was false in the branch above.
437                let s = unsafe { string_array.value_unchecked(i) };
438                builder.append_value(&unicode_case(s, lower));
439            }
440        }
441    } else {
442        for i in 0..item_len {
443            // SAFETY: no null buffer means every index is valid.
444            let s = unsafe { string_array.value_unchecked(i) };
445            builder.append_value(&unicode_case(s, lower));
446        }
447    }
448    Ok(Arc::new(builder.finish(nulls)?))
449}
450
451/// Fast path for case conversion on an all-ASCII `StringViewArray`.
452fn case_conversion_utf8view_ascii(
453    array: &StringViewArray,
454    lower: bool,
455) -> StringViewArray {
456    // Specialize per conversion so the byte call inlines in the hot loops below.
457    if lower {
458        case_conversion_utf8view_ascii_inner(array, u8::to_ascii_lowercase)
459    } else {
460        case_conversion_utf8view_ascii_inner(array, u8::to_ascii_uppercase)
461    }
462}
463
464/// Walks the views once and produces a new `StringViewArray` with
465/// case-converted bytes. Inline strings (<= 12 bytes) are converted in-place;
466/// long strings copy-and-convert into output buffers and have their view fields
467/// rewritten to address the new bytes. ASCII case conversion preserves is byte
468/// length, so no row migrates between the inline and long layouts.
469fn case_conversion_utf8view_ascii_inner<F: Fn(&u8) -> u8>(
470    array: &StringViewArray,
471    convert: F,
472) -> StringViewArray {
473    let item_len = array.len();
474    let views = array.views();
475    let data_buffers = array.data_buffers();
476    let nulls = array.nulls();
477
478    let mut new_views: Vec<u128> = Vec::with_capacity(item_len);
479    // Long values are packed into `in_progress`; when full it is sealed into
480    // `completed` and a new, larger block is started — same block-doubling
481    // scheme as Arrow's `GenericByteViewBuilder`.
482    let mut in_progress: Vec<u8> = Vec::new();
483    let mut completed: Vec<Buffer> = Vec::new();
484    let mut block_size: u32 = STRING_VIEW_INIT_BLOCK_SIZE;
485
486    for i in 0..item_len {
487        if nulls.is_some_and(|n| n.is_null(i)) {
488            // Zero view = empty, no buffer reference; the null buffer is what
489            // marks the row null, so the view's value is irrelevant.
490            new_views.push(0);
491            continue;
492        }
493        let view = views[i];
494        // Length is the low 32 bits; `as u32` discards the rest of the view.
495        let len = view as u32 as usize;
496        if len == 0 {
497            new_views.push(0);
498            continue;
499        }
500        let mut bytes = view.to_le_bytes();
501        if len <= 12 {
502            // Inline: value is in bytes[4..4+len], no buffer reference. Convert
503            // in place; nothing else in the view needs to change.
504            for b in &mut bytes[4..4 + len] {
505                *b = convert(b);
506            }
507            new_views.push(u128::from_le_bytes(bytes));
508        } else {
509            // Long: input view points into shared `data_buffers` we can't
510            // mutate, so copy-convert into our own buffer and rewrite the
511            // view's prefix/buffer_index/offset (length is preserved).
512
513            // Ensure the current block has room; otherwise flush and grow.
514            let required_cap = in_progress.len() + len;
515            if in_progress.capacity() < required_cap {
516                if !in_progress.is_empty() {
517                    completed.push(Buffer::from_vec(std::mem::take(&mut in_progress)));
518                }
519                if block_size < STRING_VIEW_MAX_BLOCK_SIZE {
520                    block_size = block_size.saturating_mul(2);
521                }
522                let to_reserve = len.max(block_size as usize);
523                in_progress.reserve(to_reserve);
524            }
525
526            // The in-progress block will be sealed at index `completed.len()`,
527            // and our value starts at the current write position within it.
528            let buffer_index: u32 = i32::try_from(completed.len())
529                .expect("buffer count exceeds i32::MAX")
530                as u32;
531            let new_offset: u32 =
532                i32::try_from(in_progress.len()).expect("offset exceeds i32::MAX") as u32;
533
534            // Source location from the input view: bytes 8..12 are buffer
535            // index, bytes 12..16 are the offset within it.
536            let src_buffer_index =
537                u32::from_le_bytes(bytes[8..12].try_into().unwrap()) as usize;
538            let src_offset =
539                u32::from_le_bytes(bytes[12..16].try_into().unwrap()) as usize;
540            let src =
541                &data_buffers[src_buffer_index].as_slice()[src_offset..src_offset + len];
542
543            let prefix_start = in_progress.len();
544            in_progress.extend(src.iter().map(&convert));
545
546            // Rewrite the three long-view fields; bytes[0..4] (length) is
547            // left untouched. The prefix is read back from the bytes we just
548            // wrote so the converted value has a single source of truth.
549            let prefix: [u8; 4] = in_progress[prefix_start..prefix_start + 4]
550                .try_into()
551                .unwrap();
552            bytes[4..8].copy_from_slice(&prefix);
553            bytes[8..12].copy_from_slice(&buffer_index.to_le_bytes());
554            bytes[12..16].copy_from_slice(&new_offset.to_le_bytes());
555            new_views.push(u128::from_le_bytes(bytes));
556        }
557    }
558
559    if !in_progress.is_empty() {
560        completed.push(Buffer::from_vec(in_progress));
561    }
562
563    // SAFETY: each long view's buffer_index addresses a buffer we wrote, and
564    // its offset addresses bytes within that buffer; prefixes were copied from
565    // those same bytes; inline views were rewritten from valid inline bytes;
566    // null/empty rows are zero views with no buffer reference; row count is
567    // unchanged.
568    unsafe {
569        StringViewArray::new_unchecked(
570            ScalarBuffer::from(new_views),
571            completed,
572            array.nulls().cloned(),
573        )
574    }
575}
576
577/// Fast path for case conversion on an all-ASCII string array. ASCII case
578/// conversion is byte-length-preserving, so we can convert the entire addressed
579/// byte range in one pass over the value buffer and reuse the offsets and nulls
580/// buffers — rebasing the offsets when the input is a sliced array.
581fn case_conversion_ascii_array<O: OffsetSizeTrait>(
582    string_array: &GenericStringArray<O>,
583    lower: bool,
584) -> Result<ArrayRef> {
585    let value_offsets = string_array.value_offsets();
586    let start = value_offsets.first().unwrap().as_usize();
587    let end = value_offsets.last().unwrap().as_usize();
588    let relevant = &string_array.value_data()[start..end];
589
590    let converted: Vec<u8> = if lower {
591        relevant.iter().map(u8::to_ascii_lowercase).collect()
592    } else {
593        relevant.iter().map(u8::to_ascii_uppercase).collect()
594    };
595    let values = Buffer::from_vec(converted);
596
597    // Shift offsets from `start`-based to 0-based so they index into `values`.
598    let offsets = if start == 0 {
599        string_array.offsets().clone()
600    } else {
601        let s = O::usize_as(start);
602        let rebased: Vec<O> = value_offsets.iter().map(|&o| o - s).collect();
603        // SAFETY: subtracting a constant from monotonic offsets preserves
604        // monotonicity, and `start` is the minimum offset, so no underflow.
605        unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(rebased)) }
606    };
607
608    let nulls = string_array.nulls().cloned();
609    // SAFETY: offsets are monotonic and in-bounds for `values`; nulls
610    // (if any) match the slice length.
611    Ok(Arc::new(unsafe {
612        GenericStringArray::<O>::new_unchecked(offsets, values, nulls)
613    }))
614}