1use arrow::array::{
19 Array, ArrayRef, AsArray, Datum, Int64Array, PrimitiveArray, StringArrayType,
20};
21use arrow::datatypes::{DataType, Int64Type};
22use arrow::datatypes::{
23 DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View,
24};
25use arrow::error::ArrowError;
26use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
27use datafusion_expr::{
28 ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact,
29 TypeSignature::Uniform, Volatility,
30};
31use datafusion_macros::user_doc;
32use itertools::izip;
33use regex::Regex;
34use std::collections::HashMap;
35use std::sync::Arc;
36
37use crate::regex::compile_and_cache_regex;
38
39#[user_doc(
40 doc_section(label = "Regular Expression Functions"),
41 description = "Returns the position in a string where the specified occurrence of a POSIX regular expression is located.",
42 syntax_example = "regexp_instr(str, regexp[, start[, N[, flags[, subexpr]]]])",
43 sql_example = r#"```sql
44> SELECT regexp_instr('ABCDEF', 'C(.)(..)');
45+---------------------------------------------------------------+
46| regexp_instr(Utf8("ABCDEF"),Utf8("C(.)(..)")) |
47+---------------------------------------------------------------+
48| 3 |
49+---------------------------------------------------------------+
50```"#,
51 standard_argument(name = "str", prefix = "String"),
52 standard_argument(name = "regexp", prefix = "Regular"),
53 argument(
54 name = "start",
55 description = "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. Defaults to 1"
56 ),
57 argument(
58 name = "N",
59 description = "- **N**: Optional The N-th occurrence of pattern to find. Defaults to 1 (first match). Can be a constant, column, or function."
60 ),
61 argument(
62 name = "flags",
63 description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported:
64 - **i**: case-insensitive: letters match both upper and lower case
65 - **m**: multi-line mode: ^ and $ match begin/end of line
66 - **s**: allow . to match \n
67 - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used
68 - **U**: swap the meaning of x* and x*?"#
69 ),
70 argument(
71 name = "subexpr",
72 description = "Optional Specifies which capture group (subexpression) to return the position for. Defaults to 0, which returns the position of the entire match."
73 )
74)]
75#[derive(Debug, PartialEq, Eq, Hash)]
76pub struct RegexpInstrFunc {
77 signature: Signature,
78}
79
80impl Default for RegexpInstrFunc {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86impl RegexpInstrFunc {
87 pub fn new() -> Self {
88 Self {
89 signature: Signature::one_of(
90 vec![
91 Uniform(2, vec![Utf8View, LargeUtf8, Utf8]),
92 Exact(vec![Utf8View, Utf8View, Int64]),
93 Exact(vec![LargeUtf8, LargeUtf8, Int64]),
94 Exact(vec![Utf8, Utf8, Int64]),
95 Exact(vec![Utf8View, Utf8View, Int64, Int64]),
96 Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]),
97 Exact(vec![Utf8, Utf8, Int64, Int64]),
98 Exact(vec![Utf8View, Utf8View, Int64, Int64, Utf8View]),
99 Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8]),
100 Exact(vec![Utf8, Utf8, Int64, Int64, Utf8]),
101 Exact(vec![Utf8View, Utf8View, Int64, Int64, Utf8View, Int64]),
102 Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8, Int64]),
103 Exact(vec![Utf8, Utf8, Int64, Int64, Utf8, Int64]),
104 ],
105 Volatility::Immutable,
106 ),
107 }
108 }
109}
110
111impl ScalarUDFImpl for RegexpInstrFunc {
112 fn as_any(&self) -> &dyn std::any::Any {
113 self
114 }
115
116 fn name(&self) -> &str {
117 "regexp_instr"
118 }
119
120 fn signature(&self) -> &Signature {
121 &self.signature
122 }
123
124 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
125 Ok(Int64)
126 }
127
128 fn invoke_with_args(
129 &self,
130 args: datafusion_expr::ScalarFunctionArgs,
131 ) -> Result<ColumnarValue> {
132 let args = &args.args;
133
134 let len = args
135 .iter()
136 .fold(Option::<usize>::None, |acc, arg| match arg {
137 ColumnarValue::Scalar(_) => acc,
138 ColumnarValue::Array(a) => Some(a.len()),
139 });
140
141 let is_scalar = len.is_none();
142 let inferred_length = len.unwrap_or(1);
143 let args = args
144 .iter()
145 .map(|arg| arg.to_array(inferred_length))
146 .collect::<Result<Vec<_>>>()?;
147
148 let result = regexp_instr_func(&args);
149 if is_scalar {
150 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
152 result.map(ColumnarValue::Scalar)
153 } else {
154 result.map(ColumnarValue::Array)
155 }
156 }
157
158 fn documentation(&self) -> Option<&Documentation> {
159 self.doc()
160 }
161}
162
163pub fn regexp_instr_func(args: &[ArrayRef]) -> Result<ArrayRef> {
164 let args_len = args.len();
165 if !(2..=6).contains(&args_len) {
166 return exec_err!("regexp_instr was called with {args_len} arguments. It requires at least 2 and at most 6.");
167 }
168
169 let values = &args[0];
170 match values.data_type() {
171 Utf8 | LargeUtf8 | Utf8View => (),
172 other => {
173 return internal_err!(
174 "Unsupported data type {other:?} for function regexp_instr"
175 );
176 }
177 }
178
179 regexp_instr(
180 values,
181 &args[1],
182 if args_len > 2 { Some(&args[2]) } else { None },
183 if args_len > 3 { Some(&args[3]) } else { None },
184 if args_len > 4 { Some(&args[4]) } else { None },
185 if args_len > 5 { Some(&args[5]) } else { None },
186 )
187 .map_err(|e| e.into())
188}
189
190pub fn regexp_instr(
209 values: &dyn Array,
210 regex_array: &dyn Datum,
211 start_array: Option<&dyn Datum>,
212 nth_array: Option<&dyn Datum>,
213 flags_array: Option<&dyn Datum>,
214 subexpr_array: Option<&dyn Datum>,
215) -> Result<ArrayRef, ArrowError> {
216 let (regex_array, _) = regex_array.get();
217 let start_array = start_array.map(|start| {
218 let (start, _) = start.get();
219 start
220 });
221 let nth_array = nth_array.map(|nth| {
222 let (nth, _) = nth.get();
223 nth
224 });
225 let flags_array = flags_array.map(|flags| {
226 let (flags, _) = flags.get();
227 flags
228 });
229 let subexpr_array = subexpr_array.map(|subexpr| {
230 let (subexpr, _) = subexpr.get();
231 subexpr
232 });
233
234 match (values.data_type(), regex_array.data_type(), flags_array) {
235 (Utf8, Utf8, None) => regexp_instr_inner(
236 values.as_string::<i32>(),
237 regex_array.as_string::<i32>(),
238 start_array.map(|start| start.as_primitive::<Int64Type>()),
239 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
240 None,
241 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
242 ),
243 (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_instr_inner(
244 values.as_string::<i32>(),
245 regex_array.as_string::<i32>(),
246 start_array.map(|start| start.as_primitive::<Int64Type>()),
247 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
248 Some(flags_array.as_string::<i32>()),
249 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
250 ),
251 (LargeUtf8, LargeUtf8, None) => regexp_instr_inner(
252 values.as_string::<i64>(),
253 regex_array.as_string::<i64>(),
254 start_array.map(|start| start.as_primitive::<Int64Type>()),
255 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
256 None,
257 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
258 ),
259 (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_instr_inner(
260 values.as_string::<i64>(),
261 regex_array.as_string::<i64>(),
262 start_array.map(|start| start.as_primitive::<Int64Type>()),
263 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
264 Some(flags_array.as_string::<i64>()),
265 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
266 ),
267 (Utf8View, Utf8View, None) => regexp_instr_inner(
268 values.as_string_view(),
269 regex_array.as_string_view(),
270 start_array.map(|start| start.as_primitive::<Int64Type>()),
271 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
272 None,
273 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
274 ),
275 (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_instr_inner(
276 values.as_string_view(),
277 regex_array.as_string_view(),
278 start_array.map(|start| start.as_primitive::<Int64Type>()),
279 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
280 Some(flags_array.as_string_view()),
281 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
282 ),
283 _ => Err(ArrowError::ComputeError(
284 "regexp_instr() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(),
285 )),
286 }
287}
288
289#[allow(clippy::too_many_arguments)]
290pub fn regexp_instr_inner<'a, S>(
291 values: S,
292 regex_array: S,
293 start_array: Option<&Int64Array>,
294 nth_array: Option<&Int64Array>,
295 flags_array: Option<S>,
296 subexp_array: Option<&Int64Array>,
297) -> Result<ArrayRef, ArrowError>
298where
299 S: StringArrayType<'a>,
300{
301 let len = values.len();
302
303 let default_start_array = PrimitiveArray::<Int64Type>::from(vec![1; len]);
304 let start_array = start_array.unwrap_or(&default_start_array);
305 let start_input: Vec<i64> = (0..start_array.len())
306 .map(|i| start_array.value(i)) .collect();
308
309 let default_nth_array = PrimitiveArray::<Int64Type>::from(vec![1; len]);
310 let nth_array = nth_array.unwrap_or(&default_nth_array);
311 let nth_input: Vec<i64> = (0..nth_array.len())
312 .map(|i| nth_array.value(i)) .collect();
314
315 let flags_input = match flags_array {
316 Some(flags) => flags.iter().collect(),
317 None => vec![None; len],
318 };
319
320 let default_subexp_array = PrimitiveArray::<Int64Type>::from(vec![0; len]);
321 let subexp_array = subexp_array.unwrap_or(&default_subexp_array);
322 let subexp_input: Vec<i64> = (0..subexp_array.len())
323 .map(|i| subexp_array.value(i)) .collect();
325
326 let mut regex_cache = HashMap::new();
327
328 let result: Result<Vec<Option<i64>>, ArrowError> = izip!(
329 values.iter(),
330 regex_array.iter(),
331 start_input.iter(),
332 nth_input.iter(),
333 flags_input.iter(),
334 subexp_input.iter()
335 )
336 .map(|(value, regex, start, nth, flags, subexp)| match regex {
337 None => Ok(None),
338 Some("") => Ok(Some(0)),
339 Some(regex) => get_index(
340 value,
341 regex,
342 *start,
343 *nth,
344 *subexp,
345 *flags,
346 &mut regex_cache,
347 ),
348 })
349 .collect();
350 Ok(Arc::new(Int64Array::from(result?)))
351}
352
353fn handle_subexp(
354 pattern: &Regex,
355 search_slice: &str,
356 subexpr: i64,
357 value: &str,
358 byte_start_offset: usize,
359) -> Result<Option<i64>, ArrowError> {
360 if let Some(captures) = pattern.captures(search_slice) {
361 if let Some(matched) = captures.get(subexpr as usize) {
362 let start_char_offset =
365 value[..byte_start_offset + matched.start()].chars().count() as i64 + 1;
366 return Ok(Some(start_char_offset));
367 }
368 }
369 Ok(Some(0)) }
371
372fn get_nth_match(
373 pattern: &Regex,
374 search_slice: &str,
375 n: i64,
376 byte_start_offset: usize,
377 value: &str,
378) -> Result<Option<i64>, ArrowError> {
379 if let Some(mat) = pattern.find_iter(search_slice).nth((n - 1) as usize) {
380 let match_start_byte_offset = byte_start_offset + mat.start();
383 let match_start_char_offset =
384 value[..match_start_byte_offset].chars().count() as i64 + 1;
385 Ok(Some(match_start_char_offset))
386 } else {
387 Ok(Some(0)) }
389}
390fn get_index<'strings, 'cache>(
391 value: Option<&str>,
392 pattern: &'strings str,
393 start: i64,
394 n: i64,
395 subexpr: i64,
396 flags: Option<&'strings str>,
397 regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>,
398) -> Result<Option<i64>, ArrowError>
399where
400 'strings: 'cache,
401{
402 let value = match value {
403 None => return Ok(None),
404 Some("") => return Ok(Some(0)),
405 Some(value) => value,
406 };
407 let pattern: &Regex = compile_and_cache_regex(pattern, flags, regex_cache)?;
408 if start < 1 {
410 return Err(ArrowError::ComputeError(
411 "regexp_instr() requires start to be 1-based".to_string(),
412 ));
413 }
414
415 if n < 1 {
416 return Err(ArrowError::ComputeError(
417 "N must be 1 or greater".to_string(),
418 ));
419 }
420
421 let total_chars = value.chars().count() as i64;
423 let byte_start_offset: usize = if start > total_chars {
424 return Ok(Some(0));
427 } else {
428 value
430 .char_indices()
431 .nth((start - 1) as usize)
432 .map(|(idx, _)| idx)
433 .unwrap_or(0) };
435 let search_slice = &value[byte_start_offset..];
438
439 if subexpr > 0 {
441 return handle_subexp(pattern, search_slice, subexpr, value, byte_start_offset);
442 }
443
444 get_nth_match(pattern, search_slice, n, byte_start_offset, value)
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use arrow::array::Int64Array;
452 use arrow::array::{GenericStringArray, StringViewArray};
453 use arrow::datatypes::Field;
454 use datafusion_common::config::ConfigOptions;
455 use datafusion_expr::ScalarFunctionArgs;
456 #[test]
457 fn test_regexp_instr() {
458 test_case_sensitive_regexp_instr_nulls();
459 test_case_sensitive_regexp_instr_scalar();
460 test_case_sensitive_regexp_instr_scalar_start();
461 test_case_sensitive_regexp_instr_scalar_nth();
462 test_case_sensitive_regexp_instr_scalar_subexp();
463
464 test_case_sensitive_regexp_instr_array::<GenericStringArray<i32>>();
465 test_case_sensitive_regexp_instr_array::<GenericStringArray<i64>>();
466 test_case_sensitive_regexp_instr_array::<StringViewArray>();
467
468 test_case_sensitive_regexp_instr_array_start::<GenericStringArray<i32>>();
469 test_case_sensitive_regexp_instr_array_start::<GenericStringArray<i64>>();
470 test_case_sensitive_regexp_instr_array_start::<StringViewArray>();
471
472 test_case_sensitive_regexp_instr_array_nth::<GenericStringArray<i32>>();
473 test_case_sensitive_regexp_instr_array_nth::<GenericStringArray<i64>>();
474 test_case_sensitive_regexp_instr_array_nth::<StringViewArray>();
475 }
476
477 fn regexp_instr_with_scalar_values(args: &[ScalarValue]) -> Result<ColumnarValue> {
478 let args_values: Vec<ColumnarValue> = args
479 .iter()
480 .map(|sv| ColumnarValue::Scalar(sv.clone()))
481 .collect();
482
483 let arg_fields = args
484 .iter()
485 .enumerate()
486 .map(|(idx, a)| {
487 Arc::new(Field::new(format!("arg_{idx}"), a.data_type(), true))
488 })
489 .collect::<Vec<_>>();
490
491 RegexpInstrFunc::new().invoke_with_args(ScalarFunctionArgs {
492 args: args_values,
493 arg_fields,
494 number_rows: args.len(),
495 return_field: Arc::new(Field::new("f", Int64, true)),
496 config_options: Arc::new(ConfigOptions::default()),
497 })
498 }
499
500 fn test_case_sensitive_regexp_instr_nulls() {
501 let v = "";
502 let r = "";
503 let expected = 0;
504 let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
505 let re = regexp_instr_with_scalar_values(&[v.to_string().into(), regex_sv]);
506 match re {
508 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
509 assert_eq!(v, Some(expected), "regexp_instr scalar test failed");
510 }
511 _ => panic!("Unexpected result"),
512 }
513 }
514 fn test_case_sensitive_regexp_instr_scalar() {
515 let values = [
516 "hello world",
517 "abcdefg",
518 "xyz123xyz",
519 "no match here",
520 "abc",
521 "ДатаФусион数据融合📊🔥",
522 ];
523 let regex = ["o", "d", "123", "z", "gg", "📊"];
524
525 let expected: Vec<i64> = vec![5, 4, 4, 0, 0, 15];
526
527 izip!(values.iter(), regex.iter())
528 .enumerate()
529 .for_each(|(pos, (&v, &r))| {
530 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
532 let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
533 let expected = expected.get(pos).cloned();
534 let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
535 match re {
537 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
538 assert_eq!(v, expected, "regexp_instr scalar test failed");
539 }
540 _ => panic!("Unexpected result"),
541 }
542
543 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
545 let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
546 let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
547 match re {
548 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
549 assert_eq!(v, expected, "regexp_instr scalar test failed");
550 }
551 _ => panic!("Unexpected result"),
552 }
553
554 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
556 let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
557 let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
558 match re {
559 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
560 assert_eq!(v, expected, "regexp_instr scalar test failed");
561 }
562 _ => panic!("Unexpected result"),
563 }
564 });
565 }
566
567 fn test_case_sensitive_regexp_instr_scalar_start() {
568 let values = ["abcabcabc", "abcabcabc", ""];
569 let regex = ["abc", "abc", "gg"];
570 let start = [4, 5, 5];
571 let expected: Vec<i64> = vec![4, 7, 0];
572
573 izip!(values.iter(), regex.iter(), start.iter())
574 .enumerate()
575 .for_each(|(pos, (&v, &r, &s))| {
576 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
578 let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
579 let start_sv = ScalarValue::Int64(Some(s));
580 let expected = expected.get(pos).cloned();
581 let re =
582 regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
583 match re {
584 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
585 assert_eq!(v, expected, "regexp_instr scalar test failed");
586 }
587 _ => panic!("Unexpected result"),
588 }
589
590 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
592 let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
593 let start_sv = ScalarValue::Int64(Some(s));
594 let re =
595 regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
596 match re {
597 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
598 assert_eq!(v, expected, "regexp_instr scalar test failed");
599 }
600 _ => panic!("Unexpected result"),
601 }
602
603 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
605 let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
606 let start_sv = ScalarValue::Int64(Some(s));
607 let re =
608 regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
609 match re {
610 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
611 assert_eq!(v, expected, "regexp_instr scalar test failed");
612 }
613 _ => panic!("Unexpected result"),
614 }
615 });
616 }
617
618 fn test_case_sensitive_regexp_instr_scalar_nth() {
619 let values = ["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"];
620 let regex = ["abc", "abc", "abc", "abc"];
621 let start = [1, 1, 1, 1];
622 let nth = [1, 2, 3, 4];
623 let expected: Vec<i64> = vec![1, 4, 7, 0];
624
625 izip!(values.iter(), regex.iter(), start.iter(), nth.iter())
626 .enumerate()
627 .for_each(|(pos, (&v, &r, &s, &n))| {
628 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
630 let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
631 let start_sv = ScalarValue::Int64(Some(s));
632 let nth_sv = ScalarValue::Int64(Some(n));
633 let expected = expected.get(pos).cloned();
634 let re = regexp_instr_with_scalar_values(&[
635 v_sv,
636 regex_sv,
637 start_sv.clone(),
638 nth_sv.clone(),
639 ]);
640 match re {
641 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
642 assert_eq!(v, expected, "regexp_instr scalar test failed");
643 }
644 _ => panic!("Unexpected result"),
645 }
646
647 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
649 let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
650 let start_sv = ScalarValue::Int64(Some(s));
651 let nth_sv = ScalarValue::Int64(Some(n));
652 let re = regexp_instr_with_scalar_values(&[
653 v_sv,
654 regex_sv,
655 start_sv.clone(),
656 nth_sv.clone(),
657 ]);
658 match re {
659 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
660 assert_eq!(v, expected, "regexp_instr scalar test failed");
661 }
662 _ => panic!("Unexpected result"),
663 }
664
665 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
667 let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
668 let start_sv = ScalarValue::Int64(Some(s));
669 let nth_sv = ScalarValue::Int64(Some(n));
670 let re = regexp_instr_with_scalar_values(&[
671 v_sv,
672 regex_sv,
673 start_sv.clone(),
674 nth_sv.clone(),
675 ]);
676 match re {
677 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
678 assert_eq!(v, expected, "regexp_instr scalar test failed");
679 }
680 _ => panic!("Unexpected result"),
681 }
682 });
683 }
684
685 fn test_case_sensitive_regexp_instr_scalar_subexp() {
686 let values = ["12 abc def ghi 34"];
687 let regex = ["(abc) (def) (ghi)"];
688 let start = [1];
689 let nth = [1];
690 let flags = ["i"];
691 let subexps = [2];
692 let expected: Vec<i64> = vec![8];
693
694 izip!(
695 values.iter(),
696 regex.iter(),
697 start.iter(),
698 nth.iter(),
699 flags.iter(),
700 subexps.iter()
701 )
702 .enumerate()
703 .for_each(|(pos, (&v, &r, &s, &n, &flag, &subexp))| {
704 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
706 let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
707 let start_sv = ScalarValue::Int64(Some(s));
708 let nth_sv = ScalarValue::Int64(Some(n));
709 let flags_sv = ScalarValue::Utf8(Some(flag.to_string()));
710 let subexp_sv = ScalarValue::Int64(Some(subexp));
711 let expected = expected.get(pos).cloned();
712 let re = regexp_instr_with_scalar_values(&[
713 v_sv,
714 regex_sv,
715 start_sv.clone(),
716 nth_sv.clone(),
717 flags_sv,
718 subexp_sv.clone(),
719 ]);
720 match re {
721 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
722 assert_eq!(v, expected, "regexp_instr scalar test failed");
723 }
724 _ => panic!("Unexpected result"),
725 }
726
727 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
729 let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
730 let start_sv = ScalarValue::Int64(Some(s));
731 let nth_sv = ScalarValue::Int64(Some(n));
732 let flags_sv = ScalarValue::LargeUtf8(Some(flag.to_string()));
733 let subexp_sv = ScalarValue::Int64(Some(subexp));
734 let re = regexp_instr_with_scalar_values(&[
735 v_sv,
736 regex_sv,
737 start_sv.clone(),
738 nth_sv.clone(),
739 flags_sv,
740 subexp_sv.clone(),
741 ]);
742 match re {
743 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
744 assert_eq!(v, expected, "regexp_instr scalar test failed");
745 }
746 _ => panic!("Unexpected result"),
747 }
748
749 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
751 let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
752 let start_sv = ScalarValue::Int64(Some(s));
753 let nth_sv = ScalarValue::Int64(Some(n));
754 let flags_sv = ScalarValue::Utf8View(Some(flag.to_string()));
755 let subexp_sv = ScalarValue::Int64(Some(subexp));
756 let re = regexp_instr_with_scalar_values(&[
757 v_sv,
758 regex_sv,
759 start_sv.clone(),
760 nth_sv.clone(),
761 flags_sv,
762 subexp_sv.clone(),
763 ]);
764 match re {
765 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
766 assert_eq!(v, expected, "regexp_instr scalar test failed");
767 }
768 _ => panic!("Unexpected result"),
769 }
770 });
771 }
772
773 fn test_case_sensitive_regexp_instr_array<A>()
774 where
775 A: From<Vec<&'static str>> + Array + 'static,
776 {
777 let values = A::from(vec![
778 "hello world",
779 "abcdefg",
780 "xyz123xyz",
781 "no match here",
782 "",
783 ]);
784 let regex = A::from(vec!["o", "d", "123", "z", "gg"]);
785
786 let expected = Int64Array::from(vec![5, 4, 4, 0, 0]);
787 let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex)]).unwrap();
788 assert_eq!(re.as_ref(), &expected);
789 }
790
791 fn test_case_sensitive_regexp_instr_array_start<A>()
792 where
793 A: From<Vec<&'static str>> + Array + 'static,
794 {
795 let values = A::from(vec!["abcabcabc", "abcabcabc", ""]);
796 let regex = A::from(vec!["abc", "abc", "gg"]);
797 let start = Int64Array::from(vec![4, 5, 5]);
798 let expected = Int64Array::from(vec![4, 7, 0]);
799
800 let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)])
801 .unwrap();
802 assert_eq!(re.as_ref(), &expected);
803 }
804
805 fn test_case_sensitive_regexp_instr_array_nth<A>()
806 where
807 A: From<Vec<&'static str>> + Array + 'static,
808 {
809 let values = A::from(vec!["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"]);
810 let regex = A::from(vec!["abc", "abc", "abc", "abc"]);
811 let start = Int64Array::from(vec![1, 1, 1, 1]);
812 let nth = Int64Array::from(vec![1, 2, 3, 4]);
813 let expected = Int64Array::from(vec![1, 4, 7, 0]);
814
815 let re = regexp_instr_func(&[
816 Arc::new(values),
817 Arc::new(regex),
818 Arc::new(start),
819 Arc::new(nth),
820 ])
821 .unwrap();
822 assert_eq!(re.as_ref(), &expected);
823 }
824}