1use std::collections::HashSet;
4
5use serde_json::Value;
6
7use crate::functions::Function;
8use crate::functions::custom_error;
9use crate::functions::number_value;
10use crate::interpreter::SearchResult;
11use crate::registry::register_if_enabled;
12use crate::{Context, Runtime, arg, defn};
13
14use regex::Regex;
15
16pub fn register_filtered(runtime: &mut Runtime, enabled: &HashSet<&str>) {
18 register_if_enabled(
19 runtime,
20 "regex_match",
21 enabled,
22 Box::new(RegexMatchFn::new()),
23 );
24 register_if_enabled(
25 runtime,
26 "regex_extract",
27 enabled,
28 Box::new(RegexExtractFn::new()),
29 );
30 register_if_enabled(
31 runtime,
32 "regex_replace",
33 enabled,
34 Box::new(RegexReplaceFn::new()),
35 );
36 register_if_enabled(
37 runtime,
38 "regex_count",
39 enabled,
40 Box::new(RegexCountFn::new()),
41 );
42 register_if_enabled(
43 runtime,
44 "regex_split",
45 enabled,
46 Box::new(RegexSplitFn::new()),
47 );
48}
49
50defn!(RegexMatchFn, vec![arg!(string), arg!(string)], None);
55
56impl Function for RegexMatchFn {
57 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
58 self.signature.validate(args, ctx)?;
59
60 let input = args[0].as_str().unwrap();
62 let pattern = args[1].as_str().unwrap();
63
64 let re = Regex::new(pattern)
65 .map_err(|e| custom_error(ctx, &format!("Invalid regex pattern: {e}")))?;
66
67 Ok(Value::Bool(re.is_match(input)))
68 }
69}
70
71defn!(RegexExtractFn, vec![arg!(string), arg!(string)], None);
76
77impl Function for RegexExtractFn {
78 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
79 self.signature.validate(args, ctx)?;
80
81 let input = args[0].as_str().unwrap();
83 let pattern = args[1].as_str().unwrap();
84
85 let re = Regex::new(pattern)
86 .map_err(|e| custom_error(ctx, &format!("Invalid regex pattern: {e}")))?;
87
88 let matches: Vec<Value> = re
89 .find_iter(input)
90 .map(|m| Value::String(m.as_str().to_string()))
91 .collect();
92
93 if matches.is_empty() {
95 Ok(Value::Null)
96 } else {
97 Ok(Value::Array(matches))
98 }
99 }
100}
101
102defn!(
107 RegexReplaceFn,
108 vec![arg!(string), arg!(string), arg!(string)],
109 None
110);
111
112impl Function for RegexReplaceFn {
113 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
114 self.signature.validate(args, ctx)?;
115
116 let input = args[0].as_str().unwrap();
118 let pattern = args[1].as_str().unwrap();
119 let replacement = args[2].as_str().unwrap();
120
121 let re = Regex::new(pattern)
122 .map_err(|e| custom_error(ctx, &format!("Invalid regex pattern: {e}")))?;
123
124 let result = re.replace_all(input, replacement);
125 Ok(Value::String(result.into_owned()))
126 }
127}
128
129defn!(RegexCountFn, vec![arg!(string), arg!(string)], None);
134
135impl Function for RegexCountFn {
136 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
137 self.signature.validate(args, ctx)?;
138
139 let input = args[0].as_str().unwrap();
141 let pattern = args[1].as_str().unwrap();
142
143 let re = Regex::new(pattern)
144 .map_err(|e| custom_error(ctx, &format!("Invalid regex pattern: {e}")))?;
145
146 let count = re.find_iter(input).count();
147 Ok(number_value(count as f64))
148 }
149}
150
151defn!(RegexSplitFn, vec![arg!(string), arg!(string)], None);
156
157impl Function for RegexSplitFn {
158 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
159 self.signature.validate(args, ctx)?;
160
161 let input = args[0].as_str().unwrap();
163 let pattern = args[1].as_str().unwrap();
164
165 let re = Regex::new(pattern)
166 .map_err(|e| custom_error(ctx, &format!("Invalid regex pattern: {e}")))?;
167
168 let parts: Vec<Value> = re
169 .split(input)
170 .map(|s| Value::String(s.to_string()))
171 .collect();
172
173 Ok(Value::Array(parts))
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use crate::Runtime;
180 use serde_json::json;
181
182 fn setup_runtime() -> Runtime {
183 Runtime::builder()
184 .with_standard()
185 .with_all_extensions()
186 .build()
187 }
188
189 #[test]
190 fn test_regex_match() {
191 let runtime = setup_runtime();
192 let expr = runtime.compile("regex_match(@, '^hello')").unwrap();
193
194 let data = json!("hello world");
195 let result = expr.search(&data).unwrap();
196 assert_eq!(result, json!(true));
197
198 let data = json!("world hello");
199 let result = expr.search(&data).unwrap();
200 assert_eq!(result, json!(false));
201 }
202
203 #[test]
204 fn test_regex_extract() {
205 let runtime = setup_runtime();
206 let expr = runtime.compile("regex_extract(@, '[0-9]+')").unwrap();
207 let data = json!("abc123def456");
208 let result = expr.search(&data).unwrap();
209 let arr = result.as_array().unwrap();
210 assert_eq!(arr.len(), 2);
211 assert_eq!(arr[0].as_str().unwrap(), "123");
212 assert_eq!(arr[1].as_str().unwrap(), "456");
213 }
214
215 #[test]
216 fn test_regex_replace() {
217 let runtime = setup_runtime();
218 let expr = runtime.compile("regex_replace(@, '[0-9]+', 'X')").unwrap();
219 let data = json!("abc123def456");
220 let result = expr.search(&data).unwrap();
221 assert_eq!(result.as_str().unwrap(), "abcXdefX");
222 }
223
224 #[test]
225 fn test_regex_count() {
226 let runtime = setup_runtime();
227 let expr = runtime.compile("regex_count(@, '[0-9]+')").unwrap();
228
229 let data = json!("abc123def456ghi789");
230 let result = expr.search(&data).unwrap();
231 assert_eq!(result, json!(3.0));
232 }
233
234 #[test]
235 fn test_regex_count_no_matches() {
236 let runtime = setup_runtime();
237 let expr = runtime.compile("regex_count(@, '[0-9]+')").unwrap();
238
239 let data = json!("abcdef");
240 let result = expr.search(&data).unwrap();
241 assert_eq!(result, json!(0.0));
242 }
243
244 #[test]
245 fn test_regex_count_overlapping_pattern() {
246 let runtime = setup_runtime();
247 let expr = runtime.compile("regex_count(@, '[aeiou]')").unwrap();
248
249 let data = json!("hello world");
250 let result = expr.search(&data).unwrap();
251 assert_eq!(result, json!(3.0));
252 }
253
254 #[test]
255 fn test_regex_split() {
256 let runtime = setup_runtime();
257 let expr = runtime.compile("regex_split(@, ',')").unwrap();
258
259 let data = json!("a,b,c");
260 let result = expr.search(&data).unwrap();
261 assert_eq!(result, json!(["a", "b", "c"]));
262 }
263
264 #[test]
265 fn test_regex_split_whitespace() {
266 let runtime = setup_runtime();
267 let expr = runtime.compile("regex_split(@, '\\s+')").unwrap();
268
269 let data = json!("hello world\tfoo");
270 let result = expr.search(&data).unwrap();
271 assert_eq!(result, json!(["hello", "world", "foo"]));
272 }
273
274 #[test]
275 fn test_regex_split_no_match() {
276 let runtime = setup_runtime();
277 let expr = runtime.compile("regex_split(@, ',')").unwrap();
278
279 let data = json!("no delimiters here");
280 let result = expr.search(&data).unwrap();
281 assert_eq!(result, json!(["no delimiters here"]));
282 }
283
284 #[test]
285 fn test_regex_match_invalid_pattern() {
286 let runtime = setup_runtime();
287 let expr = runtime.compile("regex_match(@, '[invalid')").unwrap();
288 let result = expr.search(&json!("test"));
289 assert!(result.is_err());
290 }
291
292 #[test]
293 fn test_regex_extract_no_match() {
294 let runtime = setup_runtime();
295 let expr = runtime.compile("regex_extract(@, '[0-9]+')").unwrap();
296 let data = json!("no numbers here");
297 let result = expr.search(&data).unwrap();
298 assert!(result.is_null());
299 }
300
301 #[test]
302 fn test_regex_replace_capture_groups() {
303 let runtime = setup_runtime();
304 let expr = runtime
306 .compile(r#"regex_replace(@, '(\w+) (\w+)', '$2 $1')"#)
307 .unwrap();
308 let data = json!("John Doe");
309 let result = expr.search(&data).unwrap();
310 assert_eq!(result.as_str().unwrap(), "Doe John");
311 }
312
313 #[test]
314 fn test_regex_replace_no_match() {
315 let runtime = setup_runtime();
316 let expr = runtime.compile("regex_replace(@, '[0-9]+', 'X')").unwrap();
317 let data = json!("no numbers");
318 let result = expr.search(&data).unwrap();
319 assert_eq!(result.as_str().unwrap(), "no numbers");
320 }
321
322 #[test]
323 fn test_regex_match_anchored() {
324 let runtime = setup_runtime();
325 let expr = runtime.compile(r"regex_match(@, '^\d{3}-\d{4}$')").unwrap();
327 assert_eq!(expr.search(&json!("123-4567")).unwrap(), json!(true));
328 assert_eq!(expr.search(&json!("abc-defg")).unwrap(), json!(false));
329 assert_eq!(expr.search(&json!("123-45678")).unwrap(), json!(false));
330 }
331
332 #[test]
333 fn test_regex_extract_email_pattern() {
334 let runtime = setup_runtime();
335 let expr = runtime
336 .compile(r"regex_extract(@, '[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}')")
337 .unwrap();
338 let data = json!("Contact us at info@example.com or support@test.org");
339 let result = expr.search(&data).unwrap();
340 let arr = result.as_array().unwrap();
341 assert_eq!(arr.len(), 2);
342 assert_eq!(arr[0].as_str().unwrap(), "info@example.com");
343 assert_eq!(arr[1].as_str().unwrap(), "support@test.org");
344 }
345
346 #[test]
347 fn test_regex_split_complex_delimiter() {
348 let runtime = setup_runtime();
349 let expr = runtime.compile(r"regex_split(@, ',\s*')").unwrap();
351 let data = json!("a, b,c, d");
352 let result = expr.search(&data).unwrap();
353 assert_eq!(result, json!(["a", "b", "c", "d"]));
354 }
355
356 #[test]
357 fn test_regex_count_invalid_pattern() {
358 let runtime = setup_runtime();
359 let expr = runtime.compile("regex_count(@, '[bad')").unwrap();
360 let result = expr.search(&json!("test"));
361 assert!(result.is_err());
362 }
363}