Skip to main content

pipa/regexp/
parser.rs

1use super::ast::{Ast, CharClass, ParseState, Quantifier};
2use super::charclass::{CharRange, class_digit, class_space, class_word};
3
4pub fn parse(pattern: &str, flags: u16) -> Result<Ast, String> {
5    let mut state = ParseState::new(pattern, flags);
6    let ast = parse_disjunction(&mut state)?;
7
8    if !state.is_eof() {
9        return Err(format!("Unexpected character at position {}", state.pos));
10    }
11
12    Ok(ast.simplify())
13}
14
15fn parse_disjunction(state: &mut ParseState) -> Result<Ast, String> {
16    let mut alternatives = vec![parse_alternative(state)?];
17
18    while state.peek() == Some('|') {
19        state.next();
20        alternatives.push(parse_alternative(state)?);
21    }
22
23    if alternatives.len() == 1 {
24        Ok(alternatives.into_iter().next().unwrap())
25    } else {
26        Ok(Ast::Alt(alternatives))
27    }
28}
29
30fn parse_alternative(state: &mut ParseState) -> Result<Ast, String> {
31    let mut terms = Vec::new();
32
33    while let Some(c) = state.peek() {
34        if c == '|' || c == ')' {
35            break;
36        }
37        terms.push(parse_term(state)?);
38    }
39
40    if terms.is_empty() {
41        Ok(Ast::Empty)
42    } else if terms.len() == 1 {
43        Ok(terms.into_iter().next().unwrap())
44    } else {
45        Ok(Ast::Concat(terms))
46    }
47}
48
49fn parse_term(state: &mut ParseState) -> Result<Ast, String> {
50    let atom = parse_atom(state)?;
51
52    if let Some(q) = parse_quantifier(state)? {
53        Ok(Ast::Quant(Box::new(atom), q))
54    } else {
55        Ok(atom)
56    }
57}
58
59fn parse_quantifier(state: &mut ParseState) -> Result<Option<Quantifier>, String> {
60    let c = match state.peek() {
61        Some(c) if "*+?{".contains(c) => c,
62        _ => return Ok(None),
63    };
64
65    let q = match c {
66        '*' => {
67            state.next();
68            Quantifier::star()
69        }
70        '+' => {
71            state.next();
72            Quantifier::plus()
73        }
74        '?' => {
75            state.next();
76            Quantifier::question()
77        }
78        '{' => parse_brace_quantifier(state)?,
79        _ => unreachable!(),
80    };
81
82    let q = if state.peek() == Some('?') {
83        state.next();
84        Quantifier { greedy: false, ..q }
85    } else {
86        q
87    };
88
89    Ok(Some(q))
90}
91
92fn parse_brace_quantifier(state: &mut ParseState) -> Result<Quantifier, String> {
93    state.expect('{')?;
94
95    let min = state
96        .parse_number()
97        .ok_or_else(|| "Expected number in quantifier".to_string())?;
98
99    let max = if state.peek() == Some(',') {
100        state.next();
101        if state.peek() == Some('}') {
102            None
103        } else {
104            Some(
105                state
106                    .parse_number()
107                    .ok_or_else(|| "Expected number in quantifier".to_string())?,
108            )
109        }
110    } else {
111        Some(min)
112    };
113
114    state.expect('}')?;
115
116    if let Some(max_val) = max {
117        if min > max_val {
118            return Err("Invalid quantifier: min > max".to_string());
119        }
120    }
121
122    Ok(Quantifier::range(min, max))
123}
124
125fn parse_atom(state: &mut ParseState) -> Result<Ast, String> {
126    match state.peek() {
127        None => Err("Unexpected end of pattern".to_string()),
128        Some('^') => {
129            state.next();
130            Ok(Ast::StartOfLine)
131        }
132        Some('$') => {
133            state.next();
134            Ok(Ast::EndOfLine)
135        }
136        Some('.') => {
137            state.next();
138            if state.dot_all {
139                Ok(Ast::AnyAll)
140            } else {
141                Ok(Ast::Any)
142            }
143        }
144        Some('\\') => parse_escape(state),
145        Some('(') => parse_group(state),
146        Some('[') => parse_char_class(state),
147        Some(c) if is_special_char(c) => Err(format!(
148            "Unexpected special character '{}' at position {}",
149            c, state.pos
150        )),
151        Some(_) => {
152            let c = state.next().unwrap();
153            Ok(Ast::Char(c))
154        }
155    }
156}
157
158fn parse_escape(state: &mut ParseState) -> Result<Ast, String> {
159    state.expect('\\')?;
160
161    let c = state
162        .next()
163        .ok_or_else(|| "Unexpected end after \\".to_string())?;
164
165    match c {
166        'b' => Ok(Ast::Char('\x08')),
167        'f' => Ok(Ast::Char('\x0C')),
168        'n' => Ok(Ast::Char('\n')),
169        'r' => Ok(Ast::Char('\r')),
170        't' => Ok(Ast::Char('\t')),
171        'v' => Ok(Ast::Char('\x0B')),
172        'd' => Ok(Ast::Class(CharClass {
173            negated: false,
174            ranges: class_digit(),
175        })),
176        'D' => Ok(Ast::Class(CharClass {
177            negated: true,
178            ranges: class_digit(),
179        })),
180        's' => Ok(Ast::Class(CharClass {
181            negated: false,
182            ranges: class_space(),
183        })),
184        'S' => Ok(Ast::Class(CharClass {
185            negated: true,
186            ranges: class_space(),
187        })),
188        'w' => Ok(Ast::Class(CharClass {
189            negated: false,
190            ranges: class_word(),
191        })),
192        'W' => Ok(Ast::Class(CharClass {
193            negated: true,
194            ranges: class_word(),
195        })),
196        'B' => Ok(Ast::NotWordBoundary),
197        'x' => parse_hex_escape(state, 2),
198        'u' => parse_unicode_escape(state),
199        '0' => {
200            if state.is_unicode {
201                if state.peek().map_or(false, |c| c.is_ascii_digit()) {
202                    return Err("Invalid octal escape in unicode mode".to_string());
203                }
204                Ok(Ast::Char('\0'))
205            } else {
206                parse_octal_escape(state)
207            }
208        }
209        '1'..='9' => {
210            let start = state.pos - 1;
211            let num = state.parse_number().unwrap_or(0);
212
213            if num > 0 && num < state.capture_count as u32 {
214                Ok(Ast::BackRef(num as usize))
215            } else if !state.is_unicode {
216                state.pos = start;
217                parse_octal_escape(state)
218            } else {
219                Err(format!("Invalid backreference: {}", num))
220            }
221        }
222        'p' | 'P' if state.is_unicode => parse_unicode_property(state, c == 'P'),
223        'k' => parse_named_backref(state),
224        c => Ok(Ast::Char(c)),
225    }
226}
227
228fn parse_group(state: &mut ParseState) -> Result<Ast, String> {
229    state.expect('(')?;
230
231    if state.peek() == Some('?') {
232        state.next();
233
234        match state.peek() {
235            Some(':') => {
236                state.next();
237                let inner = parse_disjunction(state)?;
238                state.expect(')')?;
239                Ok(inner)
240            }
241            Some('=') => {
242                state.next();
243                let inner = parse_disjunction(state)?;
244                state.expect(')')?;
245                Ok(Ast::Lookahead(Box::new(inner)))
246            }
247            Some('!') => {
248                state.next();
249                let inner = parse_disjunction(state)?;
250                state.expect(')')?;
251                Ok(Ast::NegativeLookahead(Box::new(inner)))
252            }
253            Some('<') => {
254                state.next();
255                let name = parse_group_name(state)?;
256                state.expect('>')?;
257
258                state.capture_count += 1;
259                state.named_groups.push(name.clone());
260
261                let inner = parse_disjunction(state)?;
262                state.expect(')')?;
263
264                Ok(Ast::Capture(Box::new(inner), Some(name)))
265            }
266            Some(c) => Err(format!("Unknown group extension: ?{}", c)),
267            None => Err("Unexpected end after (?".to_string()),
268        }
269    } else {
270        state.capture_count += 1;
271
272        let inner = parse_disjunction(state)?;
273        state.expect(')')?;
274
275        Ok(Ast::Capture(Box::new(inner), None))
276    }
277}
278
279fn parse_char_class(state: &mut ParseState) -> Result<Ast, String> {
280    state.expect('[')?;
281
282    let negated = if state.peek() == Some('^') {
283        state.next();
284        true
285    } else {
286        false
287    };
288
289    let mut ranges = CharRange::new();
290    let mut prev_char: Option<char> = None;
291
292    loop {
293        match state.peek() {
294            None => return Err("Unterminated character class".to_string()),
295            Some(']') if prev_char.is_some() => {
296                state.next();
297                break;
298            }
299            Some(']') if ranges.is_empty() => {
300                state.next();
301                ranges.add_char(']');
302                prev_char = Some(']');
303            }
304            Some(']') => {
305                state.next();
306                break;
307            }
308            Some('-') if prev_char.is_some() && state.peek_at(1) != Some(']') => {
309                state.next();
310                let end = parse_class_atom(state)?;
311                if let Some(start) = prev_char {
312                    ranges.add_range(start as u32, end as u32);
313                }
314                prev_char = None;
315            }
316            Some('-') => {
317                state.next();
318                ranges.add_char('-');
319                prev_char = Some('-');
320            }
321            Some('\\') => {
322                let atom = parse_class_escape(state)?;
323                if let Some(c) = atom {
324                    ranges.add_char(c);
325                    prev_char = Some(c);
326                } else {
327                    prev_char = None;
328                }
329            }
330            Some(c) => {
331                state.next();
332                ranges.add_char(c);
333                prev_char = Some(c);
334            }
335        }
336    }
337
338    if negated {
339        ranges.invert();
340    }
341
342    Ok(Ast::Class(CharClass { negated, ranges }))
343}
344
345fn parse_class_atom(state: &mut ParseState) -> Result<char, String> {
346    match state.peek() {
347        None => Err("Unexpected end in character class".to_string()),
348        Some('\\') => parse_class_escape(state)?
349            .ok_or_else(|| "Class escapes not allowed in range".to_string()),
350        Some(c) => {
351            state.next();
352            Ok(c)
353        }
354    }
355}
356
357fn parse_class_escape(state: &mut ParseState) -> Result<Option<char>, String> {
358    state.expect('\\')?;
359
360    let c = state
361        .next()
362        .ok_or_else(|| "Unexpected end after \\".to_string())?;
363
364    match c {
365        'b' => Ok(Some('\x08')),
366        'f' => Ok(Some('\x0C')),
367        'n' => Ok(Some('\n')),
368        'r' => Ok(Some('\r')),
369        't' => Ok(Some('\t')),
370        'v' => Ok(Some('\x0B')),
371        'x' => parse_hex_escape_char(state, 2),
372        'u' => parse_unicode_escape_char(state),
373        '0' if state.is_unicode => {
374            if state.peek().map_or(false, |c| c.is_ascii_digit()) {
375                Err("Invalid octal escape".to_string())
376            } else {
377                Ok(Some('\0'))
378            }
379        }
380
381        'd' | 'D' | 's' | 'S' | 'w' | 'W' => Ok(None),
382        c => Ok(Some(c)),
383    }
384}
385
386fn parse_hex_escape(state: &mut ParseState, digits: usize) -> Result<Ast, String> {
387    let c = parse_hex_escape_char(state, digits)?.ok_or("Invalid hex escape")?;
388    Ok(Ast::Char(c))
389}
390
391fn parse_hex_escape_char(state: &mut ParseState, digits: usize) -> Result<Option<char>, String> {
392    let mut val: u32 = 0;
393    for _ in 0..digits {
394        let c = state
395            .next()
396            .ok_or_else(|| "Unexpected end in hex escape".to_string())?;
397        let digit = c
398            .to_digit(16)
399            .ok_or_else(|| format!("Invalid hex digit: {}", c))?;
400        val = val * 16 + digit;
401    }
402    char::from_u32(val)
403        .map(Some)
404        .ok_or_else(|| "Invalid Unicode code point".to_string())
405}
406
407fn parse_unicode_escape(state: &mut ParseState) -> Result<Ast, String> {
408    let c = parse_unicode_escape_char(state)?;
409    Ok(Ast::Char(c.unwrap_or('\u{FFFD}')))
410}
411
412fn parse_unicode_escape_char(state: &mut ParseState) -> Result<Option<char>, String> {
413    if state.peek() == Some('{') {
414        state.next();
415        let mut val: u32 = 0;
416        loop {
417            match state.next() {
418                Some('}') => break,
419                Some(c) => {
420                    let digit = c
421                        .to_digit(16)
422                        .ok_or_else(|| format!("Invalid hex digit: {}", c))?;
423                    val = val * 16 + digit;
424                    if val > 0x10FFFF {
425                        return Err("Unicode code point out of range".to_string());
426                    }
427                }
428                None => return Err("Unterminated Unicode escape".to_string()),
429            }
430        }
431        char::from_u32(val)
432            .map(Some)
433            .ok_or_else(|| "Invalid Unicode code point".to_string())
434    } else {
435        parse_hex_escape_char(state, 4)
436    }
437}
438
439fn parse_octal_escape(state: &mut ParseState) -> Result<Ast, String> {
440    let mut val: u32 = 0;
441    let mut count = 0;
442
443    while count < 3 {
444        match state.peek() {
445            Some(c) if c >= '0' && c <= '7' => {
446                state.next();
447                val = val * 8 + (c as u32 - '0' as u32);
448                count += 1;
449            }
450            _ => break,
451        }
452    }
453
454    if count == 0 {
455        Ok(Ast::Char('\0'))
456    } else {
457        char::from_u32(val)
458            .map(Ast::Char)
459            .ok_or_else(|| "Invalid octal escape".to_string())
460    }
461}
462
463fn parse_unicode_property(state: &mut ParseState, negated: bool) -> Result<Ast, String> {
464    state.expect('{')?;
465
466    let mut name = String::new();
467    loop {
468        match state.next() {
469            Some('}') => break,
470            Some(c) => name.push(c),
471            None => return Err("Unterminated property name".to_string()),
472        }
473    }
474
475    let ranges = match name.as_str() {
476        "ASCII" => CharRange::from_range('\0', '\x7F'),
477        "ASCII_Hex_Digit" => {
478            let mut r = CharRange::from_range('0', '9');
479            r.add_range('A' as u32, 'F' as u32);
480            r.add_range('a' as u32, 'f' as u32);
481            r
482        }
483        _ => CharRange::new(),
484    };
485
486    Ok(Ast::Class(CharClass { negated, ranges }))
487}
488
489fn parse_named_backref(state: &mut ParseState) -> Result<Ast, String> {
490    state.expect('<')?;
491    let name = parse_group_name(state)?;
492    state.expect('>')?;
493
494    Ok(Ast::NamedBackRef(name))
495}
496
497fn parse_group_name(state: &mut ParseState) -> Result<String, String> {
498    let mut name = String::new();
499
500    loop {
501        match state.peek() {
502            Some('>') | Some(')') => break,
503            Some(c) if c.is_alphanumeric() || c == '_' => {
504                state.next();
505                name.push(c);
506            }
507            Some(c) => return Err(format!("Invalid character in group name: {}", c)),
508            None => return Err("Unexpected end in group name".to_string()),
509        }
510    }
511
512    if name.is_empty() {
513        Err("Empty group name".to_string())
514    } else {
515        Ok(name)
516    }
517}
518
519fn is_special_char(c: char) -> bool {
520    matches!(
521        c,
522        '^' | '$' | '\\' | '.' | '*' | '+' | '?' | '(' | ')' | '[' | ']' | '{' | '}' | '|'
523    )
524}
525
526trait PeekAt {
527    fn peek_at(&self, offset: usize) -> Option<char>;
528}
529
530impl PeekAt for ParseState {
531    fn peek_at(&self, offset: usize) -> Option<char> {
532        self.pattern.get(self.pos + offset).copied()
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539
540    #[test]
541    fn test_parse_simple() {
542        let ast = parse("abc", 0).unwrap();
543        match ast {
544            Ast::Concat(nodes) => {
545                assert_eq!(nodes.len(), 3);
546            }
547            _ => panic!("Expected concat"),
548        }
549    }
550
551    #[test]
552    fn test_parse_alt() {
553        let ast = parse("a|b|c", 0).unwrap();
554        match ast {
555            Ast::Alt(nodes) => {
556                assert_eq!(nodes.len(), 3);
557            }
558            _ => panic!("Expected alt"),
559        }
560    }
561
562    #[test]
563    fn test_parse_quantifier() {
564        let ast = parse("a*", 0).unwrap();
565        match ast {
566            Ast::Quant(_inner, q) => {
567                assert_eq!(q.min, 0);
568                assert_eq!(q.max, None);
569                assert!(q.greedy);
570            }
571            _ => panic!("Expected quant"),
572        }
573    }
574
575    #[test]
576    fn test_parse_capture() {
577        let ast = parse("(ab)+", 0).unwrap();
578        match ast {
579            Ast::Quant(inner, _) => match *inner {
580                Ast::Capture(_, None) => {}
581                _ => panic!("Expected capture"),
582            },
583            _ => panic!("Expected quant"),
584        }
585    }
586
587    #[test]
588    fn test_parse_char_class() {
589        let ast = parse("[a-z]", 0).unwrap();
590        match ast {
591            Ast::Class(cc) => {
592                assert!(!cc.negated);
593                assert!(cc.ranges.contains_char('a'));
594                assert!(cc.ranges.contains_char('z'));
595                assert!(!cc.ranges.contains_char('A'));
596            }
597            _ => panic!("Expected class"),
598        }
599    }
600}