Skip to main content

rgx/engine/
mod.rs

1pub mod fancy;
2#[cfg(feature = "pcre2-engine")]
3pub mod pcre2;
4#[cfg(feature = "pcre2-engine")]
5pub mod pcre2_debug;
6pub mod rust_regex;
7
8use serde::Serialize;
9use std::fmt;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum EngineKind {
13    RustRegex,
14    FancyRegex,
15    #[cfg(feature = "pcre2-engine")]
16    Pcre2,
17}
18
19impl EngineKind {
20    pub fn all() -> Vec<EngineKind> {
21        vec![
22            EngineKind::RustRegex,
23            EngineKind::FancyRegex,
24            #[cfg(feature = "pcre2-engine")]
25            EngineKind::Pcre2,
26        ]
27    }
28
29    pub fn next(self) -> EngineKind {
30        match self {
31            EngineKind::RustRegex => EngineKind::FancyRegex,
32            #[cfg(feature = "pcre2-engine")]
33            EngineKind::FancyRegex => EngineKind::Pcre2,
34            #[cfg(not(feature = "pcre2-engine"))]
35            EngineKind::FancyRegex => EngineKind::RustRegex,
36            #[cfg(feature = "pcre2-engine")]
37            EngineKind::Pcre2 => EngineKind::RustRegex,
38        }
39    }
40}
41
42impl fmt::Display for EngineKind {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        match self {
45            EngineKind::RustRegex => write!(f, "Rust regex"),
46            EngineKind::FancyRegex => write!(f, "fancy-regex"),
47            #[cfg(feature = "pcre2-engine")]
48            EngineKind::Pcre2 => write!(f, "PCRE2"),
49        }
50    }
51}
52
53#[derive(Debug, Clone, Copy, Default)]
54pub struct EngineFlags {
55    pub case_insensitive: bool,
56    pub multi_line: bool,
57    pub dot_matches_newline: bool,
58    pub unicode: bool,
59    pub extended: bool,
60}
61
62impl EngineFlags {
63    pub fn to_inline_prefix(&self) -> String {
64        let mut s = String::new();
65        if self.case_insensitive {
66            s.push('i');
67        }
68        if self.multi_line {
69            s.push('m');
70        }
71        if self.dot_matches_newline {
72            s.push('s');
73        }
74        if self.unicode {
75            s.push('u');
76        }
77        if self.extended {
78            s.push('x');
79        }
80        s
81    }
82
83    pub fn wrap_pattern(&self, pattern: &str) -> String {
84        let prefix = self.to_inline_prefix();
85        if prefix.is_empty() {
86            pattern.to_string()
87        } else {
88            format!("(?{prefix}){pattern}")
89        }
90    }
91
92    pub fn toggle_case_insensitive(&mut self) {
93        self.case_insensitive = !self.case_insensitive;
94    }
95    pub fn toggle_multi_line(&mut self) {
96        self.multi_line = !self.multi_line;
97    }
98    pub fn toggle_dot_matches_newline(&mut self) {
99        self.dot_matches_newline = !self.dot_matches_newline;
100    }
101    pub fn toggle_unicode(&mut self) {
102        self.unicode = !self.unicode;
103    }
104    pub fn toggle_extended(&mut self) {
105        self.extended = !self.extended;
106    }
107}
108
109#[derive(Debug, Clone, Serialize)]
110pub struct Match {
111    #[serde(rename = "match")]
112    pub text: String,
113    pub start: usize,
114    pub end: usize,
115    #[serde(rename = "groups")]
116    pub captures: Vec<CaptureGroup>,
117}
118
119#[derive(Debug, Clone, Serialize)]
120pub struct CaptureGroup {
121    #[serde(rename = "group")]
122    pub index: usize,
123    #[serde(skip_serializing_if = "Option::is_none")]
124    pub name: Option<String>,
125    #[serde(rename = "value")]
126    pub text: String,
127    pub start: usize,
128    pub end: usize,
129}
130
131#[derive(Debug)]
132pub enum EngineError {
133    CompileError(String),
134    MatchError(String),
135}
136
137impl fmt::Display for EngineError {
138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139        match self {
140            EngineError::CompileError(msg) => write!(f, "Compile error: {msg}"),
141            EngineError::MatchError(msg) => write!(f, "Match error: {msg}"),
142        }
143    }
144}
145
146impl std::error::Error for EngineError {}
147
148pub type EngineResult<T> = Result<T, EngineError>;
149
150pub trait RegexEngine: Send + Sync {
151    fn kind(&self) -> EngineKind;
152    fn compile(&self, pattern: &str, flags: &EngineFlags) -> EngineResult<Box<dyn CompiledRegex>>;
153}
154
155pub trait CompiledRegex: Send + Sync {
156    fn find_matches(&self, text: &str) -> EngineResult<Vec<Match>>;
157}
158
159pub fn create_engine(kind: EngineKind) -> Box<dyn RegexEngine> {
160    match kind {
161        EngineKind::RustRegex => Box::new(rust_regex::RustRegexEngine),
162        EngineKind::FancyRegex => Box::new(fancy::FancyRegexEngine),
163        #[cfg(feature = "pcre2-engine")]
164        EngineKind::Pcre2 => Box::new(pcre2::Pcre2Engine),
165    }
166}
167
168/// Return the "power level" of an engine (higher = more capable).
169fn engine_level(kind: EngineKind) -> u8 {
170    match kind {
171        EngineKind::RustRegex => 0,
172        EngineKind::FancyRegex => 1,
173        #[cfg(feature = "pcre2-engine")]
174        EngineKind::Pcre2 => 2,
175    }
176}
177
178/// Detect the minimum engine needed for the given pattern.
179pub fn detect_minimum_engine(pattern: &str) -> EngineKind {
180    #[cfg(feature = "pcre2-engine")]
181    {
182        if needs_pcre2(pattern) {
183            return EngineKind::Pcre2;
184        }
185    }
186
187    if needs_fancy(pattern) {
188        return EngineKind::FancyRegex;
189    }
190
191    EngineKind::RustRegex
192}
193
194/// Return `true` if `suggested` is a strict upgrade over `current`.
195pub fn is_engine_upgrade(current: EngineKind, suggested: EngineKind) -> bool {
196    engine_level(suggested) > engine_level(current)
197}
198
199fn needs_fancy(pattern: &str) -> bool {
200    if pattern.contains("(?=")
201        || pattern.contains("(?!")
202        || pattern.contains("(?<=")
203        || pattern.contains("(?<!")
204    {
205        return true;
206    }
207    has_backreference(pattern)
208}
209
210fn has_backreference(pattern: &str) -> bool {
211    let bytes = pattern.as_bytes();
212    let len = bytes.len();
213    let mut i = 0;
214    while i < len.saturating_sub(1) {
215        if bytes[i] == b'\\' {
216            let next = bytes[i + 1];
217            if next.is_ascii_digit() && next != b'0' {
218                return true;
219            }
220            // Skip the escaped character so we don't re-inspect it
221            i += 2;
222            continue;
223        }
224        i += 1;
225    }
226    false
227}
228
229#[cfg(feature = "pcre2-engine")]
230fn needs_pcre2(pattern: &str) -> bool {
231    if pattern.contains("(?R)")
232        || pattern.contains("(*SKIP)")
233        || pattern.contains("(*FAIL)")
234        || pattern.contains("(*PRUNE)")
235        || pattern.contains("(*COMMIT)")
236        || pattern.contains("\\K")
237        || pattern.contains("(?(")
238    {
239        return true;
240    }
241    has_subroutine_call(pattern)
242}
243
244#[cfg(feature = "pcre2-engine")]
245fn has_subroutine_call(pattern: &str) -> bool {
246    let bytes = pattern.as_bytes();
247    for i in 0..bytes.len().saturating_sub(2) {
248        if bytes[i] == b'('
249            && bytes[i + 1] == b'?'
250            && bytes.get(i + 2).is_some_and(|b| b.is_ascii_digit())
251        {
252            return true;
253        }
254    }
255    false
256}
257
258// --- Replace/Substitution support ---
259
260#[derive(Debug, Clone)]
261pub struct ReplaceSegment {
262    pub start: usize,
263    pub end: usize,
264    pub is_replacement: bool,
265}
266
267#[derive(Debug, Clone)]
268pub struct ReplaceResult {
269    pub output: String,
270    pub segments: Vec<ReplaceSegment>,
271}
272
273/// Expand a replacement template against a single match.
274///
275/// Supports: `$0` / `$&` (whole match), `$1`..`$99` (numbered groups),
276/// `${name}` (named groups), `$$` (literal `$`).
277fn expand_replacement(template: &str, m: &Match) -> String {
278    let mut result = String::new();
279    let mut chars = template.char_indices().peekable();
280
281    while let Some((_i, c)) = chars.next() {
282        if c == '$' {
283            match chars.peek() {
284                None => {
285                    result.push('$');
286                }
287                Some(&(_, '$')) => {
288                    chars.next();
289                    result.push('$');
290                }
291                Some(&(_, '&')) => {
292                    chars.next();
293                    result.push_str(&m.text);
294                }
295                Some(&(_, '{')) => {
296                    chars.next(); // consume '{'
297                    let brace_start = chars.peek().map(|&(idx, _)| idx).unwrap_or(template.len());
298                    if let Some(close) = template[brace_start..].find('}') {
299                        let ref_name = &template[brace_start..brace_start + close];
300                        if let Some(text) = lookup_capture(m, ref_name) {
301                            result.push_str(text);
302                        }
303                        // Advance past the content and closing brace
304                        let end_byte = brace_start + close + 1;
305                        while chars.peek().is_some_and(|&(idx, _)| idx < end_byte) {
306                            chars.next();
307                        }
308                    } else {
309                        result.push('$');
310                        result.push('{');
311                    }
312                }
313                Some(&(_, next_c)) if next_c.is_ascii_digit() => {
314                    let (_, d1) = chars.next().unwrap();
315                    let mut num_str = String::from(d1);
316                    // Grab a second digit if present
317                    if let Some(&(_, d2)) = chars.peek() {
318                        if d2.is_ascii_digit() {
319                            chars.next();
320                            num_str.push(d2);
321                        }
322                    }
323                    let idx: usize = num_str.parse().unwrap_or(0);
324                    if idx == 0 {
325                        result.push_str(&m.text);
326                    } else if let Some(cap) = m.captures.iter().find(|c| c.index == idx) {
327                        result.push_str(&cap.text);
328                    }
329                }
330                Some(_) => {
331                    result.push('$');
332                }
333            }
334        } else {
335            result.push(c);
336        }
337    }
338
339    result
340}
341
342/// Look up a capture by name or numeric string.
343pub fn lookup_capture<'a>(m: &'a Match, key: &str) -> Option<&'a str> {
344    // Try as number first
345    if let Ok(idx) = key.parse::<usize>() {
346        if idx == 0 {
347            return Some(&m.text);
348        }
349        return m
350            .captures
351            .iter()
352            .find(|c| c.index == idx)
353            .map(|c| c.text.as_str());
354    }
355    // Try as named capture
356    m.captures
357        .iter()
358        .find(|c| c.name.as_deref() == Some(key))
359        .map(|c| c.text.as_str())
360}
361
362/// Perform replacement across all matches, returning the output string and segment metadata.
363pub fn replace_all(text: &str, matches: &[Match], template: &str) -> ReplaceResult {
364    let mut output = String::new();
365    let mut segments = Vec::new();
366    let mut pos = 0;
367
368    for m in matches {
369        // Original text before this match
370        if m.start > pos {
371            let seg_start = output.len();
372            output.push_str(&text[pos..m.start]);
373            segments.push(ReplaceSegment {
374                start: seg_start,
375                end: output.len(),
376                is_replacement: false,
377            });
378        }
379        // Expanded replacement
380        let expanded = expand_replacement(template, m);
381        if !expanded.is_empty() {
382            let seg_start = output.len();
383            output.push_str(&expanded);
384            segments.push(ReplaceSegment {
385                start: seg_start,
386                end: output.len(),
387                is_replacement: true,
388            });
389        }
390        pos = m.end;
391    }
392
393    // Trailing original text
394    if pos < text.len() {
395        let seg_start = output.len();
396        output.push_str(&text[pos..]);
397        segments.push(ReplaceSegment {
398            start: seg_start,
399            end: output.len(),
400            is_replacement: false,
401        });
402    }
403
404    ReplaceResult { output, segments }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410
411    fn make_match(start: usize, end: usize, text: &str, captures: Vec<CaptureGroup>) -> Match {
412        Match {
413            start,
414            end,
415            text: text.to_string(),
416            captures,
417        }
418    }
419
420    fn make_cap(
421        index: usize,
422        name: Option<&str>,
423        start: usize,
424        end: usize,
425        text: &str,
426    ) -> CaptureGroup {
427        CaptureGroup {
428            index,
429            name: name.map(|s| s.to_string()),
430            start,
431            end,
432            text: text.to_string(),
433        }
434    }
435
436    #[test]
437    fn test_replace_all_basic() {
438        let matches = vec![make_match(
439            0,
440            12,
441            "user@example",
442            vec![
443                make_cap(1, None, 0, 4, "user"),
444                make_cap(2, None, 5, 12, "example"),
445            ],
446        )];
447        let result = replace_all("user@example", &matches, "$2=$1");
448        assert_eq!(result.output, "example=user");
449    }
450
451    #[test]
452    fn test_replace_all_no_matches() {
453        let result = replace_all("hello world", &[], "replacement");
454        assert_eq!(result.output, "hello world");
455        assert_eq!(result.segments.len(), 1);
456        assert!(!result.segments[0].is_replacement);
457    }
458
459    #[test]
460    fn test_replace_all_empty_template() {
461        let matches = vec![
462            make_match(4, 7, "123", vec![]),
463            make_match(12, 15, "456", vec![]),
464        ];
465        let result = replace_all("abc 123 def 456 ghi", &matches, "");
466        assert_eq!(result.output, "abc  def  ghi");
467    }
468
469    #[test]
470    fn test_replace_all_literal_dollar() {
471        let matches = vec![make_match(0, 3, "foo", vec![])];
472        let result = replace_all("foo", &matches, "$$bar");
473        assert_eq!(result.output, "$bar");
474    }
475
476    #[test]
477    fn test_replace_all_named_groups() {
478        let matches = vec![make_match(
479            0,
480            7,
481            "2024-01",
482            vec![
483                make_cap(1, Some("y"), 0, 4, "2024"),
484                make_cap(2, Some("m"), 5, 7, "01"),
485            ],
486        )];
487        let result = replace_all("2024-01", &matches, "${m}/${y}");
488        assert_eq!(result.output, "01/2024");
489    }
490
491    #[test]
492    fn test_expand_replacement_whole_match() {
493        let m = make_match(0, 5, "hello", vec![]);
494        assert_eq!(expand_replacement("$0", &m), "hello");
495        assert_eq!(expand_replacement("$&", &m), "hello");
496        assert_eq!(expand_replacement("[$0]", &m), "[hello]");
497    }
498
499    #[test]
500    fn test_expand_replacement_non_ascii() {
501        let m = make_match(0, 5, "hello", vec![]);
502        // Non-ASCII characters in replacement template should work correctly
503        assert_eq!(expand_replacement("café $0", &m), "café hello");
504        assert_eq!(expand_replacement("→$0←", &m), "→hello←");
505        assert_eq!(expand_replacement("日本語", &m), "日本語");
506        assert_eq!(expand_replacement("über $& cool", &m), "über hello cool");
507    }
508
509    #[test]
510    fn test_replace_segments_tracking() {
511        let matches = vec![make_match(6, 9, "123", vec![])];
512        let result = replace_all("hello 123 world", &matches, "NUM");
513        assert_eq!(result.output, "hello NUM world");
514        assert_eq!(result.segments.len(), 3);
515        // "hello " - original
516        assert!(!result.segments[0].is_replacement);
517        assert_eq!(
518            &result.output[result.segments[0].start..result.segments[0].end],
519            "hello "
520        );
521        // "NUM" - replacement
522        assert!(result.segments[1].is_replacement);
523        assert_eq!(
524            &result.output[result.segments[1].start..result.segments[1].end],
525            "NUM"
526        );
527        // " world" - original
528        assert!(!result.segments[2].is_replacement);
529        assert_eq!(
530            &result.output[result.segments[2].start..result.segments[2].end],
531            " world"
532        );
533    }
534
535    // --- Auto engine detection tests ---
536
537    #[test]
538    fn test_detect_simple_pattern_uses_rust_regex() {
539        assert_eq!(detect_minimum_engine(r"\d+"), EngineKind::RustRegex);
540        assert_eq!(detect_minimum_engine(r"[a-z]+"), EngineKind::RustRegex);
541        assert_eq!(detect_minimum_engine(r"foo|bar"), EngineKind::RustRegex);
542        assert_eq!(detect_minimum_engine(r"^\w+$"), EngineKind::RustRegex);
543    }
544
545    #[test]
546    fn test_detect_lookahead_needs_fancy() {
547        assert_eq!(detect_minimum_engine(r"foo(?=bar)"), EngineKind::FancyRegex);
548        assert_eq!(detect_minimum_engine(r"foo(?!bar)"), EngineKind::FancyRegex);
549    }
550
551    #[test]
552    fn test_detect_lookbehind_needs_fancy() {
553        assert_eq!(
554            detect_minimum_engine(r"(?<=foo)bar"),
555            EngineKind::FancyRegex,
556        );
557        assert_eq!(
558            detect_minimum_engine(r"(?<!foo)bar"),
559            EngineKind::FancyRegex,
560        );
561    }
562
563    #[test]
564    fn test_detect_backreference_needs_fancy() {
565        assert_eq!(detect_minimum_engine(r"(\w+)\s+\1"), EngineKind::FancyRegex,);
566        assert_eq!(detect_minimum_engine(r"(a)(b)\2"), EngineKind::FancyRegex);
567    }
568
569    #[test]
570    fn test_detect_non_backreference_escapes_stay_rust() {
571        // These look like \digit but are actually common escapes
572        assert_eq!(detect_minimum_engine(r"\d"), EngineKind::RustRegex);
573        assert_eq!(detect_minimum_engine(r"\w\s\b"), EngineKind::RustRegex);
574        assert_eq!(detect_minimum_engine(r"\0"), EngineKind::RustRegex);
575        assert_eq!(detect_minimum_engine(r"\n\r\t"), EngineKind::RustRegex);
576        assert_eq!(detect_minimum_engine(r"\x41"), EngineKind::RustRegex);
577        assert_eq!(detect_minimum_engine(r"\u0041"), EngineKind::RustRegex);
578        assert_eq!(detect_minimum_engine(r"\p{L}"), EngineKind::RustRegex);
579        assert_eq!(detect_minimum_engine(r"\P{L}"), EngineKind::RustRegex);
580        assert_eq!(detect_minimum_engine(r"\B"), EngineKind::RustRegex);
581    }
582
583    #[test]
584    fn test_has_backreference() {
585        assert!(has_backreference(r"(\w+)\1"));
586        assert!(has_backreference(r"\1"));
587        assert!(has_backreference(r"(a)(b)(c)\3"));
588        assert!(!has_backreference(r"\d+"));
589        assert!(!has_backreference(r"\0"));
590        assert!(!has_backreference(r"plain text"));
591        assert!(!has_backreference(r"\w\s\b\B\n\r\t"));
592    }
593
594    #[test]
595    fn test_detect_empty_pattern() {
596        assert_eq!(detect_minimum_engine(""), EngineKind::RustRegex);
597    }
598
599    #[test]
600    fn test_is_engine_upgrade() {
601        assert!(is_engine_upgrade(
602            EngineKind::RustRegex,
603            EngineKind::FancyRegex
604        ));
605        assert!(!is_engine_upgrade(
606            EngineKind::FancyRegex,
607            EngineKind::RustRegex
608        ));
609        assert!(!is_engine_upgrade(
610            EngineKind::FancyRegex,
611            EngineKind::FancyRegex,
612        ));
613    }
614
615    #[cfg(feature = "pcre2-engine")]
616    mod pcre2_detection_tests {
617        use super::*;
618
619        #[test]
620        fn test_detect_recursion_needs_pcre2() {
621            assert_eq!(detect_minimum_engine(r"(?R)"), EngineKind::Pcre2);
622        }
623
624        #[test]
625        fn test_detect_backtracking_verbs_need_pcre2() {
626            assert_eq!(detect_minimum_engine(r"(*SKIP)(*FAIL)"), EngineKind::Pcre2);
627            assert_eq!(detect_minimum_engine(r"(*PRUNE)"), EngineKind::Pcre2);
628            assert_eq!(detect_minimum_engine(r"(*COMMIT)"), EngineKind::Pcre2);
629        }
630
631        #[test]
632        fn test_detect_reset_match_start_needs_pcre2() {
633            assert_eq!(detect_minimum_engine(r"foo\Kbar"), EngineKind::Pcre2);
634        }
635
636        #[test]
637        fn test_detect_conditional_needs_pcre2() {
638            assert_eq!(detect_minimum_engine(r"(?(1)yes|no)"), EngineKind::Pcre2,);
639        }
640
641        #[test]
642        fn test_detect_subroutine_call_needs_pcre2() {
643            assert_eq!(detect_minimum_engine(r"(\d+)(?1)"), EngineKind::Pcre2);
644        }
645
646        #[test]
647        fn test_is_engine_upgrade_pcre2() {
648            assert!(is_engine_upgrade(EngineKind::RustRegex, EngineKind::Pcre2));
649            assert!(is_engine_upgrade(EngineKind::FancyRegex, EngineKind::Pcre2));
650            assert!(!is_engine_upgrade(
651                EngineKind::Pcre2,
652                EngineKind::FancyRegex
653            ));
654            assert!(!is_engine_upgrade(EngineKind::Pcre2, EngineKind::RustRegex));
655        }
656    }
657}