Skip to main content

datafusion_functions/regex/
regexpcount.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::regex::{compile_and_cache_regex, compile_regex};
19use arrow::array::{Array, ArrayRef, AsArray, Datum, Int64Array, StringArrayType};
20use arrow::datatypes::{DataType, Int64Type};
21use arrow::datatypes::{
22    DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View,
23};
24use arrow::error::ArrowError;
25use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
26use datafusion_expr::{
27    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
28    TypeSignature::Exact, TypeSignature::Uniform, Volatility,
29};
30use datafusion_macros::user_doc;
31use itertools::izip;
32use regex::Regex;
33use std::collections::HashMap;
34use std::sync::Arc;
35
36#[user_doc(
37    doc_section(label = "Regular Expression Functions"),
38    description = "Returns the number of matches that a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has in a string.",
39    syntax_example = "regexp_count(str, regexp[, start, flags])",
40    sql_example = r#"```sql
41> select regexp_count('abcAbAbc', 'abc', 2, 'i');
42+---------------------------------------------------------------+
43| regexp_count(Utf8("abcAbAbc"),Utf8("abc"),Int64(2),Utf8("i")) |
44+---------------------------------------------------------------+
45| 1                                                             |
46+---------------------------------------------------------------+
47```"#,
48    standard_argument(name = "str", prefix = "String"),
49    standard_argument(name = "regexp", prefix = "Regular"),
50    argument(
51        name = "start",
52        description = "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function."
53    ),
54    argument(
55        name = "flags",
56        description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported:
57  - **i**: case-insensitive: letters match both upper and lower case
58  - **m**: multi-line mode: ^ and $ match begin/end of line
59  - **s**: allow . to match \n
60  - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used
61  - **U**: swap the meaning of x* and x*?"#
62    )
63)]
64#[derive(Debug, PartialEq, Eq, Hash)]
65pub struct RegexpCountFunc {
66    signature: Signature,
67}
68
69impl Default for RegexpCountFunc {
70    fn default() -> Self {
71        Self::new()
72    }
73}
74
75impl RegexpCountFunc {
76    pub fn new() -> Self {
77        Self {
78            signature: Signature::one_of(
79                vec![
80                    Uniform(2, vec![Utf8View, LargeUtf8, Utf8]),
81                    Exact(vec![Utf8View, Utf8View, Int64]),
82                    Exact(vec![LargeUtf8, LargeUtf8, Int64]),
83                    Exact(vec![Utf8, Utf8, Int64]),
84                    Exact(vec![Utf8View, Utf8View, Int64, Utf8View]),
85                    Exact(vec![LargeUtf8, LargeUtf8, Int64, LargeUtf8]),
86                    Exact(vec![Utf8, Utf8, Int64, Utf8]),
87                ],
88                Volatility::Immutable,
89            ),
90        }
91    }
92}
93
94impl ScalarUDFImpl for RegexpCountFunc {
95    fn name(&self) -> &str {
96        "regexp_count"
97    }
98
99    fn signature(&self) -> &Signature {
100        &self.signature
101    }
102
103    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
104        Ok(Int64)
105    }
106
107    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
108        let args = &args.args;
109
110        let len = args
111            .iter()
112            .fold(Option::<usize>::None, |acc, arg| match arg {
113                ColumnarValue::Scalar(_) => acc,
114                ColumnarValue::Array(a) => Some(a.len()),
115            });
116
117        let is_scalar = len.is_none();
118        let inferred_length = len.unwrap_or(1);
119        let args = args
120            .iter()
121            .map(|arg| arg.to_array(inferred_length))
122            .collect::<Result<Vec<_>>>()?;
123
124        let result = regexp_count_func(&args);
125        if is_scalar {
126            // If all inputs are scalar, keeps output as scalar
127            let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
128            result.map(ColumnarValue::Scalar)
129        } else {
130            result.map(ColumnarValue::Array)
131        }
132    }
133
134    fn documentation(&self) -> Option<&Documentation> {
135        self.doc()
136    }
137}
138
139pub fn regexp_count_func(args: &[ArrayRef]) -> Result<ArrayRef> {
140    let args_len = args.len();
141    if !(2..=4).contains(&args_len) {
142        return exec_err!(
143            "regexp_count was called with {args_len} arguments. It requires at least 2 and at most 4."
144        );
145    }
146
147    let values = &args[0];
148    match values.data_type() {
149        Utf8 | LargeUtf8 | Utf8View => (),
150        other => {
151            return internal_err!(
152                "Unsupported data type {other:?} for function regexp_count"
153            );
154        }
155    }
156
157    regexp_count(
158        values,
159        &args[1],
160        if args_len > 2 { Some(&args[2]) } else { None },
161        if args_len > 3 { Some(&args[3]) } else { None },
162    )
163    .map_err(|e| e.into())
164}
165
166/// `arrow-rs` style implementation of `regexp_count` function.
167/// This function `regexp_count` is responsible for counting the occurrences of a regular expression pattern
168/// within a string array. It supports optional start positions and flags for case insensitivity.
169///
170/// The function accepts a variable number of arguments:
171/// - `values`: The array of strings to search within.
172/// - `regex_array`: The array of regular expression patterns to search for.
173/// - `start_array` (optional): The array of start positions for the search.
174/// - `flags_array` (optional): The array of flags to modify the search behavior (e.g., case insensitivity).
175///
176/// The function handles different combinations of scalar and array inputs for the regex patterns, start positions,
177/// and flags. It uses a cache to store compiled regular expressions for efficiency.
178///
179/// # Errors
180/// Returns an error if the input arrays have mismatched lengths or if the regular expression fails to compile.
181fn regexp_count(
182    values: &dyn Array,
183    regex_array: &dyn Datum,
184    start_array: Option<&dyn Datum>,
185    flags_array: Option<&dyn Datum>,
186) -> Result<ArrayRef, ArrowError> {
187    let (regex_array, is_regex_scalar) = regex_array.get();
188    let (start_array, is_start_scalar) = start_array.map_or((None, true), |start| {
189        let (start, is_start_scalar) = start.get();
190        (Some(start), is_start_scalar)
191    });
192    let (flags_array, is_flags_scalar) = flags_array.map_or((None, true), |flags| {
193        let (flags, is_flags_scalar) = flags.get();
194        (Some(flags), is_flags_scalar)
195    });
196
197    match (values.data_type(), regex_array.data_type(), flags_array) {
198        (Utf8, Utf8, None) => regexp_count_inner(
199            &values.as_string::<i32>(),
200            &regex_array.as_string::<i32>(),
201            is_regex_scalar,
202            start_array.map(|start| start.as_primitive::<Int64Type>()),
203            is_start_scalar,
204            None,
205            is_flags_scalar,
206        ),
207        (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_count_inner(
208            &values.as_string::<i32>(),
209            &regex_array.as_string::<i32>(),
210            is_regex_scalar,
211            start_array.map(|start| start.as_primitive::<Int64Type>()),
212            is_start_scalar,
213            Some(&flags_array.as_string::<i32>()),
214            is_flags_scalar,
215        ),
216        (LargeUtf8, LargeUtf8, None) => regexp_count_inner(
217            &values.as_string::<i64>(),
218            &regex_array.as_string::<i64>(),
219            is_regex_scalar,
220            start_array.map(|start| start.as_primitive::<Int64Type>()),
221            is_start_scalar,
222            None,
223            is_flags_scalar,
224        ),
225        (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_count_inner(
226            &values.as_string::<i64>(),
227            &regex_array.as_string::<i64>(),
228            is_regex_scalar,
229            start_array.map(|start| start.as_primitive::<Int64Type>()),
230            is_start_scalar,
231            Some(&flags_array.as_string::<i64>()),
232            is_flags_scalar,
233        ),
234        (Utf8View, Utf8View, None) => regexp_count_inner(
235            &values.as_string_view(),
236            &regex_array.as_string_view(),
237            is_regex_scalar,
238            start_array.map(|start| start.as_primitive::<Int64Type>()),
239            is_start_scalar,
240            None,
241            is_flags_scalar,
242        ),
243        (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_count_inner(
244            &values.as_string_view(),
245            &regex_array.as_string_view(),
246            is_regex_scalar,
247            start_array.map(|start| start.as_primitive::<Int64Type>()),
248            is_start_scalar,
249            Some(&flags_array.as_string_view()),
250            is_flags_scalar,
251        ),
252        _ => Err(ArrowError::ComputeError(
253            "regexp_count() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(),
254        )),
255    }
256}
257
258fn regexp_count_inner<'a, S>(
259    values: &S,
260    regex_array: &S,
261    is_regex_scalar: bool,
262    start_array: Option<&Int64Array>,
263    is_start_scalar: bool,
264    flags_array: Option<&S>,
265    is_flags_scalar: bool,
266) -> Result<ArrayRef, ArrowError>
267where
268    S: StringArrayType<'a>,
269{
270    let (regex_scalar, is_regex_scalar) = if is_regex_scalar || regex_array.len() == 1 {
271        (
272            (!regex_array.is_null(0)).then(|| regex_array.value(0)),
273            true,
274        )
275    } else {
276        (None, false)
277    };
278
279    let (start_array, start_scalar, is_start_scalar) =
280        if let Some(start_array) = start_array {
281            if is_start_scalar || start_array.len() == 1 {
282                (None, Some(start_array.value(0)), true)
283            } else {
284                (Some(start_array), None, false)
285            }
286        } else {
287            (None, Some(1), true)
288        };
289
290    let (flags_array, flags_scalar, is_flags_scalar) =
291        if let Some(flags_array) = flags_array {
292            if is_flags_scalar || flags_array.len() == 1 {
293                (None, Some(flags_array.value(0)), true)
294            } else {
295                (Some(flags_array), None, false)
296            }
297        } else {
298            (None, None, true)
299        };
300
301    let mut regex_cache = HashMap::new();
302
303    match (is_regex_scalar, is_start_scalar, is_flags_scalar) {
304        (true, true, true) => {
305            let regex = match regex_scalar {
306                None => {
307                    return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
308                }
309                Some(regex) => regex,
310            };
311
312            let pattern = compile_regex(regex, flags_scalar)?;
313
314            Ok(Arc::new(
315                values
316                    .iter()
317                    .map(|value| count_matches(value, &pattern, start_scalar))
318                    .collect::<Result<Int64Array, ArrowError>>()?,
319            ))
320        }
321        (true, true, false) => {
322            let regex = match regex_scalar {
323                None => {
324                    return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
325                }
326                Some(regex) => regex,
327            };
328
329            let flags_array = flags_array.unwrap();
330            if values.len() != flags_array.len() {
331                return Err(ArrowError::ComputeError(format!(
332                    "flags_array must be the same length as values array; got {} and {}",
333                    flags_array.len(),
334                    values.len(),
335                )));
336            }
337
338            Ok(Arc::new(
339                values
340                    .iter()
341                    .zip(flags_array.iter())
342                    .map(|(value, flags)| {
343                        let pattern =
344                            compile_and_cache_regex(regex, flags, &mut regex_cache)?;
345                        count_matches(value, pattern, start_scalar)
346                    })
347                    .collect::<Result<Int64Array, ArrowError>>()?,
348            ))
349        }
350        (true, false, true) => {
351            let regex = match regex_scalar {
352                None => {
353                    return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
354                }
355                Some(regex) => regex,
356            };
357
358            let pattern = compile_regex(regex, flags_scalar)?;
359
360            let start_array = start_array.unwrap();
361
362            Ok(Arc::new(
363                values
364                    .iter()
365                    .zip(start_array.iter())
366                    .map(|(value, start)| count_matches(value, &pattern, start))
367                    .collect::<Result<Int64Array, ArrowError>>()?,
368            ))
369        }
370        (true, false, false) => {
371            let regex = match regex_scalar {
372                None => {
373                    return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
374                }
375                Some(regex) => regex,
376            };
377
378            let flags_array = flags_array.unwrap();
379            if values.len() != flags_array.len() {
380                return Err(ArrowError::ComputeError(format!(
381                    "flags_array must be the same length as values array; got {} and {}",
382                    flags_array.len(),
383                    values.len(),
384                )));
385            }
386
387            Ok(Arc::new(
388                izip!(
389                    values.iter(),
390                    start_array.unwrap().iter(),
391                    flags_array.iter()
392                )
393                .map(|(value, start, flags)| {
394                    let pattern =
395                        compile_and_cache_regex(regex, flags, &mut regex_cache)?;
396
397                    count_matches(value, pattern, start)
398                })
399                .collect::<Result<Int64Array, ArrowError>>()?,
400            ))
401        }
402        (false, true, true) => {
403            if values.len() != regex_array.len() {
404                return Err(ArrowError::ComputeError(format!(
405                    "regex_array must be the same length as values array; got {} and {}",
406                    regex_array.len(),
407                    values.len(),
408                )));
409            }
410
411            Ok(Arc::new(
412                values
413                    .iter()
414                    .zip(regex_array.iter())
415                    .map(|(value, regex)| {
416                        let regex = match regex {
417                            None => return Ok(0),
418                            Some(regex) => regex,
419                        };
420
421                        let pattern = compile_and_cache_regex(
422                            regex,
423                            flags_scalar,
424                            &mut regex_cache,
425                        )?;
426                        count_matches(value, pattern, start_scalar)
427                    })
428                    .collect::<Result<Int64Array, ArrowError>>()?,
429            ))
430        }
431        (false, true, false) => {
432            if values.len() != regex_array.len() {
433                return Err(ArrowError::ComputeError(format!(
434                    "regex_array must be the same length as values array; got {} and {}",
435                    regex_array.len(),
436                    values.len(),
437                )));
438            }
439
440            let flags_array = flags_array.unwrap();
441            if values.len() != flags_array.len() {
442                return Err(ArrowError::ComputeError(format!(
443                    "flags_array must be the same length as values array; got {} and {}",
444                    flags_array.len(),
445                    values.len(),
446                )));
447            }
448
449            Ok(Arc::new(
450                izip!(values.iter(), regex_array.iter(), flags_array.iter())
451                    .map(|(value, regex, flags)| {
452                        let regex = match regex {
453                            None => return Ok(0),
454                            Some(regex) => regex,
455                        };
456
457                        let pattern =
458                            compile_and_cache_regex(regex, flags, &mut regex_cache)?;
459
460                        count_matches(value, pattern, start_scalar)
461                    })
462                    .collect::<Result<Int64Array, ArrowError>>()?,
463            ))
464        }
465        (false, false, true) => {
466            if values.len() != regex_array.len() {
467                return Err(ArrowError::ComputeError(format!(
468                    "regex_array must be the same length as values array; got {} and {}",
469                    regex_array.len(),
470                    values.len(),
471                )));
472            }
473
474            let start_array = start_array.unwrap();
475            if values.len() != start_array.len() {
476                return Err(ArrowError::ComputeError(format!(
477                    "start_array must be the same length as values array; got {} and {}",
478                    start_array.len(),
479                    values.len(),
480                )));
481            }
482
483            Ok(Arc::new(
484                izip!(values.iter(), regex_array.iter(), start_array.iter())
485                    .map(|(value, regex, start)| {
486                        let regex = match regex {
487                            None => return Ok(0),
488                            Some(regex) => regex,
489                        };
490
491                        let pattern = compile_and_cache_regex(
492                            regex,
493                            flags_scalar,
494                            &mut regex_cache,
495                        )?;
496                        count_matches(value, pattern, start)
497                    })
498                    .collect::<Result<Int64Array, ArrowError>>()?,
499            ))
500        }
501        (false, false, false) => {
502            if values.len() != regex_array.len() {
503                return Err(ArrowError::ComputeError(format!(
504                    "regex_array must be the same length as values array; got {} and {}",
505                    regex_array.len(),
506                    values.len(),
507                )));
508            }
509
510            let start_array = start_array.unwrap();
511            if values.len() != start_array.len() {
512                return Err(ArrowError::ComputeError(format!(
513                    "start_array must be the same length as values array; got {} and {}",
514                    start_array.len(),
515                    values.len(),
516                )));
517            }
518
519            let flags_array = flags_array.unwrap();
520            if values.len() != flags_array.len() {
521                return Err(ArrowError::ComputeError(format!(
522                    "flags_array must be the same length as values array; got {} and {}",
523                    flags_array.len(),
524                    values.len(),
525                )));
526            }
527
528            Ok(Arc::new(
529                izip!(
530                    values.iter(),
531                    regex_array.iter(),
532                    start_array.iter(),
533                    flags_array.iter()
534                )
535                .map(|(value, regex, start, flags)| {
536                    let regex = match regex {
537                        None => return Ok(0),
538                        Some(regex) => regex,
539                    };
540
541                    let pattern =
542                        compile_and_cache_regex(regex, flags, &mut regex_cache)?;
543                    count_matches(value, pattern, start)
544                })
545                .collect::<Result<Int64Array, ArrowError>>()?,
546            ))
547        }
548    }
549}
550
551fn count_matches(
552    value: Option<&str>,
553    pattern: &Regex,
554    start: Option<i64>,
555) -> Result<i64, ArrowError> {
556    let value = match value {
557        None => return Ok(0),
558        Some(value) => value,
559    };
560
561    if let Some(start) = start {
562        if start < 1 {
563            return Err(ArrowError::ComputeError(
564                "regexp_count() requires start to be 1 based".to_string(),
565            ));
566        }
567
568        let char_len = value.chars().count();
569        let start_index = (start as usize).saturating_sub(1);
570
571        if start_index > char_len {
572            return Ok(0);
573        }
574
575        // Find the byte offset for the start position (1-based character index)
576        let byte_offset = if start_index == char_len {
577            value.len()
578        } else {
579            value
580                .char_indices()
581                .nth(start_index)
582                .map(|(idx, _)| idx)
583                .unwrap_or(value.len())
584        };
585
586        // Use string slicing instead of collecting chars into a new String
587        let find_slice = &value[byte_offset..];
588        let count = pattern.find_iter(find_slice).count();
589        Ok(count as i64)
590    } else {
591        let count = pattern.find_iter(value).count();
592        Ok(count as i64)
593    }
594}
595
596#[cfg(test)]
597mod tests {
598    use super::*;
599    use arrow::array::{GenericStringArray, StringViewArray};
600    use arrow::datatypes::Field;
601    use datafusion_common::config::ConfigOptions;
602
603    #[test]
604    fn test_regexp_count() {
605        test_case_sensitive_regexp_count_scalar();
606        test_case_sensitive_regexp_count_empty_pattern_scalar();
607        test_case_sensitive_regexp_count_scalar_start();
608        test_case_insensitive_regexp_count_scalar_flags();
609        test_case_sensitive_regexp_count_start_scalar_complex();
610
611        test_case_sensitive_regexp_count_array::<GenericStringArray<i32>>();
612        test_case_sensitive_regexp_count_array::<GenericStringArray<i64>>();
613        test_case_sensitive_regexp_count_array::<StringViewArray>();
614
615        test_case_sensitive_regexp_count_array_start::<GenericStringArray<i32>>();
616        test_case_sensitive_regexp_count_array_start::<GenericStringArray<i64>>();
617        test_case_sensitive_regexp_count_array_start::<StringViewArray>();
618
619        test_case_insensitive_regexp_count_array_flags::<GenericStringArray<i32>>();
620        test_case_insensitive_regexp_count_array_flags::<GenericStringArray<i64>>();
621        test_case_insensitive_regexp_count_array_flags::<StringViewArray>();
622
623        test_case_sensitive_regexp_count_array_complex::<GenericStringArray<i32>>();
624        test_case_sensitive_regexp_count_array_complex::<GenericStringArray<i64>>();
625        test_case_sensitive_regexp_count_array_complex::<StringViewArray>();
626
627        test_case_regexp_count_cache_check::<GenericStringArray<i32>>();
628    }
629
630    fn regexp_count_with_scalar_values(args: &[ScalarValue]) -> Result<ColumnarValue> {
631        let args_values = args
632            .iter()
633            .map(|sv| ColumnarValue::Scalar(sv.clone()))
634            .collect();
635
636        let arg_fields = args
637            .iter()
638            .enumerate()
639            .map(|(idx, a)| Field::new(format!("arg_{idx}"), a.data_type(), true).into())
640            .collect::<Vec<_>>();
641
642        RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
643            args: args_values,
644            arg_fields,
645            number_rows: args.len(),
646            return_field: Field::new("f", Int64, true).into(),
647            config_options: Arc::new(ConfigOptions::default()),
648        })
649    }
650
651    fn test_case_sensitive_regexp_count_scalar() {
652        let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
653        let regex = "abc";
654        let expected: Vec<i64> = vec![0, 1, 2, 1, 3];
655
656        values.iter().enumerate().for_each(|(pos, &v)| {
657            // utf8
658            let v_sv = ScalarValue::Utf8(Some(v.to_string()));
659            let regex_sv = ScalarValue::Utf8(Some(regex.to_string()));
660            let expected = expected.get(pos).cloned();
661            let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]);
662            match re {
663                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
664                    assert_eq!(v, expected, "regexp_count scalar test failed");
665                }
666                _ => panic!("Unexpected result"),
667            }
668
669            // largeutf8
670            let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
671            let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
672            let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]);
673            match re {
674                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
675                    assert_eq!(v, expected, "regexp_count scalar test failed");
676                }
677                _ => panic!("Unexpected result"),
678            }
679
680            // utf8view
681            let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
682            let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
683            let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]);
684            match re {
685                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
686                    assert_eq!(v, expected, "regexp_count scalar test failed");
687                }
688                _ => panic!("Unexpected result"),
689            }
690        });
691    }
692
693    fn test_case_sensitive_regexp_count_empty_pattern_scalar() {
694        let values = ["", "abc", "abc"];
695        let start_positions = [1, 1, 2];
696        let expected: Vec<i64> = vec![1, 4, 3];
697
698        values
699            .iter()
700            .zip(start_positions.iter())
701            .enumerate()
702            .for_each(|(pos, (&value, &start))| {
703                let expected = expected.get(pos).cloned();
704                let start_sv = ScalarValue::Int64(Some(start));
705
706                let re = regexp_count_with_scalar_values(&[
707                    ScalarValue::Utf8(Some(value.to_string())),
708                    ScalarValue::Utf8(Some("".to_string())),
709                    start_sv.clone(),
710                ]);
711                match re {
712                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
713                        assert_eq!(v, expected, "regexp_count scalar test failed");
714                    }
715                    _ => panic!("Unexpected result"),
716                }
717
718                let re = regexp_count_with_scalar_values(&[
719                    ScalarValue::LargeUtf8(Some(value.to_string())),
720                    ScalarValue::LargeUtf8(Some("".to_string())),
721                    start_sv.clone(),
722                ]);
723                match re {
724                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
725                        assert_eq!(v, expected, "regexp_count scalar test failed");
726                    }
727                    _ => panic!("Unexpected result"),
728                }
729
730                let re = regexp_count_with_scalar_values(&[
731                    ScalarValue::Utf8View(Some(value.to_string())),
732                    ScalarValue::Utf8View(Some("".to_string())),
733                    start_sv,
734                ]);
735                match re {
736                    Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
737                        assert_eq!(v, expected, "regexp_count scalar test failed");
738                    }
739                    _ => panic!("Unexpected result"),
740                }
741            });
742    }
743
744    fn test_case_sensitive_regexp_count_scalar_start() {
745        let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
746        let regex = "abc";
747        let start = 2;
748        let expected: Vec<i64> = vec![0, 1, 1, 0, 2];
749
750        values.iter().enumerate().for_each(|(pos, &v)| {
751            // utf8
752            let v_sv = ScalarValue::Utf8(Some(v.to_string()));
753            let regex_sv = ScalarValue::Utf8(Some(regex.to_string()));
754            let start_sv = ScalarValue::Int64(Some(start));
755            let expected = expected.get(pos).cloned();
756            let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
757            match re {
758                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
759                    assert_eq!(v, expected, "regexp_count scalar test failed");
760                }
761                _ => panic!("Unexpected result"),
762            }
763
764            // largeutf8
765            let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
766            let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
767            let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
768            match re {
769                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
770                    assert_eq!(v, expected, "regexp_count scalar test failed");
771                }
772                _ => panic!("Unexpected result"),
773            }
774
775            // utf8view
776            let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
777            let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
778            let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]);
779            match re {
780                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
781                    assert_eq!(v, expected, "regexp_count scalar test failed");
782                }
783                _ => panic!("Unexpected result"),
784            }
785        });
786    }
787
788    fn test_case_insensitive_regexp_count_scalar_flags() {
789        let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
790        let regex = "abc";
791        let start = 1;
792        let flags = "i";
793        let expected: Vec<i64> = vec![0, 1, 2, 2, 3];
794
795        values.iter().enumerate().for_each(|(pos, &v)| {
796            // utf8
797            let v_sv = ScalarValue::Utf8(Some(v.to_string()));
798            let regex_sv = ScalarValue::Utf8(Some(regex.to_string()));
799            let start_sv = ScalarValue::Int64(Some(start));
800            let flags_sv = ScalarValue::Utf8(Some(flags.to_string()));
801            let expected = expected.get(pos).cloned();
802
803            let re = regexp_count_with_scalar_values(&[
804                v_sv,
805                regex_sv,
806                start_sv.clone(),
807                flags_sv.clone(),
808            ]);
809            match re {
810                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
811                    assert_eq!(v, expected, "regexp_count scalar test failed");
812                }
813                _ => panic!("Unexpected result"),
814            }
815
816            // largeutf8
817            let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
818            let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
819            let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string()));
820
821            let re = regexp_count_with_scalar_values(&[
822                v_sv,
823                regex_sv,
824                start_sv.clone(),
825                flags_sv.clone(),
826            ]);
827            match re {
828                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
829                    assert_eq!(v, expected, "regexp_count scalar test failed");
830                }
831                _ => panic!("Unexpected result"),
832            }
833
834            // utf8view
835            let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
836            let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
837            let flags_sv = ScalarValue::Utf8View(Some(flags.to_string()));
838
839            let re = regexp_count_with_scalar_values(&[
840                v_sv,
841                regex_sv,
842                start_sv.clone(),
843                flags_sv.clone(),
844            ]);
845            match re {
846                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
847                    assert_eq!(v, expected, "regexp_count scalar test failed");
848                }
849                _ => panic!("Unexpected result"),
850            }
851        });
852    }
853
854    fn test_case_sensitive_regexp_count_array<A>()
855    where
856        A: From<Vec<&'static str>> + Array + 'static,
857    {
858        let values = A::from(vec!["", "aabca", "abcabc", "abcAbcab", "abcabcAbc"]);
859        let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
860
861        let expected = Int64Array::from(vec![1, 1, 2, 2, 2]);
862
863        let re = regexp_count_func(&[Arc::new(values), Arc::new(regex)]).unwrap();
864        assert_eq!(re.as_ref(), &expected);
865    }
866
867    fn test_case_sensitive_regexp_count_array_start<A>()
868    where
869        A: From<Vec<&'static str>> + Array + 'static,
870    {
871        let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]);
872        let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
873        let start = Int64Array::from(vec![1, 2, 3, 4, 5]);
874
875        let expected = Int64Array::from(vec![1, 0, 1, 1, 0]);
876
877        let re = regexp_count_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)])
878            .unwrap();
879        assert_eq!(re.as_ref(), &expected);
880    }
881
882    fn test_case_insensitive_regexp_count_array_flags<A>()
883    where
884        A: From<Vec<&'static str>> + Array + 'static,
885    {
886        let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]);
887        let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
888        let start = Int64Array::from(vec![1]);
889        let flags = A::from(vec!["", "i", "", "", "i"]);
890
891        let expected = Int64Array::from(vec![1, 1, 2, 2, 3]);
892
893        let re = regexp_count_func(&[
894            Arc::new(values),
895            Arc::new(regex),
896            Arc::new(start),
897            Arc::new(flags),
898        ])
899        .unwrap();
900        assert_eq!(re.as_ref(), &expected);
901    }
902
903    fn test_case_sensitive_regexp_count_start_scalar_complex() {
904        let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
905        let regex = ["", "abc", "a", "bc", "ab"];
906        let start = 5;
907        let flags = ["", "i", "", "", "i"];
908        let expected: Vec<i64> = vec![0, 0, 0, 1, 1];
909
910        values.iter().enumerate().for_each(|(pos, &v)| {
911            // utf8
912            let v_sv = ScalarValue::Utf8(Some(v.to_string()));
913            let regex_sv = ScalarValue::Utf8(regex.get(pos).map(|s| (*s).to_string()));
914            let start_sv = ScalarValue::Int64(Some(start));
915            let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| (*f).to_string()));
916            let expected = expected.get(pos).cloned();
917            let re = regexp_count_with_scalar_values(&[
918                v_sv,
919                regex_sv,
920                start_sv.clone(),
921                flags_sv.clone(),
922            ]);
923            match re {
924                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
925                    assert_eq!(v, expected, "regexp_count scalar test failed");
926                }
927                _ => panic!("Unexpected result"),
928            }
929
930            // largeutf8
931            let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
932            let regex_sv =
933                ScalarValue::LargeUtf8(regex.get(pos).map(|s| (*s).to_string()));
934            let flags_sv =
935                ScalarValue::LargeUtf8(flags.get(pos).map(|f| (*f).to_string()));
936            let re = regexp_count_with_scalar_values(&[
937                v_sv,
938                regex_sv,
939                start_sv.clone(),
940                flags_sv.clone(),
941            ]);
942            match re {
943                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
944                    assert_eq!(v, expected, "regexp_count scalar test failed");
945                }
946                _ => panic!("Unexpected result"),
947            }
948
949            // utf8view
950            let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
951            let regex_sv =
952                ScalarValue::Utf8View(regex.get(pos).map(|s| (*s).to_string()));
953            let flags_sv =
954                ScalarValue::Utf8View(flags.get(pos).map(|f| (*f).to_string()));
955            let re = regexp_count_with_scalar_values(&[
956                v_sv,
957                regex_sv,
958                start_sv.clone(),
959                flags_sv.clone(),
960            ]);
961            match re {
962                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
963                    assert_eq!(v, expected, "regexp_count scalar test failed");
964                }
965                _ => panic!("Unexpected result"),
966            }
967        });
968    }
969
970    fn test_case_sensitive_regexp_count_array_complex<A>()
971    where
972        A: From<Vec<&'static str>> + Array + 'static,
973    {
974        let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]);
975        let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
976        let start = Int64Array::from(vec![1, 2, 3, 4, 5]);
977        let flags = A::from(vec!["", "i", "", "", "i"]);
978
979        let expected = Int64Array::from(vec![1, 1, 1, 1, 1]);
980
981        let re = regexp_count_func(&[
982            Arc::new(values),
983            Arc::new(regex),
984            Arc::new(start),
985            Arc::new(flags),
986        ])
987        .unwrap();
988        assert_eq!(re.as_ref(), &expected);
989    }
990
991    fn test_case_regexp_count_cache_check<A>()
992    where
993        A: From<Vec<&'static str>> + Array + 'static,
994    {
995        let values = A::from(vec!["aaa", "Aaa", "aaa"]);
996        let regex = A::from(vec!["aaa", "aaa", "aaa"]);
997        let start = Int64Array::from(vec![1, 1, 1]);
998        let flags = A::from(vec!["", "i", ""]);
999
1000        let expected = Int64Array::from(vec![1, 1, 1]);
1001
1002        let re = regexp_count_func(&[
1003            Arc::new(values),
1004            Arc::new(regex),
1005            Arc::new(start),
1006            Arc::new(flags),
1007        ])
1008        .unwrap();
1009        assert_eq!(re.as_ref(), &expected);
1010    }
1011}