Skip to main content

rgx/engine/
mod.rs

1pub mod fancy;
2#[cfg(feature = "pcre2-engine")]
3pub mod pcre2;
4pub mod rust_regex;
5
6use serde::Serialize;
7use std::fmt;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum EngineKind {
11    RustRegex,
12    FancyRegex,
13    #[cfg(feature = "pcre2-engine")]
14    Pcre2,
15}
16
17impl EngineKind {
18    pub fn all() -> Vec<EngineKind> {
19        vec![
20            EngineKind::RustRegex,
21            EngineKind::FancyRegex,
22            #[cfg(feature = "pcre2-engine")]
23            EngineKind::Pcre2,
24        ]
25    }
26
27    pub fn next(self) -> EngineKind {
28        match self {
29            EngineKind::RustRegex => EngineKind::FancyRegex,
30            #[cfg(feature = "pcre2-engine")]
31            EngineKind::FancyRegex => EngineKind::Pcre2,
32            #[cfg(not(feature = "pcre2-engine"))]
33            EngineKind::FancyRegex => EngineKind::RustRegex,
34            #[cfg(feature = "pcre2-engine")]
35            EngineKind::Pcre2 => EngineKind::RustRegex,
36        }
37    }
38}
39
40impl fmt::Display for EngineKind {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        match self {
43            EngineKind::RustRegex => write!(f, "Rust regex"),
44            EngineKind::FancyRegex => write!(f, "fancy-regex"),
45            #[cfg(feature = "pcre2-engine")]
46            EngineKind::Pcre2 => write!(f, "PCRE2"),
47        }
48    }
49}
50
51#[derive(Debug, Clone, Copy, Default)]
52pub struct EngineFlags {
53    pub case_insensitive: bool,
54    pub multi_line: bool,
55    pub dot_matches_newline: bool,
56    pub unicode: bool,
57    pub extended: bool,
58}
59
60impl EngineFlags {
61    pub fn to_inline_prefix(&self) -> String {
62        let mut s = String::new();
63        if self.case_insensitive {
64            s.push('i');
65        }
66        if self.multi_line {
67            s.push('m');
68        }
69        if self.dot_matches_newline {
70            s.push('s');
71        }
72        if self.unicode {
73            s.push('u');
74        }
75        if self.extended {
76            s.push('x');
77        }
78        s
79    }
80
81    pub fn wrap_pattern(&self, pattern: &str) -> String {
82        let prefix = self.to_inline_prefix();
83        if prefix.is_empty() {
84            pattern.to_string()
85        } else {
86            format!("(?{prefix}){pattern}")
87        }
88    }
89
90    pub fn toggle_case_insensitive(&mut self) {
91        self.case_insensitive = !self.case_insensitive;
92    }
93    pub fn toggle_multi_line(&mut self) {
94        self.multi_line = !self.multi_line;
95    }
96    pub fn toggle_dot_matches_newline(&mut self) {
97        self.dot_matches_newline = !self.dot_matches_newline;
98    }
99    pub fn toggle_unicode(&mut self) {
100        self.unicode = !self.unicode;
101    }
102    pub fn toggle_extended(&mut self) {
103        self.extended = !self.extended;
104    }
105}
106
107#[derive(Debug, Clone, Serialize)]
108pub struct Match {
109    #[serde(rename = "match")]
110    pub text: String,
111    pub start: usize,
112    pub end: usize,
113    #[serde(rename = "groups")]
114    pub captures: Vec<CaptureGroup>,
115}
116
117#[derive(Debug, Clone, Serialize)]
118pub struct CaptureGroup {
119    #[serde(rename = "group")]
120    pub index: usize,
121    #[serde(skip_serializing_if = "Option::is_none")]
122    pub name: Option<String>,
123    #[serde(rename = "value")]
124    pub text: String,
125    pub start: usize,
126    pub end: usize,
127}
128
129#[derive(Debug)]
130pub enum EngineError {
131    CompileError(String),
132    MatchError(String),
133}
134
135impl fmt::Display for EngineError {
136    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137        match self {
138            EngineError::CompileError(msg) => write!(f, "Compile error: {msg}"),
139            EngineError::MatchError(msg) => write!(f, "Match error: {msg}"),
140        }
141    }
142}
143
144impl std::error::Error for EngineError {}
145
146pub type EngineResult<T> = Result<T, EngineError>;
147
148pub trait RegexEngine: Send + Sync {
149    fn kind(&self) -> EngineKind;
150    fn compile(&self, pattern: &str, flags: &EngineFlags) -> EngineResult<Box<dyn CompiledRegex>>;
151}
152
153pub trait CompiledRegex: Send + Sync {
154    fn find_matches(&self, text: &str) -> EngineResult<Vec<Match>>;
155}
156
157pub fn create_engine(kind: EngineKind) -> Box<dyn RegexEngine> {
158    match kind {
159        EngineKind::RustRegex => Box::new(rust_regex::RustRegexEngine),
160        EngineKind::FancyRegex => Box::new(fancy::FancyRegexEngine),
161        #[cfg(feature = "pcre2-engine")]
162        EngineKind::Pcre2 => Box::new(pcre2::Pcre2Engine),
163    }
164}
165
166// --- Replace/Substitution support ---
167
168#[derive(Debug, Clone)]
169pub struct ReplaceSegment {
170    pub start: usize,
171    pub end: usize,
172    pub is_replacement: bool,
173}
174
175#[derive(Debug, Clone)]
176pub struct ReplaceResult {
177    pub output: String,
178    pub segments: Vec<ReplaceSegment>,
179}
180
181/// Expand a replacement template against a single match.
182///
183/// Supports: `$0` / `$&` (whole match), `$1`..`$99` (numbered groups),
184/// `${name}` (named groups), `$$` (literal `$`).
185fn expand_replacement(template: &str, m: &Match) -> String {
186    let mut result = String::new();
187    let mut chars = template.char_indices().peekable();
188
189    while let Some((_i, c)) = chars.next() {
190        if c == '$' {
191            match chars.peek() {
192                None => {
193                    result.push('$');
194                }
195                Some(&(_, '$')) => {
196                    chars.next();
197                    result.push('$');
198                }
199                Some(&(_, '&')) => {
200                    chars.next();
201                    result.push_str(&m.text);
202                }
203                Some(&(_, '{')) => {
204                    chars.next(); // consume '{'
205                    let brace_start = chars.peek().map(|&(idx, _)| idx).unwrap_or(template.len());
206                    if let Some(close) = template[brace_start..].find('}') {
207                        let ref_name = &template[brace_start..brace_start + close];
208                        if let Some(text) = lookup_capture(m, ref_name) {
209                            result.push_str(text);
210                        }
211                        // Advance past the content and closing brace
212                        let end_byte = brace_start + close + 1;
213                        while chars.peek().is_some_and(|&(idx, _)| idx < end_byte) {
214                            chars.next();
215                        }
216                    } else {
217                        result.push('$');
218                        result.push('{');
219                    }
220                }
221                Some(&(_, next_c)) if next_c.is_ascii_digit() => {
222                    let (_, d1) = chars.next().unwrap();
223                    let mut num_str = String::from(d1);
224                    // Grab a second digit if present
225                    if let Some(&(_, d2)) = chars.peek() {
226                        if d2.is_ascii_digit() {
227                            chars.next();
228                            num_str.push(d2);
229                        }
230                    }
231                    let idx: usize = num_str.parse().unwrap_or(0);
232                    if idx == 0 {
233                        result.push_str(&m.text);
234                    } else if let Some(cap) = m.captures.iter().find(|c| c.index == idx) {
235                        result.push_str(&cap.text);
236                    }
237                }
238                Some(_) => {
239                    result.push('$');
240                }
241            }
242        } else {
243            result.push(c);
244        }
245    }
246
247    result
248}
249
250/// Look up a capture by name or numeric string.
251pub fn lookup_capture<'a>(m: &'a Match, key: &str) -> Option<&'a str> {
252    // Try as number first
253    if let Ok(idx) = key.parse::<usize>() {
254        if idx == 0 {
255            return Some(&m.text);
256        }
257        return m
258            .captures
259            .iter()
260            .find(|c| c.index == idx)
261            .map(|c| c.text.as_str());
262    }
263    // Try as named capture
264    m.captures
265        .iter()
266        .find(|c| c.name.as_deref() == Some(key))
267        .map(|c| c.text.as_str())
268}
269
270/// Perform replacement across all matches, returning the output string and segment metadata.
271pub fn replace_all(text: &str, matches: &[Match], template: &str) -> ReplaceResult {
272    let mut output = String::new();
273    let mut segments = Vec::new();
274    let mut pos = 0;
275
276    for m in matches {
277        // Original text before this match
278        if m.start > pos {
279            let seg_start = output.len();
280            output.push_str(&text[pos..m.start]);
281            segments.push(ReplaceSegment {
282                start: seg_start,
283                end: output.len(),
284                is_replacement: false,
285            });
286        }
287        // Expanded replacement
288        let expanded = expand_replacement(template, m);
289        if !expanded.is_empty() {
290            let seg_start = output.len();
291            output.push_str(&expanded);
292            segments.push(ReplaceSegment {
293                start: seg_start,
294                end: output.len(),
295                is_replacement: true,
296            });
297        }
298        pos = m.end;
299    }
300
301    // Trailing original text
302    if pos < text.len() {
303        let seg_start = output.len();
304        output.push_str(&text[pos..]);
305        segments.push(ReplaceSegment {
306            start: seg_start,
307            end: output.len(),
308            is_replacement: false,
309        });
310    }
311
312    ReplaceResult { output, segments }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    fn make_match(start: usize, end: usize, text: &str, captures: Vec<CaptureGroup>) -> Match {
320        Match {
321            start,
322            end,
323            text: text.to_string(),
324            captures,
325        }
326    }
327
328    fn make_cap(
329        index: usize,
330        name: Option<&str>,
331        start: usize,
332        end: usize,
333        text: &str,
334    ) -> CaptureGroup {
335        CaptureGroup {
336            index,
337            name: name.map(|s| s.to_string()),
338            start,
339            end,
340            text: text.to_string(),
341        }
342    }
343
344    #[test]
345    fn test_replace_all_basic() {
346        let matches = vec![make_match(
347            0,
348            12,
349            "user@example",
350            vec![
351                make_cap(1, None, 0, 4, "user"),
352                make_cap(2, None, 5, 12, "example"),
353            ],
354        )];
355        let result = replace_all("user@example", &matches, "$2=$1");
356        assert_eq!(result.output, "example=user");
357    }
358
359    #[test]
360    fn test_replace_all_no_matches() {
361        let result = replace_all("hello world", &[], "replacement");
362        assert_eq!(result.output, "hello world");
363        assert_eq!(result.segments.len(), 1);
364        assert!(!result.segments[0].is_replacement);
365    }
366
367    #[test]
368    fn test_replace_all_empty_template() {
369        let matches = vec![
370            make_match(4, 7, "123", vec![]),
371            make_match(12, 15, "456", vec![]),
372        ];
373        let result = replace_all("abc 123 def 456 ghi", &matches, "");
374        assert_eq!(result.output, "abc  def  ghi");
375    }
376
377    #[test]
378    fn test_replace_all_literal_dollar() {
379        let matches = vec![make_match(0, 3, "foo", vec![])];
380        let result = replace_all("foo", &matches, "$$bar");
381        assert_eq!(result.output, "$bar");
382    }
383
384    #[test]
385    fn test_replace_all_named_groups() {
386        let matches = vec![make_match(
387            0,
388            7,
389            "2024-01",
390            vec![
391                make_cap(1, Some("y"), 0, 4, "2024"),
392                make_cap(2, Some("m"), 5, 7, "01"),
393            ],
394        )];
395        let result = replace_all("2024-01", &matches, "${m}/${y}");
396        assert_eq!(result.output, "01/2024");
397    }
398
399    #[test]
400    fn test_expand_replacement_whole_match() {
401        let m = make_match(0, 5, "hello", vec![]);
402        assert_eq!(expand_replacement("$0", &m), "hello");
403        assert_eq!(expand_replacement("$&", &m), "hello");
404        assert_eq!(expand_replacement("[$0]", &m), "[hello]");
405    }
406
407    #[test]
408    fn test_expand_replacement_non_ascii() {
409        let m = make_match(0, 5, "hello", vec![]);
410        // Non-ASCII characters in replacement template should work correctly
411        assert_eq!(expand_replacement("café $0", &m), "café hello");
412        assert_eq!(expand_replacement("→$0←", &m), "→hello←");
413        assert_eq!(expand_replacement("日本語", &m), "日本語");
414        assert_eq!(expand_replacement("über $& cool", &m), "über hello cool");
415    }
416
417    #[test]
418    fn test_replace_segments_tracking() {
419        let matches = vec![make_match(6, 9, "123", vec![])];
420        let result = replace_all("hello 123 world", &matches, "NUM");
421        assert_eq!(result.output, "hello NUM world");
422        assert_eq!(result.segments.len(), 3);
423        // "hello " - original
424        assert!(!result.segments[0].is_replacement);
425        assert_eq!(
426            &result.output[result.segments[0].start..result.segments[0].end],
427            "hello "
428        );
429        // "NUM" - replacement
430        assert!(result.segments[1].is_replacement);
431        assert_eq!(
432            &result.output[result.segments[1].start..result.segments[1].end],
433            "NUM"
434        );
435        // " world" - original
436        assert!(!result.segments[2].is_replacement);
437        assert_eq!(
438            &result.output[result.segments[2].start..result.segments[2].end],
439            " world"
440        );
441    }
442}