Skip to main content

rgx/engine/
rust_regex.rs

1use regex::Regex;
2
3use super::{
4    CaptureGroup, CompiledRegex, EngineError, EngineFlags, EngineKind, EngineResult, Match,
5    RegexEngine,
6};
7
8pub struct RustRegexEngine;
9
10impl RegexEngine for RustRegexEngine {
11    fn kind(&self) -> EngineKind {
12        EngineKind::RustRegex
13    }
14
15    fn compile(&self, pattern: &str, flags: &EngineFlags) -> EngineResult<Box<dyn CompiledRegex>> {
16        let full_pattern = flags.wrap_pattern(pattern);
17        let re = Regex::new(&full_pattern).map_err(|e| EngineError::CompileError(e.to_string()))?;
18
19        Ok(Box::new(RustCompiledRegex { re }))
20    }
21}
22
23struct RustCompiledRegex {
24    re: Regex,
25}
26
27impl CompiledRegex for RustCompiledRegex {
28    fn find_matches(&self, text: &str) -> EngineResult<Vec<Match>> {
29        let mut matches = Vec::new();
30
31        for caps in self.re.captures_iter(text) {
32            let overall = caps.get(0).unwrap();
33            let mut captures = Vec::new();
34
35            for (i, name) in self.re.capture_names().enumerate() {
36                if i == 0 {
37                    continue;
38                }
39                if let Some(m) = caps.get(i) {
40                    captures.push(CaptureGroup {
41                        index: i,
42                        name: name.map(String::from),
43                        start: m.start(),
44                        end: m.end(),
45                        text: m.as_str().to_string(),
46                    });
47                }
48            }
49
50            matches.push(Match {
51                start: overall.start(),
52                end: overall.end(),
53                text: overall.as_str().to_string(),
54                captures,
55            });
56        }
57
58        Ok(matches)
59    }
60}
61
62#[cfg(test)]
63mod tests {
64    use super::*;
65
66    #[test]
67    fn test_simple_match() {
68        let engine = RustRegexEngine;
69        let flags = EngineFlags::default();
70        let compiled = engine.compile(r"\d+", &flags).unwrap();
71        let matches = compiled.find_matches("abc 123 def 456").unwrap();
72        assert_eq!(matches.len(), 2);
73        assert_eq!(matches[0].text, "123");
74        assert_eq!(matches[1].text, "456");
75    }
76
77    #[test]
78    fn test_capture_groups() {
79        let engine = RustRegexEngine;
80        let flags = EngineFlags::default();
81        let compiled = engine.compile(r"(\w+)@(\w+)\.(\w+)", &flags).unwrap();
82        let matches = compiled.find_matches("user@example.com").unwrap();
83        assert_eq!(matches.len(), 1);
84        assert_eq!(matches[0].captures.len(), 3);
85        assert_eq!(matches[0].captures[0].text, "user");
86        assert_eq!(matches[0].captures[1].text, "example");
87        assert_eq!(matches[0].captures[2].text, "com");
88    }
89
90    #[test]
91    fn test_named_captures() {
92        let engine = RustRegexEngine;
93        let flags = EngineFlags::default();
94        let compiled = engine
95            .compile(r"(?P<name>\w+)@(?P<domain>\w+)", &flags)
96            .unwrap();
97        let matches = compiled.find_matches("user@example").unwrap();
98        assert_eq!(matches[0].captures[0].name, Some("name".to_string()));
99        assert_eq!(matches[0].captures[1].name, Some("domain".to_string()));
100    }
101
102    #[test]
103    fn test_case_insensitive() {
104        let engine = RustRegexEngine;
105        let flags = EngineFlags {
106            case_insensitive: true,
107            ..Default::default()
108        };
109        let compiled = engine.compile(r"hello", &flags).unwrap();
110        let matches = compiled.find_matches("Hello HELLO hello").unwrap();
111        assert_eq!(matches.len(), 3);
112    }
113
114    #[test]
115    fn test_invalid_pattern() {
116        let engine = RustRegexEngine;
117        let flags = EngineFlags::default();
118        let result = engine.compile(r"(unclosed", &flags);
119        assert!(result.is_err());
120    }
121}