Skip to main content

datafusion_spark/function/string/
substring.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    Array, ArrayAccessor, ArrayBuilder, ArrayRef, AsArray, BinaryViewBuilder,
20    GenericBinaryBuilder, GenericStringBuilder, Int64Array, OffsetSizeTrait,
21    StringViewBuilder,
22};
23use arrow::datatypes::DataType;
24use datafusion_common::arrow::datatypes::{Field, FieldRef};
25use datafusion_common::cast::as_int64_array;
26use datafusion_common::types::{
27    NativeType, logical_int32, logical_int64, logical_string,
28};
29use datafusion_common::{Result, exec_err};
30use datafusion_expr::{Coercion, ReturnFieldArgs, TypeSignatureClass};
31use datafusion_expr::{
32    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
33    Volatility,
34};
35use datafusion_functions::unicode::substr::{enable_ascii_fast_path, get_true_start_end};
36use datafusion_functions::utils::make_scalar_function;
37use std::sync::Arc;
38
39/// Spark-compatible `substring` expression
40/// <https://spark.apache.org/docs/latest/api/sql/index.html#substring>
41///
42/// Returns the substring from string starting at position pos with length len.
43/// Position is 1-indexed. If pos is negative, it counts from the end of the string.
44/// Returns NULL if any input is NULL.
45#[derive(Debug, PartialEq, Eq, Hash)]
46pub struct SparkSubstring {
47    signature: Signature,
48    aliases: Vec<String>,
49}
50
51impl Default for SparkSubstring {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl SparkSubstring {
58    pub fn new() -> Self {
59        let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string()));
60        let binary = Coercion::new_exact(TypeSignatureClass::Binary);
61        let int64 = Coercion::new_implicit(
62            TypeSignatureClass::Native(logical_int64()),
63            vec![TypeSignatureClass::Native(logical_int32())],
64            NativeType::Int64,
65        );
66        Self {
67            signature: Signature::one_of(
68                vec![
69                    TypeSignature::Coercible(vec![string.clone(), int64.clone()]),
70                    TypeSignature::Coercible(vec![
71                        string.clone(),
72                        int64.clone(),
73                        int64.clone(),
74                    ]),
75                    TypeSignature::Coercible(vec![binary.clone(), int64.clone()]),
76                    TypeSignature::Coercible(vec![
77                        binary.clone(),
78                        int64.clone(),
79                        int64.clone(),
80                    ]),
81                ],
82                Volatility::Immutable,
83            )
84            .with_parameter_names(vec![
85                "str".to_string(),
86                "pos".to_string(),
87                "length".to_string(),
88            ])
89            .expect("valid parameter names"),
90            aliases: vec![String::from("substr")],
91        }
92    }
93}
94
95impl ScalarUDFImpl for SparkSubstring {
96    fn name(&self) -> &str {
97        "substring"
98    }
99
100    fn signature(&self) -> &Signature {
101        &self.signature
102    }
103
104    fn aliases(&self) -> &[String] {
105        &self.aliases
106    }
107
108    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
109        make_scalar_function(spark_substring, vec![])(&args.args)
110    }
111
112    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
113        datafusion_common::internal_err!(
114            "return_type should not be called for Spark substring"
115        )
116    }
117
118    fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
119        // Spark semantics: substring returns NULL if ANY input is NULL
120        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
121
122        Ok(Arc::new(Field::new(
123            "substring",
124            args.arg_fields[0].data_type().clone(),
125            nullable,
126        )))
127    }
128}
129
130fn spark_substring(args: &[ArrayRef]) -> Result<ArrayRef> {
131    let start_array = as_int64_array(&args[1])?;
132    let length_array = if args.len() > 2 {
133        Some(as_int64_array(&args[2])?)
134    } else {
135        None
136    };
137
138    match args[0].data_type() {
139        DataType::Utf8 => {
140            let array = args[0].as_string::<i32>();
141            let is_ascii = enable_ascii_fast_path(&array, start_array, length_array);
142            spark_substring_generic(
143                &array,
144                start_array,
145                length_array,
146                GenericStringBuilder::<i32>::new(),
147                is_ascii,
148            )
149        }
150        DataType::LargeUtf8 => {
151            let array = args[0].as_string::<i64>();
152            let is_ascii = enable_ascii_fast_path(&array, start_array, length_array);
153            spark_substring_generic(
154                &array,
155                start_array,
156                length_array,
157                GenericStringBuilder::<i64>::new(),
158                is_ascii,
159            )
160        }
161        DataType::Utf8View => {
162            let array = args[0].as_string_view();
163            let is_ascii = enable_ascii_fast_path(&array, start_array, length_array);
164            spark_substring_generic(
165                &array,
166                start_array,
167                length_array,
168                StringViewBuilder::new(),
169                is_ascii,
170            )
171        }
172        // Binary paths always use byte-level indexing, so `is_ascii` is irrelevant
173        // and set to `true` (its value is ignored by the `[u8]` impl of
174        // `SubstringItem`).
175        DataType::Binary => spark_substring_generic(
176            &args[0].as_binary::<i32>(),
177            start_array,
178            length_array,
179            GenericBinaryBuilder::<i32>::new(),
180            true,
181        ),
182        DataType::LargeBinary => spark_substring_generic(
183            &args[0].as_binary::<i64>(),
184            start_array,
185            length_array,
186            GenericBinaryBuilder::<i64>::new(),
187            true,
188        ),
189        DataType::BinaryView => spark_substring_generic(
190            &args[0].as_binary_view(),
191            start_array,
192            length_array,
193            BinaryViewBuilder::new(),
194            true,
195        ),
196        other => exec_err!(
197            "Unsupported data type {other:?} for function spark_substring, expected Utf8View, Utf8, LargeUtf8, Binary, LargeBinary or BinaryView."
198        ),
199    }
200}
201
202/// Convert Spark's start position to DataFusion's 1-based start position.
203///
204/// Spark semantics:
205/// - Positive start: 1-based index from beginning
206/// - Zero start: treated as 1
207/// - Negative start: counts from end of string
208///
209/// The result may be `<= 0` when a negative start lands before the string
210/// (e.g. `start=-10` on a 3-char string gives `-6`). Such values are passed
211/// through to `get_true_start_end`, which clamps them and yields an empty
212/// slice — matching Spark's behavior for out-of-range negative positions.
213#[inline]
214fn spark_start_to_datafusion_start(start: i64, len: usize) -> i64 {
215    if start >= 0 {
216        start.max(1)
217    } else {
218        let len_i64 = i64::try_from(len).unwrap_or(i64::MAX);
219        start + len_i64 + 1
220    }
221}
222
223trait SubstringItem {
224    /// Length used for Spark's negative-position adjustment.
225    /// For `str` this is characters (or bytes in ASCII mode); for `[u8]` it is
226    /// always byte count.
227    fn positional_len(&self, is_ascii: bool) -> usize;
228
229    /// Converts Spark's 1-indexed adjusted start + optional length into a
230    /// byte range clamped to `[0, byte_len]`.
231    fn byte_range(
232        &self,
233        adjusted_start: i64,
234        len: Option<i64>,
235        is_ascii: bool,
236    ) -> Result<(usize, usize)>;
237
238    fn byte_slice(&self, start: usize, end: usize) -> &Self;
239}
240
241impl SubstringItem for str {
242    fn positional_len(&self, is_ascii: bool) -> usize {
243        if is_ascii {
244            self.len()
245        } else {
246            self.chars().count()
247        }
248    }
249
250    fn byte_range(
251        &self,
252        adjusted_start: i64,
253        len: Option<i64>,
254        is_ascii: bool,
255    ) -> Result<(usize, usize)> {
256        get_true_start_end(self, adjusted_start, len, is_ascii)
257    }
258
259    fn byte_slice(&self, start: usize, end: usize) -> &Self {
260        &self[start..end]
261    }
262}
263
264impl SubstringItem for [u8] {
265    fn positional_len(&self, _is_ascii: bool) -> usize {
266        self.len()
267    }
268
269    fn byte_range(
270        &self,
271        adjusted_start: i64,
272        len: Option<i64>,
273        _is_ascii: bool,
274    ) -> Result<(usize, usize)> {
275        let byte_len = self.len();
276        let start0 = adjusted_start.saturating_sub(1);
277        let end0 = match len {
278            Some(l) => start0.saturating_add(l),
279            None => byte_len as i64,
280        };
281        let byte_len_i64 = byte_len as i64;
282        Ok((
283            start0.clamp(0, byte_len_i64) as usize,
284            end0.clamp(0, byte_len_i64) as usize,
285        ))
286    }
287
288    fn byte_slice(&self, start: usize, end: usize) -> &Self {
289        &self[start..end]
290    }
291}
292
293trait SubstringBuilder: ArrayBuilder {
294    type Item: SubstringItem + ?Sized;
295    fn append_value(&mut self, val: &Self::Item);
296    fn append_null(&mut self);
297    /// Spark's semantic "empty" for this builder's item type, used for the
298    /// negative-length short-circuit.
299    fn append_empty(&mut self);
300}
301
302impl<O: OffsetSizeTrait> SubstringBuilder for GenericStringBuilder<O> {
303    type Item = str;
304    fn append_value(&mut self, val: &str) {
305        GenericStringBuilder::append_value(self, val);
306    }
307    fn append_null(&mut self) {
308        GenericStringBuilder::append_null(self);
309    }
310    fn append_empty(&mut self) {
311        GenericStringBuilder::append_value(self, "");
312    }
313}
314
315impl SubstringBuilder for StringViewBuilder {
316    type Item = str;
317    fn append_value(&mut self, val: &str) {
318        StringViewBuilder::append_value(self, val);
319    }
320    fn append_null(&mut self) {
321        StringViewBuilder::append_null(self);
322    }
323    fn append_empty(&mut self) {
324        StringViewBuilder::append_value(self, "");
325    }
326}
327
328impl<O: OffsetSizeTrait> SubstringBuilder for GenericBinaryBuilder<O> {
329    type Item = [u8];
330    fn append_value(&mut self, val: &[u8]) {
331        GenericBinaryBuilder::append_value(self, val);
332    }
333    fn append_null(&mut self) {
334        GenericBinaryBuilder::append_null(self);
335    }
336    fn append_empty(&mut self) {
337        GenericBinaryBuilder::append_value(self, &[]);
338    }
339}
340
341impl SubstringBuilder for BinaryViewBuilder {
342    type Item = [u8];
343    fn append_value(&mut self, val: &[u8]) {
344        BinaryViewBuilder::append_value(self, val);
345    }
346    fn append_null(&mut self) {
347        BinaryViewBuilder::append_null(self);
348    }
349    fn append_empty(&mut self) {
350        BinaryViewBuilder::append_value(self, []);
351    }
352}
353
354/// Unified implementation of Spark's `substring`, generic over the source
355/// array (`StringArrayType`/`BinaryArrayType` via `ArrayAccessor`) and its
356/// corresponding builder. Per-row indexing semantics are delegated to
357/// [`SubstringItem`], which differs between `str` (char-aware when
358/// `is_ascii` is false) and `[u8]` (always byte-level).
359fn spark_substring_generic<'a, Source, Item, Builder>(
360    array: &Source,
361    start_array: &Int64Array,
362    length_array: Option<&Int64Array>,
363    mut builder: Builder,
364    is_ascii: bool,
365) -> Result<ArrayRef>
366where
367    Source: ArrayAccessor<Item = &'a Item>,
368    Item: SubstringItem + ?Sized + 'a,
369    Builder: SubstringBuilder<Item = Item>,
370{
371    for i in 0..array.len() {
372        if array.is_null(i) || start_array.is_null(i) {
373            builder.append_null();
374            continue;
375        }
376
377        if let Some(len_arr) = length_array
378            && len_arr.is_null(i)
379        {
380            builder.append_null();
381            continue;
382        }
383
384        let value = array.value(i);
385        let start = start_array.value(i);
386        let len_opt = length_array.map(|arr| arr.value(i));
387
388        // Spark: negative length yields an empty value
389        if let Some(len) = len_opt
390            && len < 0
391        {
392            builder.append_empty();
393            continue;
394        }
395
396        let positional_len = value.positional_len(is_ascii);
397        let adjusted_start = spark_start_to_datafusion_start(start, positional_len);
398        let (byte_start, byte_end) =
399            value.byte_range(adjusted_start, len_opt, is_ascii)?;
400        builder.append_value(value.byte_slice(byte_start, byte_end));
401    }
402
403    Ok(builder.finish())
404}