1use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::{make_scalar_function, utf8_to_int_type};
22use arrow::array::{
23 ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, StringArrayType,
24};
25use arrow::datatypes::{
26 ArrowNativeType, DataType, Field, FieldRef, Int32Type, Int64Type,
27};
28use datafusion_common::types::logical_string;
29use datafusion_common::{Result, exec_err, internal_err};
30use datafusion_expr::{
31 Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignatureClass,
32 Volatility,
33};
34use datafusion_macros::user_doc;
35use memchr::memchr;
36
37#[user_doc(
38 doc_section(label = "String Functions"),
39 description = "Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0.",
40 syntax_example = "strpos(str, substr)",
41 alternative_syntax = "position(substr in origstr)",
42 sql_example = r#"```sql
43> select strpos('datafusion', 'fus');
44+----------------------------------------+
45| strpos(Utf8("datafusion"),Utf8("fus")) |
46+----------------------------------------+
47| 5 |
48+----------------------------------------+
49```"#,
50 standard_argument(name = "str", prefix = "String"),
51 argument(name = "substr", description = "Substring expression to search for.")
52)]
53#[derive(Debug, PartialEq, Eq, Hash)]
54pub struct StrposFunc {
55 signature: Signature,
56 aliases: Vec<String>,
57}
58
59impl Default for StrposFunc {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl StrposFunc {
66 pub fn new() -> Self {
67 Self {
68 signature: Signature::coercible(
69 vec![
70 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
71 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
72 ],
73 Volatility::Immutable,
74 ),
75 aliases: vec![String::from("instr"), String::from("position")],
76 }
77 }
78}
79
80impl ScalarUDFImpl for StrposFunc {
81 fn as_any(&self) -> &dyn Any {
82 self
83 }
84
85 fn name(&self) -> &str {
86 "strpos"
87 }
88
89 fn signature(&self) -> &Signature {
90 &self.signature
91 }
92
93 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
94 internal_err!("return_field_from_args should be used instead")
95 }
96
97 fn return_field_from_args(
98 &self,
99 args: datafusion_expr::ReturnFieldArgs,
100 ) -> Result<FieldRef> {
101 utf8_to_int_type(args.arg_fields[0].data_type(), "strpos/instr/position").map(
102 |data_type| {
103 Field::new(
104 self.name(),
105 data_type,
106 args.arg_fields.iter().any(|x| x.is_nullable()),
107 )
108 .into()
109 },
110 )
111 }
112
113 fn invoke_with_args(
114 &self,
115 args: datafusion_expr::ScalarFunctionArgs,
116 ) -> Result<ColumnarValue> {
117 make_scalar_function(strpos, vec![])(&args.args)
118 }
119
120 fn aliases(&self) -> &[String] {
121 &self.aliases
122 }
123
124 fn documentation(&self) -> Option<&Documentation> {
125 self.doc()
126 }
127}
128
129fn strpos(args: &[ArrayRef]) -> Result<ArrayRef> {
130 match (args[0].data_type(), args[1].data_type()) {
131 (DataType::Utf8, DataType::Utf8) => {
132 let string_array = args[0].as_string::<i32>();
133 let substring_array = args[1].as_string::<i32>();
134 calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array)
135 }
136 (DataType::Utf8, DataType::Utf8View) => {
137 let string_array = args[0].as_string::<i32>();
138 let substring_array = args[1].as_string_view();
139 calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array)
140 }
141 (DataType::Utf8, DataType::LargeUtf8) => {
142 let string_array = args[0].as_string::<i32>();
143 let substring_array = args[1].as_string::<i64>();
144 calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array)
145 }
146 (DataType::LargeUtf8, DataType::Utf8) => {
147 let string_array = args[0].as_string::<i64>();
148 let substring_array = args[1].as_string::<i32>();
149 calculate_strpos::<_, _, Int64Type>(&string_array, &substring_array)
150 }
151 (DataType::LargeUtf8, DataType::Utf8View) => {
152 let string_array = args[0].as_string::<i64>();
153 let substring_array = args[1].as_string_view();
154 calculate_strpos::<_, _, Int64Type>(&string_array, &substring_array)
155 }
156 (DataType::LargeUtf8, DataType::LargeUtf8) => {
157 let string_array = args[0].as_string::<i64>();
158 let substring_array = args[1].as_string::<i64>();
159 calculate_strpos::<_, _, Int64Type>(&string_array, &substring_array)
160 }
161 (DataType::Utf8View, DataType::Utf8View) => {
162 let string_array = args[0].as_string_view();
163 let substring_array = args[1].as_string_view();
164 calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array)
165 }
166 (DataType::Utf8View, DataType::Utf8) => {
167 let string_array = args[0].as_string_view();
168 let substring_array = args[1].as_string::<i32>();
169 calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array)
170 }
171 (DataType::Utf8View, DataType::LargeUtf8) => {
172 let string_array = args[0].as_string_view();
173 let substring_array = args[1].as_string::<i64>();
174 calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array)
175 }
176
177 other => {
178 exec_err!("Unsupported data type combination {other:?} for function strpos")
179 }
180 }
181}
182
183fn find_ascii_substring(haystack: &[u8], needle: &[u8]) -> usize {
190 let needle_len = needle.len();
191 let first_byte = needle[0];
192 let mut offset = 0;
193
194 while let Some(pos) = memchr(first_byte, &haystack[offset..]) {
195 let start = offset + pos;
196 if start + needle_len > haystack.len() {
197 return 0;
198 }
199 if haystack[start..start + needle_len] == *needle {
200 return start + 1;
201 }
202 offset = start + 1;
203 }
204
205 0
206}
207
208fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>(
212 string_array: &V1,
213 substring_array: &V2,
214) -> Result<ArrayRef>
215where
216 V1: StringArrayType<'a, Item = &'a str>,
217 V2: StringArrayType<'a, Item = &'a str>,
218{
219 let ascii_only = substring_array.is_ascii() && string_array.is_ascii();
220 let string_iter = string_array.iter();
221 let substring_iter = substring_array.iter();
222
223 let result = string_iter
224 .zip(substring_iter)
225 .map(|(string, substring)| match (string, substring) {
226 (Some(string), Some(substring)) => {
227 if substring.is_empty() {
228 return T::Native::from_usize(1);
229 }
230
231 let substring_bytes = substring.as_bytes();
232 let string_bytes = string.as_bytes();
233
234 if substring_bytes.len() > string_bytes.len() {
235 return T::Native::from_usize(0);
236 }
237
238 if ascii_only {
239 T::Native::from_usize(find_ascii_substring(
240 string_bytes,
241 substring_bytes,
242 ))
243 } else {
244 let mut char_pos = 0;
247 for (byte_idx, _) in string.char_indices() {
248 char_pos += 1;
249 if byte_idx + substring_bytes.len() <= string_bytes.len() {
250 let slice = unsafe {
252 string_bytes.get_unchecked(
253 byte_idx..byte_idx + substring_bytes.len(),
254 )
255 };
256 if slice == substring_bytes {
257 return T::Native::from_usize(char_pos);
258 }
259 }
260 }
261
262 T::Native::from_usize(0)
263 }
264 }
265 _ => None,
266 })
267 .collect::<PrimitiveArray<T>>();
268
269 Ok(Arc::new(result) as ArrayRef)
270}
271
272#[cfg(test)]
273mod tests {
274 use arrow::array::{Array, Int32Array, Int64Array};
275 use arrow::datatypes::DataType::{Int32, Int64};
276
277 use arrow::datatypes::{DataType, Field};
278 use datafusion_common::{Result, ScalarValue};
279 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
280
281 use crate::unicode::strpos::StrposFunc;
282 use crate::utils::test::test_function;
283
284 macro_rules! test_strpos {
285 ($lhs:literal, $rhs:literal -> $result:literal; $t1:ident $t2:ident $t3:ident $t4:ident $t5:ident) => {
286 test_function!(
287 StrposFunc::new(),
288 vec![
289 ColumnarValue::Scalar(ScalarValue::$t1(Some($lhs.to_owned()))),
290 ColumnarValue::Scalar(ScalarValue::$t2(Some($rhs.to_owned()))),
291 ],
292 Ok(Some($result)),
293 $t3,
294 $t4,
295 $t5
296 )
297 };
298 }
299
300 #[test]
301 fn test_strpos_functions() {
302 test_strpos!("alphabet", "ph" -> 3; Utf8 Utf8 i32 Int32 Int32Array);
304 test_strpos!("alphabet", "a" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
305 test_strpos!("alphabet", "z" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
306 test_strpos!("alphabet", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
307 test_strpos!("", "a" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
308 test_strpos!("", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
309 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 Utf8 i32 Int32 Int32Array);
310
311 test_strpos!("alphabet", "ph" -> 3; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
313 test_strpos!("alphabet", "a" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
314 test_strpos!("alphabet", "z" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
315 test_strpos!("alphabet", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
316 test_strpos!("", "a" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
317 test_strpos!("", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
318 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
319
320 test_strpos!("alphabet", "ph" -> 3; Utf8 LargeUtf8 i32 Int32 Int32Array);
322 test_strpos!("alphabet", "a" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
323 test_strpos!("alphabet", "z" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
324 test_strpos!("alphabet", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
325 test_strpos!("", "a" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
326 test_strpos!("", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
327 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 LargeUtf8 i32 Int32 Int32Array);
328
329 test_strpos!("alphabet", "ph" -> 3; LargeUtf8 Utf8 i64 Int64 Int64Array);
331 test_strpos!("alphabet", "a" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
332 test_strpos!("alphabet", "z" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
333 test_strpos!("alphabet", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
334 test_strpos!("", "a" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
335 test_strpos!("", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
336 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 Utf8 i64 Int64 Int64Array);
337
338 test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8View i32 Int32 Int32Array);
340 test_strpos!("alphabet", "a" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
341 test_strpos!("alphabet", "z" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
342 test_strpos!("alphabet", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
343 test_strpos!("", "a" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
344 test_strpos!("", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
345 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8View i32 Int32 Int32Array);
346
347 test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8 i32 Int32 Int32Array);
349 test_strpos!("alphabet", "a" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
350 test_strpos!("alphabet", "z" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
351 test_strpos!("alphabet", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
352 test_strpos!("", "a" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
353 test_strpos!("", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
354 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8 i32 Int32 Int32Array);
355
356 test_strpos!("alphabet", "ph" -> 3; Utf8View LargeUtf8 i32 Int32 Int32Array);
358 test_strpos!("alphabet", "a" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
359 test_strpos!("alphabet", "z" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
360 test_strpos!("alphabet", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
361 test_strpos!("", "a" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
362 test_strpos!("", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
363 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View LargeUtf8 i32 Int32 Int32Array);
364 }
365
366 #[test]
367 fn nullable_return_type() {
368 fn get_nullable(string_array_nullable: bool, substring_nullable: bool) -> bool {
369 let strpos = StrposFunc::new();
370 let args = datafusion_expr::ReturnFieldArgs {
371 arg_fields: &[
372 Field::new("f1", DataType::Utf8, string_array_nullable).into(),
373 Field::new("f2", DataType::Utf8, substring_nullable).into(),
374 ],
375 scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>],
376 };
377
378 strpos.return_field_from_args(args).unwrap().is_nullable()
379 }
380
381 assert!(!get_nullable(false, false));
382
383 assert!(get_nullable(true, false));
385 assert!(get_nullable(false, true));
386 assert!(get_nullable(true, true));
387 }
388}