Skip to main content

datafusion_spark/function/string/
substring.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    Array, ArrayBuilder, ArrayRef, AsArray, GenericStringBuilder, Int64Array,
20    OffsetSizeTrait, StringArrayType, StringViewBuilder,
21};
22use arrow::datatypes::DataType;
23use datafusion_common::arrow::datatypes::{Field, FieldRef};
24use datafusion_common::cast::as_int64_array;
25use datafusion_common::types::{
26    NativeType, logical_int32, logical_int64, logical_string,
27};
28use datafusion_common::{Result, exec_err};
29use datafusion_expr::{Coercion, ReturnFieldArgs, TypeSignatureClass};
30use datafusion_expr::{
31    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
32    Volatility,
33};
34use datafusion_functions::unicode::substr::{enable_ascii_fast_path, get_true_start_end};
35use datafusion_functions::utils::make_scalar_function;
36use std::any::Any;
37use std::sync::Arc;
38
39/// Spark-compatible `substring` expression
40/// <https://spark.apache.org/docs/latest/api/sql/index.html#substring>
41///
42/// Returns the substring from string starting at position pos with length len.
43/// Position is 1-indexed. If pos is negative, it counts from the end of the string.
44/// Returns NULL if any input is NULL.
45#[derive(Debug, PartialEq, Eq, Hash)]
46pub struct SparkSubstring {
47    signature: Signature,
48    aliases: Vec<String>,
49}
50
51impl Default for SparkSubstring {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl SparkSubstring {
58    pub fn new() -> Self {
59        let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string()));
60        let int64 = Coercion::new_implicit(
61            TypeSignatureClass::Native(logical_int64()),
62            vec![TypeSignatureClass::Native(logical_int32())],
63            NativeType::Int64,
64        );
65        Self {
66            signature: Signature::one_of(
67                vec![
68                    TypeSignature::Coercible(vec![string.clone(), int64.clone()]),
69                    TypeSignature::Coercible(vec![
70                        string.clone(),
71                        int64.clone(),
72                        int64.clone(),
73                    ]),
74                ],
75                Volatility::Immutable,
76            )
77            .with_parameter_names(vec![
78                "str".to_string(),
79                "pos".to_string(),
80                "length".to_string(),
81            ])
82            .expect("valid parameter names"),
83            aliases: vec![String::from("substr")],
84        }
85    }
86}
87
88impl ScalarUDFImpl for SparkSubstring {
89    fn as_any(&self) -> &dyn Any {
90        self
91    }
92
93    fn name(&self) -> &str {
94        "substring"
95    }
96
97    fn signature(&self) -> &Signature {
98        &self.signature
99    }
100
101    fn aliases(&self) -> &[String] {
102        &self.aliases
103    }
104
105    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
106        make_scalar_function(spark_substring, vec![])(&args.args)
107    }
108
109    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
110        datafusion_common::internal_err!(
111            "return_type should not be called for Spark substring"
112        )
113    }
114
115    fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
116        // Spark semantics: substring returns NULL if ANY input is NULL
117        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
118
119        Ok(Arc::new(Field::new(
120            "substring",
121            args.arg_fields[0].data_type().clone(),
122            nullable,
123        )))
124    }
125}
126
127fn spark_substring(args: &[ArrayRef]) -> Result<ArrayRef> {
128    let start_array = as_int64_array(&args[1])?;
129    let length_array = if args.len() > 2 {
130        Some(as_int64_array(&args[2])?)
131    } else {
132        None
133    };
134
135    match args[0].data_type() {
136        DataType::Utf8 => spark_substring_impl(
137            &args[0].as_string::<i32>(),
138            start_array,
139            length_array,
140            GenericStringBuilder::<i32>::new(),
141        ),
142        DataType::LargeUtf8 => spark_substring_impl(
143            &args[0].as_string::<i64>(),
144            start_array,
145            length_array,
146            GenericStringBuilder::<i64>::new(),
147        ),
148        DataType::Utf8View => spark_substring_impl(
149            &args[0].as_string_view(),
150            start_array,
151            length_array,
152            StringViewBuilder::new(),
153        ),
154        other => exec_err!(
155            "Unsupported data type {other:?} for function spark_substring, expected Utf8View, Utf8 or LargeUtf8."
156        ),
157    }
158}
159
160/// Convert Spark's start position to DataFusion's 1-based start position.
161///
162/// Spark semantics:
163/// - Positive start: 1-based index from beginning
164/// - Zero start: treated as 1
165/// - Negative start: counts from end of string
166///
167/// Returns the converted 1-based start position for use with `get_true_start_end`.
168#[inline]
169fn spark_start_to_datafusion_start(start: i64, len: usize) -> i64 {
170    if start >= 0 {
171        start.max(1)
172    } else {
173        let len_i64 = i64::try_from(len).unwrap_or(i64::MAX);
174        let start = start.saturating_add(len_i64).saturating_add(1);
175        start.max(1)
176    }
177}
178
179trait StringArrayBuilder: ArrayBuilder {
180    fn append_value(&mut self, val: &str);
181    fn append_null(&mut self);
182}
183
184impl<O: OffsetSizeTrait> StringArrayBuilder for GenericStringBuilder<O> {
185    fn append_value(&mut self, val: &str) {
186        GenericStringBuilder::append_value(self, val);
187    }
188    fn append_null(&mut self) {
189        GenericStringBuilder::append_null(self);
190    }
191}
192
193impl StringArrayBuilder for StringViewBuilder {
194    fn append_value(&mut self, val: &str) {
195        StringViewBuilder::append_value(self, val);
196    }
197    fn append_null(&mut self) {
198        StringViewBuilder::append_null(self);
199    }
200}
201
202fn spark_substring_impl<'a, V, B>(
203    string_array: &V,
204    start_array: &Int64Array,
205    length_array: Option<&Int64Array>,
206    mut builder: B,
207) -> Result<ArrayRef>
208where
209    V: StringArrayType<'a>,
210    B: StringArrayBuilder,
211{
212    let is_ascii = enable_ascii_fast_path(string_array, start_array, length_array);
213
214    for i in 0..string_array.len() {
215        if string_array.is_null(i) || start_array.is_null(i) {
216            builder.append_null();
217            continue;
218        }
219
220        if let Some(len_arr) = length_array
221            && len_arr.is_null(i)
222        {
223            builder.append_null();
224            continue;
225        }
226
227        let string = string_array.value(i);
228        let start = start_array.value(i);
229        let len_opt = length_array.map(|arr| arr.value(i));
230
231        // Spark: negative length returns empty string
232        if let Some(len) = len_opt
233            && len < 0
234        {
235            builder.append_value("");
236            continue;
237        }
238
239        let string_len = if is_ascii {
240            string.len()
241        } else {
242            string.chars().count()
243        };
244
245        let adjusted_start = spark_start_to_datafusion_start(start, string_len);
246
247        let (byte_start, byte_end) = get_true_start_end(
248            string,
249            adjusted_start,
250            len_opt.map(|l| l as u64),
251            is_ascii,
252        );
253        let substr = &string[byte_start..byte_end];
254        builder.append_value(substr);
255    }
256
257    Ok(builder.finish())
258}