Skip to main content

datafusion_functions/unicode/
substrindex.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22    ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray,
23    GenericStringBuilder, OffsetSizeTrait, PrimitiveArray,
24};
25use arrow::datatypes::{DataType, Int32Type, Int64Type};
26
27use crate::utils::{make_scalar_function, utf8_to_str_type};
28use datafusion_common::{Result, exec_err, utils::take_function_args};
29use datafusion_expr::TypeSignature::Exact;
30use datafusion_expr::{
31    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
32};
33use datafusion_macros::user_doc;
34
35#[user_doc(
36    doc_section(label = "String Functions"),
37    description = r#"Returns the substring from str before count occurrences of the delimiter delim.
38If count is positive, everything to the left of the final delimiter (counting from the left) is returned.
39If count is negative, everything to the right of the final delimiter (counting from the right) is returned."#,
40    syntax_example = "substr_index(str, delim, count)",
41    sql_example = r#"```sql
42> select substr_index('www.apache.org', '.', 1);
43+---------------------------------------------------------+
44| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(1)) |
45+---------------------------------------------------------+
46| www                                                     |
47+---------------------------------------------------------+
48> select substr_index('www.apache.org', '.', -1);
49+----------------------------------------------------------+
50| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(-1)) |
51+----------------------------------------------------------+
52| org                                                      |
53+----------------------------------------------------------+
54```"#,
55    standard_argument(name = "str", prefix = "String"),
56    argument(
57        name = "delim",
58        description = "The string to find in str to split str."
59    ),
60    argument(
61        name = "count",
62        description = "The number of times to search for the delimiter. Can be either a positive or negative number."
63    )
64)]
65#[derive(Debug, PartialEq, Eq, Hash)]
66pub struct SubstrIndexFunc {
67    signature: Signature,
68    aliases: Vec<String>,
69}
70
71impl Default for SubstrIndexFunc {
72    fn default() -> Self {
73        Self::new()
74    }
75}
76
77impl SubstrIndexFunc {
78    pub fn new() -> Self {
79        use DataType::*;
80        Self {
81            signature: Signature::one_of(
82                vec![
83                    Exact(vec![Utf8View, Utf8View, Int64]),
84                    Exact(vec![Utf8, Utf8, Int64]),
85                    Exact(vec![LargeUtf8, LargeUtf8, Int64]),
86                ],
87                Volatility::Immutable,
88            ),
89            aliases: vec![String::from("substring_index")],
90        }
91    }
92}
93
94impl ScalarUDFImpl for SubstrIndexFunc {
95    fn as_any(&self) -> &dyn Any {
96        self
97    }
98
99    fn name(&self) -> &str {
100        "substr_index"
101    }
102
103    fn signature(&self) -> &Signature {
104        &self.signature
105    }
106
107    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
108        utf8_to_str_type(&arg_types[0], "substr_index")
109    }
110
111    fn invoke_with_args(
112        &self,
113        args: datafusion_expr::ScalarFunctionArgs,
114    ) -> Result<ColumnarValue> {
115        make_scalar_function(substr_index, vec![])(&args.args)
116    }
117
118    fn aliases(&self) -> &[String] {
119        &self.aliases
120    }
121
122    fn documentation(&self) -> Option<&Documentation> {
123        self.doc()
124    }
125}
126
127/// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned.
128/// SUBSTRING_INDEX('www.apache.org', '.', 1) = www
129/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache
130/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org
131/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org
132fn substr_index(args: &[ArrayRef]) -> Result<ArrayRef> {
133    let [str, delim, count] = take_function_args("substr_index", args)?;
134
135    match str.data_type() {
136        DataType::Utf8 => {
137            let string_array = str.as_string::<i32>();
138            let delimiter_array = delim.as_string::<i32>();
139            let count_array: &PrimitiveArray<Int64Type> = count.as_primitive();
140            substr_index_general::<Int32Type, _, _>(
141                string_array,
142                delimiter_array,
143                count_array,
144            )
145        }
146        DataType::LargeUtf8 => {
147            let string_array = str.as_string::<i64>();
148            let delimiter_array = delim.as_string::<i64>();
149            let count_array: &PrimitiveArray<Int64Type> = count.as_primitive();
150            substr_index_general::<Int64Type, _, _>(
151                string_array,
152                delimiter_array,
153                count_array,
154            )
155        }
156        DataType::Utf8View => {
157            let string_array = str.as_string_view();
158            let delimiter_array = delim.as_string_view();
159            let count_array: &PrimitiveArray<Int64Type> = count.as_primitive();
160            substr_index_general::<Int32Type, _, _>(
161                string_array,
162                delimiter_array,
163                count_array,
164            )
165        }
166        other => {
167            exec_err!("Unsupported data type {other:?} for function substr_index")
168        }
169    }
170}
171
172fn substr_index_general<
173    'a,
174    T: ArrowPrimitiveType,
175    V: ArrayAccessor<Item = &'a str>,
176    P: ArrayAccessor<Item = i64>,
177>(
178    string_array: V,
179    delimiter_array: V,
180    count_array: P,
181) -> Result<ArrayRef>
182where
183    T::Native: OffsetSizeTrait,
184{
185    let num_rows = string_array.len();
186    let mut builder = GenericStringBuilder::<T::Native>::with_capacity(num_rows, 0);
187    let string_iter = ArrayIter::new(string_array);
188    let delimiter_array_iter = ArrayIter::new(delimiter_array);
189    let count_array_iter = ArrayIter::new(count_array);
190    string_iter
191        .zip(delimiter_array_iter)
192        .zip(count_array_iter)
193        .for_each(|((string, delimiter), n)| match (string, delimiter, n) {
194            (Some(string), Some(delimiter), Some(n)) => {
195                // In MySQL, these cases will return an empty string.
196                if n == 0 || string.is_empty() || delimiter.is_empty() {
197                    builder.append_value("");
198                    return;
199                }
200
201                let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX);
202                let result_idx = if delimiter.len() == 1 {
203                    // Fast path: use byte-level search for single-character delimiters
204                    let d_byte = delimiter.as_bytes()[0];
205                    let bytes = string.as_bytes();
206
207                    if n > 0 {
208                        bytes
209                            .iter()
210                            .enumerate()
211                            .filter(|&(_, &b)| b == d_byte)
212                            .nth(occurrences - 1)
213                            .map(|(idx, _)| idx)
214                    } else {
215                        bytes
216                            .iter()
217                            .enumerate()
218                            .rev()
219                            .filter(|&(_, &b)| b == d_byte)
220                            .nth(occurrences - 1)
221                            .map(|(idx, _)| idx + 1)
222                    }
223                } else if n > 0 {
224                    // Multi-byte path: forward search for n-th occurrence
225                    string
226                        .match_indices(delimiter)
227                        .nth(occurrences - 1)
228                        .map(|(idx, _)| idx)
229                } else {
230                    // Multi-byte path: backward search for n-th occurrence from the right
231                    string
232                        .rmatch_indices(delimiter)
233                        .nth(occurrences - 1)
234                        .map(|(idx, _)| idx + delimiter.len())
235                };
236                match result_idx {
237                    Some(idx) => {
238                        if n > 0 {
239                            builder.append_value(&string[..idx]);
240                        } else {
241                            builder.append_value(&string[idx..]);
242                        }
243                    }
244                    None => builder.append_value(string),
245                }
246            }
247            _ => builder.append_null(),
248        });
249
250    Ok(Arc::new(builder.finish()) as ArrayRef)
251}
252
253#[cfg(test)]
254mod tests {
255    use arrow::array::{Array, StringArray};
256    use arrow::datatypes::DataType::Utf8;
257
258    use datafusion_common::{Result, ScalarValue};
259    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
260
261    use crate::unicode::substrindex::SubstrIndexFunc;
262    use crate::utils::test::test_function;
263
264    #[test]
265    fn test_functions() -> Result<()> {
266        test_function!(
267            SubstrIndexFunc::new(),
268            vec![
269                ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
270                ColumnarValue::Scalar(ScalarValue::from(".")),
271                ColumnarValue::Scalar(ScalarValue::from(1i64)),
272            ],
273            Ok(Some("www")),
274            &str,
275            Utf8,
276            StringArray
277        );
278        test_function!(
279            SubstrIndexFunc::new(),
280            vec![
281                ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
282                ColumnarValue::Scalar(ScalarValue::from(".")),
283                ColumnarValue::Scalar(ScalarValue::from(2i64)),
284            ],
285            Ok(Some("www.apache")),
286            &str,
287            Utf8,
288            StringArray
289        );
290        test_function!(
291            SubstrIndexFunc::new(),
292            vec![
293                ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
294                ColumnarValue::Scalar(ScalarValue::from(".")),
295                ColumnarValue::Scalar(ScalarValue::from(-2i64)),
296            ],
297            Ok(Some("apache.org")),
298            &str,
299            Utf8,
300            StringArray
301        );
302        test_function!(
303            SubstrIndexFunc::new(),
304            vec![
305                ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
306                ColumnarValue::Scalar(ScalarValue::from(".")),
307                ColumnarValue::Scalar(ScalarValue::from(-1i64)),
308            ],
309            Ok(Some("org")),
310            &str,
311            Utf8,
312            StringArray
313        );
314        test_function!(
315            SubstrIndexFunc::new(),
316            vec![
317                ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
318                ColumnarValue::Scalar(ScalarValue::from(".")),
319                ColumnarValue::Scalar(ScalarValue::from(0i64)),
320            ],
321            Ok(Some("")),
322            &str,
323            Utf8,
324            StringArray
325        );
326        test_function!(
327            SubstrIndexFunc::new(),
328            vec![
329                ColumnarValue::Scalar(ScalarValue::from("")),
330                ColumnarValue::Scalar(ScalarValue::from(".")),
331                ColumnarValue::Scalar(ScalarValue::from(1i64)),
332            ],
333            Ok(Some("")),
334            &str,
335            Utf8,
336            StringArray
337        );
338        test_function!(
339            SubstrIndexFunc::new(),
340            vec![
341                ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
342                ColumnarValue::Scalar(ScalarValue::from("")),
343                ColumnarValue::Scalar(ScalarValue::from(1i64)),
344            ],
345            Ok(Some("")),
346            &str,
347            Utf8,
348            StringArray
349        );
350        Ok(())
351    }
352}