1use regex::Regex;
2
3use super::{
4 CaptureGroup, CompiledRegex, EngineError, EngineFlags, EngineKind, EngineResult, Match,
5 RegexEngine,
6};
7
8pub struct RustRegexEngine;
9
10impl RegexEngine for RustRegexEngine {
11 fn kind(&self) -> EngineKind {
12 EngineKind::RustRegex
13 }
14
15 fn compile(&self, pattern: &str, flags: &EngineFlags) -> EngineResult<Box<dyn CompiledRegex>> {
16 let mut flag_prefix = String::new();
17 if flags.case_insensitive {
18 flag_prefix.push('i');
19 }
20 if flags.multi_line {
21 flag_prefix.push('m');
22 }
23 if flags.dot_matches_newline {
24 flag_prefix.push('s');
25 }
26 if flags.unicode {
27 flag_prefix.push('u');
28 }
29 if flags.extended {
30 flag_prefix.push('x');
31 }
32
33 let full_pattern = if flag_prefix.is_empty() {
34 pattern.to_string()
35 } else {
36 format!("(?{flag_prefix}){pattern}")
37 };
38
39 let re = Regex::new(&full_pattern).map_err(|e| EngineError::CompileError(e.to_string()))?;
40
41 Ok(Box::new(RustCompiledRegex { re }))
42 }
43}
44
45struct RustCompiledRegex {
46 re: Regex,
47}
48
49impl CompiledRegex for RustCompiledRegex {
50 fn find_matches(&self, text: &str) -> EngineResult<Vec<Match>> {
51 let mut matches = Vec::new();
52
53 for caps in self.re.captures_iter(text) {
54 let overall = caps.get(0).unwrap();
55 let mut captures = Vec::new();
56
57 for (i, name) in self.re.capture_names().enumerate() {
58 if i == 0 {
59 continue;
60 }
61 if let Some(m) = caps.get(i) {
62 captures.push(CaptureGroup {
63 index: i,
64 name: name.map(String::from),
65 start: m.start(),
66 end: m.end(),
67 text: m.as_str().to_string(),
68 });
69 }
70 }
71
72 matches.push(Match {
73 start: overall.start(),
74 end: overall.end(),
75 text: overall.as_str().to_string(),
76 captures,
77 });
78 }
79
80 Ok(matches)
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87
88 #[test]
89 fn test_simple_match() {
90 let engine = RustRegexEngine;
91 let flags = EngineFlags::default();
92 let compiled = engine.compile(r"\d+", &flags).unwrap();
93 let matches = compiled.find_matches("abc 123 def 456").unwrap();
94 assert_eq!(matches.len(), 2);
95 assert_eq!(matches[0].text, "123");
96 assert_eq!(matches[1].text, "456");
97 }
98
99 #[test]
100 fn test_capture_groups() {
101 let engine = RustRegexEngine;
102 let flags = EngineFlags::default();
103 let compiled = engine.compile(r"(\w+)@(\w+)\.(\w+)", &flags).unwrap();
104 let matches = compiled.find_matches("user@example.com").unwrap();
105 assert_eq!(matches.len(), 1);
106 assert_eq!(matches[0].captures.len(), 3);
107 assert_eq!(matches[0].captures[0].text, "user");
108 assert_eq!(matches[0].captures[1].text, "example");
109 assert_eq!(matches[0].captures[2].text, "com");
110 }
111
112 #[test]
113 fn test_named_captures() {
114 let engine = RustRegexEngine;
115 let flags = EngineFlags::default();
116 let compiled = engine
117 .compile(r"(?P<name>\w+)@(?P<domain>\w+)", &flags)
118 .unwrap();
119 let matches = compiled.find_matches("user@example").unwrap();
120 assert_eq!(matches[0].captures[0].name, Some("name".to_string()));
121 assert_eq!(matches[0].captures[1].name, Some("domain".to_string()));
122 }
123
124 #[test]
125 fn test_case_insensitive() {
126 let engine = RustRegexEngine;
127 let flags = EngineFlags {
128 case_insensitive: true,
129 ..Default::default()
130 };
131 let compiled = engine.compile(r"hello", &flags).unwrap();
132 let matches = compiled.find_matches("Hello HELLO hello").unwrap();
133 assert_eq!(matches.len(), 3);
134 }
135
136 #[test]
137 fn test_invalid_pattern() {
138 let engine = RustRegexEngine;
139 let flags = EngineFlags::default();
140 let result = engine.compile(r"(unclosed", &flags);
141 assert!(result.is_err());
142 }
143}