1use arrow::array::{
19 Array, ArrayRef, AsArray, Datum, Int64Array, PrimitiveArray, StringArrayType,
20};
21use arrow::datatypes::{DataType, Int64Type};
22use arrow::datatypes::{
23 DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View,
24};
25use arrow::error::ArrowError;
26use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
27use datafusion_expr::{
28 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
29 TypeSignature::Exact, TypeSignature::Uniform, Volatility,
30};
31use datafusion_macros::user_doc;
32use itertools::izip;
33use regex::Regex;
34use std::collections::HashMap;
35use std::sync::Arc;
36
37use crate::regex::compile_and_cache_regex;
38
39#[user_doc(
40 doc_section(label = "Regular Expression Functions"),
41 description = "Returns the position in a string where the specified occurrence of a POSIX regular expression is located.",
42 syntax_example = "regexp_instr(str, regexp[, start[, N[, flags[, subexpr]]]])",
43 sql_example = r#"```sql
44> SELECT regexp_instr('ABCDEF', 'C(.)(..)');
45+---------------------------------------------------------------+
46| regexp_instr(Utf8("ABCDEF"),Utf8("C(.)(..)")) |
47+---------------------------------------------------------------+
48| 3 |
49+---------------------------------------------------------------+
50```"#,
51 standard_argument(name = "str", prefix = "String"),
52 standard_argument(name = "regexp", prefix = "Regular"),
53 argument(
54 name = "start",
55 description = "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. Defaults to 1"
56 ),
57 argument(
58 name = "N",
59 description = "- **N**: Optional The N-th occurrence of pattern to find. Defaults to 1 (first match). Can be a constant, column, or function."
60 ),
61 argument(
62 name = "flags",
63 description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported:
64 - **i**: case-insensitive: letters match both upper and lower case
65 - **m**: multi-line mode: ^ and $ match begin/end of line
66 - **s**: allow . to match \n
67 - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used
68 - **U**: swap the meaning of x* and x*?"#
69 ),
70 argument(
71 name = "subexpr",
72 description = "Optional Specifies which capture group (subexpression) to return the position for. Defaults to 0, which returns the position of the entire match."
73 )
74)]
75#[derive(Debug, PartialEq, Eq, Hash)]
76pub struct RegexpInstrFunc {
77 signature: Signature,
78}
79
80impl Default for RegexpInstrFunc {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86impl RegexpInstrFunc {
87 pub fn new() -> Self {
88 Self {
89 signature: Signature::one_of(
90 vec![
91 Uniform(2, vec![Utf8View, LargeUtf8, Utf8]),
92 Exact(vec![Utf8View, Utf8View, Int64]),
93 Exact(vec![LargeUtf8, LargeUtf8, Int64]),
94 Exact(vec![Utf8, Utf8, Int64]),
95 Exact(vec![Utf8View, Utf8View, Int64, Int64]),
96 Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]),
97 Exact(vec![Utf8, Utf8, Int64, Int64]),
98 Exact(vec![Utf8View, Utf8View, Int64, Int64, Utf8View]),
99 Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8]),
100 Exact(vec![Utf8, Utf8, Int64, Int64, Utf8]),
101 Exact(vec![Utf8View, Utf8View, Int64, Int64, Utf8View, Int64]),
102 Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8, Int64]),
103 Exact(vec![Utf8, Utf8, Int64, Int64, Utf8, Int64]),
104 ],
105 Volatility::Immutable,
106 ),
107 }
108 }
109}
110
111impl ScalarUDFImpl for RegexpInstrFunc {
112 fn name(&self) -> &str {
113 "regexp_instr"
114 }
115
116 fn signature(&self) -> &Signature {
117 &self.signature
118 }
119
120 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
121 Ok(Int64)
122 }
123
124 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
125 let args = &args.args;
126
127 let len = args
128 .iter()
129 .fold(Option::<usize>::None, |acc, arg| match arg {
130 ColumnarValue::Scalar(_) => acc,
131 ColumnarValue::Array(a) => Some(a.len()),
132 });
133
134 let is_scalar = len.is_none();
135 let inferred_length = len.unwrap_or(1);
136 let args = args
137 .iter()
138 .map(|arg| arg.to_array(inferred_length))
139 .collect::<Result<Vec<_>>>()?;
140
141 let result = regexp_instr_func(&args);
142 if is_scalar {
143 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
145 result.map(ColumnarValue::Scalar)
146 } else {
147 result.map(ColumnarValue::Array)
148 }
149 }
150
151 fn documentation(&self) -> Option<&Documentation> {
152 self.doc()
153 }
154}
155
156pub fn regexp_instr_func(args: &[ArrayRef]) -> Result<ArrayRef> {
157 let args_len = args.len();
158 if !(2..=6).contains(&args_len) {
159 return exec_err!(
160 "regexp_instr was called with {args_len} arguments. It requires at least 2 and at most 6."
161 );
162 }
163
164 let values = &args[0];
165 match values.data_type() {
166 Utf8 | LargeUtf8 | Utf8View => (),
167 other => {
168 return internal_err!(
169 "Unsupported data type {other:?} for function regexp_instr"
170 );
171 }
172 }
173
174 regexp_instr(
175 values,
176 &args[1],
177 if args_len > 2 { Some(&args[2]) } else { None },
178 if args_len > 3 { Some(&args[3]) } else { None },
179 if args_len > 4 { Some(&args[4]) } else { None },
180 if args_len > 5 { Some(&args[5]) } else { None },
181 )
182 .map_err(|e| e.into())
183}
184
185fn regexp_instr(
204 values: &dyn Array,
205 regex_array: &dyn Datum,
206 start_array: Option<&dyn Datum>,
207 nth_array: Option<&dyn Datum>,
208 flags_array: Option<&dyn Datum>,
209 subexpr_array: Option<&dyn Datum>,
210) -> Result<ArrayRef, ArrowError> {
211 let (regex_array, _) = regex_array.get();
212 let start_array = start_array.map(|start| {
213 let (start, _) = start.get();
214 start
215 });
216 let nth_array = nth_array.map(|nth| {
217 let (nth, _) = nth.get();
218 nth
219 });
220 let flags_array = flags_array.map(|flags| {
221 let (flags, _) = flags.get();
222 flags
223 });
224 let subexpr_array = subexpr_array.map(|subexpr| {
225 let (subexpr, _) = subexpr.get();
226 subexpr
227 });
228
229 match (values.data_type(), regex_array.data_type(), flags_array) {
230 (Utf8, Utf8, None) => regexp_instr_inner(
231 &values.as_string::<i32>(),
232 ®ex_array.as_string::<i32>(),
233 start_array.map(|start| start.as_primitive::<Int64Type>()),
234 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
235 None,
236 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
237 ),
238 (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_instr_inner(
239 &values.as_string::<i32>(),
240 ®ex_array.as_string::<i32>(),
241 start_array.map(|start| start.as_primitive::<Int64Type>()),
242 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
243 Some(flags_array.as_string::<i32>()),
244 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
245 ),
246 (LargeUtf8, LargeUtf8, None) => regexp_instr_inner(
247 &values.as_string::<i64>(),
248 ®ex_array.as_string::<i64>(),
249 start_array.map(|start| start.as_primitive::<Int64Type>()),
250 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
251 None,
252 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
253 ),
254 (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_instr_inner(
255 &values.as_string::<i64>(),
256 ®ex_array.as_string::<i64>(),
257 start_array.map(|start| start.as_primitive::<Int64Type>()),
258 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
259 Some(flags_array.as_string::<i64>()),
260 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
261 ),
262 (Utf8View, Utf8View, None) => regexp_instr_inner(
263 &values.as_string_view(),
264 ®ex_array.as_string_view(),
265 start_array.map(|start| start.as_primitive::<Int64Type>()),
266 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
267 None,
268 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
269 ),
270 (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_instr_inner(
271 &values.as_string_view(),
272 ®ex_array.as_string_view(),
273 start_array.map(|start| start.as_primitive::<Int64Type>()),
274 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
275 Some(flags_array.as_string_view()),
276 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
277 ),
278 _ => Err(ArrowError::ComputeError(
279 "regexp_instr() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(),
280 )),
281 }
282}
283
284fn regexp_instr_inner<'a, S>(
285 values: &S,
286 regex_array: &S,
287 start_array: Option<&Int64Array>,
288 nth_array: Option<&Int64Array>,
289 flags_array: Option<S>,
290 subexp_array: Option<&Int64Array>,
291) -> Result<ArrayRef, ArrowError>
292where
293 S: StringArrayType<'a>,
294{
295 let len = values.len();
296
297 let default_start_array = PrimitiveArray::<Int64Type>::from(vec![1; len]);
298 let start_array = start_array.unwrap_or(&default_start_array);
299 let start_input: Vec<i64> = (0..start_array.len())
300 .map(|i| start_array.value(i)) .collect();
302
303 let default_nth_array = PrimitiveArray::<Int64Type>::from(vec![1; len]);
304 let nth_array = nth_array.unwrap_or(&default_nth_array);
305 let nth_input: Vec<i64> = (0..nth_array.len())
306 .map(|i| nth_array.value(i)) .collect();
308
309 let flags_input = match flags_array {
310 Some(flags) => flags.iter().collect(),
311 None => vec![None; len],
312 };
313
314 let default_subexp_array = PrimitiveArray::<Int64Type>::from(vec![0; len]);
315 let subexp_array = subexp_array.unwrap_or(&default_subexp_array);
316 let subexp_input: Vec<i64> = (0..subexp_array.len())
317 .map(|i| subexp_array.value(i)) .collect();
319
320 let mut regex_cache = HashMap::new();
321
322 let result: Result<Vec<Option<i64>>, ArrowError> = izip!(
323 values.iter(),
324 regex_array.iter(),
325 start_input.iter(),
326 nth_input.iter(),
327 flags_input.iter(),
328 subexp_input.iter()
329 )
330 .map(|(value, regex, start, nth, flags, subexp)| match regex {
331 None => Ok(None),
332 Some("") => Ok(Some(0)),
333 Some(regex) => get_index(
334 value,
335 regex,
336 *start,
337 *nth,
338 *subexp,
339 *flags,
340 &mut regex_cache,
341 ),
342 })
343 .collect();
344 Ok(Arc::new(Int64Array::from(result?)))
345}
346
347fn handle_subexp(
348 pattern: &Regex,
349 search_slice: &str,
350 subexpr: i64,
351 value: &str,
352 byte_start_offset: usize,
353) -> Result<Option<i64>, ArrowError> {
354 if let Some(captures) = pattern.captures(search_slice)
355 && let Some(matched) = captures.get(subexpr as usize)
356 {
357 let start_char_offset =
360 value[..byte_start_offset + matched.start()].chars().count() as i64 + 1;
361 return Ok(Some(start_char_offset));
362 }
363 Ok(Some(0)) }
365
366fn get_nth_match(
367 pattern: &Regex,
368 search_slice: &str,
369 n: i64,
370 byte_start_offset: usize,
371 value: &str,
372) -> Result<Option<i64>, ArrowError> {
373 if let Some(mat) = pattern.find_iter(search_slice).nth((n - 1) as usize) {
374 let match_start_byte_offset = byte_start_offset + mat.start();
377 let match_start_char_offset =
378 value[..match_start_byte_offset].chars().count() as i64 + 1;
379 Ok(Some(match_start_char_offset))
380 } else {
381 Ok(Some(0)) }
383}
384fn get_index<'strings, 'cache>(
385 value: Option<&str>,
386 pattern: &'strings str,
387 start: i64,
388 n: i64,
389 subexpr: i64,
390 flags: Option<&'strings str>,
391 regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>,
392) -> Result<Option<i64>, ArrowError>
393where
394 'strings: 'cache,
395{
396 let value = match value {
397 None => return Ok(None),
398 Some("") => return Ok(Some(0)),
399 Some(value) => value,
400 };
401 let pattern: &Regex = compile_and_cache_regex(pattern, flags, regex_cache)?;
402 if start < 1 {
404 return Err(ArrowError::ComputeError(
405 "regexp_instr() requires start to be 1-based".to_string(),
406 ));
407 }
408
409 if n < 1 {
410 return Err(ArrowError::ComputeError(
411 "N must be 1 or greater".to_string(),
412 ));
413 }
414
415 let total_chars = value.chars().count() as i64;
417 let byte_start_offset: usize = if start > total_chars {
418 return Ok(Some(0));
421 } else {
422 value
424 .char_indices()
425 .nth((start - 1) as usize)
426 .map(|(idx, _)| idx)
427 .unwrap_or(0) };
429 let search_slice = &value[byte_start_offset..];
432
433 if subexpr > 0 {
435 return handle_subexp(pattern, search_slice, subexpr, value, byte_start_offset);
436 }
437
438 get_nth_match(pattern, search_slice, n, byte_start_offset, value)
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445 use arrow::array::{GenericStringArray, StringViewArray};
446 use arrow::datatypes::Field;
447 use datafusion_common::config::ConfigOptions;
448 #[test]
449 fn test_regexp_instr() {
450 test_case_sensitive_regexp_instr_nulls();
451 test_case_sensitive_regexp_instr_scalar();
452 test_case_sensitive_regexp_instr_scalar_start();
453 test_case_sensitive_regexp_instr_scalar_nth();
454 test_case_sensitive_regexp_instr_scalar_subexp();
455
456 test_case_sensitive_regexp_instr_array::<GenericStringArray<i32>>();
457 test_case_sensitive_regexp_instr_array::<GenericStringArray<i64>>();
458 test_case_sensitive_regexp_instr_array::<StringViewArray>();
459
460 test_case_sensitive_regexp_instr_array_start::<GenericStringArray<i32>>();
461 test_case_sensitive_regexp_instr_array_start::<GenericStringArray<i64>>();
462 test_case_sensitive_regexp_instr_array_start::<StringViewArray>();
463
464 test_case_sensitive_regexp_instr_array_nth::<GenericStringArray<i32>>();
465 test_case_sensitive_regexp_instr_array_nth::<GenericStringArray<i64>>();
466 test_case_sensitive_regexp_instr_array_nth::<StringViewArray>();
467 }
468
469 fn regexp_instr_with_scalar_values(args: &[ScalarValue]) -> Result<ColumnarValue> {
470 let args_values: Vec<ColumnarValue> = args
471 .iter()
472 .map(|sv| ColumnarValue::Scalar(sv.clone()))
473 .collect();
474
475 let arg_fields = args
476 .iter()
477 .enumerate()
478 .map(|(idx, a)| {
479 Arc::new(Field::new(format!("arg_{idx}"), a.data_type(), true))
480 })
481 .collect::<Vec<_>>();
482
483 RegexpInstrFunc::new().invoke_with_args(ScalarFunctionArgs {
484 args: args_values,
485 arg_fields,
486 number_rows: args.len(),
487 return_field: Arc::new(Field::new("f", Int64, true)),
488 config_options: Arc::new(ConfigOptions::default()),
489 })
490 }
491
492 fn test_case_sensitive_regexp_instr_nulls() {
493 let v = "";
494 let r = "";
495 let expected = 0;
496 let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
497 let re = regexp_instr_with_scalar_values(&[v.to_string().into(), regex_sv]);
498 match re {
500 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
501 assert_eq!(v, Some(expected), "regexp_instr scalar test failed");
502 }
503 _ => panic!("Unexpected result"),
504 }
505 }
506 fn test_case_sensitive_regexp_instr_scalar() {
507 let values = [
508 "hello world",
509 "abcdefg",
510 "xyz123xyz",
511 "no match here",
512 "abc",
513 "ДатаФусион数据融合📊🔥",
514 ];
515 let regex = ["o", "d", "123", "z", "gg", "📊"];
516
517 let expected: Vec<i64> = vec![5, 4, 4, 0, 0, 15];
518
519 izip!(values.iter(), regex.iter())
520 .enumerate()
521 .for_each(|(pos, (&v, &r))| {
522 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
524 let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
525 let expected = expected.get(pos).cloned();
526 let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
527 match re {
529 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
530 assert_eq!(v, expected, "regexp_instr scalar test failed");
531 }
532 _ => panic!("Unexpected result"),
533 }
534
535 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
537 let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
538 let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
539 match re {
540 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
541 assert_eq!(v, expected, "regexp_instr scalar test failed");
542 }
543 _ => panic!("Unexpected result"),
544 }
545
546 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
548 let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
549 let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
550 match re {
551 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
552 assert_eq!(v, expected, "regexp_instr scalar test failed");
553 }
554 _ => panic!("Unexpected result"),
555 }
556 });
557 }
558
559 fn test_case_sensitive_regexp_instr_scalar_start() {
560 let values = ["abcabcabc", "abcabcabc", ""];
561 let regex = ["abc", "abc", "gg"];
562 let start = [4, 5, 5];
563 let expected: Vec<i64> = vec![4, 7, 0];
564
565 izip!(values.iter(), regex.iter(), start.iter())
566 .enumerate()
567 .for_each(|(pos, (&v, &r, &s))| {
568 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
570 let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
571 let start_sv = ScalarValue::Int64(Some(s));
572 let expected = expected.get(pos).cloned();
573 let re =
574 regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
575 match re {
576 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
577 assert_eq!(v, expected, "regexp_instr scalar test failed");
578 }
579 _ => panic!("Unexpected result"),
580 }
581
582 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
584 let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
585 let start_sv = ScalarValue::Int64(Some(s));
586 let re =
587 regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
588 match re {
589 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
590 assert_eq!(v, expected, "regexp_instr scalar test failed");
591 }
592 _ => panic!("Unexpected result"),
593 }
594
595 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
597 let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
598 let start_sv = ScalarValue::Int64(Some(s));
599 let re =
600 regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
601 match re {
602 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
603 assert_eq!(v, expected, "regexp_instr scalar test failed");
604 }
605 _ => panic!("Unexpected result"),
606 }
607 });
608 }
609
610 fn test_case_sensitive_regexp_instr_scalar_nth() {
611 let values = ["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"];
612 let regex = ["abc", "abc", "abc", "abc"];
613 let start = [1, 1, 1, 1];
614 let nth = [1, 2, 3, 4];
615 let expected: Vec<i64> = vec![1, 4, 7, 0];
616
617 izip!(values.iter(), regex.iter(), start.iter(), nth.iter())
618 .enumerate()
619 .for_each(|(pos, (&v, &r, &s, &n))| {
620 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
622 let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
623 let start_sv = ScalarValue::Int64(Some(s));
624 let nth_sv = ScalarValue::Int64(Some(n));
625 let expected = expected.get(pos).cloned();
626 let re = regexp_instr_with_scalar_values(&[
627 v_sv,
628 regex_sv,
629 start_sv.clone(),
630 nth_sv.clone(),
631 ]);
632 match re {
633 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
634 assert_eq!(v, expected, "regexp_instr scalar test failed");
635 }
636 _ => panic!("Unexpected result"),
637 }
638
639 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
641 let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
642 let start_sv = ScalarValue::Int64(Some(s));
643 let nth_sv = ScalarValue::Int64(Some(n));
644 let re = regexp_instr_with_scalar_values(&[
645 v_sv,
646 regex_sv,
647 start_sv.clone(),
648 nth_sv.clone(),
649 ]);
650 match re {
651 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
652 assert_eq!(v, expected, "regexp_instr scalar test failed");
653 }
654 _ => panic!("Unexpected result"),
655 }
656
657 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
659 let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
660 let start_sv = ScalarValue::Int64(Some(s));
661 let nth_sv = ScalarValue::Int64(Some(n));
662 let re = regexp_instr_with_scalar_values(&[
663 v_sv,
664 regex_sv,
665 start_sv.clone(),
666 nth_sv.clone(),
667 ]);
668 match re {
669 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
670 assert_eq!(v, expected, "regexp_instr scalar test failed");
671 }
672 _ => panic!("Unexpected result"),
673 }
674 });
675 }
676
677 fn test_case_sensitive_regexp_instr_scalar_subexp() {
678 let values = ["12 abc def ghi 34"];
679 let regex = ["(abc) (def) (ghi)"];
680 let start = [1];
681 let nth = [1];
682 let flags = ["i"];
683 let subexps = [2];
684 let expected: Vec<i64> = vec![8];
685
686 izip!(
687 values.iter(),
688 regex.iter(),
689 start.iter(),
690 nth.iter(),
691 flags.iter(),
692 subexps.iter()
693 )
694 .enumerate()
695 .for_each(|(pos, (&v, &r, &s, &n, &flag, &subexp))| {
696 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
698 let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
699 let start_sv = ScalarValue::Int64(Some(s));
700 let nth_sv = ScalarValue::Int64(Some(n));
701 let flags_sv = ScalarValue::Utf8(Some(flag.to_string()));
702 let subexp_sv = ScalarValue::Int64(Some(subexp));
703 let expected = expected.get(pos).cloned();
704 let re = regexp_instr_with_scalar_values(&[
705 v_sv,
706 regex_sv,
707 start_sv.clone(),
708 nth_sv.clone(),
709 flags_sv,
710 subexp_sv.clone(),
711 ]);
712 match re {
713 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
714 assert_eq!(v, expected, "regexp_instr scalar test failed");
715 }
716 _ => panic!("Unexpected result"),
717 }
718
719 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
721 let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
722 let start_sv = ScalarValue::Int64(Some(s));
723 let nth_sv = ScalarValue::Int64(Some(n));
724 let flags_sv = ScalarValue::LargeUtf8(Some(flag.to_string()));
725 let subexp_sv = ScalarValue::Int64(Some(subexp));
726 let re = regexp_instr_with_scalar_values(&[
727 v_sv,
728 regex_sv,
729 start_sv.clone(),
730 nth_sv.clone(),
731 flags_sv,
732 subexp_sv.clone(),
733 ]);
734 match re {
735 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
736 assert_eq!(v, expected, "regexp_instr scalar test failed");
737 }
738 _ => panic!("Unexpected result"),
739 }
740
741 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
743 let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
744 let start_sv = ScalarValue::Int64(Some(s));
745 let nth_sv = ScalarValue::Int64(Some(n));
746 let flags_sv = ScalarValue::Utf8View(Some(flag.to_string()));
747 let subexp_sv = ScalarValue::Int64(Some(subexp));
748 let re = regexp_instr_with_scalar_values(&[
749 v_sv,
750 regex_sv,
751 start_sv.clone(),
752 nth_sv.clone(),
753 flags_sv,
754 subexp_sv.clone(),
755 ]);
756 match re {
757 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
758 assert_eq!(v, expected, "regexp_instr scalar test failed");
759 }
760 _ => panic!("Unexpected result"),
761 }
762 });
763 }
764
765 fn test_case_sensitive_regexp_instr_array<A>()
766 where
767 A: From<Vec<&'static str>> + Array + 'static,
768 {
769 let values = A::from(vec![
770 "hello world",
771 "abcdefg",
772 "xyz123xyz",
773 "no match here",
774 "",
775 ]);
776 let regex = A::from(vec!["o", "d", "123", "z", "gg"]);
777
778 let expected = Int64Array::from(vec![5, 4, 4, 0, 0]);
779 let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex)]).unwrap();
780 assert_eq!(re.as_ref(), &expected);
781 }
782
783 fn test_case_sensitive_regexp_instr_array_start<A>()
784 where
785 A: From<Vec<&'static str>> + Array + 'static,
786 {
787 let values = A::from(vec!["abcabcabc", "abcabcabc", ""]);
788 let regex = A::from(vec!["abc", "abc", "gg"]);
789 let start = Int64Array::from(vec![4, 5, 5]);
790 let expected = Int64Array::from(vec![4, 7, 0]);
791
792 let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)])
793 .unwrap();
794 assert_eq!(re.as_ref(), &expected);
795 }
796
797 fn test_case_sensitive_regexp_instr_array_nth<A>()
798 where
799 A: From<Vec<&'static str>> + Array + 'static,
800 {
801 let values = A::from(vec!["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"]);
802 let regex = A::from(vec!["abc", "abc", "abc", "abc"]);
803 let start = Int64Array::from(vec![1, 1, 1, 1]);
804 let nth = Int64Array::from(vec![1, 2, 3, 4]);
805 let expected = Int64Array::from(vec![1, 4, 7, 0]);
806
807 let re = regexp_instr_func(&[
808 Arc::new(values),
809 Arc::new(regex),
810 Arc::new(start),
811 Arc::new(nth),
812 ])
813 .unwrap();
814 assert_eq!(re.as_ref(), &expected);
815 }
816}