Skip to main content

pipa/regexp/
compiler.rs

1use super::ast::{Ast, CharClass, Quantifier};
2
3use super::opcode::*;
4
5pub struct Compiler {
6    bytecode: Vec<u8>,
7
8    capture_count: usize,
9
10    ignore_case: bool,
11
12    char_ranges: Vec<super::charclass::CharRange>,
13}
14
15#[derive(Debug, Clone)]
16pub struct Program {
17    pub bytecode: Vec<u8>,
18
19    pub capture_count: usize,
20
21    pub flags: u16,
22
23    pub char_ranges: Vec<super::charclass::CharRange>,
24}
25
26impl Program {
27    pub fn flags(&self) -> u16 {
28        u16::from_le_bytes([self.bytecode[HEADER_FLAGS], self.bytecode[HEADER_FLAGS + 1]])
29    }
30
31    pub fn capture_count(&self) -> usize {
32        self.bytecode[HEADER_CAPTURE_COUNT] as usize
33    }
34
35    pub fn code(&self) -> &[u8] {
36        &self.bytecode[HEADER_LEN..]
37    }
38}
39
40pub fn compile(ast: &Ast, flags: u16) -> Result<Program, String> {
41    let mut compiler = Compiler::new(flags);
42
43    compiler.write_header_placeholder();
44
45    compiler.emit_op_u8(OpCode::SaveStart, 0);
46
47    compiler.compile_node(ast)?;
48
49    compiler.emit_op_u8(OpCode::SaveEnd, 0);
50
51    compiler.emit_op(OpCode::Success);
52
53    compiler.update_header(flags)?;
54
55    Ok(Program {
56        bytecode: compiler.bytecode,
57        capture_count: compiler.capture_count,
58        flags,
59        char_ranges: compiler.char_ranges,
60    })
61}
62
63impl Compiler {
64    fn new(flags: u16) -> Self {
65        Self {
66            bytecode: Vec::new(),
67            capture_count: 1,
68            ignore_case: (flags & FLAG_IGNORE_CASE) != 0,
69            char_ranges: Vec::new(),
70        }
71    }
72
73    fn write_header_placeholder(&mut self) {
74        self.bytecode.extend_from_slice(&[0, 0]);
75
76        self.bytecode.push(0);
77
78        self.bytecode.push(REG_COUNT as u8);
79
80        self.bytecode.extend_from_slice(&[0, 0, 0, 0]);
81    }
82
83    fn update_header(&mut self, flags: u16) -> Result<(), String> {
84        let flag_bytes = flags.to_le_bytes();
85        self.bytecode[HEADER_FLAGS] = flag_bytes[0];
86        self.bytecode[HEADER_FLAGS + 1] = flag_bytes[1];
87
88        if self.capture_count > MAX_CAPTURES {
89            return Err(format!("Too many capture groups: {}", self.capture_count));
90        }
91        self.bytecode[HEADER_CAPTURE_COUNT] = self.capture_count as u8;
92
93        let len = self.bytecode.len() - HEADER_LEN;
94        let len_bytes = (len as u32).to_le_bytes();
95        self.bytecode[HEADER_CODE_LEN..HEADER_CODE_LEN + 4].copy_from_slice(&len_bytes);
96
97        Ok(())
98    }
99
100    fn compile_node(&mut self, node: &Ast) -> Result<(), String> {
101        match node {
102            Ast::Empty => Ok(()),
103            Ast::Char(c) => self.compile_char(*c),
104            Ast::Class(class) => self.compile_class(class),
105            Ast::Any => {
106                self.emit_op(OpCode::MatchDot);
107                Ok(())
108            }
109            Ast::AnyAll => {
110                self.emit_op(OpCode::MatchAny);
111                Ok(())
112            }
113            Ast::StartOfLine => {
114                self.emit_op(OpCode::CheckLineStart);
115                Ok(())
116            }
117            Ast::EndOfLine => {
118                self.emit_op(OpCode::CheckLineEnd);
119                Ok(())
120            }
121            Ast::WordBoundary => {
122                if self.ignore_case {
123                    self.emit_op(OpCode::CheckWordBoundaryI);
124                } else {
125                    self.emit_op(OpCode::CheckWordBoundary);
126                }
127                Ok(())
128            }
129            Ast::NotWordBoundary => {
130                if self.ignore_case {
131                    self.emit_op(OpCode::CheckNotWordBoundaryI);
132                } else {
133                    self.emit_op(OpCode::CheckNotWordBoundary);
134                }
135                Ok(())
136            }
137            Ast::Concat(nodes) => {
138                for node in nodes {
139                    self.compile_node(node)?;
140                }
141                Ok(())
142            }
143            Ast::Alt(nodes) => self.compile_alt(nodes),
144            Ast::Quant(inner, q) => self.compile_quant(inner, q),
145            Ast::Capture(inner, _name) => self.compile_capture(inner),
146            Ast::BackRef(idx) => self.compile_backref(*idx),
147            Ast::NamedBackRef(name) => Err(format!("Named backref not yet implemented: {}", name)),
148            Ast::Lookahead(inner) => self.compile_lookahead(inner, false),
149            Ast::NegativeLookahead(inner) => self.compile_lookahead(inner, true),
150        }
151    }
152
153    fn compile_char(&mut self, c: char) -> Result<(), String> {
154        let cp = c as u32;
155
156        if self.ignore_case {
157            let folded = unicode_fold_simple(cp);
158            if folded > 0xFFFF {
159                self.emit_match_char32_i(REG_POS, folded);
160            } else {
161                self.emit_match_char_i(REG_POS, folded as u16);
162            }
163        } else {
164            if cp > 0xFFFF {
165                self.emit_match_char32(REG_POS, cp);
166            } else {
167                self.emit_match_char(REG_POS, cp as u16);
168            }
169        }
170        Ok(())
171    }
172
173    fn compile_class(&mut self, class: &CharClass) -> Result<(), String> {
174        let range_idx = self.char_ranges.len();
175        self.char_ranges.push(class.ranges.clone());
176
177        if range_idx > u16::MAX as usize {
178            return Err("Too many character classes".to_string());
179        }
180
181        let opcode = if self.ignore_case {
182            OpCode::MatchClassI
183        } else {
184            OpCode::MatchClass
185        };
186
187        self.bytecode.push(opcode as u8);
188        self.bytecode.push(REG_POS as u8);
189        self.bytecode
190            .extend_from_slice(&(range_idx as u16).to_le_bytes());
191
192        Ok(())
193    }
194
195    fn compile_alt(&mut self, nodes: &[Ast]) -> Result<(), String> {
196        if nodes.is_empty() {
197            return Ok(());
198        }
199        if nodes.len() == 1 {
200            return self.compile_node(&nodes[0]);
201        }
202
203        let mut jump_offsets = Vec::new();
204
205        for (i, node) in nodes.iter().enumerate() {
206            if i < nodes.len() - 1 {
207                let push_pos = self.bytecode.len();
208                self.bytecode.push(OpCode::PushBacktrack as u8);
209                self.bytecode.extend_from_slice(&0i32.to_le_bytes());
210
211                self.compile_node(node)?;
212
213                let jmp_pos = self.bytecode.len();
214                self.bytecode.push(OpCode::Jmp as u8);
215                self.bytecode.extend_from_slice(&0i32.to_le_bytes());
216                jump_offsets.push(jmp_pos);
217
218                let fail_target = self.bytecode.len();
219                let push_offset = (fail_target as i32 - push_pos as i32 - 5) as i32;
220                self.bytecode[push_pos + 1..push_pos + 5]
221                    .copy_from_slice(&push_offset.to_le_bytes());
222            } else {
223                self.compile_node(node)?;
224            }
225        }
226
227        let end_pos = self.bytecode.len();
228        for jmp_pos in jump_offsets {
229            let offset = (end_pos as i32 - jmp_pos as i32 - 5) as i32;
230            self.bytecode[jmp_pos + 1..jmp_pos + 5].copy_from_slice(&offset.to_le_bytes());
231        }
232
233        Ok(())
234    }
235
236    fn compile_quant(&mut self, inner: &Ast, q: &Quantifier) -> Result<(), String> {
237        let min = q.min;
238        let max = q.max.unwrap_or(usize::MAX as u32) as usize;
239        let greedy = q.greedy;
240
241        if min == 0 && max == 0 {
242            return Ok(());
243        }
244
245        if min == 1 && max == 1 {
246            return self.compile_node(inner);
247        }
248
249        if min == 0 && max == 1 {
250            return self.compile_optional(inner, greedy);
251        }
252
253        if min == 0 && max == usize::MAX {
254            return self.compile_star(inner, greedy);
255        }
256
257        if min == 1 && max == usize::MAX {
258            return self.compile_plus(inner, greedy);
259        }
260
261        self.compile_repeat(inner, min as usize, max, greedy)
262    }
263
264    fn compile_optional(&mut self, inner: &Ast, greedy: bool) -> Result<(), String> {
265        if greedy {
266            let push_pos = self.bytecode.len();
267            self.bytecode.push(OpCode::PushBacktrack as u8);
268            self.bytecode.extend_from_slice(&0i32.to_le_bytes());
269
270            self.compile_node(inner)?;
271
272            let jmp_pos = self.bytecode.len();
273            self.bytecode.push(OpCode::Jmp as u8);
274            self.bytecode.extend_from_slice(&0i32.to_le_bytes());
275
276            let skip_target = self.bytecode.len();
277            let push_offset = (skip_target as i32 - push_pos as i32 - 5) as i32;
278            self.bytecode[push_pos + 1..push_pos + 5].copy_from_slice(&push_offset.to_le_bytes());
279
280            let done_target = self.bytecode.len();
281            let jmp_offset = (done_target as i32 - jmp_pos as i32 - 5) as i32;
282            self.bytecode[jmp_pos + 1..jmp_pos + 5].copy_from_slice(&jmp_offset.to_le_bytes());
283        } else {
284            let push_pos = self.bytecode.len();
285            self.bytecode.push(OpCode::PushBacktrack as u8);
286            self.bytecode.extend_from_slice(&0i32.to_le_bytes());
287
288            let jmp_pos = self.bytecode.len();
289            self.bytecode.push(OpCode::Jmp as u8);
290            self.bytecode.extend_from_slice(&0i32.to_le_bytes());
291
292            let match_target = self.bytecode.len();
293            let push_offset = (match_target as i32 - push_pos as i32 - 5) as i32;
294            self.bytecode[push_pos + 1..push_pos + 5].copy_from_slice(&push_offset.to_le_bytes());
295
296            self.compile_node(inner)?;
297
298            let done_target = self.bytecode.len();
299            let jmp_offset = (done_target as i32 - jmp_pos as i32 - 5) as i32;
300            self.bytecode[jmp_pos + 1..jmp_pos + 5].copy_from_slice(&jmp_offset.to_le_bytes());
301        }
302
303        Ok(())
304    }
305
306    fn compile_star(&mut self, inner: &Ast, _greedy: bool) -> Result<(), String> {
307        let start_pos = self.bytecode.len();
308
309        let push_pos = self.bytecode.len();
310        self.bytecode.push(OpCode::PushBacktrack as u8);
311        self.bytecode.extend_from_slice(&0i32.to_le_bytes());
312
313        self.compile_node(inner)?;
314
315        self.bytecode.push(OpCode::Jmp as u8);
316        let loop_offset = (start_pos as i32 - self.bytecode.len() as i32 - 5) as i32;
317        self.bytecode.extend_from_slice(&loop_offset.to_le_bytes());
318
319        let done_pos = self.bytecode.len();
320        let push_offset = (done_pos as i32 - push_pos as i32 - 5) as i32;
321        self.bytecode[push_pos + 1..push_pos + 5].copy_from_slice(&push_offset.to_le_bytes());
322
323        Ok(())
324    }
325
326    fn compile_plus(&mut self, inner: &Ast, greedy: bool) -> Result<(), String> {
327        self.compile_node(inner)?;
328
329        self.compile_star(inner, greedy)
330    }
331
332    fn compile_repeat(
333        &mut self,
334        inner: &Ast,
335        min: usize,
336        max: usize,
337        _greedy: bool,
338    ) -> Result<(), String> {
339        let counter_reg = REG_COUNTER;
340
341        self.emit_mov_imm(counter_reg, 0);
342
343        let min_start = self.bytecode.len();
344
345        self.bytecode.push(OpCode::CmpImm as u8);
346        self.bytecode.push(counter_reg as u8);
347        self.bytecode.extend_from_slice(&(min as u32).to_le_bytes());
348
349        let cmp_pos = self.bytecode.len();
350        self.bytecode.push(OpCode::JmpNe as u8);
351        self.bytecode.push(counter_reg as u8);
352        self.bytecode.extend_from_slice(&(min as u32).to_le_bytes());
353        self.bytecode.extend_from_slice(&0i32.to_le_bytes());
354
355        self.compile_node(inner)?;
356
357        self.bytecode.push(OpCode::Inc as u8);
358        self.bytecode.push(counter_reg as u8);
359
360        self.bytecode.push(OpCode::Jmp as u8);
361        let loop_offset = (min_start as i32 - self.bytecode.len() as i32 - 5) as i32;
362        self.bytecode.extend_from_slice(&loop_offset.to_le_bytes());
363
364        let opt_start = self.bytecode.len();
365        let jmp_offset = (opt_start as i32 - cmp_pos as i32 - 10) as i32;
366        self.bytecode[cmp_pos + 6..cmp_pos + 10].copy_from_slice(&jmp_offset.to_le_bytes());
367
368        if max > min && max < usize::MAX {
369            self.emit_mov_imm(counter_reg, 0);
370
371            let opt_loop_start = self.bytecode.len();
372
373            self.bytecode.push(OpCode::CmpImm as u8);
374            self.bytecode.push(counter_reg as u8);
375            self.bytecode
376                .extend_from_slice(&((max - min) as u32).to_le_bytes());
377
378            let cmp2_pos = self.bytecode.len();
379            self.bytecode.push(OpCode::JmpNe as u8);
380            self.bytecode.push(counter_reg as u8);
381            self.bytecode
382                .extend_from_slice(&((max - min) as u32).to_le_bytes());
383            self.bytecode.extend_from_slice(&0i32.to_le_bytes());
384
385            let push_pos = self.bytecode.len();
386            self.bytecode.push(OpCode::PushBacktrack as u8);
387            self.bytecode.extend_from_slice(&0i32.to_le_bytes());
388
389            self.compile_node(inner)?;
390
391            self.bytecode.push(OpCode::Inc as u8);
392            self.bytecode.push(counter_reg as u8);
393
394            self.bytecode.push(OpCode::Jmp as u8);
395            let loop2_offset = (opt_loop_start as i32 - self.bytecode.len() as i32 - 5) as i32;
396            self.bytecode.extend_from_slice(&loop2_offset.to_le_bytes());
397
398            let end_pos = self.bytecode.len();
399            let push2_offset = (end_pos as i32 - push_pos as i32 - 5) as i32;
400            self.bytecode[push_pos + 1..push_pos + 5].copy_from_slice(&push2_offset.to_le_bytes());
401
402            let jmp2_offset = (end_pos as i32 - cmp2_pos as i32 - 10) as i32;
403            self.bytecode[cmp2_pos + 6..cmp2_pos + 10].copy_from_slice(&jmp2_offset.to_le_bytes());
404        }
405
406        Ok(())
407    }
408
409    fn compile_capture(&mut self, inner: &Ast) -> Result<(), String> {
410        let capture_idx = self.capture_count;
411        self.capture_count += 1;
412
413        self.emit_op_u8(OpCode::SaveStart, capture_idx as u8);
414        self.compile_node(inner)?;
415        self.emit_op_u8(OpCode::SaveEnd, capture_idx as u8);
416
417        Ok(())
418    }
419
420    fn compile_backref(&mut self, idx: usize) -> Result<(), String> {
421        if idx >= MAX_CAPTURES {
422            return Err(format!("Backreference index too large: {}", idx));
423        }
424
425        let opcode = if self.ignore_case {
426            OpCode::CheckBackrefI
427        } else {
428            OpCode::CheckBackref
429        };
430
431        self.emit_op_u8(opcode, idx as u8);
432        Ok(())
433    }
434
435    fn compile_lookahead(&mut self, inner: &Ast, negative: bool) -> Result<(), String> {
436        self.bytecode.push(OpCode::Mark as u8);
437        self.bytecode.push(REG_MARK as u8);
438
439        let push_pos = self.bytecode.len();
440        self.bytecode.push(OpCode::PushBacktrack as u8);
441        self.bytecode.extend_from_slice(&0i32.to_le_bytes());
442
443        self.compile_node(inner)?;
444
445        self.bytecode.push(OpCode::PopBacktrack as u8);
446
447        if negative {
448            self.bytecode.push(OpCode::Restore as u8);
449            self.bytecode.push(REG_MARK as u8);
450            self.bytecode.push(OpCode::Fail as u8);
451        }
452
453        self.bytecode.push(OpCode::Restore as u8);
454        self.bytecode.push(REG_MARK as u8);
455
456        let end_pos = self.bytecode.len();
457        let offset = (end_pos as i32 - push_pos as i32 - 5) as i32;
458        self.bytecode[push_pos + 1..push_pos + 5].copy_from_slice(&offset.to_le_bytes());
459
460        if !negative {
461            self.bytecode.push(OpCode::Fail as u8);
462        }
463
464        Ok(())
465    }
466
467    fn emit_op(&mut self, op: OpCode) {
468        self.bytecode.push(op as u8);
469    }
470
471    fn emit_op_u8(&mut self, op: OpCode, val: u8) {
472        self.bytecode.push(op as u8);
473        self.bytecode.push(val);
474    }
475
476    fn emit_match_char(&mut self, reg: usize, ch: u16) {
477        self.bytecode.push(OpCode::MatchChar as u8);
478        self.bytecode.push(reg as u8);
479        self.bytecode.extend_from_slice(&ch.to_le_bytes());
480    }
481
482    fn emit_match_char_i(&mut self, reg: usize, ch: u16) {
483        self.bytecode.push(OpCode::MatchCharI as u8);
484        self.bytecode.push(reg as u8);
485        self.bytecode.extend_from_slice(&ch.to_le_bytes());
486    }
487
488    fn emit_match_char32(&mut self, reg: usize, ch: u32) {
489        self.bytecode.push(OpCode::MatchChar32 as u8);
490        self.bytecode.push(reg as u8);
491        self.bytecode.extend_from_slice(&ch.to_le_bytes());
492    }
493
494    fn emit_match_char32_i(&mut self, reg: usize, ch: u32) {
495        self.bytecode.push(OpCode::MatchChar32I as u8);
496        self.bytecode.push(reg as u8);
497        self.bytecode.extend_from_slice(&ch.to_le_bytes());
498    }
499
500    fn emit_mov_imm(&mut self, reg: usize, imm: usize) {
501        self.bytecode.push(OpCode::MovImm as u8);
502        self.bytecode.push(reg as u8);
503        self.bytecode.extend_from_slice(&(imm as u32).to_le_bytes());
504    }
505}
506
507fn unicode_fold_simple(c: u32) -> u32 {
508    if c < 128 {
509        if c >= b'A' as u32 && c <= b'Z' as u32 {
510            c + 32
511        } else {
512            c
513        }
514    } else {
515        c
516    }
517}
518
519#[cfg(test)]
520mod tests {
521    use super::super::parser::parse;
522    use super::*;
523
524    #[test]
525    fn test_compile_simple() {
526        let ast = parse("abc", 0).unwrap();
527        let prog = compile(&ast, 0).unwrap();
528        assert!(prog.bytecode.len() > HEADER_LEN);
529    }
530
531    #[test]
532    fn test_compile_capture() {
533        let ast = parse("(a)", 0).unwrap();
534        let prog = compile(&ast, 0).unwrap();
535        assert_eq!(prog.capture_count, 2);
536    }
537
538    #[test]
539    fn test_compile_alt() {
540        let ast = parse("a|b", 0).unwrap();
541        let prog = compile(&ast, 0).unwrap();
542        assert!(prog.bytecode.len() > HEADER_LEN);
543    }
544
545    #[test]
546    fn test_compile_quant() {
547        let ast = parse("a*", 0).unwrap();
548        let prog = compile(&ast, 0).unwrap();
549        assert!(prog.bytecode.len() > HEADER_LEN);
550    }
551}