datafusion_functions/regex/
regexpmatch.rs1use arrow::array::{Array, ArrayRef, AsArray};
20use arrow::compute::kernels::regexp;
21use arrow::datatypes::DataType;
22use arrow::datatypes::Field;
23use datafusion_common::Result;
24use datafusion_common::ScalarValue;
25use datafusion_common::exec_err;
26use datafusion_common::{arrow_datafusion_err, plan_err};
27use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs, TypeSignature};
28use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
29use datafusion_macros::user_doc;
30use std::sync::Arc;
31
32#[user_doc(
33 doc_section(label = "Regular Expression Functions"),
34 description = "Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string.",
35 syntax_example = "regexp_match(str, regexp[, flags])",
36 sql_example = r#"```sql
37 > select regexp_match('Köln', '[a-zA-Z]ö[a-zA-Z]{2}');
38 +---------------------------------------------------------+
39 | regexp_match(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) |
40 +---------------------------------------------------------+
41 | [Köln] |
42 +---------------------------------------------------------+
43 SELECT regexp_match('aBc', '(b|d)', 'i');
44 +---------------------------------------------------+
45 | regexp_match(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) |
46 +---------------------------------------------------+
47 | [B] |
48 +---------------------------------------------------+
49```
50Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/regexp.rs)
51"#,
52 standard_argument(name = "str", prefix = "String"),
53 argument(
54 name = "regexp",
55 description = "Regular expression to match against.
56 Can be a constant, column, or function."
57 ),
58 argument(
59 name = "flags",
60 description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported:
61 - **i**: case-insensitive: letters match both upper and lower case
62 - **m**: multi-line mode: ^ and $ match begin/end of line
63 - **s**: allow . to match \n
64 - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used
65 - **U**: swap the meaning of x* and x*?"#
66 )
67)]
68#[derive(Debug, PartialEq, Eq, Hash)]
69pub struct RegexpMatchFunc {
70 signature: Signature,
71}
72
73impl Default for RegexpMatchFunc {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl RegexpMatchFunc {
80 pub fn new() -> Self {
81 use DataType::*;
82 Self {
83 signature: Signature::one_of(
84 vec![
85 TypeSignature::Exact(vec![Utf8View, Utf8View]),
89 TypeSignature::Exact(vec![Utf8, Utf8]),
90 TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]),
91 TypeSignature::Exact(vec![Utf8View, Utf8View, Utf8View]),
92 TypeSignature::Exact(vec![Utf8, Utf8, Utf8]),
93 TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]),
94 ],
95 Volatility::Immutable,
96 ),
97 }
98 }
99}
100
101impl ScalarUDFImpl for RegexpMatchFunc {
102 fn name(&self) -> &str {
103 "regexp_match"
104 }
105
106 fn signature(&self) -> &Signature {
107 &self.signature
108 }
109
110 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
111 Ok(match &arg_types[0] {
112 DataType::Null => DataType::Null,
113 other => DataType::List(Arc::new(Field::new_list_field(other.clone(), true))),
114 })
115 }
116
117 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
118 let args = &args.args;
119 let len = args
120 .iter()
121 .fold(Option::<usize>::None, |acc, arg| match arg {
122 ColumnarValue::Scalar(_) => acc,
123 ColumnarValue::Array(a) => Some(a.len()),
124 });
125
126 let is_scalar = len.is_none();
127 let inferred_length = len.unwrap_or(1);
128 let args = args
129 .iter()
130 .map(|arg| arg.to_array(inferred_length))
131 .collect::<Result<Vec<_>>>()?;
132
133 let result = regexp_match(&args);
134 if is_scalar {
135 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
137 result.map(ColumnarValue::Scalar)
138 } else {
139 result.map(ColumnarValue::Array)
140 }
141 }
142
143 fn documentation(&self) -> Option<&Documentation> {
144 self.doc()
145 }
146}
147
148pub fn regexp_match(args: &[ArrayRef]) -> Result<ArrayRef> {
149 match args.len() {
150 2 => regexp::regexp_match(&args[0], &args[1], None)
151 .map_err(|e| arrow_datafusion_err!(e)),
152 3 => {
153 match args[2].data_type() {
154 DataType::Utf8View => {
155 if args[2].as_string_view().iter().any(|s| s == Some("g")) {
156 return plan_err!(
157 "regexp_match() does not support the \"global\" option"
158 );
159 }
160 }
161 DataType::Utf8 => {
162 if args[2].as_string::<i32>().iter().any(|s| s == Some("g")) {
163 return plan_err!(
164 "regexp_match() does not support the \"global\" option"
165 );
166 }
167 }
168 DataType::LargeUtf8 => {
169 if args[2].as_string::<i64>().iter().any(|s| s == Some("g")) {
170 return plan_err!(
171 "regexp_match() does not support the \"global\" option"
172 );
173 }
174 }
175 e => {
176 return plan_err!(
177 "regexp_match was called with unexpected data type {e:?}"
178 );
179 }
180 }
181
182 regexp::regexp_match(&args[0], &args[1], Some(&args[2]))
183 .map_err(|e| arrow_datafusion_err!(e))
184 }
185 other => exec_err!(
186 "regexp_match was called with {other} arguments. It requires at least 2 and at most 3."
187 ),
188 }
189}
190#[cfg(test)]
191mod tests {
192 use crate::regex::regexpmatch::regexp_match;
193 use arrow::array::StringArray;
194 use arrow::array::{GenericStringBuilder, ListBuilder};
195 use std::sync::Arc;
196
197 #[test]
198 fn test_case_sensitive_regexp_match() {
199 let values = StringArray::from(vec!["abc"; 5]);
200 let patterns =
201 StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]);
202
203 let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::new();
204 let mut expected_builder = ListBuilder::new(elem_builder);
205 expected_builder.values().append_value("a");
206 expected_builder.append(true);
207 expected_builder.append(false);
208 expected_builder.values().append_value("b");
209 expected_builder.append(true);
210 expected_builder.append(false);
211 expected_builder.append(false);
212 let expected = expected_builder.finish();
213
214 let re = regexp_match(&[Arc::new(values), Arc::new(patterns)]).unwrap();
215
216 assert_eq!(re.as_ref(), &expected);
217 }
218
219 #[test]
220 fn test_case_insensitive_regexp_match() {
221 let values = StringArray::from(vec!["abc"; 5]);
222 let patterns =
223 StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]);
224 let flags = StringArray::from(vec!["i"; 5]);
225
226 let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::new();
227 let mut expected_builder = ListBuilder::new(elem_builder);
228 expected_builder.values().append_value("a");
229 expected_builder.append(true);
230 expected_builder.values().append_value("a");
231 expected_builder.append(true);
232 expected_builder.values().append_value("b");
233 expected_builder.append(true);
234 expected_builder.values().append_value("b");
235 expected_builder.append(true);
236 expected_builder.append(false);
237 let expected = expected_builder.finish();
238
239 let re = regexp_match(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
240 .unwrap();
241
242 assert_eq!(re.as_ref(), &expected);
243 }
244
245 #[test]
246 fn test_unsupported_global_flag_regexp_match() {
247 let values = StringArray::from(vec!["abc"]);
248 let patterns = StringArray::from(vec!["^(a)"]);
249 let flags = StringArray::from(vec!["g"]);
250
251 let re_err =
252 regexp_match(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
253 .expect_err("unsupported flag should have failed");
254
255 assert_eq!(
256 re_err.strip_backtrace(),
257 "Error during planning: regexp_match() does not support the \"global\" option"
258 );
259 }
260}