1use arrow::array::{
19 Array, ArrayRef, AsArray, Datum, Int64Array, PrimitiveArray, StringArrayType,
20};
21use arrow::datatypes::{DataType, Int64Type};
22use arrow::datatypes::{
23 DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View,
24};
25use arrow::error::ArrowError;
26use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
27use datafusion_expr::{
28 ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact,
29 TypeSignature::Uniform, Volatility,
30};
31use datafusion_macros::user_doc;
32use itertools::izip;
33use regex::Regex;
34use std::collections::HashMap;
35use std::sync::Arc;
36
37use crate::regex::compile_and_cache_regex;
38
39#[user_doc(
40 doc_section(label = "Regular Expression Functions"),
41 description = "Returns the position in a string where the specified occurrence of a POSIX regular expression is located.",
42 syntax_example = "regexp_instr(str, regexp[, start[, N[, flags[, subexpr]]]])",
43 sql_example = r#"```sql
44> SELECT regexp_instr('ABCDEF', 'C(.)(..)');
45+---------------------------------------------------------------+
46| regexp_instr(Utf8("ABCDEF"),Utf8("C(.)(..)")) |
47+---------------------------------------------------------------+
48| 3 |
49+---------------------------------------------------------------+
50```"#,
51 standard_argument(name = "str", prefix = "String"),
52 standard_argument(name = "regexp", prefix = "Regular"),
53 argument(
54 name = "start",
55 description = "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. Defaults to 1"
56 ),
57 argument(
58 name = "N",
59 description = "- **N**: Optional The N-th occurrence of pattern to find. Defaults to 1 (first match). Can be a constant, column, or function."
60 ),
61 argument(
62 name = "flags",
63 description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported:
64 - **i**: case-insensitive: letters match both upper and lower case
65 - **m**: multi-line mode: ^ and $ match begin/end of line
66 - **s**: allow . to match \n
67 - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used
68 - **U**: swap the meaning of x* and x*?"#
69 ),
70 argument(
71 name = "subexpr",
72 description = "Optional Specifies which capture group (subexpression) to return the position for. Defaults to 0, which returns the position of the entire match."
73 )
74)]
75#[derive(Debug, PartialEq, Eq, Hash)]
76pub struct RegexpInstrFunc {
77 signature: Signature,
78}
79
80impl Default for RegexpInstrFunc {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86impl RegexpInstrFunc {
87 pub fn new() -> Self {
88 Self {
89 signature: Signature::one_of(
90 vec![
91 Uniform(2, vec![Utf8View, LargeUtf8, Utf8]),
92 Exact(vec![Utf8View, Utf8View, Int64]),
93 Exact(vec![LargeUtf8, LargeUtf8, Int64]),
94 Exact(vec![Utf8, Utf8, Int64]),
95 Exact(vec![Utf8View, Utf8View, Int64, Int64]),
96 Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]),
97 Exact(vec![Utf8, Utf8, Int64, Int64]),
98 Exact(vec![Utf8View, Utf8View, Int64, Int64, Utf8View]),
99 Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8]),
100 Exact(vec![Utf8, Utf8, Int64, Int64, Utf8]),
101 Exact(vec![Utf8View, Utf8View, Int64, Int64, Utf8View, Int64]),
102 Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8, Int64]),
103 Exact(vec![Utf8, Utf8, Int64, Int64, Utf8, Int64]),
104 ],
105 Volatility::Immutable,
106 ),
107 }
108 }
109}
110
111impl ScalarUDFImpl for RegexpInstrFunc {
112 fn as_any(&self) -> &dyn std::any::Any {
113 self
114 }
115
116 fn name(&self) -> &str {
117 "regexp_instr"
118 }
119
120 fn signature(&self) -> &Signature {
121 &self.signature
122 }
123
124 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
125 Ok(Int64)
126 }
127
128 fn invoke_with_args(
129 &self,
130 args: datafusion_expr::ScalarFunctionArgs,
131 ) -> Result<ColumnarValue> {
132 let args = &args.args;
133
134 let len = args
135 .iter()
136 .fold(Option::<usize>::None, |acc, arg| match arg {
137 ColumnarValue::Scalar(_) => acc,
138 ColumnarValue::Array(a) => Some(a.len()),
139 });
140
141 let is_scalar = len.is_none();
142 let inferred_length = len.unwrap_or(1);
143 let args = args
144 .iter()
145 .map(|arg| arg.to_array(inferred_length))
146 .collect::<Result<Vec<_>>>()?;
147
148 let result = regexp_instr_func(&args);
149 if is_scalar {
150 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
152 result.map(ColumnarValue::Scalar)
153 } else {
154 result.map(ColumnarValue::Array)
155 }
156 }
157
158 fn documentation(&self) -> Option<&Documentation> {
159 self.doc()
160 }
161}
162
163pub fn regexp_instr_func(args: &[ArrayRef]) -> Result<ArrayRef> {
164 let args_len = args.len();
165 if !(2..=6).contains(&args_len) {
166 return exec_err!(
167 "regexp_instr was called with {args_len} arguments. It requires at least 2 and at most 6."
168 );
169 }
170
171 let values = &args[0];
172 match values.data_type() {
173 Utf8 | LargeUtf8 | Utf8View => (),
174 other => {
175 return internal_err!(
176 "Unsupported data type {other:?} for function regexp_instr"
177 );
178 }
179 }
180
181 regexp_instr(
182 values,
183 &args[1],
184 if args_len > 2 { Some(&args[2]) } else { None },
185 if args_len > 3 { Some(&args[3]) } else { None },
186 if args_len > 4 { Some(&args[4]) } else { None },
187 if args_len > 5 { Some(&args[5]) } else { None },
188 )
189 .map_err(|e| e.into())
190}
191
192fn regexp_instr(
211 values: &dyn Array,
212 regex_array: &dyn Datum,
213 start_array: Option<&dyn Datum>,
214 nth_array: Option<&dyn Datum>,
215 flags_array: Option<&dyn Datum>,
216 subexpr_array: Option<&dyn Datum>,
217) -> Result<ArrayRef, ArrowError> {
218 let (regex_array, _) = regex_array.get();
219 let start_array = start_array.map(|start| {
220 let (start, _) = start.get();
221 start
222 });
223 let nth_array = nth_array.map(|nth| {
224 let (nth, _) = nth.get();
225 nth
226 });
227 let flags_array = flags_array.map(|flags| {
228 let (flags, _) = flags.get();
229 flags
230 });
231 let subexpr_array = subexpr_array.map(|subexpr| {
232 let (subexpr, _) = subexpr.get();
233 subexpr
234 });
235
236 match (values.data_type(), regex_array.data_type(), flags_array) {
237 (Utf8, Utf8, None) => regexp_instr_inner(
238 &values.as_string::<i32>(),
239 ®ex_array.as_string::<i32>(),
240 start_array.map(|start| start.as_primitive::<Int64Type>()),
241 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
242 None,
243 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
244 ),
245 (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_instr_inner(
246 &values.as_string::<i32>(),
247 ®ex_array.as_string::<i32>(),
248 start_array.map(|start| start.as_primitive::<Int64Type>()),
249 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
250 Some(flags_array.as_string::<i32>()),
251 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
252 ),
253 (LargeUtf8, LargeUtf8, None) => regexp_instr_inner(
254 &values.as_string::<i64>(),
255 ®ex_array.as_string::<i64>(),
256 start_array.map(|start| start.as_primitive::<Int64Type>()),
257 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
258 None,
259 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
260 ),
261 (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_instr_inner(
262 &values.as_string::<i64>(),
263 ®ex_array.as_string::<i64>(),
264 start_array.map(|start| start.as_primitive::<Int64Type>()),
265 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
266 Some(flags_array.as_string::<i64>()),
267 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
268 ),
269 (Utf8View, Utf8View, None) => regexp_instr_inner(
270 &values.as_string_view(),
271 ®ex_array.as_string_view(),
272 start_array.map(|start| start.as_primitive::<Int64Type>()),
273 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
274 None,
275 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
276 ),
277 (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_instr_inner(
278 &values.as_string_view(),
279 ®ex_array.as_string_view(),
280 start_array.map(|start| start.as_primitive::<Int64Type>()),
281 nth_array.map(|nth| nth.as_primitive::<Int64Type>()),
282 Some(flags_array.as_string_view()),
283 subexpr_array.map(|subexpr| subexpr.as_primitive::<Int64Type>()),
284 ),
285 _ => Err(ArrowError::ComputeError(
286 "regexp_instr() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(),
287 )),
288 }
289}
290
291fn regexp_instr_inner<'a, S>(
292 values: &S,
293 regex_array: &S,
294 start_array: Option<&Int64Array>,
295 nth_array: Option<&Int64Array>,
296 flags_array: Option<S>,
297 subexp_array: Option<&Int64Array>,
298) -> Result<ArrayRef, ArrowError>
299where
300 S: StringArrayType<'a>,
301{
302 let len = values.len();
303
304 let default_start_array = PrimitiveArray::<Int64Type>::from(vec![1; len]);
305 let start_array = start_array.unwrap_or(&default_start_array);
306 let start_input: Vec<i64> = (0..start_array.len())
307 .map(|i| start_array.value(i)) .collect();
309
310 let default_nth_array = PrimitiveArray::<Int64Type>::from(vec![1; len]);
311 let nth_array = nth_array.unwrap_or(&default_nth_array);
312 let nth_input: Vec<i64> = (0..nth_array.len())
313 .map(|i| nth_array.value(i)) .collect();
315
316 let flags_input = match flags_array {
317 Some(flags) => flags.iter().collect(),
318 None => vec![None; len],
319 };
320
321 let default_subexp_array = PrimitiveArray::<Int64Type>::from(vec![0; len]);
322 let subexp_array = subexp_array.unwrap_or(&default_subexp_array);
323 let subexp_input: Vec<i64> = (0..subexp_array.len())
324 .map(|i| subexp_array.value(i)) .collect();
326
327 let mut regex_cache = HashMap::new();
328
329 let result: Result<Vec<Option<i64>>, ArrowError> = izip!(
330 values.iter(),
331 regex_array.iter(),
332 start_input.iter(),
333 nth_input.iter(),
334 flags_input.iter(),
335 subexp_input.iter()
336 )
337 .map(|(value, regex, start, nth, flags, subexp)| match regex {
338 None => Ok(None),
339 Some("") => Ok(Some(0)),
340 Some(regex) => get_index(
341 value,
342 regex,
343 *start,
344 *nth,
345 *subexp,
346 *flags,
347 &mut regex_cache,
348 ),
349 })
350 .collect();
351 Ok(Arc::new(Int64Array::from(result?)))
352}
353
354fn handle_subexp(
355 pattern: &Regex,
356 search_slice: &str,
357 subexpr: i64,
358 value: &str,
359 byte_start_offset: usize,
360) -> Result<Option<i64>, ArrowError> {
361 if let Some(captures) = pattern.captures(search_slice)
362 && let Some(matched) = captures.get(subexpr as usize)
363 {
364 let start_char_offset =
367 value[..byte_start_offset + matched.start()].chars().count() as i64 + 1;
368 return Ok(Some(start_char_offset));
369 }
370 Ok(Some(0)) }
372
373fn get_nth_match(
374 pattern: &Regex,
375 search_slice: &str,
376 n: i64,
377 byte_start_offset: usize,
378 value: &str,
379) -> Result<Option<i64>, ArrowError> {
380 if let Some(mat) = pattern.find_iter(search_slice).nth((n - 1) as usize) {
381 let match_start_byte_offset = byte_start_offset + mat.start();
384 let match_start_char_offset =
385 value[..match_start_byte_offset].chars().count() as i64 + 1;
386 Ok(Some(match_start_char_offset))
387 } else {
388 Ok(Some(0)) }
390}
391fn get_index<'strings, 'cache>(
392 value: Option<&str>,
393 pattern: &'strings str,
394 start: i64,
395 n: i64,
396 subexpr: i64,
397 flags: Option<&'strings str>,
398 regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>,
399) -> Result<Option<i64>, ArrowError>
400where
401 'strings: 'cache,
402{
403 let value = match value {
404 None => return Ok(None),
405 Some("") => return Ok(Some(0)),
406 Some(value) => value,
407 };
408 let pattern: &Regex = compile_and_cache_regex(pattern, flags, regex_cache)?;
409 if start < 1 {
411 return Err(ArrowError::ComputeError(
412 "regexp_instr() requires start to be 1-based".to_string(),
413 ));
414 }
415
416 if n < 1 {
417 return Err(ArrowError::ComputeError(
418 "N must be 1 or greater".to_string(),
419 ));
420 }
421
422 let total_chars = value.chars().count() as i64;
424 let byte_start_offset: usize = if start > total_chars {
425 return Ok(Some(0));
428 } else {
429 value
431 .char_indices()
432 .nth((start - 1) as usize)
433 .map(|(idx, _)| idx)
434 .unwrap_or(0) };
436 let search_slice = &value[byte_start_offset..];
439
440 if subexpr > 0 {
442 return handle_subexp(pattern, search_slice, subexpr, value, byte_start_offset);
443 }
444
445 get_nth_match(pattern, search_slice, n, byte_start_offset, value)
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452 use arrow::array::Int64Array;
453 use arrow::array::{GenericStringArray, StringViewArray};
454 use arrow::datatypes::Field;
455 use datafusion_common::config::ConfigOptions;
456 use datafusion_expr::ScalarFunctionArgs;
457 #[test]
458 fn test_regexp_instr() {
459 test_case_sensitive_regexp_instr_nulls();
460 test_case_sensitive_regexp_instr_scalar();
461 test_case_sensitive_regexp_instr_scalar_start();
462 test_case_sensitive_regexp_instr_scalar_nth();
463 test_case_sensitive_regexp_instr_scalar_subexp();
464
465 test_case_sensitive_regexp_instr_array::<GenericStringArray<i32>>();
466 test_case_sensitive_regexp_instr_array::<GenericStringArray<i64>>();
467 test_case_sensitive_regexp_instr_array::<StringViewArray>();
468
469 test_case_sensitive_regexp_instr_array_start::<GenericStringArray<i32>>();
470 test_case_sensitive_regexp_instr_array_start::<GenericStringArray<i64>>();
471 test_case_sensitive_regexp_instr_array_start::<StringViewArray>();
472
473 test_case_sensitive_regexp_instr_array_nth::<GenericStringArray<i32>>();
474 test_case_sensitive_regexp_instr_array_nth::<GenericStringArray<i64>>();
475 test_case_sensitive_regexp_instr_array_nth::<StringViewArray>();
476 }
477
478 fn regexp_instr_with_scalar_values(args: &[ScalarValue]) -> Result<ColumnarValue> {
479 let args_values: Vec<ColumnarValue> = args
480 .iter()
481 .map(|sv| ColumnarValue::Scalar(sv.clone()))
482 .collect();
483
484 let arg_fields = args
485 .iter()
486 .enumerate()
487 .map(|(idx, a)| {
488 Arc::new(Field::new(format!("arg_{idx}"), a.data_type(), true))
489 })
490 .collect::<Vec<_>>();
491
492 RegexpInstrFunc::new().invoke_with_args(ScalarFunctionArgs {
493 args: args_values,
494 arg_fields,
495 number_rows: args.len(),
496 return_field: Arc::new(Field::new("f", Int64, true)),
497 config_options: Arc::new(ConfigOptions::default()),
498 })
499 }
500
501 fn test_case_sensitive_regexp_instr_nulls() {
502 let v = "";
503 let r = "";
504 let expected = 0;
505 let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
506 let re = regexp_instr_with_scalar_values(&[v.to_string().into(), regex_sv]);
507 match re {
509 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
510 assert_eq!(v, Some(expected), "regexp_instr scalar test failed");
511 }
512 _ => panic!("Unexpected result"),
513 }
514 }
515 fn test_case_sensitive_regexp_instr_scalar() {
516 let values = [
517 "hello world",
518 "abcdefg",
519 "xyz123xyz",
520 "no match here",
521 "abc",
522 "ДатаФусион数据融合📊🔥",
523 ];
524 let regex = ["o", "d", "123", "z", "gg", "📊"];
525
526 let expected: Vec<i64> = vec![5, 4, 4, 0, 0, 15];
527
528 izip!(values.iter(), regex.iter())
529 .enumerate()
530 .for_each(|(pos, (&v, &r))| {
531 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
533 let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
534 let expected = expected.get(pos).cloned();
535 let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
536 match re {
538 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
539 assert_eq!(v, expected, "regexp_instr scalar test failed");
540 }
541 _ => panic!("Unexpected result"),
542 }
543
544 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
546 let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
547 let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
548 match re {
549 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
550 assert_eq!(v, expected, "regexp_instr scalar test failed");
551 }
552 _ => panic!("Unexpected result"),
553 }
554
555 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
557 let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
558 let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]);
559 match re {
560 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
561 assert_eq!(v, expected, "regexp_instr scalar test failed");
562 }
563 _ => panic!("Unexpected result"),
564 }
565 });
566 }
567
568 fn test_case_sensitive_regexp_instr_scalar_start() {
569 let values = ["abcabcabc", "abcabcabc", ""];
570 let regex = ["abc", "abc", "gg"];
571 let start = [4, 5, 5];
572 let expected: Vec<i64> = vec![4, 7, 0];
573
574 izip!(values.iter(), regex.iter(), start.iter())
575 .enumerate()
576 .for_each(|(pos, (&v, &r, &s))| {
577 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
579 let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
580 let start_sv = ScalarValue::Int64(Some(s));
581 let expected = expected.get(pos).cloned();
582 let re =
583 regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
584 match re {
585 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
586 assert_eq!(v, expected, "regexp_instr scalar test failed");
587 }
588 _ => panic!("Unexpected result"),
589 }
590
591 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
593 let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
594 let start_sv = ScalarValue::Int64(Some(s));
595 let re =
596 regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
597 match re {
598 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
599 assert_eq!(v, expected, "regexp_instr scalar test failed");
600 }
601 _ => panic!("Unexpected result"),
602 }
603
604 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
606 let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
607 let start_sv = ScalarValue::Int64(Some(s));
608 let re =
609 regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
610 match re {
611 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
612 assert_eq!(v, expected, "regexp_instr scalar test failed");
613 }
614 _ => panic!("Unexpected result"),
615 }
616 });
617 }
618
619 fn test_case_sensitive_regexp_instr_scalar_nth() {
620 let values = ["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"];
621 let regex = ["abc", "abc", "abc", "abc"];
622 let start = [1, 1, 1, 1];
623 let nth = [1, 2, 3, 4];
624 let expected: Vec<i64> = vec![1, 4, 7, 0];
625
626 izip!(values.iter(), regex.iter(), start.iter(), nth.iter())
627 .enumerate()
628 .for_each(|(pos, (&v, &r, &s, &n))| {
629 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
631 let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
632 let start_sv = ScalarValue::Int64(Some(s));
633 let nth_sv = ScalarValue::Int64(Some(n));
634 let expected = expected.get(pos).cloned();
635 let re = regexp_instr_with_scalar_values(&[
636 v_sv,
637 regex_sv,
638 start_sv.clone(),
639 nth_sv.clone(),
640 ]);
641 match re {
642 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
643 assert_eq!(v, expected, "regexp_instr scalar test failed");
644 }
645 _ => panic!("Unexpected result"),
646 }
647
648 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
650 let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
651 let start_sv = ScalarValue::Int64(Some(s));
652 let nth_sv = ScalarValue::Int64(Some(n));
653 let re = regexp_instr_with_scalar_values(&[
654 v_sv,
655 regex_sv,
656 start_sv.clone(),
657 nth_sv.clone(),
658 ]);
659 match re {
660 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
661 assert_eq!(v, expected, "regexp_instr scalar test failed");
662 }
663 _ => panic!("Unexpected result"),
664 }
665
666 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
668 let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
669 let start_sv = ScalarValue::Int64(Some(s));
670 let nth_sv = ScalarValue::Int64(Some(n));
671 let re = regexp_instr_with_scalar_values(&[
672 v_sv,
673 regex_sv,
674 start_sv.clone(),
675 nth_sv.clone(),
676 ]);
677 match re {
678 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
679 assert_eq!(v, expected, "regexp_instr scalar test failed");
680 }
681 _ => panic!("Unexpected result"),
682 }
683 });
684 }
685
686 fn test_case_sensitive_regexp_instr_scalar_subexp() {
687 let values = ["12 abc def ghi 34"];
688 let regex = ["(abc) (def) (ghi)"];
689 let start = [1];
690 let nth = [1];
691 let flags = ["i"];
692 let subexps = [2];
693 let expected: Vec<i64> = vec![8];
694
695 izip!(
696 values.iter(),
697 regex.iter(),
698 start.iter(),
699 nth.iter(),
700 flags.iter(),
701 subexps.iter()
702 )
703 .enumerate()
704 .for_each(|(pos, (&v, &r, &s, &n, &flag, &subexp))| {
705 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
707 let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
708 let start_sv = ScalarValue::Int64(Some(s));
709 let nth_sv = ScalarValue::Int64(Some(n));
710 let flags_sv = ScalarValue::Utf8(Some(flag.to_string()));
711 let subexp_sv = ScalarValue::Int64(Some(subexp));
712 let expected = expected.get(pos).cloned();
713 let re = regexp_instr_with_scalar_values(&[
714 v_sv,
715 regex_sv,
716 start_sv.clone(),
717 nth_sv.clone(),
718 flags_sv,
719 subexp_sv.clone(),
720 ]);
721 match re {
722 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
723 assert_eq!(v, expected, "regexp_instr scalar test failed");
724 }
725 _ => panic!("Unexpected result"),
726 }
727
728 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
730 let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string()));
731 let start_sv = ScalarValue::Int64(Some(s));
732 let nth_sv = ScalarValue::Int64(Some(n));
733 let flags_sv = ScalarValue::LargeUtf8(Some(flag.to_string()));
734 let subexp_sv = ScalarValue::Int64(Some(subexp));
735 let re = regexp_instr_with_scalar_values(&[
736 v_sv,
737 regex_sv,
738 start_sv.clone(),
739 nth_sv.clone(),
740 flags_sv,
741 subexp_sv.clone(),
742 ]);
743 match re {
744 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
745 assert_eq!(v, expected, "regexp_instr scalar test failed");
746 }
747 _ => panic!("Unexpected result"),
748 }
749
750 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
752 let regex_sv = ScalarValue::Utf8View(Some(r.to_string()));
753 let start_sv = ScalarValue::Int64(Some(s));
754 let nth_sv = ScalarValue::Int64(Some(n));
755 let flags_sv = ScalarValue::Utf8View(Some(flag.to_string()));
756 let subexp_sv = ScalarValue::Int64(Some(subexp));
757 let re = regexp_instr_with_scalar_values(&[
758 v_sv,
759 regex_sv,
760 start_sv.clone(),
761 nth_sv.clone(),
762 flags_sv,
763 subexp_sv.clone(),
764 ]);
765 match re {
766 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
767 assert_eq!(v, expected, "regexp_instr scalar test failed");
768 }
769 _ => panic!("Unexpected result"),
770 }
771 });
772 }
773
774 fn test_case_sensitive_regexp_instr_array<A>()
775 where
776 A: From<Vec<&'static str>> + Array + 'static,
777 {
778 let values = A::from(vec![
779 "hello world",
780 "abcdefg",
781 "xyz123xyz",
782 "no match here",
783 "",
784 ]);
785 let regex = A::from(vec!["o", "d", "123", "z", "gg"]);
786
787 let expected = Int64Array::from(vec![5, 4, 4, 0, 0]);
788 let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex)]).unwrap();
789 assert_eq!(re.as_ref(), &expected);
790 }
791
792 fn test_case_sensitive_regexp_instr_array_start<A>()
793 where
794 A: From<Vec<&'static str>> + Array + 'static,
795 {
796 let values = A::from(vec!["abcabcabc", "abcabcabc", ""]);
797 let regex = A::from(vec!["abc", "abc", "gg"]);
798 let start = Int64Array::from(vec![4, 5, 5]);
799 let expected = Int64Array::from(vec![4, 7, 0]);
800
801 let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)])
802 .unwrap();
803 assert_eq!(re.as_ref(), &expected);
804 }
805
806 fn test_case_sensitive_regexp_instr_array_nth<A>()
807 where
808 A: From<Vec<&'static str>> + Array + 'static,
809 {
810 let values = A::from(vec!["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"]);
811 let regex = A::from(vec!["abc", "abc", "abc", "abc"]);
812 let start = Int64Array::from(vec![1, 1, 1, 1]);
813 let nth = Int64Array::from(vec![1, 2, 3, 4]);
814 let expected = Int64Array::from(vec![1, 4, 7, 0]);
815
816 let re = regexp_instr_func(&[
817 Arc::new(values),
818 Arc::new(regex),
819 Arc::new(start),
820 Arc::new(nth),
821 ])
822 .unwrap();
823 assert_eq!(re.as_ref(), &expected);
824 }
825}