Skip to main content

rgx/engine/
rust_regex.rs

1use regex::Regex;
2
3use super::{
4    CaptureGroup, CompiledRegex, EngineError, EngineFlags, EngineKind, EngineResult, Match,
5    RegexEngine,
6};
7
8pub struct RustRegexEngine;
9
10impl RegexEngine for RustRegexEngine {
11    fn kind(&self) -> EngineKind {
12        EngineKind::RustRegex
13    }
14
15    fn compile(&self, pattern: &str, flags: &EngineFlags) -> EngineResult<Box<dyn CompiledRegex>> {
16        let mut flag_prefix = String::new();
17        if flags.case_insensitive {
18            flag_prefix.push('i');
19        }
20        if flags.multi_line {
21            flag_prefix.push('m');
22        }
23        if flags.dot_matches_newline {
24            flag_prefix.push('s');
25        }
26        if flags.unicode {
27            flag_prefix.push('u');
28        }
29        if flags.extended {
30            flag_prefix.push('x');
31        }
32
33        let full_pattern = if flag_prefix.is_empty() {
34            pattern.to_string()
35        } else {
36            format!("(?{flag_prefix}){pattern}")
37        };
38
39        let re = Regex::new(&full_pattern).map_err(|e| EngineError::CompileError(e.to_string()))?;
40
41        Ok(Box::new(RustCompiledRegex { re }))
42    }
43}
44
45struct RustCompiledRegex {
46    re: Regex,
47}
48
49impl CompiledRegex for RustCompiledRegex {
50    fn find_matches(&self, text: &str) -> EngineResult<Vec<Match>> {
51        let mut matches = Vec::new();
52
53        for caps in self.re.captures_iter(text) {
54            let overall = caps.get(0).unwrap();
55            let mut captures = Vec::new();
56
57            for (i, name) in self.re.capture_names().enumerate() {
58                if i == 0 {
59                    continue;
60                }
61                if let Some(m) = caps.get(i) {
62                    captures.push(CaptureGroup {
63                        index: i,
64                        name: name.map(String::from),
65                        start: m.start(),
66                        end: m.end(),
67                        text: m.as_str().to_string(),
68                    });
69                }
70            }
71
72            matches.push(Match {
73                start: overall.start(),
74                end: overall.end(),
75                text: overall.as_str().to_string(),
76                captures,
77            });
78        }
79
80        Ok(matches)
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87
88    #[test]
89    fn test_simple_match() {
90        let engine = RustRegexEngine;
91        let flags = EngineFlags::default();
92        let compiled = engine.compile(r"\d+", &flags).unwrap();
93        let matches = compiled.find_matches("abc 123 def 456").unwrap();
94        assert_eq!(matches.len(), 2);
95        assert_eq!(matches[0].text, "123");
96        assert_eq!(matches[1].text, "456");
97    }
98
99    #[test]
100    fn test_capture_groups() {
101        let engine = RustRegexEngine;
102        let flags = EngineFlags::default();
103        let compiled = engine.compile(r"(\w+)@(\w+)\.(\w+)", &flags).unwrap();
104        let matches = compiled.find_matches("user@example.com").unwrap();
105        assert_eq!(matches.len(), 1);
106        assert_eq!(matches[0].captures.len(), 3);
107        assert_eq!(matches[0].captures[0].text, "user");
108        assert_eq!(matches[0].captures[1].text, "example");
109        assert_eq!(matches[0].captures[2].text, "com");
110    }
111
112    #[test]
113    fn test_named_captures() {
114        let engine = RustRegexEngine;
115        let flags = EngineFlags::default();
116        let compiled = engine
117            .compile(r"(?P<name>\w+)@(?P<domain>\w+)", &flags)
118            .unwrap();
119        let matches = compiled.find_matches("user@example").unwrap();
120        assert_eq!(matches[0].captures[0].name, Some("name".to_string()));
121        assert_eq!(matches[0].captures[1].name, Some("domain".to_string()));
122    }
123
124    #[test]
125    fn test_case_insensitive() {
126        let engine = RustRegexEngine;
127        let flags = EngineFlags {
128            case_insensitive: true,
129            ..Default::default()
130        };
131        let compiled = engine.compile(r"hello", &flags).unwrap();
132        let matches = compiled.find_matches("Hello HELLO hello").unwrap();
133        assert_eq!(matches.len(), 3);
134    }
135
136    #[test]
137    fn test_invalid_pattern() {
138        let engine = RustRegexEngine;
139        let flags = EngineFlags::default();
140        let result = engine.compile(r"(unclosed", &flags);
141        assert!(result.is_err());
142    }
143}