1use regex::Regex;
2
3use super::{
4 CaptureGroup, CompiledRegex, EngineError, EngineFlags, EngineKind, EngineResult, Match,
5 RegexEngine,
6};
7
8pub struct RustRegexEngine;
9
10impl RegexEngine for RustRegexEngine {
11 fn kind(&self) -> EngineKind {
12 EngineKind::RustRegex
13 }
14
15 fn compile(&self, pattern: &str, flags: &EngineFlags) -> EngineResult<Box<dyn CompiledRegex>> {
16 let full_pattern = flags.wrap_pattern(pattern);
17 let re = Regex::new(&full_pattern).map_err(|e| EngineError::CompileError(e.to_string()))?;
18
19 Ok(Box::new(RustCompiledRegex { re }))
20 }
21}
22
23struct RustCompiledRegex {
24 re: Regex,
25}
26
27impl CompiledRegex for RustCompiledRegex {
28 fn find_matches(&self, text: &str) -> EngineResult<Vec<Match>> {
29 let mut matches = Vec::new();
30
31 for caps in self.re.captures_iter(text) {
32 let overall = caps.get(0).unwrap();
33 let mut captures = Vec::new();
34
35 for (i, name) in self.re.capture_names().enumerate() {
36 if i == 0 {
37 continue;
38 }
39 if let Some(m) = caps.get(i) {
40 captures.push(CaptureGroup {
41 index: i,
42 name: name.map(String::from),
43 start: m.start(),
44 end: m.end(),
45 text: m.as_str().to_string(),
46 });
47 }
48 }
49
50 matches.push(Match {
51 start: overall.start(),
52 end: overall.end(),
53 text: overall.as_str().to_string(),
54 captures,
55 });
56 }
57
58 Ok(matches)
59 }
60}
61
62#[cfg(test)]
63mod tests {
64 use super::*;
65
66 #[test]
67 fn test_simple_match() {
68 let engine = RustRegexEngine;
69 let flags = EngineFlags::default();
70 let compiled = engine.compile(r"\d+", &flags).unwrap();
71 let matches = compiled.find_matches("abc 123 def 456").unwrap();
72 assert_eq!(matches.len(), 2);
73 assert_eq!(matches[0].text, "123");
74 assert_eq!(matches[1].text, "456");
75 }
76
77 #[test]
78 fn test_capture_groups() {
79 let engine = RustRegexEngine;
80 let flags = EngineFlags::default();
81 let compiled = engine.compile(r"(\w+)@(\w+)\.(\w+)", &flags).unwrap();
82 let matches = compiled.find_matches("user@example.com").unwrap();
83 assert_eq!(matches.len(), 1);
84 assert_eq!(matches[0].captures.len(), 3);
85 assert_eq!(matches[0].captures[0].text, "user");
86 assert_eq!(matches[0].captures[1].text, "example");
87 assert_eq!(matches[0].captures[2].text, "com");
88 }
89
90 #[test]
91 fn test_named_captures() {
92 let engine = RustRegexEngine;
93 let flags = EngineFlags::default();
94 let compiled = engine
95 .compile(r"(?P<name>\w+)@(?P<domain>\w+)", &flags)
96 .unwrap();
97 let matches = compiled.find_matches("user@example").unwrap();
98 assert_eq!(matches[0].captures[0].name, Some("name".to_string()));
99 assert_eq!(matches[0].captures[1].name, Some("domain".to_string()));
100 }
101
102 #[test]
103 fn test_case_insensitive() {
104 let engine = RustRegexEngine;
105 let flags = EngineFlags {
106 case_insensitive: true,
107 ..Default::default()
108 };
109 let compiled = engine.compile(r"hello", &flags).unwrap();
110 let matches = compiled.find_matches("Hello HELLO hello").unwrap();
111 assert_eq!(matches.len(), 3);
112 }
113
114 #[test]
115 fn test_invalid_pattern() {
116 let engine = RustRegexEngine;
117 let flags = EngineFlags::default();
118 let result = engine.compile(r"(unclosed", &flags);
119 assert!(result.is_err());
120 }
121}