1use crate::regex::{compile_and_cache_regex, compile_regex};
19use arrow::array::{Array, ArrayRef, AsArray, Datum, Int64Array, StringArrayType};
20use arrow::datatypes::{DataType, Int64Type};
21use arrow::datatypes::{
22 DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View,
23};
24use arrow::error::ArrowError;
25use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
26use datafusion_expr::{
27 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
28 TypeSignature::Exact, TypeSignature::Uniform, Volatility,
29};
30use datafusion_macros::user_doc;
31use itertools::izip;
32use regex::Regex;
33use std::collections::HashMap;
34use std::sync::Arc;
35
36#[user_doc(
37 doc_section(label = "Regular Expression Functions"),
38 description = "Returns the number of matches that a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has in a string.",
39 syntax_example = "regexp_count(str, regexp[, start, flags])",
40 sql_example = r#"```sql
41> select regexp_count('abcAbAbc', 'abc', 2, 'i');
42+---------------------------------------------------------------+
43| regexp_count(Utf8("abcAbAbc"),Utf8("abc"),Int64(2),Utf8("i")) |
44+---------------------------------------------------------------+
45| 1 |
46+---------------------------------------------------------------+
47```"#,
48 standard_argument(name = "str", prefix = "String"),
49 standard_argument(name = "regexp", prefix = "Regular"),
50 argument(
51 name = "start",
52 description = "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function."
53 ),
54 argument(
55 name = "flags",
56 description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported:
57 - **i**: case-insensitive: letters match both upper and lower case
58 - **m**: multi-line mode: ^ and $ match begin/end of line
59 - **s**: allow . to match \n
60 - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used
61 - **U**: swap the meaning of x* and x*?"#
62 )
63)]
64#[derive(Debug, PartialEq, Eq, Hash)]
65pub struct RegexpCountFunc {
66 signature: Signature,
67}
68
69impl Default for RegexpCountFunc {
70 fn default() -> Self {
71 Self::new()
72 }
73}
74
75impl RegexpCountFunc {
76 pub fn new() -> Self {
77 Self {
78 signature: Signature::one_of(
79 vec![
80 Uniform(2, vec![Utf8View, LargeUtf8, Utf8]),
81 Exact(vec![Utf8View, Utf8View, Int64]),
82 Exact(vec![LargeUtf8, LargeUtf8, Int64]),
83 Exact(vec![Utf8, Utf8, Int64]),
84 Exact(vec![Utf8View, Utf8View, Int64, Utf8View]),
85 Exact(vec![LargeUtf8, LargeUtf8, Int64, LargeUtf8]),
86 Exact(vec![Utf8, Utf8, Int64, Utf8]),
87 ],
88 Volatility::Immutable,
89 ),
90 }
91 }
92}
93
94impl ScalarUDFImpl for RegexpCountFunc {
95 fn name(&self) -> &str {
96 "regexp_count"
97 }
98
99 fn signature(&self) -> &Signature {
100 &self.signature
101 }
102
103 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
104 Ok(Int64)
105 }
106
107 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
108 let args = &args.args;
109
110 let len = args
111 .iter()
112 .fold(Option::<usize>::None, |acc, arg| match arg {
113 ColumnarValue::Scalar(_) => acc,
114 ColumnarValue::Array(a) => Some(a.len()),
115 });
116
117 let is_scalar = len.is_none();
118 let inferred_length = len.unwrap_or(1);
119 let args = args
120 .iter()
121 .map(|arg| arg.to_array(inferred_length))
122 .collect::<Result<Vec<_>>>()?;
123
124 let result = regexp_count_func(&args);
125 if is_scalar {
126 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
128 result.map(ColumnarValue::Scalar)
129 } else {
130 result.map(ColumnarValue::Array)
131 }
132 }
133
134 fn documentation(&self) -> Option<&Documentation> {
135 self.doc()
136 }
137}
138
139pub fn regexp_count_func(args: &[ArrayRef]) -> Result<ArrayRef> {
140 let args_len = args.len();
141 if !(2..=4).contains(&args_len) {
142 return exec_err!(
143 "regexp_count was called with {args_len} arguments. It requires at least 2 and at most 4."
144 );
145 }
146
147 let values = &args[0];
148 match values.data_type() {
149 Utf8 | LargeUtf8 | Utf8View => (),
150 other => {
151 return internal_err!(
152 "Unsupported data type {other:?} for function regexp_count"
153 );
154 }
155 }
156
157 regexp_count(
158 values,
159 &args[1],
160 if args_len > 2 { Some(&args[2]) } else { None },
161 if args_len > 3 { Some(&args[3]) } else { None },
162 )
163 .map_err(|e| e.into())
164}
165
166fn regexp_count(
182 values: &dyn Array,
183 regex_array: &dyn Datum,
184 start_array: Option<&dyn Datum>,
185 flags_array: Option<&dyn Datum>,
186) -> Result<ArrayRef, ArrowError> {
187 let (regex_array, is_regex_scalar) = regex_array.get();
188 let (start_array, is_start_scalar) = start_array.map_or((None, true), |start| {
189 let (start, is_start_scalar) = start.get();
190 (Some(start), is_start_scalar)
191 });
192 let (flags_array, is_flags_scalar) = flags_array.map_or((None, true), |flags| {
193 let (flags, is_flags_scalar) = flags.get();
194 (Some(flags), is_flags_scalar)
195 });
196
197 match (values.data_type(), regex_array.data_type(), flags_array) {
198 (Utf8, Utf8, None) => regexp_count_inner(
199 &values.as_string::<i32>(),
200 ®ex_array.as_string::<i32>(),
201 is_regex_scalar,
202 start_array.map(|start| start.as_primitive::<Int64Type>()),
203 is_start_scalar,
204 None,
205 is_flags_scalar,
206 ),
207 (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_count_inner(
208 &values.as_string::<i32>(),
209 ®ex_array.as_string::<i32>(),
210 is_regex_scalar,
211 start_array.map(|start| start.as_primitive::<Int64Type>()),
212 is_start_scalar,
213 Some(&flags_array.as_string::<i32>()),
214 is_flags_scalar,
215 ),
216 (LargeUtf8, LargeUtf8, None) => regexp_count_inner(
217 &values.as_string::<i64>(),
218 ®ex_array.as_string::<i64>(),
219 is_regex_scalar,
220 start_array.map(|start| start.as_primitive::<Int64Type>()),
221 is_start_scalar,
222 None,
223 is_flags_scalar,
224 ),
225 (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_count_inner(
226 &values.as_string::<i64>(),
227 ®ex_array.as_string::<i64>(),
228 is_regex_scalar,
229 start_array.map(|start| start.as_primitive::<Int64Type>()),
230 is_start_scalar,
231 Some(&flags_array.as_string::<i64>()),
232 is_flags_scalar,
233 ),
234 (Utf8View, Utf8View, None) => regexp_count_inner(
235 &values.as_string_view(),
236 ®ex_array.as_string_view(),
237 is_regex_scalar,
238 start_array.map(|start| start.as_primitive::<Int64Type>()),
239 is_start_scalar,
240 None,
241 is_flags_scalar,
242 ),
243 (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_count_inner(
244 &values.as_string_view(),
245 ®ex_array.as_string_view(),
246 is_regex_scalar,
247 start_array.map(|start| start.as_primitive::<Int64Type>()),
248 is_start_scalar,
249 Some(&flags_array.as_string_view()),
250 is_flags_scalar,
251 ),
252 _ => Err(ArrowError::ComputeError(
253 "regexp_count() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(),
254 )),
255 }
256}
257
258fn regexp_count_inner<'a, S>(
259 values: &S,
260 regex_array: &S,
261 is_regex_scalar: bool,
262 start_array: Option<&Int64Array>,
263 is_start_scalar: bool,
264 flags_array: Option<&S>,
265 is_flags_scalar: bool,
266) -> Result<ArrayRef, ArrowError>
267where
268 S: StringArrayType<'a>,
269{
270 let (regex_scalar, is_regex_scalar) = if is_regex_scalar || regex_array.len() == 1 {
271 (
272 (!regex_array.is_null(0)).then(|| regex_array.value(0)),
273 true,
274 )
275 } else {
276 (None, false)
277 };
278
279 let (start_array, start_scalar, is_start_scalar) =
280 if let Some(start_array) = start_array {
281 if is_start_scalar || start_array.len() == 1 {
282 (None, Some(start_array.value(0)), true)
283 } else {
284 (Some(start_array), None, false)
285 }
286 } else {
287 (None, Some(1), true)
288 };
289
290 let (flags_array, flags_scalar, is_flags_scalar) =
291 if let Some(flags_array) = flags_array {
292 if is_flags_scalar || flags_array.len() == 1 {
293 (None, Some(flags_array.value(0)), true)
294 } else {
295 (Some(flags_array), None, false)
296 }
297 } else {
298 (None, None, true)
299 };
300
301 let mut regex_cache = HashMap::new();
302
303 match (is_regex_scalar, is_start_scalar, is_flags_scalar) {
304 (true, true, true) => {
305 let regex = match regex_scalar {
306 None => {
307 return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
308 }
309 Some(regex) => regex,
310 };
311
312 let pattern = compile_regex(regex, flags_scalar)?;
313
314 Ok(Arc::new(
315 values
316 .iter()
317 .map(|value| count_matches(value, &pattern, start_scalar))
318 .collect::<Result<Int64Array, ArrowError>>()?,
319 ))
320 }
321 (true, true, false) => {
322 let regex = match regex_scalar {
323 None => {
324 return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
325 }
326 Some(regex) => regex,
327 };
328
329 let flags_array = flags_array.unwrap();
330 if values.len() != flags_array.len() {
331 return Err(ArrowError::ComputeError(format!(
332 "flags_array must be the same length as values array; got {} and {}",
333 flags_array.len(),
334 values.len(),
335 )));
336 }
337
338 Ok(Arc::new(
339 values
340 .iter()
341 .zip(flags_array.iter())
342 .map(|(value, flags)| {
343 let pattern =
344 compile_and_cache_regex(regex, flags, &mut regex_cache)?;
345 count_matches(value, pattern, start_scalar)
346 })
347 .collect::<Result<Int64Array, ArrowError>>()?,
348 ))
349 }
350 (true, false, true) => {
351 let regex = match regex_scalar {
352 None => {
353 return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
354 }
355 Some(regex) => regex,
356 };
357
358 let pattern = compile_regex(regex, flags_scalar)?;
359
360 let start_array = start_array.unwrap();
361
362 Ok(Arc::new(
363 values
364 .iter()
365 .zip(start_array.iter())
366 .map(|(value, start)| count_matches(value, &pattern, start))
367 .collect::<Result<Int64Array, ArrowError>>()?,
368 ))
369 }
370 (true, false, false) => {
371 let regex = match regex_scalar {
372 None => {
373 return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
374 }
375 Some(regex) => regex,
376 };
377
378 let flags_array = flags_array.unwrap();
379 if values.len() != flags_array.len() {
380 return Err(ArrowError::ComputeError(format!(
381 "flags_array must be the same length as values array; got {} and {}",
382 flags_array.len(),
383 values.len(),
384 )));
385 }
386
387 Ok(Arc::new(
388 izip!(
389 values.iter(),
390 start_array.unwrap().iter(),
391 flags_array.iter()
392 )
393 .map(|(value, start, flags)| {
394 let pattern =
395 compile_and_cache_regex(regex, flags, &mut regex_cache)?;
396
397 count_matches(value, pattern, start)
398 })
399 .collect::<Result<Int64Array, ArrowError>>()?,
400 ))
401 }
402 (false, true, true) => {
403 if values.len() != regex_array.len() {
404 return Err(ArrowError::ComputeError(format!(
405 "regex_array must be the same length as values array; got {} and {}",
406 regex_array.len(),
407 values.len(),
408 )));
409 }
410
411 Ok(Arc::new(
412 values
413 .iter()
414 .zip(regex_array.iter())
415 .map(|(value, regex)| {
416 let regex = match regex {
417 None => return Ok(0),
418 Some(regex) => regex,
419 };
420
421 let pattern = compile_and_cache_regex(
422 regex,
423 flags_scalar,
424 &mut regex_cache,
425 )?;
426 count_matches(value, pattern, start_scalar)
427 })
428 .collect::<Result<Int64Array, ArrowError>>()?,
429 ))
430 }
431 (false, true, false) => {
432 if values.len() != regex_array.len() {
433 return Err(ArrowError::ComputeError(format!(
434 "regex_array must be the same length as values array; got {} and {}",
435 regex_array.len(),
436 values.len(),
437 )));
438 }
439
440 let flags_array = flags_array.unwrap();
441 if values.len() != flags_array.len() {
442 return Err(ArrowError::ComputeError(format!(
443 "flags_array must be the same length as values array; got {} and {}",
444 flags_array.len(),
445 values.len(),
446 )));
447 }
448
449 Ok(Arc::new(
450 izip!(values.iter(), regex_array.iter(), flags_array.iter())
451 .map(|(value, regex, flags)| {
452 let regex = match regex {
453 None => return Ok(0),
454 Some(regex) => regex,
455 };
456
457 let pattern =
458 compile_and_cache_regex(regex, flags, &mut regex_cache)?;
459
460 count_matches(value, pattern, start_scalar)
461 })
462 .collect::<Result<Int64Array, ArrowError>>()?,
463 ))
464 }
465 (false, false, true) => {
466 if values.len() != regex_array.len() {
467 return Err(ArrowError::ComputeError(format!(
468 "regex_array must be the same length as values array; got {} and {}",
469 regex_array.len(),
470 values.len(),
471 )));
472 }
473
474 let start_array = start_array.unwrap();
475 if values.len() != start_array.len() {
476 return Err(ArrowError::ComputeError(format!(
477 "start_array must be the same length as values array; got {} and {}",
478 start_array.len(),
479 values.len(),
480 )));
481 }
482
483 Ok(Arc::new(
484 izip!(values.iter(), regex_array.iter(), start_array.iter())
485 .map(|(value, regex, start)| {
486 let regex = match regex {
487 None => return Ok(0),
488 Some(regex) => regex,
489 };
490
491 let pattern = compile_and_cache_regex(
492 regex,
493 flags_scalar,
494 &mut regex_cache,
495 )?;
496 count_matches(value, pattern, start)
497 })
498 .collect::<Result<Int64Array, ArrowError>>()?,
499 ))
500 }
501 (false, false, false) => {
502 if values.len() != regex_array.len() {
503 return Err(ArrowError::ComputeError(format!(
504 "regex_array must be the same length as values array; got {} and {}",
505 regex_array.len(),
506 values.len(),
507 )));
508 }
509
510 let start_array = start_array.unwrap();
511 if values.len() != start_array.len() {
512 return Err(ArrowError::ComputeError(format!(
513 "start_array must be the same length as values array; got {} and {}",
514 start_array.len(),
515 values.len(),
516 )));
517 }
518
519 let flags_array = flags_array.unwrap();
520 if values.len() != flags_array.len() {
521 return Err(ArrowError::ComputeError(format!(
522 "flags_array must be the same length as values array; got {} and {}",
523 flags_array.len(),
524 values.len(),
525 )));
526 }
527
528 Ok(Arc::new(
529 izip!(
530 values.iter(),
531 regex_array.iter(),
532 start_array.iter(),
533 flags_array.iter()
534 )
535 .map(|(value, regex, start, flags)| {
536 let regex = match regex {
537 None => return Ok(0),
538 Some(regex) => regex,
539 };
540
541 let pattern =
542 compile_and_cache_regex(regex, flags, &mut regex_cache)?;
543 count_matches(value, pattern, start)
544 })
545 .collect::<Result<Int64Array, ArrowError>>()?,
546 ))
547 }
548 }
549}
550
551fn count_matches(
552 value: Option<&str>,
553 pattern: &Regex,
554 start: Option<i64>,
555) -> Result<i64, ArrowError> {
556 let value = match value {
557 None => return Ok(0),
558 Some(value) => value,
559 };
560
561 if let Some(start) = start {
562 if start < 1 {
563 return Err(ArrowError::ComputeError(
564 "regexp_count() requires start to be 1 based".to_string(),
565 ));
566 }
567
568 let char_len = value.chars().count();
569 let start_index = (start as usize).saturating_sub(1);
570
571 if start_index > char_len {
572 return Ok(0);
573 }
574
575 let byte_offset = if start_index == char_len {
577 value.len()
578 } else {
579 value
580 .char_indices()
581 .nth(start_index)
582 .map(|(idx, _)| idx)
583 .unwrap_or(value.len())
584 };
585
586 let find_slice = &value[byte_offset..];
588 let count = pattern.find_iter(find_slice).count();
589 Ok(count as i64)
590 } else {
591 let count = pattern.find_iter(value).count();
592 Ok(count as i64)
593 }
594}
595
596#[cfg(test)]
597mod tests {
598 use super::*;
599 use arrow::array::{GenericStringArray, StringViewArray};
600 use arrow::datatypes::Field;
601 use datafusion_common::config::ConfigOptions;
602
603 #[test]
604 fn test_regexp_count() {
605 test_case_sensitive_regexp_count_scalar();
606 test_case_sensitive_regexp_count_empty_pattern_scalar();
607 test_case_sensitive_regexp_count_scalar_start();
608 test_case_insensitive_regexp_count_scalar_flags();
609 test_case_sensitive_regexp_count_start_scalar_complex();
610
611 test_case_sensitive_regexp_count_array::<GenericStringArray<i32>>();
612 test_case_sensitive_regexp_count_array::<GenericStringArray<i64>>();
613 test_case_sensitive_regexp_count_array::<StringViewArray>();
614
615 test_case_sensitive_regexp_count_array_start::<GenericStringArray<i32>>();
616 test_case_sensitive_regexp_count_array_start::<GenericStringArray<i64>>();
617 test_case_sensitive_regexp_count_array_start::<StringViewArray>();
618
619 test_case_insensitive_regexp_count_array_flags::<GenericStringArray<i32>>();
620 test_case_insensitive_regexp_count_array_flags::<GenericStringArray<i64>>();
621 test_case_insensitive_regexp_count_array_flags::<StringViewArray>();
622
623 test_case_sensitive_regexp_count_array_complex::<GenericStringArray<i32>>();
624 test_case_sensitive_regexp_count_array_complex::<GenericStringArray<i64>>();
625 test_case_sensitive_regexp_count_array_complex::<StringViewArray>();
626
627 test_case_regexp_count_cache_check::<GenericStringArray<i32>>();
628 }
629
630 fn regexp_count_with_scalar_values(args: &[ScalarValue]) -> Result<ColumnarValue> {
631 let args_values = args
632 .iter()
633 .map(|sv| ColumnarValue::Scalar(sv.clone()))
634 .collect();
635
636 let arg_fields = args
637 .iter()
638 .enumerate()
639 .map(|(idx, a)| Field::new(format!("arg_{idx}"), a.data_type(), true).into())
640 .collect::<Vec<_>>();
641
642 RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
643 args: args_values,
644 arg_fields,
645 number_rows: args.len(),
646 return_field: Field::new("f", Int64, true).into(),
647 config_options: Arc::new(ConfigOptions::default()),
648 })
649 }
650
651 fn test_case_sensitive_regexp_count_scalar() {
652 let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
653 let regex = "abc";
654 let expected: Vec<i64> = vec![0, 1, 2, 1, 3];
655
656 values.iter().enumerate().for_each(|(pos, &v)| {
657 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
659 let regex_sv = ScalarValue::Utf8(Some(regex.to_string()));
660 let expected = expected.get(pos).cloned();
661 let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]);
662 match re {
663 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
664 assert_eq!(v, expected, "regexp_count scalar test failed");
665 }
666 _ => panic!("Unexpected result"),
667 }
668
669 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
671 let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
672 let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]);
673 match re {
674 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
675 assert_eq!(v, expected, "regexp_count scalar test failed");
676 }
677 _ => panic!("Unexpected result"),
678 }
679
680 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
682 let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
683 let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]);
684 match re {
685 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
686 assert_eq!(v, expected, "regexp_count scalar test failed");
687 }
688 _ => panic!("Unexpected result"),
689 }
690 });
691 }
692
693 fn test_case_sensitive_regexp_count_empty_pattern_scalar() {
694 let values = ["", "abc", "abc"];
695 let start_positions = [1, 1, 2];
696 let expected: Vec<i64> = vec![1, 4, 3];
697
698 values
699 .iter()
700 .zip(start_positions.iter())
701 .enumerate()
702 .for_each(|(pos, (&value, &start))| {
703 let expected = expected.get(pos).cloned();
704 let start_sv = ScalarValue::Int64(Some(start));
705
706 let re = regexp_count_with_scalar_values(&[
707 ScalarValue::Utf8(Some(value.to_string())),
708 ScalarValue::Utf8(Some("".to_string())),
709 start_sv.clone(),
710 ]);
711 match re {
712 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
713 assert_eq!(v, expected, "regexp_count scalar test failed");
714 }
715 _ => panic!("Unexpected result"),
716 }
717
718 let re = regexp_count_with_scalar_values(&[
719 ScalarValue::LargeUtf8(Some(value.to_string())),
720 ScalarValue::LargeUtf8(Some("".to_string())),
721 start_sv.clone(),
722 ]);
723 match re {
724 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
725 assert_eq!(v, expected, "regexp_count scalar test failed");
726 }
727 _ => panic!("Unexpected result"),
728 }
729
730 let re = regexp_count_with_scalar_values(&[
731 ScalarValue::Utf8View(Some(value.to_string())),
732 ScalarValue::Utf8View(Some("".to_string())),
733 start_sv,
734 ]);
735 match re {
736 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
737 assert_eq!(v, expected, "regexp_count scalar test failed");
738 }
739 _ => panic!("Unexpected result"),
740 }
741 });
742 }
743
744 fn test_case_sensitive_regexp_count_scalar_start() {
745 let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
746 let regex = "abc";
747 let start = 2;
748 let expected: Vec<i64> = vec![0, 1, 1, 0, 2];
749
750 values.iter().enumerate().for_each(|(pos, &v)| {
751 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
753 let regex_sv = ScalarValue::Utf8(Some(regex.to_string()));
754 let start_sv = ScalarValue::Int64(Some(start));
755 let expected = expected.get(pos).cloned();
756 let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
757 match re {
758 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
759 assert_eq!(v, expected, "regexp_count scalar test failed");
760 }
761 _ => panic!("Unexpected result"),
762 }
763
764 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
766 let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
767 let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
768 match re {
769 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
770 assert_eq!(v, expected, "regexp_count scalar test failed");
771 }
772 _ => panic!("Unexpected result"),
773 }
774
775 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
777 let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
778 let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
779 match re {
780 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
781 assert_eq!(v, expected, "regexp_count scalar test failed");
782 }
783 _ => panic!("Unexpected result"),
784 }
785 });
786 }
787
788 fn test_case_insensitive_regexp_count_scalar_flags() {
789 let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
790 let regex = "abc";
791 let start = 1;
792 let flags = "i";
793 let expected: Vec<i64> = vec![0, 1, 2, 2, 3];
794
795 values.iter().enumerate().for_each(|(pos, &v)| {
796 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
798 let regex_sv = ScalarValue::Utf8(Some(regex.to_string()));
799 let start_sv = ScalarValue::Int64(Some(start));
800 let flags_sv = ScalarValue::Utf8(Some(flags.to_string()));
801 let expected = expected.get(pos).cloned();
802
803 let re = regexp_count_with_scalar_values(&[
804 v_sv,
805 regex_sv,
806 start_sv.clone(),
807 flags_sv.clone(),
808 ]);
809 match re {
810 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
811 assert_eq!(v, expected, "regexp_count scalar test failed");
812 }
813 _ => panic!("Unexpected result"),
814 }
815
816 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
818 let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
819 let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string()));
820
821 let re = regexp_count_with_scalar_values(&[
822 v_sv,
823 regex_sv,
824 start_sv.clone(),
825 flags_sv.clone(),
826 ]);
827 match re {
828 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
829 assert_eq!(v, expected, "regexp_count scalar test failed");
830 }
831 _ => panic!("Unexpected result"),
832 }
833
834 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
836 let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
837 let flags_sv = ScalarValue::Utf8View(Some(flags.to_string()));
838
839 let re = regexp_count_with_scalar_values(&[
840 v_sv,
841 regex_sv,
842 start_sv.clone(),
843 flags_sv.clone(),
844 ]);
845 match re {
846 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
847 assert_eq!(v, expected, "regexp_count scalar test failed");
848 }
849 _ => panic!("Unexpected result"),
850 }
851 });
852 }
853
854 fn test_case_sensitive_regexp_count_array<A>()
855 where
856 A: From<Vec<&'static str>> + Array + 'static,
857 {
858 let values = A::from(vec!["", "aabca", "abcabc", "abcAbcab", "abcabcAbc"]);
859 let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
860
861 let expected = Int64Array::from(vec![1, 1, 2, 2, 2]);
862
863 let re = regexp_count_func(&[Arc::new(values), Arc::new(regex)]).unwrap();
864 assert_eq!(re.as_ref(), &expected);
865 }
866
867 fn test_case_sensitive_regexp_count_array_start<A>()
868 where
869 A: From<Vec<&'static str>> + Array + 'static,
870 {
871 let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]);
872 let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
873 let start = Int64Array::from(vec![1, 2, 3, 4, 5]);
874
875 let expected = Int64Array::from(vec![1, 0, 1, 1, 0]);
876
877 let re = regexp_count_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)])
878 .unwrap();
879 assert_eq!(re.as_ref(), &expected);
880 }
881
882 fn test_case_insensitive_regexp_count_array_flags<A>()
883 where
884 A: From<Vec<&'static str>> + Array + 'static,
885 {
886 let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]);
887 let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
888 let start = Int64Array::from(vec![1]);
889 let flags = A::from(vec!["", "i", "", "", "i"]);
890
891 let expected = Int64Array::from(vec![1, 1, 2, 2, 3]);
892
893 let re = regexp_count_func(&[
894 Arc::new(values),
895 Arc::new(regex),
896 Arc::new(start),
897 Arc::new(flags),
898 ])
899 .unwrap();
900 assert_eq!(re.as_ref(), &expected);
901 }
902
903 fn test_case_sensitive_regexp_count_start_scalar_complex() {
904 let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
905 let regex = ["", "abc", "a", "bc", "ab"];
906 let start = 5;
907 let flags = ["", "i", "", "", "i"];
908 let expected: Vec<i64> = vec![0, 0, 0, 1, 1];
909
910 values.iter().enumerate().for_each(|(pos, &v)| {
911 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
913 let regex_sv = ScalarValue::Utf8(regex.get(pos).map(|s| (*s).to_string()));
914 let start_sv = ScalarValue::Int64(Some(start));
915 let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| (*f).to_string()));
916 let expected = expected.get(pos).cloned();
917 let re = regexp_count_with_scalar_values(&[
918 v_sv,
919 regex_sv,
920 start_sv.clone(),
921 flags_sv.clone(),
922 ]);
923 match re {
924 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
925 assert_eq!(v, expected, "regexp_count scalar test failed");
926 }
927 _ => panic!("Unexpected result"),
928 }
929
930 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
932 let regex_sv =
933 ScalarValue::LargeUtf8(regex.get(pos).map(|s| (*s).to_string()));
934 let flags_sv =
935 ScalarValue::LargeUtf8(flags.get(pos).map(|f| (*f).to_string()));
936 let re = regexp_count_with_scalar_values(&[
937 v_sv,
938 regex_sv,
939 start_sv.clone(),
940 flags_sv.clone(),
941 ]);
942 match re {
943 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
944 assert_eq!(v, expected, "regexp_count scalar test failed");
945 }
946 _ => panic!("Unexpected result"),
947 }
948
949 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
951 let regex_sv =
952 ScalarValue::Utf8View(regex.get(pos).map(|s| (*s).to_string()));
953 let flags_sv =
954 ScalarValue::Utf8View(flags.get(pos).map(|f| (*f).to_string()));
955 let re = regexp_count_with_scalar_values(&[
956 v_sv,
957 regex_sv,
958 start_sv.clone(),
959 flags_sv.clone(),
960 ]);
961 match re {
962 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
963 assert_eq!(v, expected, "regexp_count scalar test failed");
964 }
965 _ => panic!("Unexpected result"),
966 }
967 });
968 }
969
970 fn test_case_sensitive_regexp_count_array_complex<A>()
971 where
972 A: From<Vec<&'static str>> + Array + 'static,
973 {
974 let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]);
975 let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
976 let start = Int64Array::from(vec![1, 2, 3, 4, 5]);
977 let flags = A::from(vec!["", "i", "", "", "i"]);
978
979 let expected = Int64Array::from(vec![1, 1, 1, 1, 1]);
980
981 let re = regexp_count_func(&[
982 Arc::new(values),
983 Arc::new(regex),
984 Arc::new(start),
985 Arc::new(flags),
986 ])
987 .unwrap();
988 assert_eq!(re.as_ref(), &expected);
989 }
990
991 fn test_case_regexp_count_cache_check<A>()
992 where
993 A: From<Vec<&'static str>> + Array + 'static,
994 {
995 let values = A::from(vec!["aaa", "Aaa", "aaa"]);
996 let regex = A::from(vec!["aaa", "aaa", "aaa"]);
997 let start = Int64Array::from(vec![1, 1, 1]);
998 let flags = A::from(vec!["", "i", ""]);
999
1000 let expected = Int64Array::from(vec![1, 1, 1]);
1001
1002 let re = regexp_count_func(&[
1003 Arc::new(values),
1004 Arc::new(regex),
1005 Arc::new(start),
1006 Arc::new(flags),
1007 ])
1008 .unwrap();
1009 assert_eq!(re.as_ref(), &expected);
1010 }
1011}