Skip to main content

datafusion_functions/unicode/
strpos.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use crate::utils::{make_scalar_function, utf8_to_int_type};
21use arrow::array::{
22    ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, StringArrayType,
23};
24use arrow::datatypes::{
25    ArrowNativeType, DataType, Field, FieldRef, Int32Type, Int64Type,
26};
27use datafusion_common::types::logical_string;
28use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
29use datafusion_expr::{
30    Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
31    TypeSignatureClass, Volatility,
32};
33use datafusion_macros::user_doc;
34use memchr::{memchr, memmem};
35
36#[user_doc(
37    doc_section(label = "String Functions"),
38    description = "Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0.",
39    syntax_example = "strpos(str, substr)",
40    alternative_syntax = "position(substr in origstr)",
41    sql_example = r#"```sql
42> select strpos('datafusion', 'fus');
43+----------------------------------------+
44| strpos(Utf8("datafusion"),Utf8("fus")) |
45+----------------------------------------+
46| 5                                      |
47+----------------------------------------+ 
48```"#,
49    standard_argument(name = "str", prefix = "String"),
50    argument(name = "substr", description = "Substring expression to search for.")
51)]
52#[derive(Debug, PartialEq, Eq, Hash)]
53pub struct StrposFunc {
54    signature: Signature,
55    aliases: Vec<String>,
56}
57
58impl Default for StrposFunc {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl StrposFunc {
65    pub fn new() -> Self {
66        Self {
67            signature: Signature::coercible(
68                vec![
69                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
70                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
71                ],
72                Volatility::Immutable,
73            ),
74            aliases: vec![String::from("instr"), String::from("position")],
75        }
76    }
77}
78
79impl ScalarUDFImpl for StrposFunc {
80    fn name(&self) -> &str {
81        "strpos"
82    }
83
84    fn signature(&self) -> &Signature {
85        &self.signature
86    }
87
88    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
89        internal_err!("return_field_from_args should be used instead")
90    }
91
92    fn return_field_from_args(
93        &self,
94        args: datafusion_expr::ReturnFieldArgs,
95    ) -> Result<FieldRef> {
96        utf8_to_int_type(args.arg_fields[0].data_type(), "strpos/instr/position").map(
97            |data_type| {
98                Field::new(
99                    self.name(),
100                    data_type,
101                    args.arg_fields.iter().any(|x| x.is_nullable()),
102                )
103                .into()
104            },
105        )
106    }
107
108    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
109        // Fast path for array haystack and scalar needle
110        if let (
111            ColumnarValue::Array(haystack_array),
112            ColumnarValue::Scalar(needle_scalar),
113        ) = (&args.args[0], &args.args[1])
114        {
115            return strpos_scalar_needle(haystack_array, needle_scalar);
116        }
117        make_scalar_function(strpos, vec![])(&args.args)
118    }
119
120    fn aliases(&self) -> &[String] {
121        &self.aliases
122    }
123
124    fn documentation(&self) -> Option<&Documentation> {
125        self.doc()
126    }
127}
128
129fn strpos(args: &[ArrayRef]) -> Result<ArrayRef> {
130    /// Dispatches the needle array to the correct string type and calls
131    /// `strpos_general` with the given haystack and result type.
132    macro_rules! dispatch_needle {
133        ($haystack:expr, $result_type:ty, $args:expr) => {
134            match $args[1].data_type() {
135                DataType::Utf8 => strpos_general::<_, _, $result_type>(
136                    $haystack,
137                    $args[1].as_string::<i32>(),
138                ),
139                DataType::LargeUtf8 => strpos_general::<_, _, $result_type>(
140                    $haystack,
141                    $args[1].as_string::<i64>(),
142                ),
143                DataType::Utf8View => strpos_general::<_, _, $result_type>(
144                    $haystack,
145                    $args[1].as_string_view(),
146                ),
147                other => exec_err!("Unsupported data type {other:?} for strpos needle"),
148            }
149        };
150    }
151
152    match args[0].data_type() {
153        DataType::Utf8 => dispatch_needle!(args[0].as_string::<i32>(), Int32Type, args),
154        DataType::LargeUtf8 => {
155            dispatch_needle!(args[0].as_string::<i64>(), Int64Type, args)
156        }
157        DataType::Utf8View => dispatch_needle!(args[0].as_string_view(), Int32Type, args),
158        other => {
159            exec_err!("Unsupported data type {other:?} for strpos haystack")
160        }
161    }
162}
163
164/// Find `needle` in `haystack` using `memchr` to quickly skip to positions
165/// where the first byte matches, then verify the remaining bytes. Returns
166/// the 0-based byte offset of the match, or `None` if not found. An empty
167/// `needle` matches at offset 0.
168fn find_substring_bytes(haystack: &[u8], needle: &[u8]) -> Option<usize> {
169    let needle_len = needle.len();
170    let haystack_len = haystack.len();
171
172    if needle_len == 0 {
173        return Some(0);
174    }
175    if needle_len > haystack_len {
176        return None;
177    }
178
179    let first_byte = needle[0];
180    let mut offset = 0;
181
182    while let Some(pos) = memchr(first_byte, &haystack[offset..]) {
183        let start = offset + pos;
184        if start + needle_len > haystack.len() {
185            return None;
186        }
187        if haystack[start..start + needle_len] == *needle {
188            return Some(start);
189        }
190        offset = start + 1;
191    }
192
193    None
194}
195
196/// Converts a byte offset within a haystack to a 1-based character position.
197/// For ASCII data, byte offset == char offset so we just add 1. For non-ASCII,
198/// we count UTF-8 characters in the prefix before the match.
199#[inline]
200fn byte_offset_to_char_pos<T: ArrowPrimitiveType>(
201    haystack: &str,
202    byte_offset: usize,
203    ascii_only: bool,
204) -> Option<T::Native> {
205    if ascii_only {
206        return T::Native::from_usize(byte_offset + 1);
207    }
208    // SAFETY: byte_offset is at a UTF-8 char boundary because both haystack
209    // and needle are valid UTF-8, and UTF-8 is self-synchronizing: a valid
210    // needle byte sequence can only match starting at a char boundary in a
211    // valid haystack.
212    debug_assert!(haystack.is_char_boundary(byte_offset));
213    let prefix =
214        unsafe { std::str::from_utf8_unchecked(&haystack.as_bytes()[..byte_offset]) };
215    T::Native::from_usize(prefix.chars().count() + 1)
216}
217
218/// Fallback strpos implementation for when both haystack and needle are arrays.
219/// Building a new `memmem::Finder` for every row is too expensive; it is faster
220/// to use `memchr::memchr`.
221fn strpos_general<'a, V1, V2, T: ArrowPrimitiveType>(
222    haystack_array: V1,
223    needle_array: V2,
224) -> Result<ArrayRef>
225where
226    V1: StringArrayType<'a, Item = &'a str> + Copy,
227    V2: StringArrayType<'a, Item = &'a str> + Copy,
228{
229    let ascii_only = needle_array.is_ascii() && haystack_array.is_ascii();
230    let haystack_iter = haystack_array.iter();
231    let needle_iter = needle_array.iter();
232
233    let result = haystack_iter
234        .zip(needle_iter)
235        .map(|(haystack, needle)| match (haystack, needle) {
236            (Some(haystack), Some(needle)) => {
237                let haystack_bytes = haystack.as_bytes();
238                let needle_bytes = needle.as_bytes();
239
240                match find_substring_bytes(haystack_bytes, needle_bytes) {
241                    None => T::Native::from_usize(0),
242                    Some(byte_offset) => {
243                        byte_offset_to_char_pos::<T>(haystack, byte_offset, ascii_only)
244                    }
245                }
246            }
247            _ => None,
248        })
249        .collect::<PrimitiveArray<T>>();
250
251    Ok(Arc::new(result) as ArrayRef)
252}
253
254/// Fast-path strpos implementation for when the haystack is an array and the
255/// needle is a scalar.  We can pre-build a `memmem::Finder` once and reuse it
256/// for every haystack row.
257fn strpos_scalar_needle(
258    haystack_array: &ArrayRef,
259    needle_scalar: &ScalarValue,
260) -> Result<ColumnarValue> {
261    let Some(needle_str) = needle_scalar.try_as_str() else {
262        return exec_err!(
263            "Unsupported data type {:?} for strpos needle",
264            needle_scalar.data_type()
265        );
266    };
267
268    // Null needle => null result for every row
269    let Some(needle_str) = needle_str else {
270        return match haystack_array.data_type() {
271            DataType::LargeUtf8 => {
272                Ok(ColumnarValue::Array(Arc::new(
273                    PrimitiveArray::<Int64Type>::new_null(haystack_array.len()),
274                )))
275            }
276            DataType::Utf8 | DataType::Utf8View => Ok(ColumnarValue::Array(Arc::new(
277                PrimitiveArray::<Int32Type>::new_null(haystack_array.len()),
278            ))),
279            other => exec_err!("Unsupported data type {other:?} for strpos haystack"),
280        };
281    };
282
283    let result = match haystack_array.data_type() {
284        DataType::Utf8 => strpos_with_finder::<_, Int32Type>(
285            haystack_array.as_string::<i32>(),
286            needle_str,
287        ),
288        DataType::LargeUtf8 => strpos_with_finder::<_, Int64Type>(
289            haystack_array.as_string::<i64>(),
290            needle_str,
291        ),
292        DataType::Utf8View => strpos_with_finder::<_, Int32Type>(
293            haystack_array.as_string_view(),
294            needle_str,
295        ),
296        other => {
297            exec_err!("Unsupported data type {other:?} for strpos haystack")
298        }
299    }?;
300    Ok(ColumnarValue::Array(result))
301}
302
303fn strpos_with_finder<'a, V, T: ArrowPrimitiveType>(
304    haystack_array: V,
305    needle: &str,
306) -> Result<ArrayRef>
307where
308    V: StringArrayType<'a, Item = &'a str> + Copy,
309{
310    let needle_bytes = needle.as_bytes();
311    let ascii_haystack = haystack_array.is_ascii();
312    let finder = memmem::Finder::new(needle_bytes);
313
314    let result = haystack_array
315        .iter()
316        .map(|string| match string {
317            Some(string) => {
318                let haystack_bytes = string.as_bytes();
319                match finder.find(haystack_bytes) {
320                    None => T::Native::from_usize(0),
321                    Some(byte_offset) => {
322                        byte_offset_to_char_pos::<T>(string, byte_offset, ascii_haystack)
323                    }
324                }
325            }
326            None => None,
327        })
328        .collect::<PrimitiveArray<T>>();
329
330    Ok(Arc::new(result) as ArrayRef)
331}
332
333#[cfg(test)]
334mod tests {
335    use arrow::array::{Array, Int32Array, Int64Array};
336    use arrow::datatypes::DataType::{Int32, Int64};
337
338    use arrow::datatypes::{DataType, Field};
339    use datafusion_common::{Result, ScalarValue};
340    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
341
342    use crate::unicode::strpos::StrposFunc;
343    use crate::utils::test::test_function;
344
345    macro_rules! test_strpos {
346        ($lhs:literal, $rhs:literal -> $result:literal; $t1:ident $t2:ident $t3:ident $t4:ident $t5:ident) => {
347            test_function!(
348                StrposFunc::new(),
349                vec![
350                    ColumnarValue::Scalar(ScalarValue::$t1(Some($lhs.to_owned()))),
351                    ColumnarValue::Scalar(ScalarValue::$t2(Some($rhs.to_owned()))),
352                ],
353                Ok(Some($result)),
354                $t3,
355                $t4,
356                $t5
357            )
358        };
359    }
360
361    #[test]
362    fn test_strpos_functions() {
363        // Utf8 and Utf8 combinations
364        test_strpos!("alphabet", "ph" -> 3; Utf8 Utf8 i32 Int32 Int32Array);
365        test_strpos!("alphabet", "a" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
366        test_strpos!("alphabet", "z" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
367        test_strpos!("alphabet", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
368        test_strpos!("", "a" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
369        test_strpos!("", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
370        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 Utf8 i32 Int32 Int32Array);
371
372        // LargeUtf8 and LargeUtf8 combinations
373        test_strpos!("alphabet", "ph" -> 3; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
374        test_strpos!("alphabet", "a" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
375        test_strpos!("alphabet", "z" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
376        test_strpos!("alphabet", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
377        test_strpos!("", "a" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
378        test_strpos!("", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
379        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
380
381        // Utf8 and LargeUtf8 combinations
382        test_strpos!("alphabet", "ph" -> 3; Utf8 LargeUtf8 i32 Int32 Int32Array);
383        test_strpos!("alphabet", "a" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
384        test_strpos!("alphabet", "z" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
385        test_strpos!("alphabet", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
386        test_strpos!("", "a" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
387        test_strpos!("", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
388        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 LargeUtf8 i32 Int32 Int32Array);
389
390        // LargeUtf8 and Utf8 combinations
391        test_strpos!("alphabet", "ph" -> 3; LargeUtf8 Utf8 i64 Int64 Int64Array);
392        test_strpos!("alphabet", "a" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
393        test_strpos!("alphabet", "z" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
394        test_strpos!("alphabet", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
395        test_strpos!("", "a" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
396        test_strpos!("", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
397        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 Utf8 i64 Int64 Int64Array);
398
399        // Utf8View and Utf8View combinations
400        test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8View i32 Int32 Int32Array);
401        test_strpos!("alphabet", "a" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
402        test_strpos!("alphabet", "z" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
403        test_strpos!("alphabet", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
404        test_strpos!("", "a" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
405        test_strpos!("", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
406        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8View i32 Int32 Int32Array);
407
408        // Utf8View and Utf8 combinations
409        test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8 i32 Int32 Int32Array);
410        test_strpos!("alphabet", "a" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
411        test_strpos!("alphabet", "z" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
412        test_strpos!("alphabet", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
413        test_strpos!("", "a" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
414        test_strpos!("", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
415        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8 i32 Int32 Int32Array);
416
417        // Utf8View and LargeUtf8 combinations
418        test_strpos!("alphabet", "ph" -> 3; Utf8View LargeUtf8 i32 Int32 Int32Array);
419        test_strpos!("alphabet", "a" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
420        test_strpos!("alphabet", "z" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
421        test_strpos!("alphabet", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
422        test_strpos!("", "a" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
423        test_strpos!("", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
424        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View LargeUtf8 i32 Int32 Int32Array);
425    }
426
427    #[test]
428    fn nullable_return_type() {
429        fn get_nullable(string_array_nullable: bool, substring_nullable: bool) -> bool {
430            let strpos = StrposFunc::new();
431            let args = datafusion_expr::ReturnFieldArgs {
432                arg_fields: &[
433                    Field::new("f1", DataType::Utf8, string_array_nullable).into(),
434                    Field::new("f2", DataType::Utf8, substring_nullable).into(),
435                ],
436                scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>],
437            };
438
439            strpos.return_field_from_args(args).unwrap().is_nullable()
440        }
441
442        assert!(!get_nullable(false, false));
443
444        // If any of the arguments is nullable, the result is nullable
445        assert!(get_nullable(true, false));
446        assert!(get_nullable(false, true));
447        assert!(get_nullable(true, true));
448    }
449}