1use std::sync::Arc;
19
20use crate::utils::{make_scalar_function, utf8_to_int_type};
21use arrow::array::{
22 ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, StringArrayType,
23};
24use arrow::datatypes::{
25 ArrowNativeType, DataType, Field, FieldRef, Int32Type, Int64Type,
26};
27use datafusion_common::types::logical_string;
28use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
29use datafusion_expr::{
30 Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
31 TypeSignatureClass, Volatility,
32};
33use datafusion_macros::user_doc;
34use memchr::{memchr, memmem};
35
36#[user_doc(
37 doc_section(label = "String Functions"),
38 description = "Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0.",
39 syntax_example = "strpos(str, substr)",
40 alternative_syntax = "position(substr in origstr)",
41 sql_example = r#"```sql
42> select strpos('datafusion', 'fus');
43+----------------------------------------+
44| strpos(Utf8("datafusion"),Utf8("fus")) |
45+----------------------------------------+
46| 5 |
47+----------------------------------------+
48```"#,
49 standard_argument(name = "str", prefix = "String"),
50 argument(name = "substr", description = "Substring expression to search for.")
51)]
52#[derive(Debug, PartialEq, Eq, Hash)]
53pub struct StrposFunc {
54 signature: Signature,
55 aliases: Vec<String>,
56}
57
58impl Default for StrposFunc {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl StrposFunc {
65 pub fn new() -> Self {
66 Self {
67 signature: Signature::coercible(
68 vec![
69 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
70 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
71 ],
72 Volatility::Immutable,
73 ),
74 aliases: vec![String::from("instr"), String::from("position")],
75 }
76 }
77}
78
79impl ScalarUDFImpl for StrposFunc {
80 fn name(&self) -> &str {
81 "strpos"
82 }
83
84 fn signature(&self) -> &Signature {
85 &self.signature
86 }
87
88 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
89 internal_err!("return_field_from_args should be used instead")
90 }
91
92 fn return_field_from_args(
93 &self,
94 args: datafusion_expr::ReturnFieldArgs,
95 ) -> Result<FieldRef> {
96 utf8_to_int_type(args.arg_fields[0].data_type(), "strpos/instr/position").map(
97 |data_type| {
98 Field::new(
99 self.name(),
100 data_type,
101 args.arg_fields.iter().any(|x| x.is_nullable()),
102 )
103 .into()
104 },
105 )
106 }
107
108 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
109 if let (
111 ColumnarValue::Array(haystack_array),
112 ColumnarValue::Scalar(needle_scalar),
113 ) = (&args.args[0], &args.args[1])
114 {
115 return strpos_scalar_needle(haystack_array, needle_scalar);
116 }
117 make_scalar_function(strpos, vec![])(&args.args)
118 }
119
120 fn aliases(&self) -> &[String] {
121 &self.aliases
122 }
123
124 fn documentation(&self) -> Option<&Documentation> {
125 self.doc()
126 }
127}
128
129fn strpos(args: &[ArrayRef]) -> Result<ArrayRef> {
130 macro_rules! dispatch_needle {
133 ($haystack:expr, $result_type:ty, $args:expr) => {
134 match $args[1].data_type() {
135 DataType::Utf8 => strpos_general::<_, _, $result_type>(
136 $haystack,
137 $args[1].as_string::<i32>(),
138 ),
139 DataType::LargeUtf8 => strpos_general::<_, _, $result_type>(
140 $haystack,
141 $args[1].as_string::<i64>(),
142 ),
143 DataType::Utf8View => strpos_general::<_, _, $result_type>(
144 $haystack,
145 $args[1].as_string_view(),
146 ),
147 other => exec_err!("Unsupported data type {other:?} for strpos needle"),
148 }
149 };
150 }
151
152 match args[0].data_type() {
153 DataType::Utf8 => dispatch_needle!(args[0].as_string::<i32>(), Int32Type, args),
154 DataType::LargeUtf8 => {
155 dispatch_needle!(args[0].as_string::<i64>(), Int64Type, args)
156 }
157 DataType::Utf8View => dispatch_needle!(args[0].as_string_view(), Int32Type, args),
158 other => {
159 exec_err!("Unsupported data type {other:?} for strpos haystack")
160 }
161 }
162}
163
164fn find_substring_bytes(haystack: &[u8], needle: &[u8]) -> Option<usize> {
169 let needle_len = needle.len();
170 let haystack_len = haystack.len();
171
172 if needle_len == 0 {
173 return Some(0);
174 }
175 if needle_len > haystack_len {
176 return None;
177 }
178
179 let first_byte = needle[0];
180 let mut offset = 0;
181
182 while let Some(pos) = memchr(first_byte, &haystack[offset..]) {
183 let start = offset + pos;
184 if start + needle_len > haystack.len() {
185 return None;
186 }
187 if haystack[start..start + needle_len] == *needle {
188 return Some(start);
189 }
190 offset = start + 1;
191 }
192
193 None
194}
195
196#[inline]
200fn byte_offset_to_char_pos<T: ArrowPrimitiveType>(
201 haystack: &str,
202 byte_offset: usize,
203 ascii_only: bool,
204) -> Option<T::Native> {
205 if ascii_only {
206 return T::Native::from_usize(byte_offset + 1);
207 }
208 debug_assert!(haystack.is_char_boundary(byte_offset));
213 let prefix =
214 unsafe { std::str::from_utf8_unchecked(&haystack.as_bytes()[..byte_offset]) };
215 T::Native::from_usize(prefix.chars().count() + 1)
216}
217
218fn strpos_general<'a, V1, V2, T: ArrowPrimitiveType>(
222 haystack_array: V1,
223 needle_array: V2,
224) -> Result<ArrayRef>
225where
226 V1: StringArrayType<'a, Item = &'a str> + Copy,
227 V2: StringArrayType<'a, Item = &'a str> + Copy,
228{
229 let ascii_only = needle_array.is_ascii() && haystack_array.is_ascii();
230 let haystack_iter = haystack_array.iter();
231 let needle_iter = needle_array.iter();
232
233 let result = haystack_iter
234 .zip(needle_iter)
235 .map(|(haystack, needle)| match (haystack, needle) {
236 (Some(haystack), Some(needle)) => {
237 let haystack_bytes = haystack.as_bytes();
238 let needle_bytes = needle.as_bytes();
239
240 match find_substring_bytes(haystack_bytes, needle_bytes) {
241 None => T::Native::from_usize(0),
242 Some(byte_offset) => {
243 byte_offset_to_char_pos::<T>(haystack, byte_offset, ascii_only)
244 }
245 }
246 }
247 _ => None,
248 })
249 .collect::<PrimitiveArray<T>>();
250
251 Ok(Arc::new(result) as ArrayRef)
252}
253
254fn strpos_scalar_needle(
258 haystack_array: &ArrayRef,
259 needle_scalar: &ScalarValue,
260) -> Result<ColumnarValue> {
261 let Some(needle_str) = needle_scalar.try_as_str() else {
262 return exec_err!(
263 "Unsupported data type {:?} for strpos needle",
264 needle_scalar.data_type()
265 );
266 };
267
268 let Some(needle_str) = needle_str else {
270 return match haystack_array.data_type() {
271 DataType::LargeUtf8 => {
272 Ok(ColumnarValue::Array(Arc::new(
273 PrimitiveArray::<Int64Type>::new_null(haystack_array.len()),
274 )))
275 }
276 DataType::Utf8 | DataType::Utf8View => Ok(ColumnarValue::Array(Arc::new(
277 PrimitiveArray::<Int32Type>::new_null(haystack_array.len()),
278 ))),
279 other => exec_err!("Unsupported data type {other:?} for strpos haystack"),
280 };
281 };
282
283 let result = match haystack_array.data_type() {
284 DataType::Utf8 => strpos_with_finder::<_, Int32Type>(
285 haystack_array.as_string::<i32>(),
286 needle_str,
287 ),
288 DataType::LargeUtf8 => strpos_with_finder::<_, Int64Type>(
289 haystack_array.as_string::<i64>(),
290 needle_str,
291 ),
292 DataType::Utf8View => strpos_with_finder::<_, Int32Type>(
293 haystack_array.as_string_view(),
294 needle_str,
295 ),
296 other => {
297 exec_err!("Unsupported data type {other:?} for strpos haystack")
298 }
299 }?;
300 Ok(ColumnarValue::Array(result))
301}
302
303fn strpos_with_finder<'a, V, T: ArrowPrimitiveType>(
304 haystack_array: V,
305 needle: &str,
306) -> Result<ArrayRef>
307where
308 V: StringArrayType<'a, Item = &'a str> + Copy,
309{
310 let needle_bytes = needle.as_bytes();
311 let ascii_haystack = haystack_array.is_ascii();
312 let finder = memmem::Finder::new(needle_bytes);
313
314 let result = haystack_array
315 .iter()
316 .map(|string| match string {
317 Some(string) => {
318 let haystack_bytes = string.as_bytes();
319 match finder.find(haystack_bytes) {
320 None => T::Native::from_usize(0),
321 Some(byte_offset) => {
322 byte_offset_to_char_pos::<T>(string, byte_offset, ascii_haystack)
323 }
324 }
325 }
326 None => None,
327 })
328 .collect::<PrimitiveArray<T>>();
329
330 Ok(Arc::new(result) as ArrayRef)
331}
332
333#[cfg(test)]
334mod tests {
335 use arrow::array::{Array, Int32Array, Int64Array};
336 use arrow::datatypes::DataType::{Int32, Int64};
337
338 use arrow::datatypes::{DataType, Field};
339 use datafusion_common::{Result, ScalarValue};
340 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
341
342 use crate::unicode::strpos::StrposFunc;
343 use crate::utils::test::test_function;
344
345 macro_rules! test_strpos {
346 ($lhs:literal, $rhs:literal -> $result:literal; $t1:ident $t2:ident $t3:ident $t4:ident $t5:ident) => {
347 test_function!(
348 StrposFunc::new(),
349 vec![
350 ColumnarValue::Scalar(ScalarValue::$t1(Some($lhs.to_owned()))),
351 ColumnarValue::Scalar(ScalarValue::$t2(Some($rhs.to_owned()))),
352 ],
353 Ok(Some($result)),
354 $t3,
355 $t4,
356 $t5
357 )
358 };
359 }
360
361 #[test]
362 fn test_strpos_functions() {
363 test_strpos!("alphabet", "ph" -> 3; Utf8 Utf8 i32 Int32 Int32Array);
365 test_strpos!("alphabet", "a" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
366 test_strpos!("alphabet", "z" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
367 test_strpos!("alphabet", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
368 test_strpos!("", "a" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
369 test_strpos!("", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
370 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 Utf8 i32 Int32 Int32Array);
371
372 test_strpos!("alphabet", "ph" -> 3; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
374 test_strpos!("alphabet", "a" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
375 test_strpos!("alphabet", "z" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
376 test_strpos!("alphabet", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
377 test_strpos!("", "a" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
378 test_strpos!("", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
379 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
380
381 test_strpos!("alphabet", "ph" -> 3; Utf8 LargeUtf8 i32 Int32 Int32Array);
383 test_strpos!("alphabet", "a" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
384 test_strpos!("alphabet", "z" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
385 test_strpos!("alphabet", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
386 test_strpos!("", "a" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
387 test_strpos!("", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
388 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 LargeUtf8 i32 Int32 Int32Array);
389
390 test_strpos!("alphabet", "ph" -> 3; LargeUtf8 Utf8 i64 Int64 Int64Array);
392 test_strpos!("alphabet", "a" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
393 test_strpos!("alphabet", "z" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
394 test_strpos!("alphabet", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
395 test_strpos!("", "a" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
396 test_strpos!("", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
397 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 Utf8 i64 Int64 Int64Array);
398
399 test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8View i32 Int32 Int32Array);
401 test_strpos!("alphabet", "a" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
402 test_strpos!("alphabet", "z" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
403 test_strpos!("alphabet", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
404 test_strpos!("", "a" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
405 test_strpos!("", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
406 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8View i32 Int32 Int32Array);
407
408 test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8 i32 Int32 Int32Array);
410 test_strpos!("alphabet", "a" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
411 test_strpos!("alphabet", "z" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
412 test_strpos!("alphabet", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
413 test_strpos!("", "a" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
414 test_strpos!("", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
415 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8 i32 Int32 Int32Array);
416
417 test_strpos!("alphabet", "ph" -> 3; Utf8View LargeUtf8 i32 Int32 Int32Array);
419 test_strpos!("alphabet", "a" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
420 test_strpos!("alphabet", "z" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
421 test_strpos!("alphabet", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
422 test_strpos!("", "a" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
423 test_strpos!("", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
424 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View LargeUtf8 i32 Int32 Int32Array);
425 }
426
427 #[test]
428 fn nullable_return_type() {
429 fn get_nullable(string_array_nullable: bool, substring_nullable: bool) -> bool {
430 let strpos = StrposFunc::new();
431 let args = datafusion_expr::ReturnFieldArgs {
432 arg_fields: &[
433 Field::new("f1", DataType::Utf8, string_array_nullable).into(),
434 Field::new("f2", DataType::Utf8, substring_nullable).into(),
435 ],
436 scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>],
437 };
438
439 strpos.return_field_from_args(args).unwrap().is_nullable()
440 }
441
442 assert!(!get_nullable(false, false));
443
444 assert!(get_nullable(true, false));
446 assert!(get_nullable(false, true));
447 assert!(get_nullable(true, true));
448 }
449}