datafusion_functions/regex/
regexpmatch.rs1use arrow::array::{Array, ArrayRef, AsArray};
20use arrow::compute::kernels::regexp;
21use arrow::datatypes::DataType;
22use arrow::datatypes::Field;
23use datafusion_common::Result;
24use datafusion_common::ScalarValue;
25use datafusion_common::exec_err;
26use datafusion_common::{arrow_datafusion_err, plan_err};
27use datafusion_expr::{ColumnarValue, Documentation, TypeSignature};
28use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
29use datafusion_macros::user_doc;
30use std::any::Any;
31use std::sync::Arc;
32
33#[user_doc(
34 doc_section(label = "Regular Expression Functions"),
35 description = "Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string.",
36 syntax_example = "regexp_match(str, regexp[, flags])",
37 sql_example = r#"```sql
38 > select regexp_match('Köln', '[a-zA-Z]ö[a-zA-Z]{2}');
39 +---------------------------------------------------------+
40 | regexp_match(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) |
41 +---------------------------------------------------------+
42 | [Köln] |
43 +---------------------------------------------------------+
44 SELECT regexp_match('aBc', '(b|d)', 'i');
45 +---------------------------------------------------+
46 | regexp_match(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) |
47 +---------------------------------------------------+
48 | [B] |
49 +---------------------------------------------------+
50```
51Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/regexp.rs)
52"#,
53 standard_argument(name = "str", prefix = "String"),
54 argument(
55 name = "regexp",
56 description = "Regular expression to match against.
57 Can be a constant, column, or function."
58 ),
59 argument(
60 name = "flags",
61 description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported:
62 - **i**: case-insensitive: letters match both upper and lower case
63 - **m**: multi-line mode: ^ and $ match begin/end of line
64 - **s**: allow . to match \n
65 - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used
66 - **U**: swap the meaning of x* and x*?"#
67 )
68)]
69#[derive(Debug, PartialEq, Eq, Hash)]
70pub struct RegexpMatchFunc {
71 signature: Signature,
72}
73
74impl Default for RegexpMatchFunc {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl RegexpMatchFunc {
81 pub fn new() -> Self {
82 use DataType::*;
83 Self {
84 signature: Signature::one_of(
85 vec![
86 TypeSignature::Exact(vec![Utf8View, Utf8View]),
90 TypeSignature::Exact(vec![Utf8, Utf8]),
91 TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]),
92 TypeSignature::Exact(vec![Utf8View, Utf8View, Utf8View]),
93 TypeSignature::Exact(vec![Utf8, Utf8, Utf8]),
94 TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]),
95 ],
96 Volatility::Immutable,
97 ),
98 }
99 }
100}
101
102impl ScalarUDFImpl for RegexpMatchFunc {
103 fn as_any(&self) -> &dyn Any {
104 self
105 }
106
107 fn name(&self) -> &str {
108 "regexp_match"
109 }
110
111 fn signature(&self) -> &Signature {
112 &self.signature
113 }
114
115 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
116 Ok(match &arg_types[0] {
117 DataType::Null => DataType::Null,
118 other => DataType::List(Arc::new(Field::new_list_field(other.clone(), true))),
119 })
120 }
121
122 fn invoke_with_args(
123 &self,
124 args: datafusion_expr::ScalarFunctionArgs,
125 ) -> Result<ColumnarValue> {
126 let args = &args.args;
127 let len = args
128 .iter()
129 .fold(Option::<usize>::None, |acc, arg| match arg {
130 ColumnarValue::Scalar(_) => acc,
131 ColumnarValue::Array(a) => Some(a.len()),
132 });
133
134 let is_scalar = len.is_none();
135 let inferred_length = len.unwrap_or(1);
136 let args = args
137 .iter()
138 .map(|arg| arg.to_array(inferred_length))
139 .collect::<Result<Vec<_>>>()?;
140
141 let result = regexp_match(&args);
142 if is_scalar {
143 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
145 result.map(ColumnarValue::Scalar)
146 } else {
147 result.map(ColumnarValue::Array)
148 }
149 }
150
151 fn documentation(&self) -> Option<&Documentation> {
152 self.doc()
153 }
154}
155
156pub fn regexp_match(args: &[ArrayRef]) -> Result<ArrayRef> {
157 match args.len() {
158 2 => regexp::regexp_match(&args[0], &args[1], None)
159 .map_err(|e| arrow_datafusion_err!(e)),
160 3 => {
161 match args[2].data_type() {
162 DataType::Utf8View => {
163 if args[2].as_string_view().iter().any(|s| s == Some("g")) {
164 return plan_err!(
165 "regexp_match() does not support the \"global\" option"
166 );
167 }
168 }
169 DataType::Utf8 => {
170 if args[2].as_string::<i32>().iter().any(|s| s == Some("g")) {
171 return plan_err!(
172 "regexp_match() does not support the \"global\" option"
173 );
174 }
175 }
176 DataType::LargeUtf8 => {
177 if args[2].as_string::<i64>().iter().any(|s| s == Some("g")) {
178 return plan_err!(
179 "regexp_match() does not support the \"global\" option"
180 );
181 }
182 }
183 e => {
184 return plan_err!(
185 "regexp_match was called with unexpected data type {e:?}"
186 );
187 }
188 }
189
190 regexp::regexp_match(&args[0], &args[1], Some(&args[2]))
191 .map_err(|e| arrow_datafusion_err!(e))
192 }
193 other => exec_err!(
194 "regexp_match was called with {other} arguments. It requires at least 2 and at most 3."
195 ),
196 }
197}
198#[cfg(test)]
199mod tests {
200 use crate::regex::regexpmatch::regexp_match;
201 use arrow::array::StringArray;
202 use arrow::array::{GenericStringBuilder, ListBuilder};
203 use std::sync::Arc;
204
205 #[test]
206 fn test_case_sensitive_regexp_match() {
207 let values = StringArray::from(vec!["abc"; 5]);
208 let patterns =
209 StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]);
210
211 let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::new();
212 let mut expected_builder = ListBuilder::new(elem_builder);
213 expected_builder.values().append_value("a");
214 expected_builder.append(true);
215 expected_builder.append(false);
216 expected_builder.values().append_value("b");
217 expected_builder.append(true);
218 expected_builder.append(false);
219 expected_builder.append(false);
220 let expected = expected_builder.finish();
221
222 let re = regexp_match(&[Arc::new(values), Arc::new(patterns)]).unwrap();
223
224 assert_eq!(re.as_ref(), &expected);
225 }
226
227 #[test]
228 fn test_case_insensitive_regexp_match() {
229 let values = StringArray::from(vec!["abc"; 5]);
230 let patterns =
231 StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]);
232 let flags = StringArray::from(vec!["i"; 5]);
233
234 let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::new();
235 let mut expected_builder = ListBuilder::new(elem_builder);
236 expected_builder.values().append_value("a");
237 expected_builder.append(true);
238 expected_builder.values().append_value("a");
239 expected_builder.append(true);
240 expected_builder.values().append_value("b");
241 expected_builder.append(true);
242 expected_builder.values().append_value("b");
243 expected_builder.append(true);
244 expected_builder.append(false);
245 let expected = expected_builder.finish();
246
247 let re = regexp_match(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
248 .unwrap();
249
250 assert_eq!(re.as_ref(), &expected);
251 }
252
253 #[test]
254 fn test_unsupported_global_flag_regexp_match() {
255 let values = StringArray::from(vec!["abc"]);
256 let patterns = StringArray::from(vec!["^(a)"]);
257 let flags = StringArray::from(vec!["g"]);
258
259 let re_err =
260 regexp_match(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
261 .expect_err("unsupported flag should have failed");
262
263 assert_eq!(
264 re_err.strip_backtrace(),
265 "Error during planning: regexp_match() does not support the \"global\" option"
266 );
267 }
268}