1use std::collections::HashSet;
4
5use aho_corasick::AhoCorasick;
6use serde_json::{Number, Value};
7
8use crate::functions::Function;
9use crate::interpreter::SearchResult;
10use crate::registry::register_if_enabled;
11use crate::{Context, Runtime, arg, defn};
12
13defn!(MatchAnyFn, vec![arg!(string), arg!(array)], None);
16
17impl Function for MatchAnyFn {
18 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
19 self.signature.validate(args, ctx)?;
20
21 let text = args[0].as_str().unwrap();
22 let patterns_arr = args[1].as_array().unwrap();
23
24 let patterns: Vec<&str> = patterns_arr.iter().filter_map(|p| p.as_str()).collect();
25
26 if patterns.is_empty() {
27 return Ok(Value::Bool(false));
28 }
29
30 let ac = AhoCorasick::new(&patterns).unwrap();
31 let has_match = ac.find(text).is_some();
32
33 Ok(Value::Bool(has_match))
34 }
35}
36
37defn!(MatchAllFn, vec![arg!(string), arg!(array)], None);
40
41impl Function for MatchAllFn {
42 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
43 self.signature.validate(args, ctx)?;
44
45 let text = args[0].as_str().unwrap();
46 let patterns_arr = args[1].as_array().unwrap();
47
48 let patterns: Vec<&str> = patterns_arr.iter().filter_map(|p| p.as_str()).collect();
49
50 if patterns.is_empty() {
51 return Ok(Value::Bool(true));
52 }
53
54 let ac = AhoCorasick::new(&patterns).unwrap();
55
56 let mut found = vec![false; patterns.len()];
57
58 for mat in ac.find_iter(text) {
59 found[mat.pattern().as_usize()] = true;
60 }
61
62 let all_found = found.iter().all(|&f| f);
63 Ok(Value::Bool(all_found))
64 }
65}
66
67defn!(MatchWhichFn, vec![arg!(string), arg!(array)], None);
70
71impl Function for MatchWhichFn {
72 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
73 self.signature.validate(args, ctx)?;
74
75 let text = args[0].as_str().unwrap();
76 let patterns_arr = args[1].as_array().unwrap();
77
78 let patterns: Vec<&str> = patterns_arr.iter().filter_map(|p| p.as_str()).collect();
79
80 if patterns.is_empty() {
81 return Ok(Value::Array(vec![]));
82 }
83
84 let ac = AhoCorasick::new(&patterns).unwrap();
85
86 let mut found = vec![false; patterns.len()];
87
88 for mat in ac.find_iter(text) {
89 found[mat.pattern().as_usize()] = true;
90 }
91
92 let matched: Vec<Value> = patterns
93 .iter()
94 .enumerate()
95 .filter(|(i, _)| found[*i])
96 .map(|(_, p)| Value::String((*p).to_string()))
97 .collect();
98
99 Ok(Value::Array(matched))
100 }
101}
102
103defn!(MatchCountFn, vec![arg!(string), arg!(array)], None);
106
107impl Function for MatchCountFn {
108 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
109 self.signature.validate(args, ctx)?;
110
111 let text = args[0].as_str().unwrap();
112 let patterns_arr = args[1].as_array().unwrap();
113
114 let patterns: Vec<&str> = patterns_arr.iter().filter_map(|p| p.as_str()).collect();
115
116 if patterns.is_empty() {
117 return Ok(Value::Number(Number::from(0)));
118 }
119
120 let ac = AhoCorasick::new(&patterns).unwrap();
121 let count = ac.find_iter(text).count();
122
123 Ok(Value::Number(Number::from(count)))
124 }
125}
126
127defn!(ReplaceManyFn, vec![arg!(string), arg!(object)], None);
130
131impl Function for ReplaceManyFn {
132 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
133 self.signature.validate(args, ctx)?;
134
135 let text = args[0].as_str().unwrap();
136 let replacements_obj = args[1].as_object().unwrap();
137
138 if replacements_obj.is_empty() {
139 return Ok(Value::String(text.to_string()));
140 }
141
142 let mut patterns: Vec<&str> = Vec::new();
143 let mut replacements: Vec<String> = Vec::new();
144
145 for (pattern, replacement) in replacements_obj.iter() {
146 patterns.push(pattern);
147 if let Some(s) = replacement.as_str() {
148 replacements.push(s.to_string());
149 } else {
150 replacements.push(replacement.to_string());
151 }
152 }
153
154 let ac = AhoCorasick::new(&patterns).unwrap();
155 let result = ac.replace_all(text, &replacements);
156
157 Ok(Value::String(result))
158 }
159}
160
161defn!(ExtractAllFn, vec![arg!(string), arg!(array)], None);
163
164impl Function for ExtractAllFn {
165 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
166 self.signature.validate(args, ctx)?;
167
168 let text = args[0].as_str().unwrap();
169 let patterns_arr = args[1].as_array().unwrap();
170
171 let patterns: Vec<&str> = patterns_arr.iter().filter_map(|p| p.as_str()).collect();
172
173 if patterns.is_empty() {
174 return Ok(Value::Array(vec![]));
175 }
176
177 let ac = AhoCorasick::new(&patterns).unwrap();
178 let mut results: Vec<Value> = Vec::new();
179
180 for mat in ac.find_iter(text) {
181 let mut obj = serde_json::Map::new();
182 obj.insert(
183 "pattern".to_string(),
184 Value::String(patterns[mat.pattern().as_usize()].to_string()),
185 );
186 obj.insert(
187 "match".to_string(),
188 Value::String(text[mat.start()..mat.end()].to_string()),
189 );
190 obj.insert(
191 "start".to_string(),
192 Value::Number(Number::from(mat.start())),
193 );
194 obj.insert("end".to_string(), Value::Number(Number::from(mat.end())));
195 results.push(Value::Object(obj));
196 }
197
198 Ok(Value::Array(results))
199 }
200}
201
202defn!(MatchPositionsFn, vec![arg!(string), arg!(array)], None);
204
205impl Function for MatchPositionsFn {
206 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
207 self.signature.validate(args, ctx)?;
208
209 let text = args[0].as_str().unwrap();
210 let patterns_arr = args[1].as_array().unwrap();
211
212 let patterns: Vec<&str> = patterns_arr.iter().filter_map(|p| p.as_str()).collect();
213
214 if patterns.is_empty() {
215 return Ok(Value::Array(vec![]));
216 }
217
218 let ac = AhoCorasick::new(&patterns).unwrap();
219 let mut results: Vec<Value> = Vec::new();
220
221 for mat in ac.find_iter(text) {
222 let mut obj = serde_json::Map::new();
223 obj.insert(
224 "pattern".to_string(),
225 Value::String(patterns[mat.pattern().as_usize()].to_string()),
226 );
227 obj.insert(
228 "start".to_string(),
229 Value::Number(Number::from(mat.start())),
230 );
231 obj.insert("end".to_string(), Value::Number(Number::from(mat.end())));
232 results.push(Value::Object(obj));
233 }
234
235 Ok(Value::Array(results))
236 }
237}
238
239defn!(MmTokenizeFn, vec![arg!(string)], Some(arg!(any)));
242
243impl Function for MmTokenizeFn {
244 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
245 self.signature.validate(args, ctx)?;
246
247 let text = args[0].as_str().unwrap();
248
249 let lowercase = args
250 .get(1)
251 .and_then(|v| v.as_object())
252 .and_then(|obj| obj.get("lowercase"))
253 .and_then(|v| v.as_bool())
254 .unwrap_or(false);
255
256 let min_length = args
257 .get(1)
258 .and_then(|v| v.as_object())
259 .and_then(|obj| obj.get("min_length"))
260 .and_then(|v| v.as_f64())
261 .map(|n| n as usize)
262 .unwrap_or(1);
263
264 let tokens: Vec<Value> = text
265 .split(|c: char| !c.is_alphanumeric())
266 .filter(|s| !s.is_empty() && s.len() >= min_length)
267 .map(|s| {
268 let token = if lowercase {
269 s.to_lowercase()
270 } else {
271 s.to_string()
272 };
273 Value::String(token)
274 })
275 .collect();
276
277 Ok(Value::Array(tokens))
278 }
279}
280
281defn!(
283 ExtractBetweenFn,
284 vec![arg!(string), arg!(string), arg!(string)],
285 None
286);
287
288impl Function for ExtractBetweenFn {
289 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
290 self.signature.validate(args, ctx)?;
291
292 let text = args[0].as_str().unwrap();
293 let start_delim = args[1].as_str().unwrap();
294 let end_delim = args[2].as_str().unwrap();
295
296 if let Some(start_pos) = text.find(start_delim) {
297 let after_start = start_pos + start_delim.len();
298 if let Some(end_pos) = text[after_start..].find(end_delim) {
299 let extracted = &text[after_start..after_start + end_pos];
300 return Ok(Value::String(extracted.to_string()));
301 }
302 }
303
304 Ok(Value::Null)
305 }
306}
307
308defn!(SplitKeepFn, vec![arg!(string), arg!(string)], None);
310
311impl Function for SplitKeepFn {
312 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
313 self.signature.validate(args, ctx)?;
314
315 let text = args[0].as_str().unwrap();
316 let delimiter = args[1].as_str().unwrap();
317
318 if delimiter.is_empty() {
319 return Ok(Value::Array(vec![Value::String(text.to_string())]));
320 }
321
322 let mut result: Vec<Value> = Vec::new();
323 let mut last_end = 0;
324
325 for (start, part) in text.match_indices(delimiter) {
326 if start > last_end {
327 result.push(Value::String(text[last_end..start].to_string()));
328 }
329 result.push(Value::String(part.to_string()));
330 last_end = start + part.len();
331 }
332
333 if last_end < text.len() {
334 result.push(Value::String(text[last_end..].to_string()));
335 }
336
337 Ok(Value::Array(result))
338 }
339}
340
341pub fn register_filtered(runtime: &mut Runtime, enabled: &HashSet<&str>) {
343 register_if_enabled(runtime, "match_any", enabled, Box::new(MatchAnyFn::new()));
344 register_if_enabled(runtime, "match_all", enabled, Box::new(MatchAllFn::new()));
345 register_if_enabled(
346 runtime,
347 "match_which",
348 enabled,
349 Box::new(MatchWhichFn::new()),
350 );
351 register_if_enabled(
352 runtime,
353 "match_count",
354 enabled,
355 Box::new(MatchCountFn::new()),
356 );
357 register_if_enabled(
358 runtime,
359 "replace_many",
360 enabled,
361 Box::new(ReplaceManyFn::new()),
362 );
363 register_if_enabled(
364 runtime,
365 "extract_all",
366 enabled,
367 Box::new(ExtractAllFn::new()),
368 );
369 register_if_enabled(
370 runtime,
371 "match_positions",
372 enabled,
373 Box::new(MatchPositionsFn::new()),
374 );
375 register_if_enabled(
376 runtime,
377 "mm_tokenize",
378 enabled,
379 Box::new(MmTokenizeFn::new()),
380 );
381 register_if_enabled(
382 runtime,
383 "extract_between",
384 enabled,
385 Box::new(ExtractBetweenFn::new()),
386 );
387 register_if_enabled(runtime, "split_keep", enabled, Box::new(SplitKeepFn::new()));
388}
389
390#[cfg(test)]
391mod tests {
392 use crate::Runtime;
393 use serde_json::json;
394
395 fn setup_runtime() -> Runtime {
396 Runtime::builder()
397 .with_standard()
398 .with_all_extensions()
399 .build()
400 }
401
402 #[test]
405 fn test_match_any_found() {
406 let runtime = setup_runtime();
407 let data = json!("an error occurred in the system");
408 let expr = runtime
409 .compile("match_any(@, ['error', 'warning', 'critical'])")
410 .unwrap();
411 let result = expr.search(&data).unwrap();
412 assert_eq!(result, json!(true));
413 }
414
415 #[test]
416 fn test_match_any_not_found() {
417 let runtime = setup_runtime();
418 let data = json!("everything is fine");
419 let expr = runtime
420 .compile("match_any(@, ['error', 'warning', 'critical'])")
421 .unwrap();
422 let result = expr.search(&data).unwrap();
423 assert_eq!(result, json!(false));
424 }
425
426 #[test]
427 fn test_match_any_empty_patterns() {
428 let runtime = setup_runtime();
429 let data = json!({"text": "some text", "patterns": []});
430 let expr = runtime.compile("match_any(text, patterns)").unwrap();
431 let result = expr.search(&data).unwrap();
432 assert_eq!(result, json!(false));
433 }
434
435 #[test]
436 fn test_match_any_multiple_matches() {
437 let runtime = setup_runtime();
438 let data = json!("error and warning detected");
439 let expr = runtime
440 .compile("match_any(@, ['error', 'warning'])")
441 .unwrap();
442 let result = expr.search(&data).unwrap();
443 assert_eq!(result, json!(true));
444 }
445
446 #[test]
449 fn test_match_all_all_found() {
450 let runtime = setup_runtime();
451 let data = json!("error and warning detected");
452 let expr = runtime
453 .compile("match_all(@, ['error', 'warning'])")
454 .unwrap();
455 let result = expr.search(&data).unwrap();
456 assert_eq!(result, json!(true));
457 }
458
459 #[test]
460 fn test_match_all_some_missing() {
461 let runtime = setup_runtime();
462 let data = json!("error detected");
463 let expr = runtime
464 .compile("match_all(@, ['error', 'warning'])")
465 .unwrap();
466 let result = expr.search(&data).unwrap();
467 assert_eq!(result, json!(false));
468 }
469
470 #[test]
471 fn test_match_all_empty_patterns() {
472 let runtime = setup_runtime();
473 let data = json!({"text": "some text", "patterns": []});
474 let expr = runtime.compile("match_all(text, patterns)").unwrap();
475 let result = expr.search(&data).unwrap();
476 assert_eq!(result, json!(true));
477 }
478
479 #[test]
482 fn test_match_which_some_found() {
483 let runtime = setup_runtime();
484 let data = json!("error and warning detected");
485 let expr = runtime
486 .compile("match_which(@, ['error', 'warning', 'critical'])")
487 .unwrap();
488 let result = expr.search(&data).unwrap();
489 let arr = result.as_array().unwrap();
490 assert_eq!(arr.len(), 2);
491 let strs: Vec<&str> = arr.iter().map(|v| v.as_str().unwrap()).collect();
492 assert!(strs.contains(&"error"));
493 assert!(strs.contains(&"warning"));
494 }
495
496 #[test]
497 fn test_match_which_none_found() {
498 let runtime = setup_runtime();
499 let data = json!("everything is fine");
500 let expr = runtime
501 .compile("match_which(@, ['error', 'warning'])")
502 .unwrap();
503 let result = expr.search(&data).unwrap();
504 let arr = result.as_array().unwrap();
505 assert_eq!(arr.len(), 0);
506 }
507
508 #[test]
509 fn test_match_which_preserves_order() {
510 let runtime = setup_runtime();
511 let data = json!("warning then error");
512 let expr = runtime
513 .compile("match_which(@, ['error', 'warning'])")
514 .unwrap();
515 let result = expr.search(&data).unwrap();
516 let arr = result.as_array().unwrap();
517 assert_eq!(arr[0].as_str().unwrap(), "error");
519 assert_eq!(arr[1].as_str().unwrap(), "warning");
520 }
521
522 #[test]
525 fn test_match_count_multiple() {
526 let runtime = setup_runtime();
527 let data = json!("error error warning error");
528 let expr = runtime
529 .compile("match_count(@, ['error', 'warning'])")
530 .unwrap();
531 let result = expr.search(&data).unwrap();
532 assert_eq!(result.as_f64().unwrap(), 4.0);
533 }
534
535 #[test]
536 fn test_match_count_none() {
537 let runtime = setup_runtime();
538 let data = json!("everything is fine");
539 let expr = runtime
540 .compile("match_count(@, ['error', 'warning'])")
541 .unwrap();
542 let result = expr.search(&data).unwrap();
543 assert_eq!(result.as_f64().unwrap(), 0.0);
544 }
545
546 #[test]
547 fn test_match_count_empty_patterns() {
548 let runtime = setup_runtime();
549 let data = json!({"text": "some text", "patterns": []});
550 let expr = runtime.compile("match_count(text, patterns)").unwrap();
551 let result = expr.search(&data).unwrap();
552 assert_eq!(result.as_f64().unwrap(), 0.0);
553 }
554
555 #[test]
558 fn test_replace_many_basic() {
559 let runtime = setup_runtime();
560 let data = json!({"text": "hello world"});
561 let expr = runtime
562 .compile("replace_many(text, {hello: 'hi', world: 'there'})")
563 .unwrap();
564 let result = expr.search(&data).unwrap();
565 assert_eq!(result.as_str().unwrap(), "hi there");
566 }
567
568 #[test]
569 fn test_replace_many_no_matches() {
570 let runtime = setup_runtime();
571 let data = json!({"text": "hello world"});
572 let expr = runtime.compile("replace_many(text, {foo: 'bar'})").unwrap();
573 let result = expr.search(&data).unwrap();
574 assert_eq!(result.as_str().unwrap(), "hello world");
575 }
576
577 #[test]
578 fn test_replace_many_empty_replacements() {
579 let runtime = setup_runtime();
580 let data = json!({"text": "hello world", "replacements": {}});
581 let expr = runtime.compile("replace_many(text, replacements)").unwrap();
582 let result = expr.search(&data).unwrap();
583 assert_eq!(result.as_str().unwrap(), "hello world");
584 }
585
586 #[test]
587 fn test_replace_many_multiple_occurrences() {
588 let runtime = setup_runtime();
589 let data = json!({"text": "error: connection error"});
590 let expr = runtime
591 .compile("replace_many(text, {error: 'ERROR', connection: 'CONN'})")
592 .unwrap();
593 let result = expr.search(&data).unwrap();
594 assert_eq!(result.as_str().unwrap(), "ERROR: CONN ERROR");
595 }
596
597 #[test]
600 fn test_extract_all_basic() {
601 let runtime = setup_runtime();
602 let data = json!("error and warning detected");
603 let expr = runtime
604 .compile("extract_all(@, ['error', 'warning'])")
605 .unwrap();
606 let result = expr.search(&data).unwrap();
607 let arr = result.as_array().unwrap();
608 assert_eq!(arr.len(), 2);
609 let first = arr[0].as_object().unwrap();
610 assert_eq!(first.get("match").unwrap().as_str().unwrap(), "error");
611 assert!(first.get("start").is_some());
612 assert!(first.get("end").is_some());
613 }
614
615 #[test]
616 fn test_extract_all_empty() {
617 let runtime = setup_runtime();
618 let data = json!("no matches here");
619 let expr = runtime
620 .compile("extract_all(@, ['error', 'warning'])")
621 .unwrap();
622 let result = expr.search(&data).unwrap();
623 let arr = result.as_array().unwrap();
624 assert_eq!(arr.len(), 0);
625 }
626
627 #[test]
630 fn test_match_positions_basic() {
631 let runtime = setup_runtime();
632 let data = json!("The quick brown fox");
633 let expr = runtime
634 .compile("match_positions(@, ['quick', 'fox'])")
635 .unwrap();
636 let result = expr.search(&data).unwrap();
637 let arr = result.as_array().unwrap();
638 assert_eq!(arr.len(), 2);
639 let first = arr[0].as_object().unwrap();
640 assert_eq!(first.get("pattern").unwrap().as_str().unwrap(), "quick");
641 assert_eq!(first.get("start").unwrap().as_f64().unwrap() as i64, 4);
642 assert_eq!(first.get("end").unwrap().as_f64().unwrap() as i64, 9);
643 }
644
645 #[test]
646 fn test_mm_tokenize_basic() {
647 let runtime = setup_runtime();
648 let data = json!("Hello, world! This is a test.");
649 let expr = runtime.compile("mm_tokenize(@)").unwrap();
650 let result = expr.search(&data).unwrap();
651 let arr = result.as_array().unwrap();
652 assert!(arr.len() >= 6);
653 assert_eq!(arr[0].as_str().unwrap(), "Hello");
654 }
655
656 #[test]
657 fn test_mm_tokenize_with_options() {
658 let runtime = setup_runtime();
659 let data = json!("Hello, world! A test.");
660 let expr = runtime
661 .compile("mm_tokenize(@, {lowercase: `true`, min_length: `2`})")
662 .unwrap();
663 let result = expr.search(&data).unwrap();
664 let arr = result.as_array().unwrap();
665 let tokens: Vec<&str> = arr.iter().map(|v| v.as_str().unwrap()).collect();
667 assert!(tokens.contains(&"hello"));
668 assert!(tokens.contains(&"world"));
669 assert!(!tokens.iter().any(|t| t.len() < 2));
670 }
671
672 #[test]
675 fn test_extract_between_basic() {
676 let runtime = setup_runtime();
677 let data = json!("<title>Page Title</title>");
678 let expr = runtime
679 .compile("extract_between(@, '<title>', '</title>')")
680 .unwrap();
681 let result = expr.search(&data).unwrap();
682 assert_eq!(result.as_str().unwrap(), "Page Title");
683 }
684
685 #[test]
686 fn test_extract_between_not_found() {
687 let runtime = setup_runtime();
688 let data = json!("no delimiters here");
689 let expr = runtime
690 .compile("extract_between(@, '<start>', '<end>')")
691 .unwrap();
692 let result = expr.search(&data).unwrap();
693 assert!(result.is_null());
694 }
695
696 #[test]
699 fn test_split_keep_basic() {
700 let runtime = setup_runtime();
701 let data = json!("a-b-c");
702 let expr = runtime.compile("split_keep(@, '-')").unwrap();
703 let result = expr.search(&data).unwrap();
704 let arr = result.as_array().unwrap();
705 assert_eq!(arr.len(), 5);
706 assert_eq!(arr[0].as_str().unwrap(), "a");
707 assert_eq!(arr[1].as_str().unwrap(), "-");
708 assert_eq!(arr[2].as_str().unwrap(), "b");
709 }
710
711 #[test]
712 fn test_split_keep_no_delimiter() {
713 let runtime = setup_runtime();
714 let data = json!("no delimiters");
715 let expr = runtime.compile("split_keep(@, '-')").unwrap();
716 let result = expr.search(&data).unwrap();
717 let arr = result.as_array().unwrap();
718 assert_eq!(arr.len(), 1);
719 assert_eq!(arr[0].as_str().unwrap(), "no delimiters");
720 }
721}