1use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::{make_scalar_function, utf8_to_int_type};
22use arrow::array::{
23 ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, StringArrayType,
24};
25use arrow::datatypes::{
26 ArrowNativeType, DataType, Field, FieldRef, Int32Type, Int64Type,
27};
28use datafusion_common::types::logical_string;
29use datafusion_common::{exec_err, internal_err, Result};
30use datafusion_expr::{
31 Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignatureClass,
32 Volatility,
33};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37 doc_section(label = "String Functions"),
38 description = "Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0.",
39 syntax_example = "strpos(str, substr)",
40 alternative_syntax = "position(substr in origstr)",
41 sql_example = r#"```sql
42> select strpos('datafusion', 'fus');
43+----------------------------------------+
44| strpos(Utf8("datafusion"),Utf8("fus")) |
45+----------------------------------------+
46| 5 |
47+----------------------------------------+
48```"#,
49 standard_argument(name = "str", prefix = "String"),
50 argument(name = "substr", description = "Substring expression to search for.")
51)]
52#[derive(Debug)]
53pub struct StrposFunc {
54 signature: Signature,
55 aliases: Vec<String>,
56}
57
58impl Default for StrposFunc {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl StrposFunc {
65 pub fn new() -> Self {
66 Self {
67 signature: Signature::coercible(
68 vec![
69 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
70 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
71 ],
72 Volatility::Immutable,
73 ),
74 aliases: vec![String::from("instr"), String::from("position")],
75 }
76 }
77}
78
79impl ScalarUDFImpl for StrposFunc {
80 fn as_any(&self) -> &dyn Any {
81 self
82 }
83
84 fn name(&self) -> &str {
85 "strpos"
86 }
87
88 fn signature(&self) -> &Signature {
89 &self.signature
90 }
91
92 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
93 internal_err!("return_field_from_args should be used instead")
94 }
95
96 fn return_field_from_args(
97 &self,
98 args: datafusion_expr::ReturnFieldArgs,
99 ) -> Result<FieldRef> {
100 utf8_to_int_type(args.arg_fields[0].data_type(), "strpos/instr/position").map(
101 |data_type| {
102 Field::new(
103 self.name(),
104 data_type,
105 args.arg_fields.iter().any(|x| x.is_nullable()),
106 )
107 .into()
108 },
109 )
110 }
111
112 fn invoke_with_args(
113 &self,
114 args: datafusion_expr::ScalarFunctionArgs,
115 ) -> Result<ColumnarValue> {
116 make_scalar_function(strpos, vec![])(&args.args)
117 }
118
119 fn aliases(&self) -> &[String] {
120 &self.aliases
121 }
122
123 fn documentation(&self) -> Option<&Documentation> {
124 self.doc()
125 }
126}
127
128fn strpos(args: &[ArrayRef]) -> Result<ArrayRef> {
129 match (args[0].data_type(), args[1].data_type()) {
130 (DataType::Utf8, DataType::Utf8) => {
131 let string_array = args[0].as_string::<i32>();
132 let substring_array = args[1].as_string::<i32>();
133 calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
134 }
135 (DataType::Utf8, DataType::Utf8View) => {
136 let string_array = args[0].as_string::<i32>();
137 let substring_array = args[1].as_string_view();
138 calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
139 }
140 (DataType::Utf8, DataType::LargeUtf8) => {
141 let string_array = args[0].as_string::<i32>();
142 let substring_array = args[1].as_string::<i64>();
143 calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
144 }
145 (DataType::LargeUtf8, DataType::Utf8) => {
146 let string_array = args[0].as_string::<i64>();
147 let substring_array = args[1].as_string::<i32>();
148 calculate_strpos::<_, _, Int64Type>(string_array, substring_array)
149 }
150 (DataType::LargeUtf8, DataType::Utf8View) => {
151 let string_array = args[0].as_string::<i64>();
152 let substring_array = args[1].as_string_view();
153 calculate_strpos::<_, _, Int64Type>(string_array, substring_array)
154 }
155 (DataType::LargeUtf8, DataType::LargeUtf8) => {
156 let string_array = args[0].as_string::<i64>();
157 let substring_array = args[1].as_string::<i64>();
158 calculate_strpos::<_, _, Int64Type>(string_array, substring_array)
159 }
160 (DataType::Utf8View, DataType::Utf8View) => {
161 let string_array = args[0].as_string_view();
162 let substring_array = args[1].as_string_view();
163 calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
164 }
165 (DataType::Utf8View, DataType::Utf8) => {
166 let string_array = args[0].as_string_view();
167 let substring_array = args[1].as_string::<i32>();
168 calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
169 }
170 (DataType::Utf8View, DataType::LargeUtf8) => {
171 let string_array = args[0].as_string_view();
172 let substring_array = args[1].as_string::<i64>();
173 calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
174 }
175
176 other => {
177 exec_err!("Unsupported data type combination {other:?} for function strpos")
178 }
179 }
180}
181
182fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>(
186 string_array: V1,
187 substring_array: V2,
188) -> Result<ArrayRef>
189where
190 V1: StringArrayType<'a, Item = &'a str>,
191 V2: StringArrayType<'a, Item = &'a str>,
192{
193 let ascii_only = substring_array.is_ascii() && string_array.is_ascii();
194 let string_iter = string_array.iter();
195 let substring_iter = substring_array.iter();
196
197 let result = string_iter
198 .zip(substring_iter)
199 .map(|(string, substring)| match (string, substring) {
200 (Some(string), Some(substring)) => {
201 if ascii_only {
204 if substring.is_empty() {
206 T::Native::from_usize(1)
207 } else {
208 T::Native::from_usize(
209 string
210 .as_bytes()
211 .windows(substring.len())
212 .position(|w| w == substring.as_bytes())
213 .map(|x| x + 1)
214 .unwrap_or(0),
215 )
216 }
217 } else {
218 T::Native::from_usize(
221 string
222 .find(substring)
223 .map(|x| string[..x].chars().count() + 1)
224 .unwrap_or(0),
225 )
226 }
227 }
228 _ => None,
229 })
230 .collect::<PrimitiveArray<T>>();
231
232 Ok(Arc::new(result) as ArrayRef)
233}
234
235#[cfg(test)]
236mod tests {
237 use arrow::array::{Array, Int32Array, Int64Array};
238 use arrow::datatypes::DataType::{Int32, Int64};
239
240 use arrow::datatypes::{DataType, Field};
241 use datafusion_common::{Result, ScalarValue};
242 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
243
244 use crate::unicode::strpos::StrposFunc;
245 use crate::utils::test::test_function;
246
247 macro_rules! test_strpos {
248 ($lhs:literal, $rhs:literal -> $result:literal; $t1:ident $t2:ident $t3:ident $t4:ident $t5:ident) => {
249 test_function!(
250 StrposFunc::new(),
251 vec![
252 ColumnarValue::Scalar(ScalarValue::$t1(Some($lhs.to_owned()))),
253 ColumnarValue::Scalar(ScalarValue::$t2(Some($rhs.to_owned()))),
254 ],
255 Ok(Some($result)),
256 $t3,
257 $t4,
258 $t5
259 )
260 };
261 }
262
263 #[test]
264 fn test_strpos_functions() {
265 test_strpos!("alphabet", "ph" -> 3; Utf8 Utf8 i32 Int32 Int32Array);
267 test_strpos!("alphabet", "a" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
268 test_strpos!("alphabet", "z" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
269 test_strpos!("alphabet", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
270 test_strpos!("", "a" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
271 test_strpos!("", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
272 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 Utf8 i32 Int32 Int32Array);
273
274 test_strpos!("alphabet", "ph" -> 3; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
276 test_strpos!("alphabet", "a" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
277 test_strpos!("alphabet", "z" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
278 test_strpos!("alphabet", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
279 test_strpos!("", "a" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
280 test_strpos!("", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
281 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
282
283 test_strpos!("alphabet", "ph" -> 3; Utf8 LargeUtf8 i32 Int32 Int32Array);
285 test_strpos!("alphabet", "a" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
286 test_strpos!("alphabet", "z" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
287 test_strpos!("alphabet", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
288 test_strpos!("", "a" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
289 test_strpos!("", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
290 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 LargeUtf8 i32 Int32 Int32Array);
291
292 test_strpos!("alphabet", "ph" -> 3; LargeUtf8 Utf8 i64 Int64 Int64Array);
294 test_strpos!("alphabet", "a" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
295 test_strpos!("alphabet", "z" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
296 test_strpos!("alphabet", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
297 test_strpos!("", "a" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
298 test_strpos!("", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
299 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 Utf8 i64 Int64 Int64Array);
300
301 test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8View i32 Int32 Int32Array);
303 test_strpos!("alphabet", "a" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
304 test_strpos!("alphabet", "z" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
305 test_strpos!("alphabet", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
306 test_strpos!("", "a" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
307 test_strpos!("", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
308 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8View i32 Int32 Int32Array);
309
310 test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8 i32 Int32 Int32Array);
312 test_strpos!("alphabet", "a" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
313 test_strpos!("alphabet", "z" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
314 test_strpos!("alphabet", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
315 test_strpos!("", "a" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
316 test_strpos!("", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
317 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8 i32 Int32 Int32Array);
318
319 test_strpos!("alphabet", "ph" -> 3; Utf8View LargeUtf8 i32 Int32 Int32Array);
321 test_strpos!("alphabet", "a" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
322 test_strpos!("alphabet", "z" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
323 test_strpos!("alphabet", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
324 test_strpos!("", "a" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
325 test_strpos!("", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
326 test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View LargeUtf8 i32 Int32 Int32Array);
327 }
328
329 #[test]
330 fn nullable_return_type() {
331 fn get_nullable(string_array_nullable: bool, substring_nullable: bool) -> bool {
332 let strpos = StrposFunc::new();
333 let args = datafusion_expr::ReturnFieldArgs {
334 arg_fields: &[
335 Field::new("f1", DataType::Utf8, string_array_nullable).into(),
336 Field::new("f2", DataType::Utf8, substring_nullable).into(),
337 ],
338 scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>],
339 };
340
341 strpos.return_field_from_args(args).unwrap().is_nullable()
342 }
343
344 assert!(!get_nullable(false, false));
345
346 assert!(get_nullable(true, false));
348 assert!(get_nullable(false, true));
349 assert!(get_nullable(true, true));
350 }
351}