Skip to main content

pipa/regexp/
pattern_matcher.rs

1use super::ast::Ast;
2use super::engine::Match;
3use crate::util::memchr::memchr;
4
5#[derive(Debug, Clone, PartialEq)]
6pub enum PatternFingerprint {
7    Unknown,
8
9    Literal(String),
10
11    OneOrMoreCharClass(CharClassType),
12
13    StartAnchored(Box<PatternFingerprint>),
14
15    EndAnchored(Box<PatternFingerprint>),
16
17    EmailLike {
18        local_part: CharClassType,
19        domain_part: CharClassType,
20        tld_part: CharClassType,
21    },
22
23    UrlLike,
24
25    DateLike,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq)]
29pub enum CharClassType {
30    Digits,
31    Words,
32    Spaces,
33    Lowercase,
34    Uppercase,
35    Alpha,
36    Alnum,
37    Any,
38}
39
40pub fn extract_fingerprint(ast: &Ast) -> PatternFingerprint {
41    match ast {
42        Ast::Empty => PatternFingerprint::Literal(String::new()),
43        Ast::Char(c) => PatternFingerprint::Literal(c.to_string()),
44        Ast::Concat(nodes) => extract_concat_fingerprint(nodes),
45        Ast::Quant(inner, q) if q.min >= 1 => {
46            if let Some(class_type) = extract_char_class_type(inner) {
47                PatternFingerprint::OneOrMoreCharClass(class_type)
48            } else {
49                PatternFingerprint::Unknown
50            }
51        }
52        Ast::StartOfLine => PatternFingerprint::Unknown,
53        Ast::EndOfLine => PatternFingerprint::Unknown,
54        _ => PatternFingerprint::Unknown,
55    }
56}
57
58fn extract_concat_fingerprint(nodes: &[Ast]) -> PatternFingerprint {
59    if nodes.is_empty() {
60        return PatternFingerprint::Literal(String::new());
61    }
62
63    if matches!(nodes.first(), Some(Ast::StartOfLine)) && nodes.len() > 1 {
64        let inner_fp = extract_fingerprint(&Ast::Concat(nodes[1..].to_vec()));
65        return PatternFingerprint::StartAnchored(Box::new(inner_fp));
66    }
67
68    if matches!(nodes.last(), Some(Ast::EndOfLine)) && nodes.len() > 1 {
69        let inner_fp = extract_fingerprint(&Ast::Concat(nodes[..nodes.len() - 1].to_vec()));
70        return PatternFingerprint::EndAnchored(Box::new(inner_fp));
71    }
72
73    if let Some(email_fp) = try_extract_email_fingerprint(nodes) {
74        return email_fp;
75    }
76
77    let mut literal = String::new();
78    for node in nodes {
79        match node {
80            Ast::Char(c) => literal.push(*c),
81            _ => return PatternFingerprint::Unknown,
82        }
83    }
84    PatternFingerprint::Literal(literal)
85}
86
87fn try_extract_email_fingerprint(nodes: &[Ast]) -> Option<PatternFingerprint> {
88    if nodes.len() < 5 {
89        return None;
90    }
91
92    for at_pos in 0..nodes.len() {
93        for dot_pos in (at_pos + 1)..nodes.len() {
94            let has_at = matches!(&nodes[at_pos], Ast::Char('@'));
95
96            let has_dot = matches!(&nodes[dot_pos], Ast::Char('.'));
97
98            if !has_at || !has_dot {
99                continue;
100            }
101
102            let local_part = if at_pos == 1 {
103                match &nodes[0] {
104                    Ast::Quant(inner, q) if q.min >= 1 => extract_char_class_type(inner),
105                    _ => None,
106                }
107            } else {
108                None
109            }?;
110
111            let domain_part = if dot_pos == at_pos + 2 {
112                match &nodes[at_pos + 1] {
113                    Ast::Quant(inner, q) if q.min >= 1 => extract_char_class_type(inner),
114                    _ => None,
115                }
116            } else {
117                None
118            }?;
119
120            let tld_part = if dot_pos + 1 < nodes.len() {
121                match &nodes[dot_pos + 1] {
122                    Ast::Quant(inner, q) if q.min >= 1 => extract_char_class_type(inner),
123                    _ => None,
124                }
125            } else {
126                None
127            }?;
128
129            return Some(PatternFingerprint::EmailLike {
130                local_part,
131                domain_part,
132                tld_part,
133            });
134        }
135    }
136
137    None
138}
139
140fn extract_char_class_type(ast: &Ast) -> Option<CharClassType> {
141    match ast {
142        Ast::Class(class) => {
143            if class.negated {
144                return None;
145            }
146
147            let ranges = class.ranges.as_slice();
148            match ranges {
149                [(48, 58)] => Some(CharClassType::Digits),
150                [(97, 123)] => Some(CharClassType::Lowercase),
151                [(65, 91)] => Some(CharClassType::Uppercase),
152                [(97, 123), (65, 91)] => Some(CharClassType::Alpha),
153                [(97, 123), (65, 91), (48, 58)] => Some(CharClassType::Alnum),
154
155                [(48, 58), (65, 91), (95, 96), (97, 123)] => Some(CharClassType::Words),
156
157                [(9, 14), (32, 33)] => Some(CharClassType::Spaces),
158
159                [(48, 58), (97, 123)] => Some(CharClassType::Alnum),
160
161                [(48, 58), (65, 91), (97, 123)] => Some(CharClassType::Alnum),
162                _ => None,
163            }
164        }
165        _ => None,
166    }
167}
168
169pub struct FastPatternMatcher {
170    fingerprint: PatternFingerprint,
171}
172
173impl FastPatternMatcher {
174    pub fn new(fingerprint: PatternFingerprint) -> Option<Self> {
175        match &fingerprint {
176            PatternFingerprint::Unknown => None,
177            _ => Some(Self { fingerprint }),
178        }
179    }
180
181    pub fn find(&self, input: &str) -> Option<Match> {
182        match &self.fingerprint {
183            PatternFingerprint::EmailLike {
184                local_part,
185                domain_part,
186                tld_part,
187            } => self.find_email_like(input, *local_part, *domain_part, *tld_part),
188            PatternFingerprint::StartAnchored(_inner) => None,
189            PatternFingerprint::EndAnchored(_inner) => None,
190            _ => None,
191        }
192    }
193
194    fn find_email_like(
195        &self,
196        input: &str,
197        _local_class: CharClassType,
198        _domain_class: CharClassType,
199        _tld_class: CharClassType,
200    ) -> Option<Match> {
201        let bytes = input.as_bytes();
202        let len = bytes.len();
203
204        let mut search_start = 0;
205        while search_start < len {
206            match memchr(b'@', &bytes[search_start..]) {
207                Some(rel_at) => {
208                    let at_abs = search_start + rel_at;
209
210                    if at_abs < 1 || at_abs + 3 >= len {
211                        search_start = at_abs + 1;
212                        continue;
213                    }
214
215                    if !is_valid_for_class(bytes[at_abs - 1], _local_class) {
216                        search_start = at_abs + 1;
217                        continue;
218                    }
219
220                    let mut found = false;
221                    let mut dot_abs = at_abs + 2;
222                    while dot_abs < len {
223                        if bytes[dot_abs] == b'.' {
224                            let mut valid = true;
225                            for i in at_abs + 1..dot_abs {
226                                if !is_valid_for_class(bytes[i], _domain_class) {
227                                    valid = false;
228                                    break;
229                                }
230                            }
231                            if valid {
232                                found = true;
233                                break;
234                            }
235                        }
236                        dot_abs += 1;
237                    }
238
239                    if !found {
240                        search_start = at_abs + 1;
241                        continue;
242                    }
243
244                    if dot_abs + 1 >= len || !is_valid_for_class(bytes[dot_abs + 1], _tld_class) {
245                        search_start = at_abs + 1;
246                        continue;
247                    }
248
249                    let mut start = at_abs;
250                    while start > 0 && is_valid_for_class(bytes[start - 1], _local_class) {
251                        start -= 1;
252                    }
253
254                    let mut end = dot_abs + 2;
255                    while end < len && is_valid_for_class(bytes[end], _tld_class) {
256                        end += 1;
257                    }
258
259                    return Some(Match {
260                        start,
261                        end,
262                        captures: vec![(Some(start), Some(end))],
263                    });
264                }
265                None => break,
266            }
267        }
268
269        None
270    }
271}
272
273fn is_valid_for_class(byte: u8, class: CharClassType) -> bool {
274    match class {
275        CharClassType::Digits => byte.is_ascii_digit(),
276        CharClassType::Words => byte.is_ascii_alphanumeric() || byte == b'_',
277        CharClassType::Spaces => matches!(byte, b' ' | b'\t' | b'\n' | b'\r'),
278        CharClassType::Lowercase => byte.is_ascii_lowercase(),
279        CharClassType::Uppercase => byte.is_ascii_uppercase(),
280        CharClassType::Alpha => byte.is_ascii_alphabetic(),
281        CharClassType::Alnum => byte.is_ascii_alphanumeric(),
282        CharClassType::Any => true,
283    }
284}