Skip to main content

datafusion_functions/unicode/
strpos.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::{make_scalar_function, utf8_to_int_type};
22use arrow::array::{
23    ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, StringArrayType,
24};
25use arrow::datatypes::{
26    ArrowNativeType, DataType, Field, FieldRef, Int32Type, Int64Type,
27};
28use datafusion_common::types::logical_string;
29use datafusion_common::{Result, exec_err, internal_err};
30use datafusion_expr::{
31    Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignatureClass,
32    Volatility,
33};
34use datafusion_macros::user_doc;
35use memchr::memchr;
36
37#[user_doc(
38    doc_section(label = "String Functions"),
39    description = "Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0.",
40    syntax_example = "strpos(str, substr)",
41    alternative_syntax = "position(substr in origstr)",
42    sql_example = r#"```sql
43> select strpos('datafusion', 'fus');
44+----------------------------------------+
45| strpos(Utf8("datafusion"),Utf8("fus")) |
46+----------------------------------------+
47| 5                                      |
48+----------------------------------------+ 
49```"#,
50    standard_argument(name = "str", prefix = "String"),
51    argument(name = "substr", description = "Substring expression to search for.")
52)]
53#[derive(Debug, PartialEq, Eq, Hash)]
54pub struct StrposFunc {
55    signature: Signature,
56    aliases: Vec<String>,
57}
58
59impl Default for StrposFunc {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl StrposFunc {
66    pub fn new() -> Self {
67        Self {
68            signature: Signature::coercible(
69                vec![
70                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
71                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
72                ],
73                Volatility::Immutable,
74            ),
75            aliases: vec![String::from("instr"), String::from("position")],
76        }
77    }
78}
79
80impl ScalarUDFImpl for StrposFunc {
81    fn as_any(&self) -> &dyn Any {
82        self
83    }
84
85    fn name(&self) -> &str {
86        "strpos"
87    }
88
89    fn signature(&self) -> &Signature {
90        &self.signature
91    }
92
93    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
94        internal_err!("return_field_from_args should be used instead")
95    }
96
97    fn return_field_from_args(
98        &self,
99        args: datafusion_expr::ReturnFieldArgs,
100    ) -> Result<FieldRef> {
101        utf8_to_int_type(args.arg_fields[0].data_type(), "strpos/instr/position").map(
102            |data_type| {
103                Field::new(
104                    self.name(),
105                    data_type,
106                    args.arg_fields.iter().any(|x| x.is_nullable()),
107                )
108                .into()
109            },
110        )
111    }
112
113    fn invoke_with_args(
114        &self,
115        args: datafusion_expr::ScalarFunctionArgs,
116    ) -> Result<ColumnarValue> {
117        make_scalar_function(strpos, vec![])(&args.args)
118    }
119
120    fn aliases(&self) -> &[String] {
121        &self.aliases
122    }
123
124    fn documentation(&self) -> Option<&Documentation> {
125        self.doc()
126    }
127}
128
129fn strpos(args: &[ArrayRef]) -> Result<ArrayRef> {
130    match (args[0].data_type(), args[1].data_type()) {
131        (DataType::Utf8, DataType::Utf8) => {
132            let string_array = args[0].as_string::<i32>();
133            let substring_array = args[1].as_string::<i32>();
134            calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array)
135        }
136        (DataType::Utf8, DataType::Utf8View) => {
137            let string_array = args[0].as_string::<i32>();
138            let substring_array = args[1].as_string_view();
139            calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array)
140        }
141        (DataType::Utf8, DataType::LargeUtf8) => {
142            let string_array = args[0].as_string::<i32>();
143            let substring_array = args[1].as_string::<i64>();
144            calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array)
145        }
146        (DataType::LargeUtf8, DataType::Utf8) => {
147            let string_array = args[0].as_string::<i64>();
148            let substring_array = args[1].as_string::<i32>();
149            calculate_strpos::<_, _, Int64Type>(&string_array, &substring_array)
150        }
151        (DataType::LargeUtf8, DataType::Utf8View) => {
152            let string_array = args[0].as_string::<i64>();
153            let substring_array = args[1].as_string_view();
154            calculate_strpos::<_, _, Int64Type>(&string_array, &substring_array)
155        }
156        (DataType::LargeUtf8, DataType::LargeUtf8) => {
157            let string_array = args[0].as_string::<i64>();
158            let substring_array = args[1].as_string::<i64>();
159            calculate_strpos::<_, _, Int64Type>(&string_array, &substring_array)
160        }
161        (DataType::Utf8View, DataType::Utf8View) => {
162            let string_array = args[0].as_string_view();
163            let substring_array = args[1].as_string_view();
164            calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array)
165        }
166        (DataType::Utf8View, DataType::Utf8) => {
167            let string_array = args[0].as_string_view();
168            let substring_array = args[1].as_string::<i32>();
169            calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array)
170        }
171        (DataType::Utf8View, DataType::LargeUtf8) => {
172            let string_array = args[0].as_string_view();
173            let substring_array = args[1].as_string::<i64>();
174            calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array)
175        }
176
177        other => {
178            exec_err!("Unsupported data type combination {other:?} for function strpos")
179        }
180    }
181}
182
183/// Find `needle` in `haystack` using `memchr` to quickly skip to positions
184/// where the first byte matches, then verify the remaining bytes. Using
185/// string::find is slower because it has significant per-call overhead that
186/// `memchr` does not, and strpos is often invoked many times on short inputs.
187/// Returns a 1-based position, or 0 if not found.
188/// Both inputs must be ASCII-only.
189fn find_ascii_substring(haystack: &[u8], needle: &[u8]) -> usize {
190    let needle_len = needle.len();
191    let first_byte = needle[0];
192    let mut offset = 0;
193
194    while let Some(pos) = memchr(first_byte, &haystack[offset..]) {
195        let start = offset + pos;
196        if start + needle_len > haystack.len() {
197            return 0;
198        }
199        if haystack[start..start + needle_len] == *needle {
200            return start + 1;
201        }
202        offset = start + 1;
203    }
204
205    0
206}
207
208/// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.)
209/// strpos('high', 'ig') = 2
210/// The implementation uses UTF-8 code points as characters
211fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>(
212    string_array: &V1,
213    substring_array: &V2,
214) -> Result<ArrayRef>
215where
216    V1: StringArrayType<'a, Item = &'a str>,
217    V2: StringArrayType<'a, Item = &'a str>,
218{
219    let ascii_only = substring_array.is_ascii() && string_array.is_ascii();
220    let string_iter = string_array.iter();
221    let substring_iter = substring_array.iter();
222
223    let result = string_iter
224        .zip(substring_iter)
225        .map(|(string, substring)| match (string, substring) {
226            (Some(string), Some(substring)) => {
227                if substring.is_empty() {
228                    return T::Native::from_usize(1);
229                }
230
231                let substring_bytes = substring.as_bytes();
232                let string_bytes = string.as_bytes();
233
234                if substring_bytes.len() > string_bytes.len() {
235                    return T::Native::from_usize(0);
236                }
237
238                if ascii_only {
239                    T::Native::from_usize(find_ascii_substring(
240                        string_bytes,
241                        substring_bytes,
242                    ))
243                } else {
244                    // For non-ASCII, use a single-pass search that tracks both
245                    // byte position and character position simultaneously
246                    let mut char_pos = 0;
247                    for (byte_idx, _) in string.char_indices() {
248                        char_pos += 1;
249                        if byte_idx + substring_bytes.len() <= string_bytes.len() {
250                            // SAFETY: We just checked that byte_idx + substring_bytes.len() <= string_bytes.len()
251                            let slice = unsafe {
252                                string_bytes.get_unchecked(
253                                    byte_idx..byte_idx + substring_bytes.len(),
254                                )
255                            };
256                            if slice == substring_bytes {
257                                return T::Native::from_usize(char_pos);
258                            }
259                        }
260                    }
261
262                    T::Native::from_usize(0)
263                }
264            }
265            _ => None,
266        })
267        .collect::<PrimitiveArray<T>>();
268
269    Ok(Arc::new(result) as ArrayRef)
270}
271
272#[cfg(test)]
273mod tests {
274    use arrow::array::{Array, Int32Array, Int64Array};
275    use arrow::datatypes::DataType::{Int32, Int64};
276
277    use arrow::datatypes::{DataType, Field};
278    use datafusion_common::{Result, ScalarValue};
279    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
280
281    use crate::unicode::strpos::StrposFunc;
282    use crate::utils::test::test_function;
283
284    macro_rules! test_strpos {
285        ($lhs:literal, $rhs:literal -> $result:literal; $t1:ident $t2:ident $t3:ident $t4:ident $t5:ident) => {
286            test_function!(
287                StrposFunc::new(),
288                vec![
289                    ColumnarValue::Scalar(ScalarValue::$t1(Some($lhs.to_owned()))),
290                    ColumnarValue::Scalar(ScalarValue::$t2(Some($rhs.to_owned()))),
291                ],
292                Ok(Some($result)),
293                $t3,
294                $t4,
295                $t5
296            )
297        };
298    }
299
300    #[test]
301    fn test_strpos_functions() {
302        // Utf8 and Utf8 combinations
303        test_strpos!("alphabet", "ph" -> 3; Utf8 Utf8 i32 Int32 Int32Array);
304        test_strpos!("alphabet", "a" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
305        test_strpos!("alphabet", "z" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
306        test_strpos!("alphabet", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
307        test_strpos!("", "a" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
308        test_strpos!("", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
309        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 Utf8 i32 Int32 Int32Array);
310
311        // LargeUtf8 and LargeUtf8 combinations
312        test_strpos!("alphabet", "ph" -> 3; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
313        test_strpos!("alphabet", "a" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
314        test_strpos!("alphabet", "z" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
315        test_strpos!("alphabet", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
316        test_strpos!("", "a" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
317        test_strpos!("", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
318        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
319
320        // Utf8 and LargeUtf8 combinations
321        test_strpos!("alphabet", "ph" -> 3; Utf8 LargeUtf8 i32 Int32 Int32Array);
322        test_strpos!("alphabet", "a" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
323        test_strpos!("alphabet", "z" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
324        test_strpos!("alphabet", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
325        test_strpos!("", "a" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
326        test_strpos!("", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
327        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 LargeUtf8 i32 Int32 Int32Array);
328
329        // LargeUtf8 and Utf8 combinations
330        test_strpos!("alphabet", "ph" -> 3; LargeUtf8 Utf8 i64 Int64 Int64Array);
331        test_strpos!("alphabet", "a" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
332        test_strpos!("alphabet", "z" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
333        test_strpos!("alphabet", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
334        test_strpos!("", "a" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
335        test_strpos!("", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
336        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 Utf8 i64 Int64 Int64Array);
337
338        // Utf8View and Utf8View combinations
339        test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8View i32 Int32 Int32Array);
340        test_strpos!("alphabet", "a" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
341        test_strpos!("alphabet", "z" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
342        test_strpos!("alphabet", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
343        test_strpos!("", "a" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
344        test_strpos!("", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
345        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8View i32 Int32 Int32Array);
346
347        // Utf8View and Utf8 combinations
348        test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8 i32 Int32 Int32Array);
349        test_strpos!("alphabet", "a" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
350        test_strpos!("alphabet", "z" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
351        test_strpos!("alphabet", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
352        test_strpos!("", "a" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
353        test_strpos!("", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
354        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8 i32 Int32 Int32Array);
355
356        // Utf8View and LargeUtf8 combinations
357        test_strpos!("alphabet", "ph" -> 3; Utf8View LargeUtf8 i32 Int32 Int32Array);
358        test_strpos!("alphabet", "a" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
359        test_strpos!("alphabet", "z" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
360        test_strpos!("alphabet", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
361        test_strpos!("", "a" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
362        test_strpos!("", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
363        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View LargeUtf8 i32 Int32 Int32Array);
364    }
365
366    #[test]
367    fn nullable_return_type() {
368        fn get_nullable(string_array_nullable: bool, substring_nullable: bool) -> bool {
369            let strpos = StrposFunc::new();
370            let args = datafusion_expr::ReturnFieldArgs {
371                arg_fields: &[
372                    Field::new("f1", DataType::Utf8, string_array_nullable).into(),
373                    Field::new("f2", DataType::Utf8, substring_nullable).into(),
374                ],
375                scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>],
376            };
377
378            strpos.return_field_from_args(args).unwrap().is_nullable()
379        }
380
381        assert!(!get_nullable(false, false));
382
383        // If any of the arguments is nullable, the result is nullable
384        assert!(get_nullable(true, false));
385        assert!(get_nullable(false, true));
386        assert!(get_nullable(true, true));
387    }
388}