1use super::ast::{Ast, CharClass, Quantifier};
2
3use super::opcode::*;
4
5pub struct Compiler {
6 bytecode: Vec<u8>,
7
8 capture_count: usize,
9
10 ignore_case: bool,
11
12 char_ranges: Vec<super::charclass::CharRange>,
13}
14
15#[derive(Debug, Clone)]
16pub struct Program {
17 pub bytecode: Vec<u8>,
18
19 pub capture_count: usize,
20
21 pub flags: u16,
22
23 pub char_ranges: Vec<super::charclass::CharRange>,
24}
25
26impl Program {
27 pub fn flags(&self) -> u16 {
28 u16::from_le_bytes([self.bytecode[HEADER_FLAGS], self.bytecode[HEADER_FLAGS + 1]])
29 }
30
31 pub fn capture_count(&self) -> usize {
32 self.bytecode[HEADER_CAPTURE_COUNT] as usize
33 }
34
35 pub fn code(&self) -> &[u8] {
36 &self.bytecode[HEADER_LEN..]
37 }
38}
39
40pub fn compile(ast: &Ast, flags: u16) -> Result<Program, String> {
41 let mut compiler = Compiler::new(flags);
42
43 compiler.write_header_placeholder();
44
45 compiler.emit_op_u8(OpCode::SaveStart, 0);
46
47 compiler.compile_node(ast)?;
48
49 compiler.emit_op_u8(OpCode::SaveEnd, 0);
50
51 compiler.emit_op(OpCode::Success);
52
53 compiler.update_header(flags)?;
54
55 Ok(Program {
56 bytecode: compiler.bytecode,
57 capture_count: compiler.capture_count,
58 flags,
59 char_ranges: compiler.char_ranges,
60 })
61}
62
63impl Compiler {
64 fn new(flags: u16) -> Self {
65 Self {
66 bytecode: Vec::new(),
67 capture_count: 1,
68 ignore_case: (flags & FLAG_IGNORE_CASE) != 0,
69 char_ranges: Vec::new(),
70 }
71 }
72
73 fn write_header_placeholder(&mut self) {
74 self.bytecode.extend_from_slice(&[0, 0]);
75
76 self.bytecode.push(0);
77
78 self.bytecode.push(REG_COUNT as u8);
79
80 self.bytecode.extend_from_slice(&[0, 0, 0, 0]);
81 }
82
83 fn update_header(&mut self, flags: u16) -> Result<(), String> {
84 let flag_bytes = flags.to_le_bytes();
85 self.bytecode[HEADER_FLAGS] = flag_bytes[0];
86 self.bytecode[HEADER_FLAGS + 1] = flag_bytes[1];
87
88 if self.capture_count > MAX_CAPTURES {
89 return Err(format!("Too many capture groups: {}", self.capture_count));
90 }
91 self.bytecode[HEADER_CAPTURE_COUNT] = self.capture_count as u8;
92
93 let len = self.bytecode.len() - HEADER_LEN;
94 let len_bytes = (len as u32).to_le_bytes();
95 self.bytecode[HEADER_CODE_LEN..HEADER_CODE_LEN + 4].copy_from_slice(&len_bytes);
96
97 Ok(())
98 }
99
100 fn compile_node(&mut self, node: &Ast) -> Result<(), String> {
101 match node {
102 Ast::Empty => Ok(()),
103 Ast::Char(c) => self.compile_char(*c),
104 Ast::Class(class) => self.compile_class(class),
105 Ast::Any => {
106 self.emit_op(OpCode::MatchDot);
107 Ok(())
108 }
109 Ast::AnyAll => {
110 self.emit_op(OpCode::MatchAny);
111 Ok(())
112 }
113 Ast::StartOfLine => {
114 self.emit_op(OpCode::CheckLineStart);
115 Ok(())
116 }
117 Ast::EndOfLine => {
118 self.emit_op(OpCode::CheckLineEnd);
119 Ok(())
120 }
121 Ast::WordBoundary => {
122 if self.ignore_case {
123 self.emit_op(OpCode::CheckWordBoundaryI);
124 } else {
125 self.emit_op(OpCode::CheckWordBoundary);
126 }
127 Ok(())
128 }
129 Ast::NotWordBoundary => {
130 if self.ignore_case {
131 self.emit_op(OpCode::CheckNotWordBoundaryI);
132 } else {
133 self.emit_op(OpCode::CheckNotWordBoundary);
134 }
135 Ok(())
136 }
137 Ast::Concat(nodes) => {
138 for node in nodes {
139 self.compile_node(node)?;
140 }
141 Ok(())
142 }
143 Ast::Alt(nodes) => self.compile_alt(nodes),
144 Ast::Quant(inner, q) => self.compile_quant(inner, q),
145 Ast::Capture(inner, _name) => self.compile_capture(inner),
146 Ast::BackRef(idx) => self.compile_backref(*idx),
147 Ast::NamedBackRef(name) => Err(format!("Named backref not yet implemented: {}", name)),
148 Ast::Lookahead(inner) => self.compile_lookahead(inner, false),
149 Ast::NegativeLookahead(inner) => self.compile_lookahead(inner, true),
150 }
151 }
152
153 fn compile_char(&mut self, c: char) -> Result<(), String> {
154 let cp = c as u32;
155
156 if self.ignore_case {
157 let folded = unicode_fold_simple(cp);
158 if folded > 0xFFFF {
159 self.emit_match_char32_i(REG_POS, folded);
160 } else {
161 self.emit_match_char_i(REG_POS, folded as u16);
162 }
163 } else {
164 if cp > 0xFFFF {
165 self.emit_match_char32(REG_POS, cp);
166 } else {
167 self.emit_match_char(REG_POS, cp as u16);
168 }
169 }
170 Ok(())
171 }
172
173 fn compile_class(&mut self, class: &CharClass) -> Result<(), String> {
174 let range_idx = self.char_ranges.len();
175 self.char_ranges.push(class.ranges.clone());
176
177 if range_idx > u16::MAX as usize {
178 return Err("Too many character classes".to_string());
179 }
180
181 let opcode = if self.ignore_case {
182 OpCode::MatchClassI
183 } else {
184 OpCode::MatchClass
185 };
186
187 self.bytecode.push(opcode as u8);
188 self.bytecode.push(REG_POS as u8);
189 self.bytecode
190 .extend_from_slice(&(range_idx as u16).to_le_bytes());
191
192 Ok(())
193 }
194
195 fn compile_alt(&mut self, nodes: &[Ast]) -> Result<(), String> {
196 if nodes.is_empty() {
197 return Ok(());
198 }
199 if nodes.len() == 1 {
200 return self.compile_node(&nodes[0]);
201 }
202
203 let mut jump_offsets = Vec::new();
204
205 for (i, node) in nodes.iter().enumerate() {
206 if i < nodes.len() - 1 {
207 let push_pos = self.bytecode.len();
208 self.bytecode.push(OpCode::PushBacktrack as u8);
209 self.bytecode.extend_from_slice(&0i32.to_le_bytes());
210
211 self.compile_node(node)?;
212
213 let jmp_pos = self.bytecode.len();
214 self.bytecode.push(OpCode::Jmp as u8);
215 self.bytecode.extend_from_slice(&0i32.to_le_bytes());
216 jump_offsets.push(jmp_pos);
217
218 let fail_target = self.bytecode.len();
219 let push_offset = (fail_target as i32 - push_pos as i32 - 5) as i32;
220 self.bytecode[push_pos + 1..push_pos + 5]
221 .copy_from_slice(&push_offset.to_le_bytes());
222 } else {
223 self.compile_node(node)?;
224 }
225 }
226
227 let end_pos = self.bytecode.len();
228 for jmp_pos in jump_offsets {
229 let offset = (end_pos as i32 - jmp_pos as i32 - 5) as i32;
230 self.bytecode[jmp_pos + 1..jmp_pos + 5].copy_from_slice(&offset.to_le_bytes());
231 }
232
233 Ok(())
234 }
235
236 fn compile_quant(&mut self, inner: &Ast, q: &Quantifier) -> Result<(), String> {
237 let min = q.min;
238 let max = q.max.unwrap_or(usize::MAX as u32) as usize;
239 let greedy = q.greedy;
240
241 if min == 0 && max == 0 {
242 return Ok(());
243 }
244
245 if min == 1 && max == 1 {
246 return self.compile_node(inner);
247 }
248
249 if min == 0 && max == 1 {
250 return self.compile_optional(inner, greedy);
251 }
252
253 if min == 0 && max == usize::MAX {
254 return self.compile_star(inner, greedy);
255 }
256
257 if min == 1 && max == usize::MAX {
258 return self.compile_plus(inner, greedy);
259 }
260
261 self.compile_repeat(inner, min as usize, max, greedy)
262 }
263
264 fn compile_optional(&mut self, inner: &Ast, greedy: bool) -> Result<(), String> {
265 if greedy {
266 let push_pos = self.bytecode.len();
267 self.bytecode.push(OpCode::PushBacktrack as u8);
268 self.bytecode.extend_from_slice(&0i32.to_le_bytes());
269
270 self.compile_node(inner)?;
271
272 let jmp_pos = self.bytecode.len();
273 self.bytecode.push(OpCode::Jmp as u8);
274 self.bytecode.extend_from_slice(&0i32.to_le_bytes());
275
276 let skip_target = self.bytecode.len();
277 let push_offset = (skip_target as i32 - push_pos as i32 - 5) as i32;
278 self.bytecode[push_pos + 1..push_pos + 5].copy_from_slice(&push_offset.to_le_bytes());
279
280 let done_target = self.bytecode.len();
281 let jmp_offset = (done_target as i32 - jmp_pos as i32 - 5) as i32;
282 self.bytecode[jmp_pos + 1..jmp_pos + 5].copy_from_slice(&jmp_offset.to_le_bytes());
283 } else {
284 let push_pos = self.bytecode.len();
285 self.bytecode.push(OpCode::PushBacktrack as u8);
286 self.bytecode.extend_from_slice(&0i32.to_le_bytes());
287
288 let jmp_pos = self.bytecode.len();
289 self.bytecode.push(OpCode::Jmp as u8);
290 self.bytecode.extend_from_slice(&0i32.to_le_bytes());
291
292 let match_target = self.bytecode.len();
293 let push_offset = (match_target as i32 - push_pos as i32 - 5) as i32;
294 self.bytecode[push_pos + 1..push_pos + 5].copy_from_slice(&push_offset.to_le_bytes());
295
296 self.compile_node(inner)?;
297
298 let done_target = self.bytecode.len();
299 let jmp_offset = (done_target as i32 - jmp_pos as i32 - 5) as i32;
300 self.bytecode[jmp_pos + 1..jmp_pos + 5].copy_from_slice(&jmp_offset.to_le_bytes());
301 }
302
303 Ok(())
304 }
305
306 fn compile_star(&mut self, inner: &Ast, _greedy: bool) -> Result<(), String> {
307 let start_pos = self.bytecode.len();
308
309 let push_pos = self.bytecode.len();
310 self.bytecode.push(OpCode::PushBacktrack as u8);
311 self.bytecode.extend_from_slice(&0i32.to_le_bytes());
312
313 self.compile_node(inner)?;
314
315 self.bytecode.push(OpCode::Jmp as u8);
316 let loop_offset = (start_pos as i32 - self.bytecode.len() as i32 - 5) as i32;
317 self.bytecode.extend_from_slice(&loop_offset.to_le_bytes());
318
319 let done_pos = self.bytecode.len();
320 let push_offset = (done_pos as i32 - push_pos as i32 - 5) as i32;
321 self.bytecode[push_pos + 1..push_pos + 5].copy_from_slice(&push_offset.to_le_bytes());
322
323 Ok(())
324 }
325
326 fn compile_plus(&mut self, inner: &Ast, greedy: bool) -> Result<(), String> {
327 self.compile_node(inner)?;
328
329 self.compile_star(inner, greedy)
330 }
331
332 fn compile_repeat(
333 &mut self,
334 inner: &Ast,
335 min: usize,
336 max: usize,
337 _greedy: bool,
338 ) -> Result<(), String> {
339 let counter_reg = REG_COUNTER;
340
341 self.emit_mov_imm(counter_reg, 0);
342
343 let min_start = self.bytecode.len();
344
345 self.bytecode.push(OpCode::CmpImm as u8);
346 self.bytecode.push(counter_reg as u8);
347 self.bytecode.extend_from_slice(&(min as u32).to_le_bytes());
348
349 let cmp_pos = self.bytecode.len();
350 self.bytecode.push(OpCode::JmpNe as u8);
351 self.bytecode.push(counter_reg as u8);
352 self.bytecode.extend_from_slice(&(min as u32).to_le_bytes());
353 self.bytecode.extend_from_slice(&0i32.to_le_bytes());
354
355 self.compile_node(inner)?;
356
357 self.bytecode.push(OpCode::Inc as u8);
358 self.bytecode.push(counter_reg as u8);
359
360 self.bytecode.push(OpCode::Jmp as u8);
361 let loop_offset = (min_start as i32 - self.bytecode.len() as i32 - 5) as i32;
362 self.bytecode.extend_from_slice(&loop_offset.to_le_bytes());
363
364 let opt_start = self.bytecode.len();
365 let jmp_offset = (opt_start as i32 - cmp_pos as i32 - 10) as i32;
366 self.bytecode[cmp_pos + 6..cmp_pos + 10].copy_from_slice(&jmp_offset.to_le_bytes());
367
368 if max > min && max < usize::MAX {
369 self.emit_mov_imm(counter_reg, 0);
370
371 let opt_loop_start = self.bytecode.len();
372
373 self.bytecode.push(OpCode::CmpImm as u8);
374 self.bytecode.push(counter_reg as u8);
375 self.bytecode
376 .extend_from_slice(&((max - min) as u32).to_le_bytes());
377
378 let cmp2_pos = self.bytecode.len();
379 self.bytecode.push(OpCode::JmpNe as u8);
380 self.bytecode.push(counter_reg as u8);
381 self.bytecode
382 .extend_from_slice(&((max - min) as u32).to_le_bytes());
383 self.bytecode.extend_from_slice(&0i32.to_le_bytes());
384
385 let push_pos = self.bytecode.len();
386 self.bytecode.push(OpCode::PushBacktrack as u8);
387 self.bytecode.extend_from_slice(&0i32.to_le_bytes());
388
389 self.compile_node(inner)?;
390
391 self.bytecode.push(OpCode::Inc as u8);
392 self.bytecode.push(counter_reg as u8);
393
394 self.bytecode.push(OpCode::Jmp as u8);
395 let loop2_offset = (opt_loop_start as i32 - self.bytecode.len() as i32 - 5) as i32;
396 self.bytecode.extend_from_slice(&loop2_offset.to_le_bytes());
397
398 let end_pos = self.bytecode.len();
399 let push2_offset = (end_pos as i32 - push_pos as i32 - 5) as i32;
400 self.bytecode[push_pos + 1..push_pos + 5].copy_from_slice(&push2_offset.to_le_bytes());
401
402 let jmp2_offset = (end_pos as i32 - cmp2_pos as i32 - 10) as i32;
403 self.bytecode[cmp2_pos + 6..cmp2_pos + 10].copy_from_slice(&jmp2_offset.to_le_bytes());
404 }
405
406 Ok(())
407 }
408
409 fn compile_capture(&mut self, inner: &Ast) -> Result<(), String> {
410 let capture_idx = self.capture_count;
411 self.capture_count += 1;
412
413 self.emit_op_u8(OpCode::SaveStart, capture_idx as u8);
414 self.compile_node(inner)?;
415 self.emit_op_u8(OpCode::SaveEnd, capture_idx as u8);
416
417 Ok(())
418 }
419
420 fn compile_backref(&mut self, idx: usize) -> Result<(), String> {
421 if idx >= MAX_CAPTURES {
422 return Err(format!("Backreference index too large: {}", idx));
423 }
424
425 let opcode = if self.ignore_case {
426 OpCode::CheckBackrefI
427 } else {
428 OpCode::CheckBackref
429 };
430
431 self.emit_op_u8(opcode, idx as u8);
432 Ok(())
433 }
434
435 fn compile_lookahead(&mut self, inner: &Ast, negative: bool) -> Result<(), String> {
436 self.bytecode.push(OpCode::Mark as u8);
437 self.bytecode.push(REG_MARK as u8);
438
439 let push_pos = self.bytecode.len();
440 self.bytecode.push(OpCode::PushBacktrack as u8);
441 self.bytecode.extend_from_slice(&0i32.to_le_bytes());
442
443 self.compile_node(inner)?;
444
445 self.bytecode.push(OpCode::PopBacktrack as u8);
446
447 if negative {
448 self.bytecode.push(OpCode::Restore as u8);
449 self.bytecode.push(REG_MARK as u8);
450 self.bytecode.push(OpCode::Fail as u8);
451 }
452
453 self.bytecode.push(OpCode::Restore as u8);
454 self.bytecode.push(REG_MARK as u8);
455
456 let end_pos = self.bytecode.len();
457 let offset = (end_pos as i32 - push_pos as i32 - 5) as i32;
458 self.bytecode[push_pos + 1..push_pos + 5].copy_from_slice(&offset.to_le_bytes());
459
460 if !negative {
461 self.bytecode.push(OpCode::Fail as u8);
462 }
463
464 Ok(())
465 }
466
467 fn emit_op(&mut self, op: OpCode) {
468 self.bytecode.push(op as u8);
469 }
470
471 fn emit_op_u8(&mut self, op: OpCode, val: u8) {
472 self.bytecode.push(op as u8);
473 self.bytecode.push(val);
474 }
475
476 fn emit_match_char(&mut self, reg: usize, ch: u16) {
477 self.bytecode.push(OpCode::MatchChar as u8);
478 self.bytecode.push(reg as u8);
479 self.bytecode.extend_from_slice(&ch.to_le_bytes());
480 }
481
482 fn emit_match_char_i(&mut self, reg: usize, ch: u16) {
483 self.bytecode.push(OpCode::MatchCharI as u8);
484 self.bytecode.push(reg as u8);
485 self.bytecode.extend_from_slice(&ch.to_le_bytes());
486 }
487
488 fn emit_match_char32(&mut self, reg: usize, ch: u32) {
489 self.bytecode.push(OpCode::MatchChar32 as u8);
490 self.bytecode.push(reg as u8);
491 self.bytecode.extend_from_slice(&ch.to_le_bytes());
492 }
493
494 fn emit_match_char32_i(&mut self, reg: usize, ch: u32) {
495 self.bytecode.push(OpCode::MatchChar32I as u8);
496 self.bytecode.push(reg as u8);
497 self.bytecode.extend_from_slice(&ch.to_le_bytes());
498 }
499
500 fn emit_mov_imm(&mut self, reg: usize, imm: usize) {
501 self.bytecode.push(OpCode::MovImm as u8);
502 self.bytecode.push(reg as u8);
503 self.bytecode.extend_from_slice(&(imm as u32).to_le_bytes());
504 }
505}
506
507fn unicode_fold_simple(c: u32) -> u32 {
508 if c < 128 {
509 if c >= b'A' as u32 && c <= b'Z' as u32 {
510 c + 32
511 } else {
512 c
513 }
514 } else {
515 c
516 }
517}
518
519#[cfg(test)]
520mod tests {
521 use super::super::parser::parse;
522 use super::*;
523
524 #[test]
525 fn test_compile_simple() {
526 let ast = parse("abc", 0).unwrap();
527 let prog = compile(&ast, 0).unwrap();
528 assert!(prog.bytecode.len() > HEADER_LEN);
529 }
530
531 #[test]
532 fn test_compile_capture() {
533 let ast = parse("(a)", 0).unwrap();
534 let prog = compile(&ast, 0).unwrap();
535 assert_eq!(prog.capture_count, 2);
536 }
537
538 #[test]
539 fn test_compile_alt() {
540 let ast = parse("a|b", 0).unwrap();
541 let prog = compile(&ast, 0).unwrap();
542 assert!(prog.bytecode.len() > HEADER_LEN);
543 }
544
545 #[test]
546 fn test_compile_quant() {
547 let ast = parse("a*", 0).unwrap();
548 let prog = compile(&ast, 0).unwrap();
549 assert!(prog.bytecode.len() > HEADER_LEN);
550 }
551}