Skip to main content

jpx_core/extensions/
multi_match.rs

1//! Multi-pattern matching functions.
2
3use std::collections::HashSet;
4
5use aho_corasick::AhoCorasick;
6use serde_json::{Number, Value};
7
8use crate::functions::Function;
9use crate::interpreter::SearchResult;
10use crate::registry::register_if_enabled;
11use crate::{Context, Runtime, arg, defn};
12
13// match_any(string, patterns) -> boolean
14// Returns true if any of the patterns match the string
15defn!(MatchAnyFn, vec![arg!(string), arg!(array)], None);
16
17impl Function for MatchAnyFn {
18    fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
19        self.signature.validate(args, ctx)?;
20
21        let text = args[0].as_str().unwrap();
22        let patterns_arr = args[1].as_array().unwrap();
23
24        let patterns: Vec<&str> = patterns_arr.iter().filter_map(|p| p.as_str()).collect();
25
26        if patterns.is_empty() {
27            return Ok(Value::Bool(false));
28        }
29
30        let ac = AhoCorasick::new(&patterns).unwrap();
31        let has_match = ac.find(text).is_some();
32
33        Ok(Value::Bool(has_match))
34    }
35}
36
37// match_all(string, patterns) -> boolean
38// Returns true if all patterns match the string
39defn!(MatchAllFn, vec![arg!(string), arg!(array)], None);
40
41impl Function for MatchAllFn {
42    fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
43        self.signature.validate(args, ctx)?;
44
45        let text = args[0].as_str().unwrap();
46        let patterns_arr = args[1].as_array().unwrap();
47
48        let patterns: Vec<&str> = patterns_arr.iter().filter_map(|p| p.as_str()).collect();
49
50        if patterns.is_empty() {
51            return Ok(Value::Bool(true));
52        }
53
54        let ac = AhoCorasick::new(&patterns).unwrap();
55
56        let mut found = vec![false; patterns.len()];
57
58        for mat in ac.find_iter(text) {
59            found[mat.pattern().as_usize()] = true;
60        }
61
62        let all_found = found.iter().all(|&f| f);
63        Ok(Value::Bool(all_found))
64    }
65}
66
67// match_which(string, patterns) -> array
68// Returns array of patterns that matched
69defn!(MatchWhichFn, vec![arg!(string), arg!(array)], None);
70
71impl Function for MatchWhichFn {
72    fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
73        self.signature.validate(args, ctx)?;
74
75        let text = args[0].as_str().unwrap();
76        let patterns_arr = args[1].as_array().unwrap();
77
78        let patterns: Vec<&str> = patterns_arr.iter().filter_map(|p| p.as_str()).collect();
79
80        if patterns.is_empty() {
81            return Ok(Value::Array(vec![]));
82        }
83
84        let ac = AhoCorasick::new(&patterns).unwrap();
85
86        let mut found = vec![false; patterns.len()];
87
88        for mat in ac.find_iter(text) {
89            found[mat.pattern().as_usize()] = true;
90        }
91
92        let matched: Vec<Value> = patterns
93            .iter()
94            .enumerate()
95            .filter(|(i, _)| found[*i])
96            .map(|(_, p)| Value::String((*p).to_string()))
97            .collect();
98
99        Ok(Value::Array(matched))
100    }
101}
102
103// match_count(string, patterns) -> number
104// Count total number of matches (non-overlapping) across all patterns
105defn!(MatchCountFn, vec![arg!(string), arg!(array)], None);
106
107impl Function for MatchCountFn {
108    fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
109        self.signature.validate(args, ctx)?;
110
111        let text = args[0].as_str().unwrap();
112        let patterns_arr = args[1].as_array().unwrap();
113
114        let patterns: Vec<&str> = patterns_arr.iter().filter_map(|p| p.as_str()).collect();
115
116        if patterns.is_empty() {
117            return Ok(Value::Number(Number::from(0)));
118        }
119
120        let ac = AhoCorasick::new(&patterns).unwrap();
121        let count = ac.find_iter(text).count();
122
123        Ok(Value::Number(Number::from(count)))
124    }
125}
126
127// replace_many(string, replacements) -> string
128// Replace multiple patterns at once. replacements is an object {pattern: replacement, ...}
129defn!(ReplaceManyFn, vec![arg!(string), arg!(object)], None);
130
131impl Function for ReplaceManyFn {
132    fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
133        self.signature.validate(args, ctx)?;
134
135        let text = args[0].as_str().unwrap();
136        let replacements_obj = args[1].as_object().unwrap();
137
138        if replacements_obj.is_empty() {
139            return Ok(Value::String(text.to_string()));
140        }
141
142        let mut patterns: Vec<&str> = Vec::new();
143        let mut replacements: Vec<String> = Vec::new();
144
145        for (pattern, replacement) in replacements_obj.iter() {
146            patterns.push(pattern);
147            if let Some(s) = replacement.as_str() {
148                replacements.push(s.to_string());
149            } else {
150                replacements.push(replacement.to_string());
151            }
152        }
153
154        let ac = AhoCorasick::new(&patterns).unwrap();
155        let result = ac.replace_all(text, &replacements);
156
157        Ok(Value::String(result))
158    }
159}
160
161// extract_all(string, patterns) -> array of matches with pattern info
162defn!(ExtractAllFn, vec![arg!(string), arg!(array)], None);
163
164impl Function for ExtractAllFn {
165    fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
166        self.signature.validate(args, ctx)?;
167
168        let text = args[0].as_str().unwrap();
169        let patterns_arr = args[1].as_array().unwrap();
170
171        let patterns: Vec<&str> = patterns_arr.iter().filter_map(|p| p.as_str()).collect();
172
173        if patterns.is_empty() {
174            return Ok(Value::Array(vec![]));
175        }
176
177        let ac = AhoCorasick::new(&patterns).unwrap();
178        let mut results: Vec<Value> = Vec::new();
179
180        for mat in ac.find_iter(text) {
181            let mut obj = serde_json::Map::new();
182            obj.insert(
183                "pattern".to_string(),
184                Value::String(patterns[mat.pattern().as_usize()].to_string()),
185            );
186            obj.insert(
187                "match".to_string(),
188                Value::String(text[mat.start()..mat.end()].to_string()),
189            );
190            obj.insert(
191                "start".to_string(),
192                Value::Number(Number::from(mat.start())),
193            );
194            obj.insert("end".to_string(), Value::Number(Number::from(mat.end())));
195            results.push(Value::Object(obj));
196        }
197
198        Ok(Value::Array(results))
199    }
200}
201
202// match_positions(string, patterns) -> array of positions
203defn!(MatchPositionsFn, vec![arg!(string), arg!(array)], None);
204
205impl Function for MatchPositionsFn {
206    fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
207        self.signature.validate(args, ctx)?;
208
209        let text = args[0].as_str().unwrap();
210        let patterns_arr = args[1].as_array().unwrap();
211
212        let patterns: Vec<&str> = patterns_arr.iter().filter_map(|p| p.as_str()).collect();
213
214        if patterns.is_empty() {
215            return Ok(Value::Array(vec![]));
216        }
217
218        let ac = AhoCorasick::new(&patterns).unwrap();
219        let mut results: Vec<Value> = Vec::new();
220
221        for mat in ac.find_iter(text) {
222            let mut obj = serde_json::Map::new();
223            obj.insert(
224                "pattern".to_string(),
225                Value::String(patterns[mat.pattern().as_usize()].to_string()),
226            );
227            obj.insert(
228                "start".to_string(),
229                Value::Number(Number::from(mat.start())),
230            );
231            obj.insert("end".to_string(), Value::Number(Number::from(mat.end())));
232            results.push(Value::Object(obj));
233        }
234
235        Ok(Value::Array(results))
236    }
237}
238
239// mm_tokenize(string, options?) -> array of tokens
240// Smart word tokenization with optional configuration
241defn!(MmTokenizeFn, vec![arg!(string)], Some(arg!(any)));
242
243impl Function for MmTokenizeFn {
244    fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
245        self.signature.validate(args, ctx)?;
246
247        let text = args[0].as_str().unwrap();
248
249        let lowercase = args
250            .get(1)
251            .and_then(|v| v.as_object())
252            .and_then(|obj| obj.get("lowercase"))
253            .and_then(|v| v.as_bool())
254            .unwrap_or(false);
255
256        let min_length = args
257            .get(1)
258            .and_then(|v| v.as_object())
259            .and_then(|obj| obj.get("min_length"))
260            .and_then(|v| v.as_f64())
261            .map(|n| n as usize)
262            .unwrap_or(1);
263
264        let tokens: Vec<Value> = text
265            .split(|c: char| !c.is_alphanumeric())
266            .filter(|s| !s.is_empty() && s.len() >= min_length)
267            .map(|s| {
268                let token = if lowercase {
269                    s.to_lowercase()
270                } else {
271                    s.to_string()
272                };
273                Value::String(token)
274            })
275            .collect();
276
277        Ok(Value::Array(tokens))
278    }
279}
280
281// extract_between(string, start, end) -> string or null
282defn!(
283    ExtractBetweenFn,
284    vec![arg!(string), arg!(string), arg!(string)],
285    None
286);
287
288impl Function for ExtractBetweenFn {
289    fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
290        self.signature.validate(args, ctx)?;
291
292        let text = args[0].as_str().unwrap();
293        let start_delim = args[1].as_str().unwrap();
294        let end_delim = args[2].as_str().unwrap();
295
296        if let Some(start_pos) = text.find(start_delim) {
297            let after_start = start_pos + start_delim.len();
298            if let Some(end_pos) = text[after_start..].find(end_delim) {
299                let extracted = &text[after_start..after_start + end_pos];
300                return Ok(Value::String(extracted.to_string()));
301            }
302        }
303
304        Ok(Value::Null)
305    }
306}
307
308// split_keep(string, delimiter) -> array keeping delimiters
309defn!(SplitKeepFn, vec![arg!(string), arg!(string)], None);
310
311impl Function for SplitKeepFn {
312    fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
313        self.signature.validate(args, ctx)?;
314
315        let text = args[0].as_str().unwrap();
316        let delimiter = args[1].as_str().unwrap();
317
318        if delimiter.is_empty() {
319            return Ok(Value::Array(vec![Value::String(text.to_string())]));
320        }
321
322        let mut result: Vec<Value> = Vec::new();
323        let mut last_end = 0;
324
325        for (start, part) in text.match_indices(delimiter) {
326            if start > last_end {
327                result.push(Value::String(text[last_end..start].to_string()));
328            }
329            result.push(Value::String(part.to_string()));
330            last_end = start + part.len();
331        }
332
333        if last_end < text.len() {
334            result.push(Value::String(text[last_end..].to_string()));
335        }
336
337        Ok(Value::Array(result))
338    }
339}
340
341/// Register multi-match functions filtered by the enabled set.
342pub fn register_filtered(runtime: &mut Runtime, enabled: &HashSet<&str>) {
343    register_if_enabled(runtime, "match_any", enabled, Box::new(MatchAnyFn::new()));
344    register_if_enabled(runtime, "match_all", enabled, Box::new(MatchAllFn::new()));
345    register_if_enabled(
346        runtime,
347        "match_which",
348        enabled,
349        Box::new(MatchWhichFn::new()),
350    );
351    register_if_enabled(
352        runtime,
353        "match_count",
354        enabled,
355        Box::new(MatchCountFn::new()),
356    );
357    register_if_enabled(
358        runtime,
359        "replace_many",
360        enabled,
361        Box::new(ReplaceManyFn::new()),
362    );
363    register_if_enabled(
364        runtime,
365        "extract_all",
366        enabled,
367        Box::new(ExtractAllFn::new()),
368    );
369    register_if_enabled(
370        runtime,
371        "match_positions",
372        enabled,
373        Box::new(MatchPositionsFn::new()),
374    );
375    register_if_enabled(
376        runtime,
377        "mm_tokenize",
378        enabled,
379        Box::new(MmTokenizeFn::new()),
380    );
381    register_if_enabled(
382        runtime,
383        "extract_between",
384        enabled,
385        Box::new(ExtractBetweenFn::new()),
386    );
387    register_if_enabled(runtime, "split_keep", enabled, Box::new(SplitKeepFn::new()));
388}
389
390#[cfg(test)]
391mod tests {
392    use crate::Runtime;
393    use serde_json::json;
394
395    fn setup_runtime() -> Runtime {
396        Runtime::builder()
397            .with_standard()
398            .with_all_extensions()
399            .build()
400    }
401
402    // match_any tests
403
404    #[test]
405    fn test_match_any_found() {
406        let runtime = setup_runtime();
407        let data = json!("an error occurred in the system");
408        let expr = runtime
409            .compile("match_any(@, ['error', 'warning', 'critical'])")
410            .unwrap();
411        let result = expr.search(&data).unwrap();
412        assert_eq!(result, json!(true));
413    }
414
415    #[test]
416    fn test_match_any_not_found() {
417        let runtime = setup_runtime();
418        let data = json!("everything is fine");
419        let expr = runtime
420            .compile("match_any(@, ['error', 'warning', 'critical'])")
421            .unwrap();
422        let result = expr.search(&data).unwrap();
423        assert_eq!(result, json!(false));
424    }
425
426    #[test]
427    fn test_match_any_empty_patterns() {
428        let runtime = setup_runtime();
429        let data = json!({"text": "some text", "patterns": []});
430        let expr = runtime.compile("match_any(text, patterns)").unwrap();
431        let result = expr.search(&data).unwrap();
432        assert_eq!(result, json!(false));
433    }
434
435    #[test]
436    fn test_match_any_multiple_matches() {
437        let runtime = setup_runtime();
438        let data = json!("error and warning detected");
439        let expr = runtime
440            .compile("match_any(@, ['error', 'warning'])")
441            .unwrap();
442        let result = expr.search(&data).unwrap();
443        assert_eq!(result, json!(true));
444    }
445
446    // match_all tests
447
448    #[test]
449    fn test_match_all_all_found() {
450        let runtime = setup_runtime();
451        let data = json!("error and warning detected");
452        let expr = runtime
453            .compile("match_all(@, ['error', 'warning'])")
454            .unwrap();
455        let result = expr.search(&data).unwrap();
456        assert_eq!(result, json!(true));
457    }
458
459    #[test]
460    fn test_match_all_some_missing() {
461        let runtime = setup_runtime();
462        let data = json!("error detected");
463        let expr = runtime
464            .compile("match_all(@, ['error', 'warning'])")
465            .unwrap();
466        let result = expr.search(&data).unwrap();
467        assert_eq!(result, json!(false));
468    }
469
470    #[test]
471    fn test_match_all_empty_patterns() {
472        let runtime = setup_runtime();
473        let data = json!({"text": "some text", "patterns": []});
474        let expr = runtime.compile("match_all(text, patterns)").unwrap();
475        let result = expr.search(&data).unwrap();
476        assert_eq!(result, json!(true));
477    }
478
479    // match_which tests
480
481    #[test]
482    fn test_match_which_some_found() {
483        let runtime = setup_runtime();
484        let data = json!("error and warning detected");
485        let expr = runtime
486            .compile("match_which(@, ['error', 'warning', 'critical'])")
487            .unwrap();
488        let result = expr.search(&data).unwrap();
489        let arr = result.as_array().unwrap();
490        assert_eq!(arr.len(), 2);
491        let strs: Vec<&str> = arr.iter().map(|v| v.as_str().unwrap()).collect();
492        assert!(strs.contains(&"error"));
493        assert!(strs.contains(&"warning"));
494    }
495
496    #[test]
497    fn test_match_which_none_found() {
498        let runtime = setup_runtime();
499        let data = json!("everything is fine");
500        let expr = runtime
501            .compile("match_which(@, ['error', 'warning'])")
502            .unwrap();
503        let result = expr.search(&data).unwrap();
504        let arr = result.as_array().unwrap();
505        assert_eq!(arr.len(), 0);
506    }
507
508    #[test]
509    fn test_match_which_preserves_order() {
510        let runtime = setup_runtime();
511        let data = json!("warning then error");
512        let expr = runtime
513            .compile("match_which(@, ['error', 'warning'])")
514            .unwrap();
515        let result = expr.search(&data).unwrap();
516        let arr = result.as_array().unwrap();
517        // Should return in pattern order, not match order
518        assert_eq!(arr[0].as_str().unwrap(), "error");
519        assert_eq!(arr[1].as_str().unwrap(), "warning");
520    }
521
522    // match_count tests
523
524    #[test]
525    fn test_match_count_multiple() {
526        let runtime = setup_runtime();
527        let data = json!("error error warning error");
528        let expr = runtime
529            .compile("match_count(@, ['error', 'warning'])")
530            .unwrap();
531        let result = expr.search(&data).unwrap();
532        assert_eq!(result.as_f64().unwrap(), 4.0);
533    }
534
535    #[test]
536    fn test_match_count_none() {
537        let runtime = setup_runtime();
538        let data = json!("everything is fine");
539        let expr = runtime
540            .compile("match_count(@, ['error', 'warning'])")
541            .unwrap();
542        let result = expr.search(&data).unwrap();
543        assert_eq!(result.as_f64().unwrap(), 0.0);
544    }
545
546    #[test]
547    fn test_match_count_empty_patterns() {
548        let runtime = setup_runtime();
549        let data = json!({"text": "some text", "patterns": []});
550        let expr = runtime.compile("match_count(text, patterns)").unwrap();
551        let result = expr.search(&data).unwrap();
552        assert_eq!(result.as_f64().unwrap(), 0.0);
553    }
554
555    // replace_many tests
556
557    #[test]
558    fn test_replace_many_basic() {
559        let runtime = setup_runtime();
560        let data = json!({"text": "hello world"});
561        let expr = runtime
562            .compile("replace_many(text, {hello: 'hi', world: 'there'})")
563            .unwrap();
564        let result = expr.search(&data).unwrap();
565        assert_eq!(result.as_str().unwrap(), "hi there");
566    }
567
568    #[test]
569    fn test_replace_many_no_matches() {
570        let runtime = setup_runtime();
571        let data = json!({"text": "hello world"});
572        let expr = runtime.compile("replace_many(text, {foo: 'bar'})").unwrap();
573        let result = expr.search(&data).unwrap();
574        assert_eq!(result.as_str().unwrap(), "hello world");
575    }
576
577    #[test]
578    fn test_replace_many_empty_replacements() {
579        let runtime = setup_runtime();
580        let data = json!({"text": "hello world", "replacements": {}});
581        let expr = runtime.compile("replace_many(text, replacements)").unwrap();
582        let result = expr.search(&data).unwrap();
583        assert_eq!(result.as_str().unwrap(), "hello world");
584    }
585
586    #[test]
587    fn test_replace_many_multiple_occurrences() {
588        let runtime = setup_runtime();
589        let data = json!({"text": "error: connection error"});
590        let expr = runtime
591            .compile("replace_many(text, {error: 'ERROR', connection: 'CONN'})")
592            .unwrap();
593        let result = expr.search(&data).unwrap();
594        assert_eq!(result.as_str().unwrap(), "ERROR: CONN ERROR");
595    }
596
597    // extract_all tests
598
599    #[test]
600    fn test_extract_all_basic() {
601        let runtime = setup_runtime();
602        let data = json!("error and warning detected");
603        let expr = runtime
604            .compile("extract_all(@, ['error', 'warning'])")
605            .unwrap();
606        let result = expr.search(&data).unwrap();
607        let arr = result.as_array().unwrap();
608        assert_eq!(arr.len(), 2);
609        let first = arr[0].as_object().unwrap();
610        assert_eq!(first.get("match").unwrap().as_str().unwrap(), "error");
611        assert!(first.get("start").is_some());
612        assert!(first.get("end").is_some());
613    }
614
615    #[test]
616    fn test_extract_all_empty() {
617        let runtime = setup_runtime();
618        let data = json!("no matches here");
619        let expr = runtime
620            .compile("extract_all(@, ['error', 'warning'])")
621            .unwrap();
622        let result = expr.search(&data).unwrap();
623        let arr = result.as_array().unwrap();
624        assert_eq!(arr.len(), 0);
625    }
626
627    // match_positions tests
628
629    #[test]
630    fn test_match_positions_basic() {
631        let runtime = setup_runtime();
632        let data = json!("The quick brown fox");
633        let expr = runtime
634            .compile("match_positions(@, ['quick', 'fox'])")
635            .unwrap();
636        let result = expr.search(&data).unwrap();
637        let arr = result.as_array().unwrap();
638        assert_eq!(arr.len(), 2);
639        let first = arr[0].as_object().unwrap();
640        assert_eq!(first.get("pattern").unwrap().as_str().unwrap(), "quick");
641        assert_eq!(first.get("start").unwrap().as_f64().unwrap() as i64, 4);
642        assert_eq!(first.get("end").unwrap().as_f64().unwrap() as i64, 9);
643    }
644
645    #[test]
646    fn test_mm_tokenize_basic() {
647        let runtime = setup_runtime();
648        let data = json!("Hello, world! This is a test.");
649        let expr = runtime.compile("mm_tokenize(@)").unwrap();
650        let result = expr.search(&data).unwrap();
651        let arr = result.as_array().unwrap();
652        assert!(arr.len() >= 6);
653        assert_eq!(arr[0].as_str().unwrap(), "Hello");
654    }
655
656    #[test]
657    fn test_mm_tokenize_with_options() {
658        let runtime = setup_runtime();
659        let data = json!("Hello, world! A test.");
660        let expr = runtime
661            .compile("mm_tokenize(@, {lowercase: `true`, min_length: `2`})")
662            .unwrap();
663        let result = expr.search(&data).unwrap();
664        let arr = result.as_array().unwrap();
665        // Should have: hello, world, test (not "A" due to min_length)
666        let tokens: Vec<&str> = arr.iter().map(|v| v.as_str().unwrap()).collect();
667        assert!(tokens.contains(&"hello"));
668        assert!(tokens.contains(&"world"));
669        assert!(!tokens.iter().any(|t| t.len() < 2));
670    }
671
672    // extract_between tests
673
674    #[test]
675    fn test_extract_between_basic() {
676        let runtime = setup_runtime();
677        let data = json!("<title>Page Title</title>");
678        let expr = runtime
679            .compile("extract_between(@, '<title>', '</title>')")
680            .unwrap();
681        let result = expr.search(&data).unwrap();
682        assert_eq!(result.as_str().unwrap(), "Page Title");
683    }
684
685    #[test]
686    fn test_extract_between_not_found() {
687        let runtime = setup_runtime();
688        let data = json!("no delimiters here");
689        let expr = runtime
690            .compile("extract_between(@, '<start>', '<end>')")
691            .unwrap();
692        let result = expr.search(&data).unwrap();
693        assert!(result.is_null());
694    }
695
696    // split_keep tests
697
698    #[test]
699    fn test_split_keep_basic() {
700        let runtime = setup_runtime();
701        let data = json!("a-b-c");
702        let expr = runtime.compile("split_keep(@, '-')").unwrap();
703        let result = expr.search(&data).unwrap();
704        let arr = result.as_array().unwrap();
705        assert_eq!(arr.len(), 5);
706        assert_eq!(arr[0].as_str().unwrap(), "a");
707        assert_eq!(arr[1].as_str().unwrap(), "-");
708        assert_eq!(arr[2].as_str().unwrap(), "b");
709    }
710
711    #[test]
712    fn test_split_keep_no_delimiter() {
713        let runtime = setup_runtime();
714        let data = json!("no delimiters");
715        let expr = runtime.compile("split_keep(@, '-')").unwrap();
716        let result = expr.search(&data).unwrap();
717        let arr = result.as_array().unwrap();
718        assert_eq!(arr.len(), 1);
719        assert_eq!(arr[0].as_str().unwrap(), "no delimiters");
720    }
721}