Skip to main content

pipa/regexp/
mod.rs

1mod ast;
2mod char_table;
3mod charclass;
4pub mod compiler;
5mod dfa;
6mod engine;
7mod fast_class;
8mod fast_engine;
9mod memchr;
10pub mod opcode;
11pub mod optimizer;
12pub mod parser;
13pub mod pattern_matcher;
14mod pool;
15
16pub mod string_search;
17#[cfg(test)]
18mod tests;
19
20use std::fmt;
21
22pub use engine::Match;
23pub use fast_engine::FastRegex;
24pub use optimizer::OptimizedPattern;
25
26#[derive(Clone)]
27pub struct Regex {
28    fast_engine: Option<fast_engine::FastRegex>,
29
30    program: Option<compiler::Program>,
31
32    pattern: String,
33
34    flags: RegexFlags,
35}
36
37#[derive(Debug, Clone)]
38pub struct MatchResult<'a> {
39    text: &'a str,
40
41    start: usize,
42
43    end: usize,
44
45    captures: Vec<Option<&'a str>>,
46
47    capture_positions: Vec<(Option<usize>, Option<usize>)>,
48}
49
50impl<'a> MatchResult<'a> {
51    pub fn as_str(&self) -> &'a str {
52        &self.text[self.start..self.end]
53    }
54
55    pub fn start(&self) -> usize {
56        self.start
57    }
58
59    pub fn end(&self) -> usize {
60        self.end
61    }
62
63    pub fn range(&self) -> std::ops::Range<usize> {
64        self.start..self.end
65    }
66
67    pub fn get(&self, i: usize) -> Option<&'a str> {
68        self.captures.get(i).copied().flatten()
69    }
70
71    pub fn len(&self) -> usize {
72        self.captures.len()
73    }
74
75    pub fn is_empty(&self) -> bool {
76        self.captures.is_empty()
77    }
78
79    pub fn iter(&self) -> impl Iterator<Item = Option<&'a str>> + '_ {
80        self.captures.iter().copied()
81    }
82
83    pub fn positions(&self) -> &[(Option<usize>, Option<usize>)] {
84        &self.capture_positions
85    }
86}
87
88#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
89pub struct RegexFlags {
90    pub global: bool,
91
92    pub ignore_case: bool,
93
94    pub multi_line: bool,
95
96    pub dot_all: bool,
97
98    pub unicode: bool,
99
100    pub sticky: bool,
101}
102
103impl RegexFlags {
104    pub fn from_str(s: &str) -> Result<Self, String> {
105        let mut flags = Self::default();
106        for c in s.chars() {
107            match c {
108                'g' => flags.global = true,
109                'i' => flags.ignore_case = true,
110                'm' => flags.multi_line = true,
111                's' => flags.dot_all = true,
112                'u' => flags.unicode = true,
113                'y' => flags.sticky = true,
114                _ => return Err(format!("Invalid flag: {}", c)),
115            }
116        }
117        Ok(flags)
118    }
119
120    fn to_u16(&self) -> u16 {
121        let mut f = 0u16;
122        if self.global {
123            f |= opcode::FLAG_GLOBAL;
124        }
125        if self.ignore_case {
126            f |= opcode::FLAG_IGNORE_CASE;
127        }
128        if self.multi_line {
129            f |= opcode::FLAG_MULTI_LINE;
130        }
131        if self.dot_all {
132            f |= opcode::FLAG_DOT_ALL;
133        }
134        if self.unicode {
135            f |= opcode::FLAG_UNICODE;
136        }
137        if self.sticky {
138            f |= opcode::FLAG_STICKY;
139        }
140        f
141    }
142}
143
144impl fmt::Display for RegexFlags {
145    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146        if self.global {
147            write!(f, "g")?;
148        }
149        if self.ignore_case {
150            write!(f, "i")?;
151        }
152        if self.multi_line {
153            write!(f, "m")?;
154        }
155        if self.dot_all {
156            write!(f, "s")?;
157        }
158        if self.unicode {
159            write!(f, "u")?;
160        }
161        if self.sticky {
162            write!(f, "y")?;
163        }
164        Ok(())
165    }
166}
167
168#[derive(Debug, Default)]
169pub struct RegexBuilder {
170    pattern: String,
171    flags: RegexFlags,
172}
173
174impl RegexBuilder {
175    pub fn new(pattern: &str) -> Self {
176        Self {
177            pattern: pattern.to_string(),
178            flags: RegexFlags::default(),
179        }
180    }
181
182    pub fn global(mut self, value: bool) -> Self {
183        self.flags.global = value;
184        self
185    }
186
187    pub fn ignore_case(mut self, value: bool) -> Self {
188        self.flags.ignore_case = value;
189        self
190    }
191
192    pub fn multi_line(mut self, value: bool) -> Self {
193        self.flags.multi_line = value;
194        self
195    }
196
197    pub fn dot_all(mut self, value: bool) -> Self {
198        self.flags.dot_all = value;
199        self
200    }
201
202    pub fn unicode(mut self, value: bool) -> Self {
203        self.flags.unicode = value;
204        self
205    }
206
207    pub fn sticky(mut self, value: bool) -> Self {
208        self.flags.sticky = value;
209        self
210    }
211
212    pub fn build(self) -> Result<Regex, RegexError> {
213        Regex::new_with_flags(&self.pattern, self.flags)
214    }
215}
216
217#[derive(Debug, Clone, PartialEq)]
218pub enum RegexError {
219    Parse(String),
220
221    Compile(String),
222
223    Other(String),
224}
225
226impl fmt::Display for RegexError {
227    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228        match self {
229            RegexError::Parse(msg) => write!(f, "Parse error: {}", msg),
230            RegexError::Compile(msg) => write!(f, "Compile error: {}", msg),
231            RegexError::Other(msg) => write!(f, "Error: {}", msg),
232        }
233    }
234}
235
236impl std::error::Error for RegexError {}
237
238impl Regex {
239    pub fn new(pattern: &str) -> Result<Self, RegexError> {
240        Self::new_with_flags(pattern, RegexFlags::default())
241    }
242
243    pub fn new_with_flags(pattern: &str, flags: RegexFlags) -> Result<Self, RegexError> {
244        let flag_bits = flags.to_u16();
245
246        match fast_engine::FastRegex::new(pattern, flag_bits) {
247            Ok(fast) => {
248                return Ok(Self {
249                    fast_engine: Some(fast),
250                    program: None,
251                    pattern: pattern.to_string(),
252                    flags,
253                });
254            }
255            Err(_) => {
256                let ast = parser::parse(pattern, flag_bits).map_err(RegexError::Parse)?;
257
258                let program = compiler::compile(&ast, flag_bits).map_err(RegexError::Compile)?;
259
260                Ok(Self {
261                    fast_engine: None,
262                    program: Some(program),
263                    pattern: pattern.to_string(),
264                    flags,
265                })
266            }
267        }
268    }
269
270    pub fn with_flags(pattern: &str, flags: &str) -> Result<Self, RegexError> {
271        let flags = RegexFlags::from_str(flags).map_err(|e| RegexError::Parse(e))?;
272        Self::new_with_flags(pattern, flags)
273    }
274
275    pub fn pattern(&self) -> &str {
276        &self.pattern
277    }
278
279    pub fn flags(&self) -> RegexFlags {
280        self.flags
281    }
282
283    pub fn find<'a>(&self, input: &'a str) -> Option<MatchResult<'a>> {
284        if let Some(ref fast) = self.fast_engine {
285            if let Some(m) = fast.find(input) {
286                return self.match_result_from_engine(input, m);
287            }
288        }
289
290        if let Some(ref program) = self.program {
291            let m = engine::execute(program, input, 0)?;
292            return self.match_result_from_engine(input, m);
293        }
294
295        None
296    }
297
298    pub fn find_at<'a>(&self, input: &'a str, start: usize) -> Option<MatchResult<'a>> {
299        if let Some(ref program) = self.program {
300            let m = engine::execute(program, input, start)?;
301            return self.match_result_from_engine(input, m);
302        }
303
304        if let Some(ref fast) = self.fast_engine {
305            if let Some(m) = fast.find(&input[start..]) {
306                let shifted = Match {
307                    start: m.start + start,
308                    end: m.end + start,
309                    captures: m.captures,
310                };
311                return self.match_result_from_engine(input, shifted);
312            }
313        }
314
315        None
316    }
317
318    pub fn find_all<'a>(&self, input: &'a str) -> Vec<MatchResult<'a>> {
319        let is_ascii = is_ascii_fast(input);
320
321        if let Some(ref fast) = self.fast_engine {
322            fast.find_all(input)
323                .into_iter()
324                .filter_map(|m| self.match_result_from_engine_fast(input, m, is_ascii))
325                .collect()
326        } else if let Some(ref program) = self.program {
327            engine::find_all(program, input)
328                .into_iter()
329                .filter_map(|m| self.match_result_from_engine_fast(input, m, is_ascii))
330                .collect()
331        } else {
332            Vec::new()
333        }
334    }
335
336    fn match_result_from_engine_fast<'a>(
337        &self,
338        input: &'a str,
339        m: engine::Match,
340        is_ascii: bool,
341    ) -> Option<MatchResult<'a>> {
342        let (start_byte, end_byte) = if is_ascii {
343            (m.start, m.end)
344        } else {
345            let char_positions: Vec<usize> = input.char_indices().map(|(i, _)| i).collect();
346            let start_byte = char_positions.get(m.start).copied().unwrap_or(0);
347            let end_byte = char_positions.get(m.end).copied().unwrap_or(input.len());
348            (start_byte, end_byte)
349        };
350
351        let mut captures = Vec::with_capacity(m.captures.len());
352        for (start, end) in &m.captures {
353            let cap = match (start, end) {
354                (Some(s), Some(e)) => {
355                    if is_ascii {
356                        Some(&input[*s..*e])
357                    } else {
358                        let start_byte = input.char_indices().nth(*s).map(|(i, _)| i).unwrap_or(0);
359                        let end_byte = input
360                            .char_indices()
361                            .nth(*e)
362                            .map(|(i, _)| i)
363                            .unwrap_or(input.len());
364                        Some(&input[start_byte..end_byte])
365                    }
366                }
367                _ => None,
368            };
369            captures.push(cap);
370        }
371
372        Some(MatchResult {
373            text: input,
374            start: start_byte,
375            end: end_byte,
376            captures,
377            capture_positions: m.captures,
378        })
379    }
380
381    pub fn is_match(&self, input: &str) -> bool {
382        if let Some(ref fast) = self.fast_engine {
383            if let Some(ref dfa) = fast.dfa() {
384                return dfa.is_match(input);
385            }
386            return fast.is_match(input);
387        }
388
389        self.find(input).is_some()
390    }
391
392    pub fn is_full_match(&self, input: &str) -> bool {
393        if let Some(m) = self.find(input) {
394            m.start() == 0 && m.end() == input.len()
395        } else {
396            false
397        }
398    }
399
400    pub fn replace<'a>(&self, input: &'a str, replacement: &str) -> String {
401        self.replace_n(input, replacement, 1)
402    }
403
404    pub fn replace_all<'a>(&self, input: &'a str, replacement: &str) -> String {
405        self.replace_n(input, replacement, usize::MAX)
406    }
407
408    pub fn replace_n<'a>(&self, input: &'a str, replacement: &str, n: usize) -> String {
409        let mut result = String::new();
410        let mut last_end = 0;
411        let mut count = 0;
412
413        for m in self.find_all(input) {
414            if count >= n {
415                break;
416            }
417            result.push_str(&input[last_end..m.start()]);
418            result.push_str(replacement);
419            last_end = m.end();
420            count += 1;
421        }
422
423        result.push_str(&input[last_end..]);
424        result
425    }
426
427    pub fn replace_fn<'a, F>(&self, input: &'a str, f: F) -> String
428    where
429        F: Fn(&MatchResult) -> String,
430    {
431        let mut result = String::new();
432        let mut last_end = 0;
433
434        for m in self.find_all(input) {
435            result.push_str(&input[last_end..m.start()]);
436            result.push_str(&f(&m));
437            last_end = m.end();
438        }
439
440        result.push_str(&input[last_end..]);
441        result
442    }
443
444    pub fn capture_count(&self) -> usize {
445        if let Some(ref program) = self.program {
446            program.capture_count
447        } else {
448            1
449        }
450    }
451
452    fn match_result_from_engine<'a>(
453        &self,
454        input: &'a str,
455        m: engine::Match,
456    ) -> Option<MatchResult<'a>> {
457        let is_ascii = is_ascii_fast(input);
458        self.match_result_from_engine_fast(input, m, is_ascii)
459    }
460}
461
462#[inline(always)]
463fn is_ascii_fast(s: &str) -> bool {
464    let bytes = s.as_bytes();
465    let len = bytes.len();
466    let mut i = 0;
467
468    while i + 8 <= len {
469        let chunk = &bytes[i..i + 8];
470        if (chunk[0] | chunk[1] | chunk[2] | chunk[3] | chunk[4] | chunk[5] | chunk[6] | chunk[7])
471            >= 0x80
472        {
473            return false;
474        }
475        i += 8;
476    }
477
478    while i < len {
479        if bytes[i] >= 0x80 {
480            return false;
481        }
482        i += 1;
483    }
484
485    true
486}
487
488impl fmt::Debug for Regex {
489    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
490        write!(f, "Regex(/{}/{}", self.pattern, self.flags)
491    }
492}
493
494impl fmt::Display for Regex {
495    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
496        write!(f, "/{}/{}", self.pattern, self.flags)
497    }
498}
499
500pub fn find(pattern: &str, input: &str) -> Result<Option<String>, RegexError> {
501    let re = Regex::new(pattern)?;
502    Ok(re.find(input).map(|m| m.as_str().to_string()))
503}
504
505pub fn is_match(pattern: &str, input: &str) -> Result<bool, RegexError> {
506    let re = Regex::new(pattern)?;
507    Ok(re.is_match(input))
508}
509
510pub fn replace(pattern: &str, input: &str, replacement: &str) -> Result<String, RegexError> {
511    let re = Regex::new(pattern)?;
512    Ok(re.replace_all(input, replacement))
513}