1use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::{make_scalar_function, utf8_to_int_type};
22use arrow::array::{
23 ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, StringArrayType,
24};
25use arrow::datatypes::{
26 ArrowNativeType, DataType, Field, FieldRef, Int32Type, Int64Type,
27};
28use datafusion_common::types::logical_string;
29use datafusion_common::{Result, exec_err, internal_err};
30use datafusion_expr::{
31 Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignatureClass,
32 Volatility,
33};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37 doc_section(label = "String Functions"),
38 description = "Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0.",
39 syntax_example = "strpos(str, substr)",
40 alternative_syntax = "position(substr in origstr)",
41 sql_example = r#"```sql
42> select strpos('datafusion', 'fus');
43+----------------------------------------+
44| strpos(Utf8("datafusion"),Utf8("fus")) |
45+----------------------------------------+
46| 5 |
47+----------------------------------------+
48```"#,
49 standard_argument(name = "str", prefix = "String"),
50 argument(name = "substr", description = "Substring expression to search for.")
51)]
52#[derive(Debug, PartialEq, Eq, Hash)]
53pub struct StrposFunc {
54 signature: Signature,
55 aliases: Vec<String>,
56}
57
58impl Default for StrposFunc {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl StrposFunc {
65 pub fn new() -> Self {
66 Self {
67 signature: Signature::coercible(
68 vec![
69 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
70 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
71 ],
72 Volatility::Immutable,
73 ),
74 aliases: vec![String::from("instr"), String::from("position")],
75 }
76 }
77}
78
79impl ScalarUDFImpl for StrposFunc {
80 fn as_any(&self) -> &dyn Any {
81 self
82 }
83
84 fn name(&self) -> &str {
85 "strpos"
86 }
87
88 fn signature(&self) -> &Signature {
89 &self.signature
90 }
91
92 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
93 internal_err!("return_field_from_args should be used instead")
94 }
95
96 fn return_field_from_args(
97 &self,
98 args: datafusion_expr::ReturnFieldArgs,
99 ) -> Result<FieldRef> {
100 utf8_to_int_type(args.arg_fields[0].data_type(), "strpos/instr/position").map(
101 |data_type| {
102 Field::new(
103 self.name(),
104 data_type,
105 args.arg_fields.iter().any(|x| x.is_nullable()),
106 )
107 .into()
108 },
109 )
110 }
111
112 fn invoke_with_args(
113 &self,
114 args: datafusion_expr::ScalarFunctionArgs,
115 ) -> Result<ColumnarValue> {
116 make_scalar_function(strpos, vec![])(&args.args)
117 }
118
119 fn aliases(&self) -> &[String] {
120 &self.aliases
121 }
122
123 fn documentation(&self) -> Option<&Documentation> {
124 self.doc()
125 }
126}
127
128fn strpos(args: &[ArrayRef]) -> Result<ArrayRef> {
129 match (args[0].data_type(), args[1].data_type()) {
130 (DataType::Utf8, DataType::Utf8) => {
131 let string_array = args[0].as_string::<i32>();
132 let substring_array = args[1].as_string::<i32>();
133 calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array)
134 }
135 (DataType::Utf8, DataType::Utf8View) => {
136 let string_array = args[0].as_string::<i32>();
137 let substring_array = args[1].as_string_view();
138 calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array)
139 }
140 (DataType::Utf8, DataType::LargeUtf8) => {
141 let string_array = args[0].as_string::<i32>();
142 let substring_array = args[1].as_string::<i64>();
143 calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array)
144 }
145 (DataType::LargeUtf8, DataType::Utf8) => {
146 let string_array = args[0].as_string::<i64>();
147 let substring_array = args[1].as_string::<i32>();
148 calculate_strpos::<_, _, Int64Type>(&string_array, &substring_array)
149 }
150 (DataType::LargeUtf8, DataType::Utf8View) => {
151 let string_array = args[0].as_string::<i64>();
152 let substring_array = args[1].as_string_view();
153 calculate_strpos::<_, _, Int64Type>(&string_array, &substring_array)
154 }
155 (DataType::LargeUtf8, DataType::LargeUtf8) => {
156 let string_array = args[0].as_string::<i64>();
157 let substring_array = args[1].as_string::<i64>();
158 calculate_strpos::<_, _, Int64Type>(&string_array, &substring_array)
159 }
160 (DataType::Utf8View, DataType::Utf8View) => {
161 let string_array = args[0].as_string_view();
162 let substring_array = args[1].as_string_view();
163 calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array)
164 }
165 (DataType::Utf8View, DataType::Utf8) => {
166 let string_array = args[0].as_string_view();
167 let substring_array = args[1].as_string::<i32>();
168 calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array)
169 }
170 (DataType::Utf8View, DataType::LargeUtf8) => {
171 let string_array = args[0].as_string_view();
172 let substring_array = args[1].as_string::<i64>();
173 calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array)
174 }
175
176 other => {
177 exec_err!("Unsupported data type combination {other:?} for function strpos")
178 }
179 }
180}
181
182fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>(
186 string_array: &V1,
187 substring_array: &V2,
188) -> Result<ArrayRef>
189where
190 V1: StringArrayType<'a, Item = &'a str>,
191 V2: StringArrayType<'a, Item = &'a str>,
192{
193 let ascii_only = substring_array.is_ascii() && string_array.is_ascii();
194 let string_iter = string_array.iter();
195 let substring_iter = substring_array.iter();
196
197 let result = string_iter
198 .zip(substring_iter)
199 .map(|(string, substring)| match (string, substring) {
200 (Some(string), Some(substring)) => {
201 if ascii_only {
204 if substring.is_empty() {
206 T::Native::from_usize(1)
207 } else {
208 T::Native::from_usize(
209 string
210 .as_bytes()
211 .windows(substring.len())
212 .position(|w| w == substring.as_bytes())
213 .map(|x| x + 1)
214 .unwrap_or(0),
215 )
216 }
217 } else {
218 if substring.is_empty() {
221 return T::Native::from_usize(1);
222 }
223
224 let substring_bytes = substring.as_bytes();
225 let string_bytes = string.as_bytes();
226
227 if substring_bytes.len() > string_bytes.len() {
228 return T::Native::from_usize(0);
229 }
230
231 let mut char_pos = 0;
233 for (byte_idx, _) in string.char_indices() {
234 char_pos += 1;
235 if byte_idx + substring_bytes.len() <= string_bytes.len() {
236 let slice = unsafe {
238 string_bytes.get_unchecked(
239 byte_idx..byte_idx + substring_bytes.len(),
240 )
241 };
242 if slice == substring_bytes {
243 return T::Native::from_usize(char_pos);
244 }
245 }
246 }
247
248 T::Native::from_usize(0)
249 }
250 }
251 _ => None,
252 })
253 .collect::<PrimitiveArray<T>>();
254
255 Ok(Arc::new(result) as ArrayRef)
256}
257
258#[cfg(test)]
259mod tests {
260 use arrow::array::{Array, Int32Array, Int64Array};
261 use arrow::datatypes::DataType::{Int32, Int64};
262
263 use arrow::datatypes::{DataType, Field};
264 use datafusion_common::{Result, ScalarValue};
265 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
266
267 use crate::unicode::strpos::StrposFunc;
268 use crate::utils::test::test_function;
269
270 macro_rules! test_strpos {
271 ($lhs:literal, $rhs:literal -> $result:literal; $t1:ident $t2:ident $t3:ident $t4:ident $t5:ident) => {
272 test_function!(
273 StrposFunc::new(),
274 vec![
275 ColumnarValue::Scalar(ScalarValue::$t1(Some($lhs.to_owned()))),
276 ColumnarValue::Scalar(ScalarValue::$t2(Some($rhs.to_owned()))),
277 ],
278 Ok(Some($result)),
279 $t3,
280 $t4,
281 $t5
282 )
283 };
284 }
285
286 #[test]
287 fn test_strpos_functions() {
288 test_strpos!("alphabet", "ph" -> 3; Utf8 Utf8 i32 Int32 Int32Array);
290 test_strpos!("alphabet", "a" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
291 test_strpos!("alphabet", "z" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
292 test_strpos!("alphabet", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
293 test_strpos!("", "a" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
294 test_strpos!("", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
295 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 Utf8 i32 Int32 Int32Array);
296
297 test_strpos!("alphabet", "ph" -> 3; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
299 test_strpos!("alphabet", "a" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
300 test_strpos!("alphabet", "z" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
301 test_strpos!("alphabet", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
302 test_strpos!("", "a" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
303 test_strpos!("", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
304 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
305
306 test_strpos!("alphabet", "ph" -> 3; Utf8 LargeUtf8 i32 Int32 Int32Array);
308 test_strpos!("alphabet", "a" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
309 test_strpos!("alphabet", "z" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
310 test_strpos!("alphabet", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
311 test_strpos!("", "a" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
312 test_strpos!("", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
313 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 LargeUtf8 i32 Int32 Int32Array);
314
315 test_strpos!("alphabet", "ph" -> 3; LargeUtf8 Utf8 i64 Int64 Int64Array);
317 test_strpos!("alphabet", "a" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
318 test_strpos!("alphabet", "z" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
319 test_strpos!("alphabet", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
320 test_strpos!("", "a" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
321 test_strpos!("", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
322 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 Utf8 i64 Int64 Int64Array);
323
324 test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8View i32 Int32 Int32Array);
326 test_strpos!("alphabet", "a" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
327 test_strpos!("alphabet", "z" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
328 test_strpos!("alphabet", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
329 test_strpos!("", "a" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
330 test_strpos!("", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
331 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8View i32 Int32 Int32Array);
332
333 test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8 i32 Int32 Int32Array);
335 test_strpos!("alphabet", "a" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
336 test_strpos!("alphabet", "z" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
337 test_strpos!("alphabet", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
338 test_strpos!("", "a" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
339 test_strpos!("", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
340 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8 i32 Int32 Int32Array);
341
342 test_strpos!("alphabet", "ph" -> 3; Utf8View LargeUtf8 i32 Int32 Int32Array);
344 test_strpos!("alphabet", "a" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
345 test_strpos!("alphabet", "z" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
346 test_strpos!("alphabet", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
347 test_strpos!("", "a" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
348 test_strpos!("", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
349 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View LargeUtf8 i32 Int32 Int32Array);
350 }
351
352 #[test]
353 fn nullable_return_type() {
354 fn get_nullable(string_array_nullable: bool, substring_nullable: bool) -> bool {
355 let strpos = StrposFunc::new();
356 let args = datafusion_expr::ReturnFieldArgs {
357 arg_fields: &[
358 Field::new("f1", DataType::Utf8, string_array_nullable).into(),
359 Field::new("f2", DataType::Utf8, substring_nullable).into(),
360 ],
361 scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>],
362 };
363
364 strpos.return_field_from_args(args).unwrap().is_nullable()
365 }
366
367 assert!(!get_nullable(false, false));
368
369 assert!(get_nullable(true, false));
371 assert!(get_nullable(false, true));
372 assert!(get_nullable(true, true));
373 }
374}