Skip to main content

datafusion_functions/regex/
regexpmatch.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Regex expressions
19use arrow::array::{Array, ArrayRef, AsArray};
20use arrow::compute::kernels::regexp;
21use arrow::datatypes::DataType;
22use arrow::datatypes::Field;
23use datafusion_common::Result;
24use datafusion_common::ScalarValue;
25use datafusion_common::exec_err;
26use datafusion_common::{arrow_datafusion_err, plan_err};
27use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs, TypeSignature};
28use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
29use datafusion_macros::user_doc;
30use std::sync::Arc;
31
32#[user_doc(
33    doc_section(label = "Regular Expression Functions"),
34    description = "Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string.",
35    syntax_example = "regexp_match(str, regexp[, flags])",
36    sql_example = r#"```sql
37            > select regexp_match('Köln', '[a-zA-Z]ö[a-zA-Z]{2}');
38            +---------------------------------------------------------+
39            | regexp_match(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) |
40            +---------------------------------------------------------+
41            | [Köln]                                                  |
42            +---------------------------------------------------------+
43            SELECT regexp_match('aBc', '(b|d)', 'i');
44            +---------------------------------------------------+
45            | regexp_match(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) |
46            +---------------------------------------------------+
47            | [B]                                               |
48            +---------------------------------------------------+
49```
50Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/regexp.rs)
51"#,
52    standard_argument(name = "str", prefix = "String"),
53    argument(
54        name = "regexp",
55        description = "Regular expression to match against.
56            Can be a constant, column, or function."
57    ),
58    argument(
59        name = "flags",
60        description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported:
61  - **i**: case-insensitive: letters match both upper and lower case
62  - **m**: multi-line mode: ^ and $ match begin/end of line
63  - **s**: allow . to match \n
64  - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used
65  - **U**: swap the meaning of x* and x*?"#
66    )
67)]
68#[derive(Debug, PartialEq, Eq, Hash)]
69pub struct RegexpMatchFunc {
70    signature: Signature,
71}
72
73impl Default for RegexpMatchFunc {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl RegexpMatchFunc {
80    pub fn new() -> Self {
81        use DataType::*;
82        Self {
83            signature: Signature::one_of(
84                vec![
85                    // Planner attempts coercion to the target type starting with the most preferred candidate.
86                    // For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`.
87                    // If that fails, it proceeds to `(Utf8, Utf8)`.
88                    TypeSignature::Exact(vec![Utf8View, Utf8View]),
89                    TypeSignature::Exact(vec![Utf8, Utf8]),
90                    TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]),
91                    TypeSignature::Exact(vec![Utf8View, Utf8View, Utf8View]),
92                    TypeSignature::Exact(vec![Utf8, Utf8, Utf8]),
93                    TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]),
94                ],
95                Volatility::Immutable,
96            ),
97        }
98    }
99}
100
101impl ScalarUDFImpl for RegexpMatchFunc {
102    fn name(&self) -> &str {
103        "regexp_match"
104    }
105
106    fn signature(&self) -> &Signature {
107        &self.signature
108    }
109
110    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
111        Ok(match &arg_types[0] {
112            DataType::Null => DataType::Null,
113            other => DataType::List(Arc::new(Field::new_list_field(other.clone(), true))),
114        })
115    }
116
117    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
118        let args = &args.args;
119        let len = args
120            .iter()
121            .fold(Option::<usize>::None, |acc, arg| match arg {
122                ColumnarValue::Scalar(_) => acc,
123                ColumnarValue::Array(a) => Some(a.len()),
124            });
125
126        let is_scalar = len.is_none();
127        let inferred_length = len.unwrap_or(1);
128        let args = args
129            .iter()
130            .map(|arg| arg.to_array(inferred_length))
131            .collect::<Result<Vec<_>>>()?;
132
133        let result = regexp_match(&args);
134        if is_scalar {
135            // If all inputs are scalar, keeps output as scalar
136            let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
137            result.map(ColumnarValue::Scalar)
138        } else {
139            result.map(ColumnarValue::Array)
140        }
141    }
142
143    fn documentation(&self) -> Option<&Documentation> {
144        self.doc()
145    }
146}
147
148pub fn regexp_match(args: &[ArrayRef]) -> Result<ArrayRef> {
149    match args.len() {
150        2 => regexp::regexp_match(&args[0], &args[1], None)
151            .map_err(|e| arrow_datafusion_err!(e)),
152        3 => {
153            match args[2].data_type() {
154                DataType::Utf8View => {
155                    if args[2].as_string_view().iter().any(|s| s == Some("g")) {
156                        return plan_err!(
157                            "regexp_match() does not support the \"global\" option"
158                        );
159                    }
160                }
161                DataType::Utf8 => {
162                    if args[2].as_string::<i32>().iter().any(|s| s == Some("g")) {
163                        return plan_err!(
164                            "regexp_match() does not support the \"global\" option"
165                        );
166                    }
167                }
168                DataType::LargeUtf8 => {
169                    if args[2].as_string::<i64>().iter().any(|s| s == Some("g")) {
170                        return plan_err!(
171                            "regexp_match() does not support the \"global\" option"
172                        );
173                    }
174                }
175                e => {
176                    return plan_err!(
177                        "regexp_match was called with unexpected data type {e:?}"
178                    );
179                }
180            }
181
182            regexp::regexp_match(&args[0], &args[1], Some(&args[2]))
183                .map_err(|e| arrow_datafusion_err!(e))
184        }
185        other => exec_err!(
186            "regexp_match was called with {other} arguments. It requires at least 2 and at most 3."
187        ),
188    }
189}
190#[cfg(test)]
191mod tests {
192    use crate::regex::regexpmatch::regexp_match;
193    use arrow::array::StringArray;
194    use arrow::array::{GenericStringBuilder, ListBuilder};
195    use std::sync::Arc;
196
197    #[test]
198    fn test_case_sensitive_regexp_match() {
199        let values = StringArray::from(vec!["abc"; 5]);
200        let patterns =
201            StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]);
202
203        let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::new();
204        let mut expected_builder = ListBuilder::new(elem_builder);
205        expected_builder.values().append_value("a");
206        expected_builder.append(true);
207        expected_builder.append(false);
208        expected_builder.values().append_value("b");
209        expected_builder.append(true);
210        expected_builder.append(false);
211        expected_builder.append(false);
212        let expected = expected_builder.finish();
213
214        let re = regexp_match(&[Arc::new(values), Arc::new(patterns)]).unwrap();
215
216        assert_eq!(re.as_ref(), &expected);
217    }
218
219    #[test]
220    fn test_case_insensitive_regexp_match() {
221        let values = StringArray::from(vec!["abc"; 5]);
222        let patterns =
223            StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]);
224        let flags = StringArray::from(vec!["i"; 5]);
225
226        let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::new();
227        let mut expected_builder = ListBuilder::new(elem_builder);
228        expected_builder.values().append_value("a");
229        expected_builder.append(true);
230        expected_builder.values().append_value("a");
231        expected_builder.append(true);
232        expected_builder.values().append_value("b");
233        expected_builder.append(true);
234        expected_builder.values().append_value("b");
235        expected_builder.append(true);
236        expected_builder.append(false);
237        let expected = expected_builder.finish();
238
239        let re = regexp_match(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
240            .unwrap();
241
242        assert_eq!(re.as_ref(), &expected);
243    }
244
245    #[test]
246    fn test_unsupported_global_flag_regexp_match() {
247        let values = StringArray::from(vec!["abc"]);
248        let patterns = StringArray::from(vec!["^(a)"]);
249        let flags = StringArray::from(vec!["g"]);
250
251        let re_err =
252            regexp_match(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
253                .expect_err("unsupported flag should have failed");
254
255        assert_eq!(
256            re_err.strip_backtrace(),
257            "Error during planning: regexp_match() does not support the \"global\" option"
258        );
259    }
260}