Skip to main content

aft/
pattern_compile.rs

1use regex::bytes::{Regex, RegexBuilder};
2
3const DEFAULT_SIZE_LIMIT_BYTES: usize = 10 * 1024 * 1024;
4
5#[derive(Clone, Debug)]
6pub enum CompiledPattern {
7    Literal(LiteralSearch),
8    Regex {
9        compiled: Regex,
10        raw_pattern: String,
11        case_insensitive: bool,
12    },
13}
14
15#[derive(Clone, Debug, PartialEq, Eq)]
16pub struct LiteralSearch {
17    pub needle: Vec<u8>,
18    pub case_insensitive_ascii: bool,
19}
20
21#[derive(Clone, Debug)]
22pub struct CompileOpts {
23    pub literal: bool,
24    pub case_insensitive: bool,
25    pub multi_line: bool,
26    pub size_limit_bytes: usize,
27}
28
29impl Default for CompileOpts {
30    fn default() -> Self {
31        Self {
32            literal: false,
33            case_insensitive: false,
34            multi_line: true,
35            size_limit_bytes: DEFAULT_SIZE_LIMIT_BYTES,
36        }
37    }
38}
39
40#[derive(Clone, Debug, PartialEq, Eq)]
41pub enum CompileResult {
42    Ok(CompiledPattern),
43    InvalidPattern { message: String, pattern: String },
44    UnsupportedSyntax { feature: String, pattern: String },
45}
46
47impl PartialEq for CompiledPattern {
48    fn eq(&self, other: &Self) -> bool {
49        match (self, other) {
50            (CompiledPattern::Literal(left), CompiledPattern::Literal(right)) => left == right,
51            (
52                CompiledPattern::Regex {
53                    raw_pattern: left_pattern,
54                    case_insensitive: left_case,
55                    ..
56                },
57                CompiledPattern::Regex {
58                    raw_pattern: right_pattern,
59                    case_insensitive: right_case,
60                    ..
61                },
62            ) => left_pattern == right_pattern && left_case == right_case,
63            _ => false,
64        }
65    }
66}
67
68impl Eq for CompiledPattern {}
69
70impl CompiledPattern {
71    pub fn is_literal(&self) -> bool {
72        matches!(self, CompiledPattern::Literal(_))
73    }
74
75    pub fn case_insensitive(&self) -> bool {
76        match self {
77            CompiledPattern::Literal(literal) => literal.case_insensitive_ascii,
78            CompiledPattern::Regex {
79                case_insensitive, ..
80            } => *case_insensitive,
81        }
82    }
83
84    pub fn raw_pattern_for_trigrams(&self) -> String {
85        match self {
86            CompiledPattern::Literal(literal) => {
87                String::from_utf8_lossy(&literal.needle).into_owned()
88            }
89            CompiledPattern::Regex { raw_pattern, .. } => raw_pattern.clone(),
90        }
91    }
92
93    pub fn ripgrep_pattern(&self) -> String {
94        match self {
95            CompiledPattern::Literal(literal) => {
96                String::from_utf8_lossy(&literal.needle).into_owned()
97            }
98            CompiledPattern::Regex { raw_pattern, .. } => raw_pattern.clone(),
99        }
100    }
101}
102
103pub fn compile(pattern: &str, opts: CompileOpts) -> CompileResult {
104    if pattern.len() > opts.size_limit_bytes {
105        return CompileResult::InvalidPattern {
106            message: format!(
107                "invalid regex: pattern exceeds size limit of {} bytes",
108                opts.size_limit_bytes
109            ),
110            pattern: pattern.to_string(),
111        };
112    }
113
114    if !opts.literal {
115        if let Some(feature) = detect_unsupported_features(pattern) {
116            return CompileResult::UnsupportedSyntax {
117                feature,
118                pattern: pattern.to_string(),
119            };
120        }
121    }
122
123    let has_regex_meta = has_regex_metachar(pattern);
124    let ascii_safe_literal = opts.case_insensitive && pattern.is_ascii();
125    if opts.literal || (!has_regex_meta && (!opts.case_insensitive || ascii_safe_literal)) {
126        if !opts.case_insensitive || pattern.is_ascii() {
127            let needle = if opts.case_insensitive {
128                pattern
129                    .as_bytes()
130                    .iter()
131                    .map(|byte| byte.to_ascii_lowercase())
132                    .collect()
133            } else {
134                pattern.as_bytes().to_vec()
135            };
136            return CompileResult::Ok(CompiledPattern::Literal(LiteralSearch {
137                needle,
138                case_insensitive_ascii: opts.case_insensitive,
139            }));
140        }
141    }
142
143    let mut regex_pattern = if opts.literal || !has_regex_meta {
144        regex::escape(pattern)
145    } else {
146        pattern.to_string()
147    };
148    let mut builder_case_insensitive = opts.case_insensitive;
149    if opts.case_insensitive && !pattern.is_ascii() {
150        regex_pattern = format!("(?i){regex_pattern}");
151        builder_case_insensitive = false;
152    }
153
154    let mut builder = RegexBuilder::new(&regex_pattern);
155    builder.case_insensitive(builder_case_insensitive);
156    builder.multi_line(opts.multi_line);
157    builder.size_limit(opts.size_limit_bytes);
158
159    match builder.build() {
160        Ok(compiled) => CompileResult::Ok(CompiledPattern::Regex {
161            compiled,
162            raw_pattern: regex_pattern,
163            case_insensitive: opts.case_insensitive,
164        }),
165        Err(error) => CompileResult::InvalidPattern {
166            message: format!("invalid regex: {error}"),
167            pattern: pattern.to_string(),
168        },
169    }
170}
171
172pub fn detect_unsupported_features(pattern: &str) -> Option<String> {
173    if pattern.contains("(?=")
174        || pattern.contains("(?!")
175        || pattern.contains("(?<=")
176        || pattern.contains("(?<!")
177    {
178        return Some("lookaround".to_string());
179    }
180    if pattern.contains("(?P=") || contains_numeric_backreference(pattern) {
181        return Some("backreference".to_string());
182    }
183    if pattern.contains("*+") || pattern.contains("++") || pattern.contains("?+") {
184        return Some("possessive quantifier".to_string());
185    }
186    if pattern.contains("(?>") {
187        return Some("atomic group".to_string());
188    }
189    None
190}
191
192fn has_regex_metachar(pattern: &str) -> bool {
193    pattern.chars().any(|c| {
194        matches!(
195            c,
196            '.' | '*' | '+' | '?' | '(' | ')' | '[' | ']' | '{' | '}' | '|' | '^' | '$' | '\\'
197        )
198    })
199}
200
201fn contains_numeric_backreference(pattern: &str) -> bool {
202    let mut escaped = false;
203    for ch in pattern.chars() {
204        if escaped {
205            if ('1'..='9').contains(&ch) {
206                return true;
207            }
208            escaped = false;
209            continue;
210        }
211        escaped = ch == '\\';
212    }
213    false
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    fn assert_literal(pattern: &str, case_insensitive: bool, expected: &[u8]) {
221        let result = compile(
222            pattern,
223            CompileOpts {
224                case_insensitive,
225                ..CompileOpts::default()
226            },
227        );
228        match result {
229            CompileResult::Ok(CompiledPattern::Literal(literal)) => {
230                assert_eq!(literal.needle, expected);
231                assert_eq!(literal.case_insensitive_ascii, case_insensitive);
232            }
233            other => panic!("expected literal, got {other:?}"),
234        }
235    }
236
237    #[test]
238    fn literal_pattern_without_metachars_uses_fast_path() {
239        assert_literal("needle", false, b"needle");
240    }
241
242    #[test]
243    fn ascii_case_insensitive_literal_uses_lowercase_fast_path() {
244        assert_literal("Needle", true, b"needle");
245    }
246
247    #[test]
248    fn non_ascii_case_insensitive_literal_forces_regex_with_inline_flag() {
249        let result = compile(
250            "Äbc",
251            CompileOpts {
252                case_insensitive: true,
253                ..CompileOpts::default()
254            },
255        );
256        match result {
257            CompileResult::Ok(CompiledPattern::Regex {
258                raw_pattern,
259                case_insensitive,
260                ..
261            }) => {
262                assert!(raw_pattern.starts_with("(?i)"));
263                assert!(case_insensitive);
264            }
265            other => panic!("expected regex, got {other:?}"),
266        }
267    }
268
269    #[test]
270    fn regex_pattern_retains_raw_pattern_and_compiles_bytes_regex() {
271        let result = compile("foo.*bar", CompileOpts::default());
272        match result {
273            CompileResult::Ok(CompiledPattern::Regex {
274                compiled,
275                raw_pattern,
276                ..
277            }) => {
278                assert_eq!(raw_pattern, "foo.*bar");
279                assert!(compiled.is_match(b"foo middle bar"));
280            }
281            other => panic!("expected regex, got {other:?}"),
282        }
283    }
284
285    #[test]
286    fn invalid_pattern_surfaces_compile_error() {
287        let result = compile("[", CompileOpts::default());
288        assert!(matches!(result, CompileResult::InvalidPattern { .. }));
289    }
290
291    #[test]
292    fn pattern_exceeding_size_limit_is_invalid() {
293        let result = compile(
294            "abcd",
295            CompileOpts {
296                size_limit_bytes: 3,
297                ..CompileOpts::default()
298            },
299        );
300        assert!(matches!(result, CompileResult::InvalidPattern { .. }));
301    }
302
303    #[test]
304    fn unsupported_syntax_is_detected_before_compile() {
305        for pattern in [
306            "(?=foo)",
307            "(?!foo)",
308            "(?<=foo)",
309            "(?<!foo)",
310            "(?P=name)",
311            r"\1",
312            "foo*+",
313            "(?>foo)",
314        ] {
315            assert!(
316                matches!(
317                    compile(pattern, CompileOpts::default()),
318                    CompileResult::UnsupportedSyntax { .. }
319                ),
320                "{pattern}"
321            );
322        }
323    }
324
325    #[test]
326    fn forced_literal_honors_regex_characters() {
327        let result = compile(
328            "foo.*bar",
329            CompileOpts {
330                literal: true,
331                ..CompileOpts::default()
332            },
333        );
334        match result {
335            CompileResult::Ok(CompiledPattern::Literal(literal)) => {
336                assert_eq!(literal.needle, b"foo.*bar");
337            }
338            other => panic!("expected literal, got {other:?}"),
339        }
340    }
341}