1use crate::regex::{compile_and_cache_regex, compile_regex};
19use arrow::array::{Array, ArrayRef, AsArray, Datum, Int64Array, StringArrayType};
20use arrow::datatypes::{DataType, Int64Type};
21use arrow::datatypes::{
22 DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View,
23};
24use arrow::error::ArrowError;
25use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
26use datafusion_expr::{
27 ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact,
28 TypeSignature::Uniform, Volatility,
29};
30use datafusion_macros::user_doc;
31use itertools::izip;
32use regex::Regex;
33use std::collections::HashMap;
34use std::sync::Arc;
35
36#[user_doc(
37 doc_section(label = "Regular Expression Functions"),
38 description = "Returns the number of matches that a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has in a string.",
39 syntax_example = "regexp_count(str, regexp[, start, flags])",
40 sql_example = r#"```sql
41> select regexp_count('abcAbAbc', 'abc', 2, 'i');
42+---------------------------------------------------------------+
43| regexp_count(Utf8("abcAbAbc"),Utf8("abc"),Int64(2),Utf8("i")) |
44+---------------------------------------------------------------+
45| 1 |
46+---------------------------------------------------------------+
47```"#,
48 standard_argument(name = "str", prefix = "String"),
49 standard_argument(name = "regexp", prefix = "Regular"),
50 argument(
51 name = "start",
52 description = "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function."
53 ),
54 argument(
55 name = "flags",
56 description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported:
57 - **i**: case-insensitive: letters match both upper and lower case
58 - **m**: multi-line mode: ^ and $ match begin/end of line
59 - **s**: allow . to match \n
60 - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used
61 - **U**: swap the meaning of x* and x*?"#
62 )
63)]
64#[derive(Debug, PartialEq, Eq, Hash)]
65pub struct RegexpCountFunc {
66 signature: Signature,
67}
68
69impl Default for RegexpCountFunc {
70 fn default() -> Self {
71 Self::new()
72 }
73}
74
75impl RegexpCountFunc {
76 pub fn new() -> Self {
77 Self {
78 signature: Signature::one_of(
79 vec![
80 Uniform(2, vec![Utf8View, LargeUtf8, Utf8]),
81 Exact(vec![Utf8View, Utf8View, Int64]),
82 Exact(vec![LargeUtf8, LargeUtf8, Int64]),
83 Exact(vec![Utf8, Utf8, Int64]),
84 Exact(vec![Utf8View, Utf8View, Int64, Utf8View]),
85 Exact(vec![LargeUtf8, LargeUtf8, Int64, LargeUtf8]),
86 Exact(vec![Utf8, Utf8, Int64, Utf8]),
87 ],
88 Volatility::Immutable,
89 ),
90 }
91 }
92}
93
94impl ScalarUDFImpl for RegexpCountFunc {
95 fn as_any(&self) -> &dyn std::any::Any {
96 self
97 }
98
99 fn name(&self) -> &str {
100 "regexp_count"
101 }
102
103 fn signature(&self) -> &Signature {
104 &self.signature
105 }
106
107 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
108 Ok(Int64)
109 }
110
111 fn invoke_with_args(
112 &self,
113 args: datafusion_expr::ScalarFunctionArgs,
114 ) -> Result<ColumnarValue> {
115 let args = &args.args;
116
117 let len = args
118 .iter()
119 .fold(Option::<usize>::None, |acc, arg| match arg {
120 ColumnarValue::Scalar(_) => acc,
121 ColumnarValue::Array(a) => Some(a.len()),
122 });
123
124 let is_scalar = len.is_none();
125 let inferred_length = len.unwrap_or(1);
126 let args = args
127 .iter()
128 .map(|arg| arg.to_array(inferred_length))
129 .collect::<Result<Vec<_>>>()?;
130
131 let result = regexp_count_func(&args);
132 if is_scalar {
133 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
135 result.map(ColumnarValue::Scalar)
136 } else {
137 result.map(ColumnarValue::Array)
138 }
139 }
140
141 fn documentation(&self) -> Option<&Documentation> {
142 self.doc()
143 }
144}
145
146pub fn regexp_count_func(args: &[ArrayRef]) -> Result<ArrayRef> {
147 let args_len = args.len();
148 if !(2..=4).contains(&args_len) {
149 return exec_err!(
150 "regexp_count was called with {args_len} arguments. It requires at least 2 and at most 4."
151 );
152 }
153
154 let values = &args[0];
155 match values.data_type() {
156 Utf8 | LargeUtf8 | Utf8View => (),
157 other => {
158 return internal_err!(
159 "Unsupported data type {other:?} for function regexp_count"
160 );
161 }
162 }
163
164 regexp_count(
165 values,
166 &args[1],
167 if args_len > 2 { Some(&args[2]) } else { None },
168 if args_len > 3 { Some(&args[3]) } else { None },
169 )
170 .map_err(|e| e.into())
171}
172
173fn regexp_count(
189 values: &dyn Array,
190 regex_array: &dyn Datum,
191 start_array: Option<&dyn Datum>,
192 flags_array: Option<&dyn Datum>,
193) -> Result<ArrayRef, ArrowError> {
194 let (regex_array, is_regex_scalar) = regex_array.get();
195 let (start_array, is_start_scalar) = start_array.map_or((None, true), |start| {
196 let (start, is_start_scalar) = start.get();
197 (Some(start), is_start_scalar)
198 });
199 let (flags_array, is_flags_scalar) = flags_array.map_or((None, true), |flags| {
200 let (flags, is_flags_scalar) = flags.get();
201 (Some(flags), is_flags_scalar)
202 });
203
204 match (values.data_type(), regex_array.data_type(), flags_array) {
205 (Utf8, Utf8, None) => regexp_count_inner(
206 &values.as_string::<i32>(),
207 ®ex_array.as_string::<i32>(),
208 is_regex_scalar,
209 start_array.map(|start| start.as_primitive::<Int64Type>()),
210 is_start_scalar,
211 None,
212 is_flags_scalar,
213 ),
214 (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_count_inner(
215 &values.as_string::<i32>(),
216 ®ex_array.as_string::<i32>(),
217 is_regex_scalar,
218 start_array.map(|start| start.as_primitive::<Int64Type>()),
219 is_start_scalar,
220 Some(&flags_array.as_string::<i32>()),
221 is_flags_scalar,
222 ),
223 (LargeUtf8, LargeUtf8, None) => regexp_count_inner(
224 &values.as_string::<i64>(),
225 ®ex_array.as_string::<i64>(),
226 is_regex_scalar,
227 start_array.map(|start| start.as_primitive::<Int64Type>()),
228 is_start_scalar,
229 None,
230 is_flags_scalar,
231 ),
232 (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_count_inner(
233 &values.as_string::<i64>(),
234 ®ex_array.as_string::<i64>(),
235 is_regex_scalar,
236 start_array.map(|start| start.as_primitive::<Int64Type>()),
237 is_start_scalar,
238 Some(&flags_array.as_string::<i64>()),
239 is_flags_scalar,
240 ),
241 (Utf8View, Utf8View, None) => regexp_count_inner(
242 &values.as_string_view(),
243 ®ex_array.as_string_view(),
244 is_regex_scalar,
245 start_array.map(|start| start.as_primitive::<Int64Type>()),
246 is_start_scalar,
247 None,
248 is_flags_scalar,
249 ),
250 (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_count_inner(
251 &values.as_string_view(),
252 ®ex_array.as_string_view(),
253 is_regex_scalar,
254 start_array.map(|start| start.as_primitive::<Int64Type>()),
255 is_start_scalar,
256 Some(&flags_array.as_string_view()),
257 is_flags_scalar,
258 ),
259 _ => Err(ArrowError::ComputeError(
260 "regexp_count() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(),
261 )),
262 }
263}
264
265fn regexp_count_inner<'a, S>(
266 values: &S,
267 regex_array: &S,
268 is_regex_scalar: bool,
269 start_array: Option<&Int64Array>,
270 is_start_scalar: bool,
271 flags_array: Option<&S>,
272 is_flags_scalar: bool,
273) -> Result<ArrayRef, ArrowError>
274where
275 S: StringArrayType<'a>,
276{
277 let (regex_scalar, is_regex_scalar) = if is_regex_scalar || regex_array.len() == 1 {
278 (Some(regex_array.value(0)), true)
279 } else {
280 (None, false)
281 };
282
283 let (start_array, start_scalar, is_start_scalar) =
284 if let Some(start_array) = start_array {
285 if is_start_scalar || start_array.len() == 1 {
286 (None, Some(start_array.value(0)), true)
287 } else {
288 (Some(start_array), None, false)
289 }
290 } else {
291 (None, Some(1), true)
292 };
293
294 let (flags_array, flags_scalar, is_flags_scalar) =
295 if let Some(flags_array) = flags_array {
296 if is_flags_scalar || flags_array.len() == 1 {
297 (None, Some(flags_array.value(0)), true)
298 } else {
299 (Some(flags_array), None, false)
300 }
301 } else {
302 (None, None, true)
303 };
304
305 let mut regex_cache = HashMap::new();
306
307 match (is_regex_scalar, is_start_scalar, is_flags_scalar) {
308 (true, true, true) => {
309 let regex = match regex_scalar {
310 None | Some("") => {
311 return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
312 }
313 Some(regex) => regex,
314 };
315
316 let pattern = compile_regex(regex, flags_scalar)?;
317
318 Ok(Arc::new(
319 values
320 .iter()
321 .map(|value| count_matches(value, &pattern, start_scalar))
322 .collect::<Result<Int64Array, ArrowError>>()?,
323 ))
324 }
325 (true, true, false) => {
326 let regex = match regex_scalar {
327 None | Some("") => {
328 return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
329 }
330 Some(regex) => regex,
331 };
332
333 let flags_array = flags_array.unwrap();
334 if values.len() != flags_array.len() {
335 return Err(ArrowError::ComputeError(format!(
336 "flags_array must be the same length as values array; got {} and {}",
337 flags_array.len(),
338 values.len(),
339 )));
340 }
341
342 Ok(Arc::new(
343 values
344 .iter()
345 .zip(flags_array.iter())
346 .map(|(value, flags)| {
347 let pattern =
348 compile_and_cache_regex(regex, flags, &mut regex_cache)?;
349 count_matches(value, pattern, start_scalar)
350 })
351 .collect::<Result<Int64Array, ArrowError>>()?,
352 ))
353 }
354 (true, false, true) => {
355 let regex = match regex_scalar {
356 None | Some("") => {
357 return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
358 }
359 Some(regex) => regex,
360 };
361
362 let pattern = compile_regex(regex, flags_scalar)?;
363
364 let start_array = start_array.unwrap();
365
366 Ok(Arc::new(
367 values
368 .iter()
369 .zip(start_array.iter())
370 .map(|(value, start)| count_matches(value, &pattern, start))
371 .collect::<Result<Int64Array, ArrowError>>()?,
372 ))
373 }
374 (true, false, false) => {
375 let regex = match regex_scalar {
376 None | Some("") => {
377 return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
378 }
379 Some(regex) => regex,
380 };
381
382 let flags_array = flags_array.unwrap();
383 if values.len() != flags_array.len() {
384 return Err(ArrowError::ComputeError(format!(
385 "flags_array must be the same length as values array; got {} and {}",
386 flags_array.len(),
387 values.len(),
388 )));
389 }
390
391 Ok(Arc::new(
392 izip!(
393 values.iter(),
394 start_array.unwrap().iter(),
395 flags_array.iter()
396 )
397 .map(|(value, start, flags)| {
398 let pattern =
399 compile_and_cache_regex(regex, flags, &mut regex_cache)?;
400
401 count_matches(value, pattern, start)
402 })
403 .collect::<Result<Int64Array, ArrowError>>()?,
404 ))
405 }
406 (false, true, true) => {
407 if values.len() != regex_array.len() {
408 return Err(ArrowError::ComputeError(format!(
409 "regex_array must be the same length as values array; got {} and {}",
410 regex_array.len(),
411 values.len(),
412 )));
413 }
414
415 Ok(Arc::new(
416 values
417 .iter()
418 .zip(regex_array.iter())
419 .map(|(value, regex)| {
420 let regex = match regex {
421 None | Some("") => return Ok(0),
422 Some(regex) => regex,
423 };
424
425 let pattern = compile_and_cache_regex(
426 regex,
427 flags_scalar,
428 &mut regex_cache,
429 )?;
430 count_matches(value, pattern, start_scalar)
431 })
432 .collect::<Result<Int64Array, ArrowError>>()?,
433 ))
434 }
435 (false, true, false) => {
436 if values.len() != regex_array.len() {
437 return Err(ArrowError::ComputeError(format!(
438 "regex_array must be the same length as values array; got {} and {}",
439 regex_array.len(),
440 values.len(),
441 )));
442 }
443
444 let flags_array = flags_array.unwrap();
445 if values.len() != flags_array.len() {
446 return Err(ArrowError::ComputeError(format!(
447 "flags_array must be the same length as values array; got {} and {}",
448 flags_array.len(),
449 values.len(),
450 )));
451 }
452
453 Ok(Arc::new(
454 izip!(values.iter(), regex_array.iter(), flags_array.iter())
455 .map(|(value, regex, flags)| {
456 let regex = match regex {
457 None | Some("") => return Ok(0),
458 Some(regex) => regex,
459 };
460
461 let pattern =
462 compile_and_cache_regex(regex, flags, &mut regex_cache)?;
463
464 count_matches(value, pattern, start_scalar)
465 })
466 .collect::<Result<Int64Array, ArrowError>>()?,
467 ))
468 }
469 (false, false, true) => {
470 if values.len() != regex_array.len() {
471 return Err(ArrowError::ComputeError(format!(
472 "regex_array must be the same length as values array; got {} and {}",
473 regex_array.len(),
474 values.len(),
475 )));
476 }
477
478 let start_array = start_array.unwrap();
479 if values.len() != start_array.len() {
480 return Err(ArrowError::ComputeError(format!(
481 "start_array must be the same length as values array; got {} and {}",
482 start_array.len(),
483 values.len(),
484 )));
485 }
486
487 Ok(Arc::new(
488 izip!(values.iter(), regex_array.iter(), start_array.iter())
489 .map(|(value, regex, start)| {
490 let regex = match regex {
491 None | Some("") => return Ok(0),
492 Some(regex) => regex,
493 };
494
495 let pattern = compile_and_cache_regex(
496 regex,
497 flags_scalar,
498 &mut regex_cache,
499 )?;
500 count_matches(value, pattern, start)
501 })
502 .collect::<Result<Int64Array, ArrowError>>()?,
503 ))
504 }
505 (false, false, false) => {
506 if values.len() != regex_array.len() {
507 return Err(ArrowError::ComputeError(format!(
508 "regex_array must be the same length as values array; got {} and {}",
509 regex_array.len(),
510 values.len(),
511 )));
512 }
513
514 let start_array = start_array.unwrap();
515 if values.len() != start_array.len() {
516 return Err(ArrowError::ComputeError(format!(
517 "start_array must be the same length as values array; got {} and {}",
518 start_array.len(),
519 values.len(),
520 )));
521 }
522
523 let flags_array = flags_array.unwrap();
524 if values.len() != flags_array.len() {
525 return Err(ArrowError::ComputeError(format!(
526 "flags_array must be the same length as values array; got {} and {}",
527 flags_array.len(),
528 values.len(),
529 )));
530 }
531
532 Ok(Arc::new(
533 izip!(
534 values.iter(),
535 regex_array.iter(),
536 start_array.iter(),
537 flags_array.iter()
538 )
539 .map(|(value, regex, start, flags)| {
540 let regex = match regex {
541 None | Some("") => return Ok(0),
542 Some(regex) => regex,
543 };
544
545 let pattern =
546 compile_and_cache_regex(regex, flags, &mut regex_cache)?;
547 count_matches(value, pattern, start)
548 })
549 .collect::<Result<Int64Array, ArrowError>>()?,
550 ))
551 }
552 }
553}
554
555fn count_matches(
556 value: Option<&str>,
557 pattern: &Regex,
558 start: Option<i64>,
559) -> Result<i64, ArrowError> {
560 let value = match value {
561 None | Some("") => return Ok(0),
562 Some(value) => value,
563 };
564
565 if let Some(start) = start {
566 if start < 1 {
567 return Err(ArrowError::ComputeError(
568 "regexp_count() requires start to be 1 based".to_string(),
569 ));
570 }
571
572 let byte_offset = value
574 .char_indices()
575 .nth((start as usize).saturating_sub(1))
576 .map(|(idx, _)| idx)
577 .unwrap_or(value.len());
578
579 let find_slice = &value[byte_offset..];
581 let count = pattern.find_iter(find_slice).count();
582 Ok(count as i64)
583 } else {
584 let count = pattern.find_iter(value).count();
585 Ok(count as i64)
586 }
587}
588
589#[cfg(test)]
590mod tests {
591 use super::*;
592 use arrow::array::{GenericStringArray, StringViewArray};
593 use arrow::datatypes::Field;
594 use datafusion_common::config::ConfigOptions;
595 use datafusion_expr::ScalarFunctionArgs;
596
597 #[test]
598 fn test_regexp_count() {
599 test_case_sensitive_regexp_count_scalar();
600 test_case_sensitive_regexp_count_scalar_start();
601 test_case_insensitive_regexp_count_scalar_flags();
602 test_case_sensitive_regexp_count_start_scalar_complex();
603
604 test_case_sensitive_regexp_count_array::<GenericStringArray<i32>>();
605 test_case_sensitive_regexp_count_array::<GenericStringArray<i64>>();
606 test_case_sensitive_regexp_count_array::<StringViewArray>();
607
608 test_case_sensitive_regexp_count_array_start::<GenericStringArray<i32>>();
609 test_case_sensitive_regexp_count_array_start::<GenericStringArray<i64>>();
610 test_case_sensitive_regexp_count_array_start::<StringViewArray>();
611
612 test_case_insensitive_regexp_count_array_flags::<GenericStringArray<i32>>();
613 test_case_insensitive_regexp_count_array_flags::<GenericStringArray<i64>>();
614 test_case_insensitive_regexp_count_array_flags::<StringViewArray>();
615
616 test_case_sensitive_regexp_count_array_complex::<GenericStringArray<i32>>();
617 test_case_sensitive_regexp_count_array_complex::<GenericStringArray<i64>>();
618 test_case_sensitive_regexp_count_array_complex::<StringViewArray>();
619
620 test_case_regexp_count_cache_check::<GenericStringArray<i32>>();
621 }
622
623 fn regexp_count_with_scalar_values(args: &[ScalarValue]) -> Result<ColumnarValue> {
624 let args_values = args
625 .iter()
626 .map(|sv| ColumnarValue::Scalar(sv.clone()))
627 .collect();
628
629 let arg_fields = args
630 .iter()
631 .enumerate()
632 .map(|(idx, a)| Field::new(format!("arg_{idx}"), a.data_type(), true).into())
633 .collect::<Vec<_>>();
634
635 RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
636 args: args_values,
637 arg_fields,
638 number_rows: args.len(),
639 return_field: Field::new("f", Int64, true).into(),
640 config_options: Arc::new(ConfigOptions::default()),
641 })
642 }
643
644 fn test_case_sensitive_regexp_count_scalar() {
645 let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
646 let regex = "abc";
647 let expected: Vec<i64> = vec![0, 1, 2, 1, 3];
648
649 values.iter().enumerate().for_each(|(pos, &v)| {
650 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
652 let regex_sv = ScalarValue::Utf8(Some(regex.to_string()));
653 let expected = expected.get(pos).cloned();
654 let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]);
655 match re {
656 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
657 assert_eq!(v, expected, "regexp_count scalar test failed");
658 }
659 _ => panic!("Unexpected result"),
660 }
661
662 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
664 let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
665 let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]);
666 match re {
667 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
668 assert_eq!(v, expected, "regexp_count scalar test failed");
669 }
670 _ => panic!("Unexpected result"),
671 }
672
673 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
675 let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
676 let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]);
677 match re {
678 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
679 assert_eq!(v, expected, "regexp_count scalar test failed");
680 }
681 _ => panic!("Unexpected result"),
682 }
683 });
684 }
685
686 fn test_case_sensitive_regexp_count_scalar_start() {
687 let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
688 let regex = "abc";
689 let start = 2;
690 let expected: Vec<i64> = vec![0, 1, 1, 0, 2];
691
692 values.iter().enumerate().for_each(|(pos, &v)| {
693 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
695 let regex_sv = ScalarValue::Utf8(Some(regex.to_string()));
696 let start_sv = ScalarValue::Int64(Some(start));
697 let expected = expected.get(pos).cloned();
698 let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
699 match re {
700 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
701 assert_eq!(v, expected, "regexp_count scalar test failed");
702 }
703 _ => panic!("Unexpected result"),
704 }
705
706 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
708 let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
709 let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
710 match re {
711 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
712 assert_eq!(v, expected, "regexp_count scalar test failed");
713 }
714 _ => panic!("Unexpected result"),
715 }
716
717 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
719 let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
720 let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
721 match re {
722 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
723 assert_eq!(v, expected, "regexp_count scalar test failed");
724 }
725 _ => panic!("Unexpected result"),
726 }
727 });
728 }
729
730 fn test_case_insensitive_regexp_count_scalar_flags() {
731 let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
732 let regex = "abc";
733 let start = 1;
734 let flags = "i";
735 let expected: Vec<i64> = vec![0, 1, 2, 2, 3];
736
737 values.iter().enumerate().for_each(|(pos, &v)| {
738 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
740 let regex_sv = ScalarValue::Utf8(Some(regex.to_string()));
741 let start_sv = ScalarValue::Int64(Some(start));
742 let flags_sv = ScalarValue::Utf8(Some(flags.to_string()));
743 let expected = expected.get(pos).cloned();
744
745 let re = regexp_count_with_scalar_values(&[
746 v_sv,
747 regex_sv,
748 start_sv.clone(),
749 flags_sv.clone(),
750 ]);
751 match re {
752 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
753 assert_eq!(v, expected, "regexp_count scalar test failed");
754 }
755 _ => panic!("Unexpected result"),
756 }
757
758 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
760 let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
761 let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string()));
762
763 let re = regexp_count_with_scalar_values(&[
764 v_sv,
765 regex_sv,
766 start_sv.clone(),
767 flags_sv.clone(),
768 ]);
769 match re {
770 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
771 assert_eq!(v, expected, "regexp_count scalar test failed");
772 }
773 _ => panic!("Unexpected result"),
774 }
775
776 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
778 let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
779 let flags_sv = ScalarValue::Utf8View(Some(flags.to_string()));
780
781 let re = regexp_count_with_scalar_values(&[
782 v_sv,
783 regex_sv,
784 start_sv.clone(),
785 flags_sv.clone(),
786 ]);
787 match re {
788 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
789 assert_eq!(v, expected, "regexp_count scalar test failed");
790 }
791 _ => panic!("Unexpected result"),
792 }
793 });
794 }
795
796 fn test_case_sensitive_regexp_count_array<A>()
797 where
798 A: From<Vec<&'static str>> + Array + 'static,
799 {
800 let values = A::from(vec!["", "aabca", "abcabc", "abcAbcab", "abcabcAbc"]);
801 let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
802
803 let expected = Int64Array::from(vec![0, 1, 2, 2, 2]);
804
805 let re = regexp_count_func(&[Arc::new(values), Arc::new(regex)]).unwrap();
806 assert_eq!(re.as_ref(), &expected);
807 }
808
809 fn test_case_sensitive_regexp_count_array_start<A>()
810 where
811 A: From<Vec<&'static str>> + Array + 'static,
812 {
813 let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]);
814 let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
815 let start = Int64Array::from(vec![1, 2, 3, 4, 5]);
816
817 let expected = Int64Array::from(vec![0, 0, 1, 1, 0]);
818
819 let re = regexp_count_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)])
820 .unwrap();
821 assert_eq!(re.as_ref(), &expected);
822 }
823
824 fn test_case_insensitive_regexp_count_array_flags<A>()
825 where
826 A: From<Vec<&'static str>> + Array + 'static,
827 {
828 let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]);
829 let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
830 let start = Int64Array::from(vec![1]);
831 let flags = A::from(vec!["", "i", "", "", "i"]);
832
833 let expected = Int64Array::from(vec![0, 1, 2, 2, 3]);
834
835 let re = regexp_count_func(&[
836 Arc::new(values),
837 Arc::new(regex),
838 Arc::new(start),
839 Arc::new(flags),
840 ])
841 .unwrap();
842 assert_eq!(re.as_ref(), &expected);
843 }
844
845 fn test_case_sensitive_regexp_count_start_scalar_complex() {
846 let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
847 let regex = ["", "abc", "a", "bc", "ab"];
848 let start = 5;
849 let flags = ["", "i", "", "", "i"];
850 let expected: Vec<i64> = vec![0, 0, 0, 1, 1];
851
852 values.iter().enumerate().for_each(|(pos, &v)| {
853 let v_sv = ScalarValue::Utf8(Some(v.to_string()));
855 let regex_sv = ScalarValue::Utf8(regex.get(pos).map(|s| (*s).to_string()));
856 let start_sv = ScalarValue::Int64(Some(start));
857 let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| (*f).to_string()));
858 let expected = expected.get(pos).cloned();
859 let re = regexp_count_with_scalar_values(&[
860 v_sv,
861 regex_sv,
862 start_sv.clone(),
863 flags_sv.clone(),
864 ]);
865 match re {
866 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
867 assert_eq!(v, expected, "regexp_count scalar test failed");
868 }
869 _ => panic!("Unexpected result"),
870 }
871
872 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
874 let regex_sv =
875 ScalarValue::LargeUtf8(regex.get(pos).map(|s| (*s).to_string()));
876 let flags_sv =
877 ScalarValue::LargeUtf8(flags.get(pos).map(|f| (*f).to_string()));
878 let re = regexp_count_with_scalar_values(&[
879 v_sv,
880 regex_sv,
881 start_sv.clone(),
882 flags_sv.clone(),
883 ]);
884 match re {
885 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
886 assert_eq!(v, expected, "regexp_count scalar test failed");
887 }
888 _ => panic!("Unexpected result"),
889 }
890
891 let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
893 let regex_sv =
894 ScalarValue::Utf8View(regex.get(pos).map(|s| (*s).to_string()));
895 let flags_sv =
896 ScalarValue::Utf8View(flags.get(pos).map(|f| (*f).to_string()));
897 let re = regexp_count_with_scalar_values(&[
898 v_sv,
899 regex_sv,
900 start_sv.clone(),
901 flags_sv.clone(),
902 ]);
903 match re {
904 Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
905 assert_eq!(v, expected, "regexp_count scalar test failed");
906 }
907 _ => panic!("Unexpected result"),
908 }
909 });
910 }
911
912 fn test_case_sensitive_regexp_count_array_complex<A>()
913 where
914 A: From<Vec<&'static str>> + Array + 'static,
915 {
916 let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]);
917 let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
918 let start = Int64Array::from(vec![1, 2, 3, 4, 5]);
919 let flags = A::from(vec!["", "i", "", "", "i"]);
920
921 let expected = Int64Array::from(vec![0, 1, 1, 1, 1]);
922
923 let re = regexp_count_func(&[
924 Arc::new(values),
925 Arc::new(regex),
926 Arc::new(start),
927 Arc::new(flags),
928 ])
929 .unwrap();
930 assert_eq!(re.as_ref(), &expected);
931 }
932
933 fn test_case_regexp_count_cache_check<A>()
934 where
935 A: From<Vec<&'static str>> + Array + 'static,
936 {
937 let values = A::from(vec!["aaa", "Aaa", "aaa"]);
938 let regex = A::from(vec!["aaa", "aaa", "aaa"]);
939 let start = Int64Array::from(vec![1, 1, 1]);
940 let flags = A::from(vec!["", "i", ""]);
941
942 let expected = Int64Array::from(vec![1, 1, 1]);
943
944 let re = regexp_count_func(&[
945 Arc::new(values),
946 Arc::new(regex),
947 Arc::new(start),
948 Arc::new(flags),
949 ])
950 .unwrap();
951 assert_eq!(re.as_ref(), &expected);
952 }
953}