1use regex::bytes::{Regex, RegexBuilder};
2
3const DEFAULT_SIZE_LIMIT_BYTES: usize = 10 * 1024 * 1024;
4
5#[derive(Clone, Debug)]
6pub enum CompiledPattern {
7 Literal(LiteralSearch),
8 Regex {
9 compiled: Regex,
10 raw_pattern: String,
11 case_insensitive: bool,
12 },
13}
14
15#[derive(Clone, Debug, PartialEq, Eq)]
16pub struct LiteralSearch {
17 pub needle: Vec<u8>,
18 pub case_insensitive_ascii: bool,
19}
20
21#[derive(Clone, Debug)]
22pub struct CompileOpts {
23 pub literal: bool,
24 pub case_insensitive: bool,
25 pub multi_line: bool,
26 pub size_limit_bytes: usize,
27}
28
29impl Default for CompileOpts {
30 fn default() -> Self {
31 Self {
32 literal: false,
33 case_insensitive: false,
34 multi_line: true,
35 size_limit_bytes: DEFAULT_SIZE_LIMIT_BYTES,
36 }
37 }
38}
39
40#[derive(Clone, Debug, PartialEq, Eq)]
41pub enum CompileResult {
42 Ok(CompiledPattern),
43 InvalidPattern { message: String, pattern: String },
44 UnsupportedSyntax { feature: String, pattern: String },
45}
46
47impl PartialEq for CompiledPattern {
48 fn eq(&self, other: &Self) -> bool {
49 match (self, other) {
50 (CompiledPattern::Literal(left), CompiledPattern::Literal(right)) => left == right,
51 (
52 CompiledPattern::Regex {
53 raw_pattern: left_pattern,
54 case_insensitive: left_case,
55 ..
56 },
57 CompiledPattern::Regex {
58 raw_pattern: right_pattern,
59 case_insensitive: right_case,
60 ..
61 },
62 ) => left_pattern == right_pattern && left_case == right_case,
63 _ => false,
64 }
65 }
66}
67
68impl Eq for CompiledPattern {}
69
70impl CompiledPattern {
71 pub fn is_literal(&self) -> bool {
72 matches!(self, CompiledPattern::Literal(_))
73 }
74
75 pub fn case_insensitive(&self) -> bool {
76 match self {
77 CompiledPattern::Literal(literal) => literal.case_insensitive_ascii,
78 CompiledPattern::Regex {
79 case_insensitive, ..
80 } => *case_insensitive,
81 }
82 }
83
84 pub fn raw_pattern_for_trigrams(&self) -> String {
85 match self {
86 CompiledPattern::Literal(literal) => {
87 String::from_utf8_lossy(&literal.needle).into_owned()
88 }
89 CompiledPattern::Regex { raw_pattern, .. } => raw_pattern.clone(),
90 }
91 }
92
93 pub fn ripgrep_pattern(&self) -> String {
94 match self {
95 CompiledPattern::Literal(literal) => {
96 String::from_utf8_lossy(&literal.needle).into_owned()
97 }
98 CompiledPattern::Regex { raw_pattern, .. } => raw_pattern.clone(),
99 }
100 }
101}
102
103pub fn compile(pattern: &str, opts: CompileOpts) -> CompileResult {
104 if pattern.len() > opts.size_limit_bytes {
105 return CompileResult::InvalidPattern {
106 message: format!(
107 "invalid regex: pattern exceeds size limit of {} bytes",
108 opts.size_limit_bytes
109 ),
110 pattern: pattern.to_string(),
111 };
112 }
113
114 if !opts.literal {
115 if let Some(feature) = detect_unsupported_features(pattern) {
116 return CompileResult::UnsupportedSyntax {
117 feature,
118 pattern: pattern.to_string(),
119 };
120 }
121 }
122
123 let has_regex_meta = has_regex_metachar(pattern);
124 let ascii_safe_literal = opts.case_insensitive && pattern.is_ascii();
125 if opts.literal || (!has_regex_meta && (!opts.case_insensitive || ascii_safe_literal)) {
126 if !opts.case_insensitive || pattern.is_ascii() {
127 let needle = if opts.case_insensitive {
128 pattern
129 .as_bytes()
130 .iter()
131 .map(|byte| byte.to_ascii_lowercase())
132 .collect()
133 } else {
134 pattern.as_bytes().to_vec()
135 };
136 return CompileResult::Ok(CompiledPattern::Literal(LiteralSearch {
137 needle,
138 case_insensitive_ascii: opts.case_insensitive,
139 }));
140 }
141 }
142
143 let mut regex_pattern = if opts.literal || !has_regex_meta {
144 regex::escape(pattern)
145 } else {
146 pattern.to_string()
147 };
148 let mut builder_case_insensitive = opts.case_insensitive;
149 if opts.case_insensitive && !pattern.is_ascii() {
150 regex_pattern = format!("(?i){regex_pattern}");
151 builder_case_insensitive = false;
152 }
153
154 let mut builder = RegexBuilder::new(®ex_pattern);
155 builder.case_insensitive(builder_case_insensitive);
156 builder.multi_line(opts.multi_line);
157 builder.size_limit(opts.size_limit_bytes);
158
159 match builder.build() {
160 Ok(compiled) => CompileResult::Ok(CompiledPattern::Regex {
161 compiled,
162 raw_pattern: regex_pattern,
163 case_insensitive: opts.case_insensitive,
164 }),
165 Err(error) => CompileResult::InvalidPattern {
166 message: format!("invalid regex: {error}"),
167 pattern: pattern.to_string(),
168 },
169 }
170}
171
172pub fn detect_unsupported_features(pattern: &str) -> Option<String> {
173 if pattern.contains("(?=")
174 || pattern.contains("(?!")
175 || pattern.contains("(?<=")
176 || pattern.contains("(?<!")
177 {
178 return Some("lookaround".to_string());
179 }
180 if pattern.contains("(?P=") || contains_numeric_backreference(pattern) {
181 return Some("backreference".to_string());
182 }
183 if pattern.contains("*+") || pattern.contains("++") || pattern.contains("?+") {
184 return Some("possessive quantifier".to_string());
185 }
186 if pattern.contains("(?>") {
187 return Some("atomic group".to_string());
188 }
189 None
190}
191
192fn has_regex_metachar(pattern: &str) -> bool {
193 pattern.chars().any(|c| {
194 matches!(
195 c,
196 '.' | '*' | '+' | '?' | '(' | ')' | '[' | ']' | '{' | '}' | '|' | '^' | '$' | '\\'
197 )
198 })
199}
200
201fn contains_numeric_backreference(pattern: &str) -> bool {
202 let mut escaped = false;
203 for ch in pattern.chars() {
204 if escaped {
205 if ('1'..='9').contains(&ch) {
206 return true;
207 }
208 escaped = false;
209 continue;
210 }
211 escaped = ch == '\\';
212 }
213 false
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 fn assert_literal(pattern: &str, case_insensitive: bool, expected: &[u8]) {
221 let result = compile(
222 pattern,
223 CompileOpts {
224 case_insensitive,
225 ..CompileOpts::default()
226 },
227 );
228 match result {
229 CompileResult::Ok(CompiledPattern::Literal(literal)) => {
230 assert_eq!(literal.needle, expected);
231 assert_eq!(literal.case_insensitive_ascii, case_insensitive);
232 }
233 other => panic!("expected literal, got {other:?}"),
234 }
235 }
236
237 #[test]
238 fn literal_pattern_without_metachars_uses_fast_path() {
239 assert_literal("needle", false, b"needle");
240 }
241
242 #[test]
243 fn ascii_case_insensitive_literal_uses_lowercase_fast_path() {
244 assert_literal("Needle", true, b"needle");
245 }
246
247 #[test]
248 fn non_ascii_case_insensitive_literal_forces_regex_with_inline_flag() {
249 let result = compile(
250 "Äbc",
251 CompileOpts {
252 case_insensitive: true,
253 ..CompileOpts::default()
254 },
255 );
256 match result {
257 CompileResult::Ok(CompiledPattern::Regex {
258 raw_pattern,
259 case_insensitive,
260 ..
261 }) => {
262 assert!(raw_pattern.starts_with("(?i)"));
263 assert!(case_insensitive);
264 }
265 other => panic!("expected regex, got {other:?}"),
266 }
267 }
268
269 #[test]
270 fn regex_pattern_retains_raw_pattern_and_compiles_bytes_regex() {
271 let result = compile("foo.*bar", CompileOpts::default());
272 match result {
273 CompileResult::Ok(CompiledPattern::Regex {
274 compiled,
275 raw_pattern,
276 ..
277 }) => {
278 assert_eq!(raw_pattern, "foo.*bar");
279 assert!(compiled.is_match(b"foo middle bar"));
280 }
281 other => panic!("expected regex, got {other:?}"),
282 }
283 }
284
285 #[test]
286 fn invalid_pattern_surfaces_compile_error() {
287 let result = compile("[", CompileOpts::default());
288 assert!(matches!(result, CompileResult::InvalidPattern { .. }));
289 }
290
291 #[test]
292 fn pattern_exceeding_size_limit_is_invalid() {
293 let result = compile(
294 "abcd",
295 CompileOpts {
296 size_limit_bytes: 3,
297 ..CompileOpts::default()
298 },
299 );
300 assert!(matches!(result, CompileResult::InvalidPattern { .. }));
301 }
302
303 #[test]
304 fn unsupported_syntax_is_detected_before_compile() {
305 for pattern in [
306 "(?=foo)",
307 "(?!foo)",
308 "(?<=foo)",
309 "(?<!foo)",
310 "(?P=name)",
311 r"\1",
312 "foo*+",
313 "(?>foo)",
314 ] {
315 assert!(
316 matches!(
317 compile(pattern, CompileOpts::default()),
318 CompileResult::UnsupportedSyntax { .. }
319 ),
320 "{pattern}"
321 );
322 }
323 }
324
325 #[test]
326 fn forced_literal_honors_regex_characters() {
327 let result = compile(
328 "foo.*bar",
329 CompileOpts {
330 literal: true,
331 ..CompileOpts::default()
332 },
333 );
334 match result {
335 CompileResult::Ok(CompiledPattern::Literal(literal)) => {
336 assert_eq!(literal.needle, b"foo.*bar");
337 }
338 other => panic!("expected literal, got {other:?}"),
339 }
340 }
341}