1use arrow::array::{
19 Array, ArrayRef, AsArray, Datum, Int64Array, PrimitiveArray, StringArrayType,
20};
21use arrow::datatypes::{DataType, Int64Type};
22use arrow::datatypes::{
23 DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View,
24};
25use arrow::error::ArrowError;
26use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
27use datafusion_expr::{
28 ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact,
29 TypeSignature::Uniform, Volatility,
30};
31use datafusion_macros::user_doc;
32use itertools::izip;
33use regex::Regex;
34use std::collections::HashMap;
35use std::sync::Arc;
36
37use crate::regex::compile_and_cache_regex;
38
39#[user_doc(
40 doc_section(label = "Regular Expression Functions"),
41 description = "Returns the position in a string where the specified occurrence of a POSIX regular expression is located.",
42 syntax_example = "regexp_instr(str, regexp[, start[, N[, flags[, subexpr]]]])",
43 sql_example = r#"```sql
44> SELECT regexp_instr('ABCDEF', 'C(.)(..)');
45+---------------------------------------------------------------+
46| regexp_instr(Utf8("ABCDEF"),Utf8("C(.)(..)")) |
47+---------------------------------------------------------------+
48| 3 |
49+---------------------------------------------------------------+
50```"#,
51 standard_argument(name = "str", prefix = "String"),
52 standard_argument(name = "regexp", prefix = "Regular"),
53 argument(
54 name = "start",
55 description = "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. Defaults to 1"
56 ),
57 argument(
58 name = "N",
59 description = "- **N**: Optional The N-th occurrence of pattern to find. Defaults to 1 (first match). Can be a constant, column, or function."
60 ),
61 argument(
62 name = "flags",
63 description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported:
64 - **i**: case-insensitive: letters match both upper and lower case
65 - **m**: multi-line mode: ^ and $ match begin/end of line
66 - **s**: allow . to match \n
67 - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used
68 - **U**: swap the meaning of x* and x*?"#
69 ),
70 argument(
71 name = "subexpr",
72 description = "Optional Specifies which capture group (subexpression) to return the position for. Defaults to 0, which returns the position of the entire match."
73 )
74)]
75#[derive(Debug)]
76pub struct RegexpInstrFunc {
77 signature: Signature,
78}
79
80impl Default for RegexpInstrFunc {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86impl RegexpInstrFunc {
87 pub fn new() -> Self {
88 Self {
89 signature: Signature::one_of(
90 vec![
91 Uniform(2, vec![Utf8View, LargeUtf8, Utf8]),
92 Exact(vec![Utf8View, Utf8View, Int64]),
93 Exact(vec![LargeUtf8, LargeUtf8, Int64]),
94 Exact(vec![Utf8, Utf8, Int64]),
95 Exact(vec![Utf8View, Utf8View, Int64, Int64]),
96 Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]),
97 Exact(vec![Utf8, Utf8, Int64, Int64]),
98 Exact(vec![Utf8View, Utf8View, Int64, Int64, Utf8View]),
99 Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8]),
100 Exact(vec![Utf8, Utf8, Int64, Int64, Utf8]),
101 Exact(vec![Utf8View, Utf8View, Int64, Int64, Utf8View, Int64]),
102 Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8, Int64]),
103 Exact(vec![Utf8, Utf8, Int64, Int64, Utf8, Int64]),
104 ],
105 Volatility::Immutable,
106 ),
107 }
108 }
109}
110
111impl ScalarUDFImpl for RegexpInstrFunc {
112 fn as_any(&self) -> &dyn std::any::Any {
113 self
114 }
115
116 fn name(&self) -> &str {
117 "regexp_instr"
118 }
119
120 fn signature(&self) -> &Signature {
121 &self.signature
122 }
123
124 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
125 Ok(Int64)
126 }
127
128 fn invoke_with_args(
129 &self,
130 args: datafusion_expr::ScalarFunctionArgs,
131 ) -> Result<ColumnarValue> {
132 let args = &args.args;
133
134 let len = args
135 .iter()
136 .fold(Option::<usize>::None, |acc, arg| match arg {
137 ColumnarValue::Scalar(_) => acc,
138 ColumnarValue::Array(a) => Some(a.len()),
139 });
140
141 let is_scalar = len.is_none();
142 let inferred_length = len.unwrap_or(1);
143 let args = args
144 .iter()
145 .map(|arg| arg.to_array(inferred_length))
146 .collect::<Result<Vec<_>>>()?;
147
148 let result = regexp_instr_func(&args);
149 if is_scalar {
150 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
152 result.map(ColumnarValue::Scalar)
153 } else {
154 result.map(ColumnarValue::Array)
155 }
156 }
157
158 fn documentation(&self) -> Option<&Documentation> {
159 self.doc()
160 }
161}
162
163pub fn regexp_instr_func(args: &[ArrayRef]) -> Result<ArrayRef> {
164 let args_len = args.len();
165 if !(2..=6).contains(&args_len) {
166 return exec_err!("regexp_instr was called with {args_len} arguments. It requires at least 2 and at most 6.");
167 }
168
169 let values = &args[0];
170 match values.data_type() {
171 Utf8 | LargeUtf8 | Utf8View => (),
172 other => {
173 return internal_err!(
174 "Unsupported data type {other:?} for function regexp_instr"
175 );
176 }
177 }
178
179 regexp_instr(
180 values,
181 &args[1],
182 if args_len > 2 { Some(&args[2]) } else { None },
183 if args_len > 3 { Some(&args[3]) } else { None },
184 if args_len > 4 { Some(&args[4]) } else { None },
185 if args_len > 5 { Some(&args[5]) } else { None },
186 )
187 .map_err(|e| e.into())
188}
189
190pub fn regexp_instr(
209 values: &dyn Array,
210 regex_array: &dyn Datum,
211 start_array: Option<&dyn Datum>,
212 nth_array: Option<&dyn Datum>,
213 flags_array: Option<&dyn Datum>,
214 subexpr_array: Option<&dyn Datum>,
215) -> Result<ArrayRef, ArrowError> {
216 let (regex_array, _) = regex_array.get();
217 let start_array = start_array.map(|start| {
218 let (start, _) = start.get();
219 start
220 });
221 let nth_array = nth_array.map(|nth| {
222 let (nth, _) = nth.get();
223 nth
224 });
225 let flags_array = flags_array.map(|flags| {
226 let (flags, _) = flags.get();
227 flags
228 });
229 let subexpr_array = subexpr_array.map(|subexpr| {
230 let (subexpr, _) = subexpr.get();
231 subexpr
232 });
233
234 match (values.data_type(), regex_array.data_type(), flags_array) {
235 (Utf8, Utf8, None) => regexp_instr_inner(
236 values.as_string::<i32>(),
237 regex_array.as_string::<i32>(),
238 start_array.map(|start| start.as_primitive::<Int64Type>()),
239 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
240 None,
241 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
242 ),
243 (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_instr_inner(
244 values.as_string::<i32>(),
245 regex_array.as_string::<i32>(),
246 start_array.map(|start| start.as_primitive::<Int64Type>()),
247 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
248 Some(flags_array.as_string::<i32>()),
249 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
250 ),
251 (LargeUtf8, LargeUtf8, None) => regexp_instr_inner(
252 values.as_string::<i64>(),
253 regex_array.as_string::<i64>(),
254 start_array.map(|start| start.as_primitive::<Int64Type>()),
255 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
256 None,
257 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
258 ),
259 (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_instr_inner(
260 values.as_string::<i64>(),
261 regex_array.as_string::<i64>(),
262 start_array.map(|start| start.as_primitive::<Int64Type>()),
263 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
264 Some(flags_array.as_string::<i64>()),
265 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
266 ),
267 (Utf8View, Utf8View, None) => regexp_instr_inner(
268 values.as_string_view(),
269 regex_array.as_string_view(),
270 start_array.map(|start| start.as_primitive::<Int64Type>()),
271 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
272 None,
273 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
274 ),
275 (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_instr_inner(
276 values.as_string_view(),
277 regex_array.as_string_view(),
278 start_array.map(|start| start.as_primitive::<Int64Type>()),
279 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
280 Some(flags_array.as_string_view()),
281 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
282 ),
283 _ => Err(ArrowError::ComputeError(
284 "regexp_instr() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(),
285 )),
286 }
287}
288
289#[allow(clippy::too_many_arguments)]
290pub fn regexp_instr_inner<'a, S>(
291 values: S,
292 regex_array: S,
293 start_array: Option<&Int64Array>,
294 nth_array: Option<&Int64Array>,
295 flags_array: Option<S>,
296 subexp_array: Option<&Int64Array>,
297) -> Result<ArrayRef, ArrowError>
298where
299 S: StringArrayType<'a>,
300{
301 let len = values.len();
302
303 let default_start_array = PrimitiveArray::<Int64Type>::from(vec![1; len]);
304 let start_array = start_array.unwrap_or(&default_start_array);
305 let start_input: Vec<i64> = (0..start_array.len())
306 .map(|i| start_array.value(i)) .collect();
308
309 let default_nth_array = PrimitiveArray::<Int64Type>::from(vec![1; len]);
310 let nth_array = nth_array.unwrap_or(&default_nth_array);
311 let nth_input: Vec<i64> = (0..nth_array.len())
312 .map(|i| nth_array.value(i)) .collect();
314
315 let flags_input = match flags_array {
316 Some(flags) => flags.iter().collect(),
317 None => vec![None; len],
318 };
319
320 let default_subexp_array = PrimitiveArray::<Int64Type>::from(vec![0; len]);
321 let subexp_array = subexp_array.unwrap_or(&default_subexp_array);
322 let subexp_input: Vec<i64> = (0..subexp_array.len())
323 .map(|i| subexp_array.value(i)) .collect();
325
326 let mut regex_cache = HashMap::new();
327
328 let result: Result<Vec<Option<i64>>, ArrowError> = izip!(
329 values.iter(),
330 regex_array.iter(),
331 start_input.iter(),
332 nth_input.iter(),
333 flags_input.iter(),
334 subexp_input.iter()
335 )
336 .map(|(value, regex, start, nth, flags, subexp)| match regex {
337 None => Ok(None),
338 Some("") => Ok(Some(0)),
339 Some(regex) => get_index(
340 value,
341 regex,
342 *start,
343 *nth,
344 *subexp,
345 *flags,
346 &mut regex_cache,
347 ),
348 })
349 .collect();
350 Ok(Arc::new(Int64Array::from(result?)))
351}
352
353fn handle_subexp(
354 pattern: &Regex,
355 search_slice: &str,
356 subexpr: i64,
357 value: &str,
358 byte_start_offset: usize,
359) -> Result<Option<i64>, ArrowError> {
360 if let Some(captures) = pattern.captures(search_slice) {
361 if let Some(matched) = captures.get(subexpr as usize) {
362 let start_char_offset =
365 value[..byte_start_offset + matched.start()].chars().count() as i64 + 1;
366 return Ok(Some(start_char_offset));
367 }
368 }
369 Ok(Some(0)) }
371
372fn get_nth_match(
373 pattern: &Regex,
374 search_slice: &str,
375 n: i64,
376 byte_start_offset: usize,
377 value: &str,
378) -> Result<Option<i64>, ArrowError> {
379 if let Some(mat) = pattern.find_iter(search_slice).nth((n - 1) as usize) {
380 let match_start_byte_offset = byte_start_offset + mat.start();
383 let match_start_char_offset =
384 value[..match_start_byte_offset].chars().count() as i64 + 1;
385 Ok(Some(match_start_char_offset))
386 } else {
387 Ok(Some(0)) }
389}
390fn get_index<'strings, 'cache>(
391 value: Option<&str>,
392 pattern: &'strings str,
393 start: i64,
394 n: i64,
395 subexpr: i64,
396 flags: Option<&'strings str>,
397 regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>,
398) -> Result<Option<i64>, ArrowError>
399where
400 'strings: 'cache,
401{
402 let value = match value {
403 None => return Ok(None),
404 Some("") => return Ok(Some(0)),
405 Some(value) => value,
406 };
407 let pattern: &Regex = compile_and_cache_regex(pattern, flags, regex_cache)?;
408 if start < 1 {
410 return Err(ArrowError::ComputeError(
411 "regexp_instr() requires start to be 1-based".to_string(),
412 ));
413 }
414
415 if n < 1 {
416 return Err(ArrowError::ComputeError(
417 "N must be 1 or greater".to_string(),
418 ));
419 }
420
421 let total_chars = value.chars().count() as i64;
423 let byte_start_offset: usize = if start > total_chars {
424 return Ok(Some(0));
427 } else {
428 value
430 .char_indices()
431 .nth((start - 1) as usize)
432 .map(|(idx, _)| idx)
433 .unwrap_or(0) };
435 let search_slice = &value[byte_start_offset..];
438
439 if subexpr > 0 {
441 return handle_subexp(pattern, search_slice, subexpr, value, byte_start_offset);
442 }
443
444 get_nth_match(pattern, search_slice, n, byte_start_offset, value)
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use arrow::array::Int64Array;
452 use arrow::array::{GenericStringArray, StringViewArray};
453 use arrow::datatypes::Field;
454 use datafusion_expr::ScalarFunctionArgs;
455 #[test]
456 fn test_regexp_instr() {
457 test_case_sensitive_regexp_instr_nulls();
458 test_case_sensitive_regexp_instr_scalar();
459 test_case_sensitive_regexp_instr_scalar_start();
460 test_case_sensitive_regexp_instr_scalar_nth();
461 test_case_sensitive_regexp_instr_scalar_subexp();
462
463 test_case_sensitive_regexp_instr_array::<GenericStringArray<i32>>();
464 test_case_sensitive_regexp_instr_array::<GenericStringArray<i64>>();
465 test_case_sensitive_regexp_instr_array::<StringViewArray>();
466
467 test_case_sensitive_regexp_instr_array_start::<GenericStringArray<i32>>();
468 test_case_sensitive_regexp_instr_array_start::<GenericStringArray<i64>>();
469 test_case_sensitive_regexp_instr_array_start::<StringViewArray>();
470
471 test_case_sensitive_regexp_instr_array_nth::<GenericStringArray<i32>>();
472 test_case_sensitive_regexp_instr_array_nth::<GenericStringArray<i64>>();
473 test_case_sensitive_regexp_instr_array_nth::<StringViewArray>();
474 }
475
476 fn regexp_instr_with_scalar_values(args: &[ScalarValue]) -> Result<ColumnarValue> {
477 let args_values: Vec<ColumnarValue> = args
478 .iter()
479 .map(|sv| ColumnarValue::Scalar(sv.clone()))
480 .collect();
481
482 let arg_fields = args
483 .iter()
484 .enumerate()
485 .map(|(idx, a)| {
486 Arc::new(Field::new(format!("arg_{idx}"), a.data_type(), true))
487 })
488 .collect::<Vec<_>>();
489
490 RegexpInstrFunc::new().invoke_with_args(ScalarFunctionArgs {
491 args: args_values,
492 arg_fields,
493 number_rows: args.len(),
494 return_field: Arc::new(Field::new("f", Int64, true)),
495 })
496 }
497
498 fn test_case_sensitive_regexp_instr_nulls() {
499 let v = "";
500 let r = "";
501 let expected = 0;
502 let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
503 let re = regexp_instr_with_scalar_values(&[v.to_string().into(), regex_sv]);
504 match re {
506 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
507 assert_eq!(v, Some(expected), "regexp_instr scalar test failed");
508 }
509 _ => panic!("Unexpected result"),
510 }
511 }
512 fn test_case_sensitive_regexp_instr_scalar() {
513 let values = [
514 "hello world",
515 "abcdefg",
516 "xyz123xyz",
517 "no match here",
518 "abc",
519 "ДатаФусион数据融合📊🔥",
520 ];
521 let regex = ["o", "d", "123", "z", "gg", "📊"];
522
523 let expected: Vec<i64> = vec![5, 4, 4, 0, 0, 15];
524
525 izip!(values.iter(), regex.iter())
526 .enumerate()
527 .for_each(|(pos, (&v, &r))| {
528 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
530 let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
531 let expected = expected.get(pos).cloned();
532 let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
533 match re {
535 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
536 assert_eq!(v, expected, "regexp_instr scalar test failed");
537 }
538 _ => panic!("Unexpected result"),
539 }
540
541 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
543 let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
544 let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
545 match re {
546 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
547 assert_eq!(v, expected, "regexp_instr scalar test failed");
548 }
549 _ => panic!("Unexpected result"),
550 }
551
552 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
554 let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
555 let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
556 match re {
557 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
558 assert_eq!(v, expected, "regexp_instr scalar test failed");
559 }
560 _ => panic!("Unexpected result"),
561 }
562 });
563 }
564
565 fn test_case_sensitive_regexp_instr_scalar_start() {
566 let values = ["abcabcabc", "abcabcabc", ""];
567 let regex = ["abc", "abc", "gg"];
568 let start = [4, 5, 5];
569 let expected: Vec<i64> = vec![4, 7, 0];
570
571 izip!(values.iter(), regex.iter(), start.iter())
572 .enumerate()
573 .for_each(|(pos, (&v, &r, &s))| {
574 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
576 let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
577 let start_sv = ScalarValue::Int64(Some(s));
578 let expected = expected.get(pos).cloned();
579 let re =
580 regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
581 match re {
582 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
583 assert_eq!(v, expected, "regexp_instr scalar test failed");
584 }
585 _ => panic!("Unexpected result"),
586 }
587
588 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
590 let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
591 let start_sv = ScalarValue::Int64(Some(s));
592 let re =
593 regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
594 match re {
595 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
596 assert_eq!(v, expected, "regexp_instr scalar test failed");
597 }
598 _ => panic!("Unexpected result"),
599 }
600
601 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
603 let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
604 let start_sv = ScalarValue::Int64(Some(s));
605 let re =
606 regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
607 match re {
608 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
609 assert_eq!(v, expected, "regexp_instr scalar test failed");
610 }
611 _ => panic!("Unexpected result"),
612 }
613 });
614 }
615
616 fn test_case_sensitive_regexp_instr_scalar_nth() {
617 let values = ["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"];
618 let regex = ["abc", "abc", "abc", "abc"];
619 let start = [1, 1, 1, 1];
620 let nth = [1, 2, 3, 4];
621 let expected: Vec<i64> = vec![1, 4, 7, 0];
622
623 izip!(values.iter(), regex.iter(), start.iter(), nth.iter())
624 .enumerate()
625 .for_each(|(pos, (&v, &r, &s, &n))| {
626 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
628 let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
629 let start_sv = ScalarValue::Int64(Some(s));
630 let nth_sv = ScalarValue::Int64(Some(n));
631 let expected = expected.get(pos).cloned();
632 let re = regexp_instr_with_scalar_values(&[
633 v_sv,
634 regex_sv,
635 start_sv.clone(),
636 nth_sv.clone(),
637 ]);
638 match re {
639 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
640 assert_eq!(v, expected, "regexp_instr scalar test failed");
641 }
642 _ => panic!("Unexpected result"),
643 }
644
645 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
647 let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
648 let start_sv = ScalarValue::Int64(Some(s));
649 let nth_sv = ScalarValue::Int64(Some(n));
650 let re = regexp_instr_with_scalar_values(&[
651 v_sv,
652 regex_sv,
653 start_sv.clone(),
654 nth_sv.clone(),
655 ]);
656 match re {
657 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
658 assert_eq!(v, expected, "regexp_instr scalar test failed");
659 }
660 _ => panic!("Unexpected result"),
661 }
662
663 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
665 let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
666 let start_sv = ScalarValue::Int64(Some(s));
667 let nth_sv = ScalarValue::Int64(Some(n));
668 let re = regexp_instr_with_scalar_values(&[
669 v_sv,
670 regex_sv,
671 start_sv.clone(),
672 nth_sv.clone(),
673 ]);
674 match re {
675 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
676 assert_eq!(v, expected, "regexp_instr scalar test failed");
677 }
678 _ => panic!("Unexpected result"),
679 }
680 });
681 }
682
683 fn test_case_sensitive_regexp_instr_scalar_subexp() {
684 let values = ["12 abc def ghi 34"];
685 let regex = ["(abc) (def) (ghi)"];
686 let start = [1];
687 let nth = [1];
688 let flags = ["i"];
689 let subexps = [2];
690 let expected: Vec<i64> = vec![8];
691
692 izip!(
693 values.iter(),
694 regex.iter(),
695 start.iter(),
696 nth.iter(),
697 flags.iter(),
698 subexps.iter()
699 )
700 .enumerate()
701 .for_each(|(pos, (&v, &r, &s, &n, &flag, &subexp))| {
702 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
704 let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
705 let start_sv = ScalarValue::Int64(Some(s));
706 let nth_sv = ScalarValue::Int64(Some(n));
707 let flags_sv = ScalarValue::Utf8(Some(flag.to_string()));
708 let subexp_sv = ScalarValue::Int64(Some(subexp));
709 let expected = expected.get(pos).cloned();
710 let re = regexp_instr_with_scalar_values(&[
711 v_sv,
712 regex_sv,
713 start_sv.clone(),
714 nth_sv.clone(),
715 flags_sv,
716 subexp_sv.clone(),
717 ]);
718 match re {
719 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
720 assert_eq!(v, expected, "regexp_instr scalar test failed");
721 }
722 _ => panic!("Unexpected result"),
723 }
724
725 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
727 let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
728 let start_sv = ScalarValue::Int64(Some(s));
729 let nth_sv = ScalarValue::Int64(Some(n));
730 let flags_sv = ScalarValue::LargeUtf8(Some(flag.to_string()));
731 let subexp_sv = ScalarValue::Int64(Some(subexp));
732 let re = regexp_instr_with_scalar_values(&[
733 v_sv,
734 regex_sv,
735 start_sv.clone(),
736 nth_sv.clone(),
737 flags_sv,
738 subexp_sv.clone(),
739 ]);
740 match re {
741 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
742 assert_eq!(v, expected, "regexp_instr scalar test failed");
743 }
744 _ => panic!("Unexpected result"),
745 }
746
747 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
749 let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
750 let start_sv = ScalarValue::Int64(Some(s));
751 let nth_sv = ScalarValue::Int64(Some(n));
752 let flags_sv = ScalarValue::Utf8View(Some(flag.to_string()));
753 let subexp_sv = ScalarValue::Int64(Some(subexp));
754 let re = regexp_instr_with_scalar_values(&[
755 v_sv,
756 regex_sv,
757 start_sv.clone(),
758 nth_sv.clone(),
759 flags_sv,
760 subexp_sv.clone(),
761 ]);
762 match re {
763 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
764 assert_eq!(v, expected, "regexp_instr scalar test failed");
765 }
766 _ => panic!("Unexpected result"),
767 }
768 });
769 }
770
771 fn test_case_sensitive_regexp_instr_array<A>()
772 where
773 A: From<Vec<&'static str>> + Array + 'static,
774 {
775 let values = A::from(vec![
776 "hello world",
777 "abcdefg",
778 "xyz123xyz",
779 "no match here",
780 "",
781 ]);
782 let regex = A::from(vec!["o", "d", "123", "z", "gg"]);
783
784 let expected = Int64Array::from(vec![5, 4, 4, 0, 0]);
785 let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex)]).unwrap();
786 assert_eq!(re.as_ref(), &expected);
787 }
788
789 fn test_case_sensitive_regexp_instr_array_start<A>()
790 where
791 A: From<Vec<&'static str>> + Array + 'static,
792 {
793 let values = A::from(vec!["abcabcabc", "abcabcabc", ""]);
794 let regex = A::from(vec!["abc", "abc", "gg"]);
795 let start = Int64Array::from(vec![4, 5, 5]);
796 let expected = Int64Array::from(vec![4, 7, 0]);
797
798 let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)])
799 .unwrap();
800 assert_eq!(re.as_ref(), &expected);
801 }
802
803 fn test_case_sensitive_regexp_instr_array_nth<A>()
804 where
805 A: From<Vec<&'static str>> + Array + 'static,
806 {
807 let values = A::from(vec!["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"]);
808 let regex = A::from(vec!["abc", "abc", "abc", "abc"]);
809 let start = Int64Array::from(vec![1, 1, 1, 1]);
810 let nth = Int64Array::from(vec![1, 2, 3, 4]);
811 let expected = Int64Array::from(vec![1, 4, 7, 0]);
812
813 let re = regexp_instr_func(&[
814 Arc::new(values),
815 Arc::new(regex),
816 Arc::new(start),
817 Arc::new(nth),
818 ])
819 .unwrap();
820 assert_eq!(re.as_ref(), &expected);
821 }
822}