1use std::collections::HashMap;
6
7use regex::bytes::Regex;
8use thiserror::Error;
9
10use crate::opcode::Opcode;
11use crate::parser::{Expr, ParseError, ParserConfig};
12use crate::vm::CompiledFilter;
13
14#[derive(Debug, Clone, Error)]
16#[allow(missing_docs)]
17pub enum CompileError {
18 #[error("Parse error: {0}")]
19 Parse(#[from] ParseError),
20
21 #[error("Unknown field '{0}'. Available fields: {1}")]
22 UnknownField(String, String),
23
24 #[error("Invalid regex pattern '{pattern}': {error}")]
25 InvalidRegex { pattern: String, error: String },
26
27 #[error("Too many strings (max 65535)")]
28 TooManyStrings,
29
30 #[error("Too many regexes (max 65535)")]
31 TooManyRegexes,
32
33 #[error("Too many string sets (max 65535)")]
34 TooManySets,
35}
36
37struct Compiler<'a> {
39 config: &'a ParserConfig,
40 bytecode: Vec<u8>,
41 strings: Vec<Vec<u8>>,
42 string_map: HashMap<Vec<u8>, u16>,
43 regexes: Vec<Regex>,
44 regex_map: HashMap<String, u16>,
45 string_sets: Vec<Vec<u16>>,
46}
47
48impl<'a> Compiler<'a> {
49 fn new(config: &'a ParserConfig) -> Self {
50 Self {
51 config,
52 bytecode: Vec::new(),
53 strings: Vec::new(),
54 string_map: HashMap::new(),
55 regexes: Vec::new(),
56 regex_map: HashMap::new(),
57 string_sets: Vec::new(),
58 }
59 }
60
61 fn intern_string(&mut self, s: &str) -> Result<u16, CompileError> {
63 let bytes = s.as_bytes().to_vec();
64 if let Some(&idx) = self.string_map.get(&bytes) {
65 return Ok(idx);
66 }
67
68 let idx = self.strings.len();
69 if idx > u16::MAX as usize {
70 return Err(CompileError::TooManyStrings);
71 }
72
73 self.string_map.insert(bytes.clone(), idx as u16);
74 self.strings.push(bytes);
75 Ok(idx as u16)
76 }
77
78 fn intern_regex(&mut self, pattern: &str) -> Result<u16, CompileError> {
80 if let Some(&idx) = self.regex_map.get(pattern) {
81 return Ok(idx);
82 }
83
84 let regex = Regex::new(pattern).map_err(|e| CompileError::InvalidRegex {
85 pattern: pattern.to_string(),
86 error: e.to_string(),
87 })?;
88
89 let idx = self.regexes.len();
90 if idx > u16::MAX as usize {
91 return Err(CompileError::TooManyRegexes);
92 }
93
94 self.regex_map.insert(pattern.to_string(), idx as u16);
95 self.regexes.push(regex);
96 Ok(idx as u16)
97 }
98
99 fn add_string_set(&mut self, values: &[String]) -> Result<u16, CompileError> {
101 let indices: Vec<u16> = values
102 .iter()
103 .map(|v| self.intern_string(v))
104 .collect::<Result<_, _>>()?;
105
106 let idx = self.string_sets.len();
107 if idx > u16::MAX as usize {
108 return Err(CompileError::TooManySets);
109 }
110
111 self.string_sets.push(indices);
112 Ok(idx as u16)
113 }
114
115 fn lookup_field(&self, name: &str) -> Result<u8, CompileError> {
117 let upper = name.to_uppercase();
119 self.config
120 .fields
121 .get(&upper)
122 .or_else(|| self.config.fields.get(name))
123 .copied()
124 .ok_or_else(|| {
125 let available: Vec<_> = self.config.fields.keys().cloned().collect();
126 CompileError::UnknownField(name.to_string(), available.join(", "))
127 })
128 }
129
130 fn emit(&mut self, byte: u8) {
132 self.bytecode.push(byte);
133 }
134
135 fn emit_u16(&mut self, value: u16) {
137 self.bytecode.extend_from_slice(&value.to_le_bytes());
138 }
139
140 fn emit_i16(&mut self, value: i16) {
142 self.bytecode.extend_from_slice(&value.to_le_bytes());
143 }
144
145 fn offset(&self) -> usize {
147 self.bytecode.len()
148 }
149
150 fn patch_i16(&mut self, pos: usize, value: i16) {
152 let bytes = value.to_le_bytes();
153 self.bytecode[pos] = bytes[0];
154 self.bytecode[pos + 1] = bytes[1];
155 }
156
157 fn compile_expr(&mut self, expr: &Expr) -> Result<(), CompileError> {
159 match expr {
160 Expr::Bool(true) => {
161 self.emit(Opcode::PushTrue as u8);
162 }
163 Expr::Bool(false) => {
164 self.emit(Opcode::PushFalse as u8);
165 }
166 Expr::Rand(n) => {
167 self.emit(Opcode::Rand as u8);
168 self.emit_u16(*n);
169 }
170
171 Expr::Contains(s) => {
173 let idx = self.intern_string(s)?;
174 self.emit(Opcode::Contains as u8);
175 self.emit_u16(idx);
176 }
177 Expr::StartsWith(s) => {
178 let idx = self.intern_string(s)?;
179 self.emit(Opcode::StartsWith as u8);
180 self.emit_u16(idx);
181 }
182 Expr::EndsWith(s) => {
183 let idx = self.intern_string(s)?;
184 self.emit(Opcode::EndsWith as u8);
185 self.emit_u16(idx);
186 }
187 Expr::Equals(s) => {
188 let idx = self.intern_string(s)?;
189 self.emit(Opcode::Equals as u8);
190 self.emit_u16(idx);
191 }
192 Expr::Matches(pattern) => {
193 let idx = self.intern_regex(pattern)?;
194 self.emit(Opcode::Matches as u8);
195 self.emit_u16(idx);
196 }
197
198 Expr::PartContains { part, value } => {
200 let part_idx = self.lookup_field(part)?;
201 let str_idx = self.intern_string(value)?;
202 self.emit(Opcode::PartContains as u8);
203 self.emit(part_idx);
204 self.emit_u16(str_idx);
205 }
206 Expr::PartIContains { part, value } => {
207 let part_idx = self.lookup_field(part)?;
208 let str_idx = self.intern_string(value)?;
209 self.emit(Opcode::PartIContains as u8);
210 self.emit(part_idx);
211 self.emit_u16(str_idx);
212 }
213 Expr::PartStartsWith { part, value } => {
214 let part_idx = self.lookup_field(part)?;
215 let str_idx = self.intern_string(value)?;
216 self.emit(Opcode::PartStartsWith as u8);
217 self.emit(part_idx);
218 self.emit_u16(str_idx);
219 }
220 Expr::PartEndsWith { part, value } => {
221 let part_idx = self.lookup_field(part)?;
222 let str_idx = self.intern_string(value)?;
223 self.emit(Opcode::PartEndsWith as u8);
224 self.emit(part_idx);
225 self.emit_u16(str_idx);
226 }
227 Expr::PartEquals { part, value } => {
228 let part_idx = self.lookup_field(part)?;
229 let str_idx = self.intern_string(value)?;
230 self.emit(Opcode::PartEquals as u8);
231 self.emit(part_idx);
232 self.emit_u16(str_idx);
233 }
234 Expr::PartIEquals { part, value } => {
235 let part_idx = self.lookup_field(part)?;
236 let str_idx = self.intern_string(value)?;
237 self.emit(Opcode::PartIEquals as u8);
238 self.emit(part_idx);
239 self.emit_u16(str_idx);
240 }
241 Expr::PartNotEquals { part, value } => {
242 let part_idx = self.lookup_field(part)?;
244 let str_idx = self.intern_string(value)?;
245 self.emit(Opcode::PartEquals as u8);
246 self.emit(part_idx);
247 self.emit_u16(str_idx);
248 self.emit(Opcode::Not as u8);
249 }
250 Expr::PartMatches { part, pattern } => {
251 let part_idx = self.lookup_field(part)?;
252 let regex_idx = self.intern_regex(pattern)?;
253 self.emit(Opcode::PartMatches as u8);
254 self.emit(part_idx);
255 self.emit_u16(regex_idx);
256 }
257 Expr::PartInSet { part, values } => {
258 let part_idx = self.lookup_field(part)?;
259 let set_idx = self.add_string_set(values)?;
260 self.emit(Opcode::PartInSet as u8);
261 self.emit(part_idx);
262 self.emit_u16(set_idx);
263 }
264 Expr::PartIsEmpty { part } => {
265 let part_idx = self.lookup_field(part)?;
266 self.emit(Opcode::PartIsEmpty as u8);
267 self.emit(part_idx);
268 }
269 Expr::PartNotEmpty { part } => {
270 let part_idx = self.lookup_field(part)?;
271 self.emit(Opcode::PartNotEmpty as u8);
272 self.emit(part_idx);
273 }
274
275 Expr::HeaderEquals {
277 part,
278 header,
279 value,
280 } => {
281 let part_idx = self.lookup_field(part)?;
282 let hdr_idx = self.intern_string(header)?;
283 let val_idx = self.intern_string(value)?;
284 self.emit(Opcode::HeaderEquals as u8);
285 self.emit(part_idx);
286 self.emit_u16(hdr_idx);
287 self.emit_u16(val_idx);
288 }
289 Expr::HeaderIEquals {
290 part,
291 header,
292 value,
293 } => {
294 let part_idx = self.lookup_field(part)?;
295 let hdr_idx = self.intern_string(header)?;
296 let val_idx = self.intern_string(value)?;
297 self.emit(Opcode::HeaderIEquals as u8);
298 self.emit(part_idx);
299 self.emit_u16(hdr_idx);
300 self.emit_u16(val_idx);
301 }
302 Expr::HeaderContains {
303 part,
304 header,
305 value,
306 } => {
307 let part_idx = self.lookup_field(part)?;
308 let hdr_idx = self.intern_string(header)?;
309 let val_idx = self.intern_string(value)?;
310 self.emit(Opcode::HeaderContains as u8);
311 self.emit(part_idx);
312 self.emit_u16(hdr_idx);
313 self.emit_u16(val_idx);
314 }
315 Expr::HeaderExists { part, header } => {
316 let part_idx = self.lookup_field(part)?;
317 let hdr_idx = self.intern_string(header)?;
318 self.emit(Opcode::HeaderExists as u8);
319 self.emit(part_idx);
320 self.emit_u16(hdr_idx);
321 }
322
323 Expr::And(left, right) => {
325 self.compile_expr(left)?;
327 let opcode_pos = self.offset();
329 self.emit(Opcode::JumpIfFalse as u8);
330 let patch_pos = self.offset();
331 self.emit_i16(0); self.compile_expr(right)?;
334 let jump_target = self.offset();
336 let offset = (jump_target as isize - opcode_pos as isize) as i16;
337 self.patch_i16(patch_pos, offset);
338 }
339 Expr::Or(left, right) => {
340 self.compile_expr(left)?;
342 let opcode_pos = self.offset();
344 self.emit(Opcode::JumpIfTrue as u8);
345 let patch_pos = self.offset();
346 self.emit_i16(0); self.compile_expr(right)?;
349 let jump_target = self.offset();
351 let offset = (jump_target as isize - opcode_pos as isize) as i16;
352 self.patch_i16(patch_pos, offset);
353 }
354 Expr::Not(inner) => {
355 self.compile_expr(inner)?;
356 self.emit(Opcode::Not as u8);
357 }
358 }
359
360 Ok(())
361 }
362
363 fn finish(mut self, source: String) -> CompiledFilter {
365 self.emit(Opcode::Return as u8);
366
367 CompiledFilter::new(
368 self.bytecode,
369 self.strings,
370 self.regexes,
371 self.string_sets,
372 self.config.delimiter.clone(),
373 source,
374 )
375 }
376}
377
378pub fn compile(source: &str, config: &ParserConfig) -> Result<CompiledFilter, CompileError> {
404 let expr = crate::parser::parse(source, config)?;
405 compile_expr(&expr, config, source.to_string())
406}
407
408pub fn compile_expr(
413 expr: &Expr,
414 config: &ParserConfig,
415 source: String,
416) -> Result<CompiledFilter, CompileError> {
417 let mut compiler = Compiler::new(config);
418 compiler.compile_expr(expr)?;
419 Ok(compiler.finish(source))
420}
421
422#[cfg(test)]
423mod tests {
424 use bytes::Bytes;
425
426 use super::*;
427
428 fn test_config() -> ParserConfig {
429 let mut config = ParserConfig::default();
430 config.add_field("LEVEL", 0);
431 config.add_field("CODE", 1);
432 config.add_field("METHOD", 2);
433 config.add_field("PATH", 3);
434 config.add_field("HEADERS", 4);
435 config.add_field("BODY", 5);
436 config
437 }
438
439 fn compile_and_test(input: &str, payload: &str, expected: bool) {
440 let config = test_config();
441 let filter = compile(input, &config).expect("Failed to compile");
442 let result = filter.evaluate(Bytes::from(payload.to_string()));
443 assert_eq!(
444 result, expected,
445 "Filter '{}' on payload '{}' expected {} but got {}",
446 input, payload, expected, result
447 );
448 }
449
450 #[test]
451 fn test_compile_true() {
452 compile_and_test("true", "", true);
453 }
454
455 #[test]
456 fn test_compile_false() {
457 compile_and_test("false", "", false);
458 }
459
460 #[test]
461 fn test_compile_payload_contains() {
462 compile_and_test(r#"payload contains "error""#, "an error occurred", true);
463 compile_and_test(r#"payload contains "error""#, "all good", false);
464 }
465
466 #[test]
467 fn test_compile_field_equals() {
468 compile_and_test(r#"LEVEL == "error""#, "error;;;500;;;GET", true);
469 compile_and_test(r#"LEVEL == "error""#, "info;;;500;;;GET", false);
470 }
471
472 #[test]
473 fn test_compile_field_in_set() {
474 compile_and_test(r#"LEVEL in {"error", "warn", "fatal"}"#, "error;;;500;;;GET", true);
475 compile_and_test(r#"LEVEL in {"error", "warn", "fatal"}"#, "warn;;;500;;;GET", true);
476 compile_and_test(r#"LEVEL in {"error", "warn", "fatal"}"#, "info;;;500;;;GET", false);
477 }
478
479 #[test]
480 fn test_compile_and() {
481 compile_and_test(
482 r#"LEVEL == "error" AND CODE == "500""#,
483 "error;;;500;;;GET",
484 true,
485 );
486 compile_and_test(
487 r#"LEVEL == "error" AND CODE == "500""#,
488 "error;;;200;;;GET",
489 false,
490 );
491 }
492
493 #[test]
494 fn test_compile_or() {
495 compile_and_test(
496 r#"LEVEL == "error" OR LEVEL == "warn""#,
497 "error;;;500;;;GET",
498 true,
499 );
500 compile_and_test(
501 r#"LEVEL == "error" OR LEVEL == "warn""#,
502 "warn;;;500;;;GET",
503 true,
504 );
505 compile_and_test(
506 r#"LEVEL == "error" OR LEVEL == "warn""#,
507 "info;;;500;;;GET",
508 false,
509 );
510 }
511
512 #[test]
513 fn test_compile_not() {
514 compile_and_test(r#"NOT LEVEL == "debug""#, "error;;;500;;;GET", true);
515 compile_and_test(r#"NOT LEVEL == "debug""#, "debug;;;500;;;GET", false);
516 }
517
518 #[test]
519 fn test_compile_header_iequals() {
520 let mut parts = vec![""; 6];
521 parts[0] = "error";
522 parts[4] = "X-Custom: value\r\n";
523 let payload = parts.join(";;;");
524
525 let config = test_config();
526 let filter = compile(
527 r#"HEADERS.header("x-custom") iequals "value""#,
528 &config,
529 )
530 .unwrap();
531
532 assert!(filter.evaluate(Bytes::from(payload)));
533 }
534
535 #[test]
536 fn test_compile_complex_filter() {
537 let filter_str = r#"
538 LEVEL == "error"
539 AND CODE == "500"
540 AND HEADERS.header("Content-Type") iequals "application/json"
541 "#;
542
543 let config = test_config();
544 let filter = compile(filter_str, &config).unwrap();
545
546 let mut parts = vec![""; 6];
547 parts[0] = "error";
548 parts[1] = "500";
549 parts[4] = "Content-Type: application/json\r\n";
550 let payload = parts.join(";;;");
551 assert!(filter.evaluate(Bytes::from(payload)));
552
553 parts[0] = "info";
554 let payload = parts.join(";;;");
555 assert!(!filter.evaluate(Bytes::from(payload)));
556 }
557
558 #[test]
559 fn test_compile_rand() {
560 crate::vm::reset_rand_counter();
561
562 let config = test_config();
563 let filter = compile("rand(2)", &config).unwrap();
564
565 assert!(filter.evaluate(Bytes::new()));
566 assert!(!filter.evaluate(Bytes::new()));
567 assert!(filter.evaluate(Bytes::new()));
568 assert!(!filter.evaluate(Bytes::new()));
569 }
570
571 #[test]
572 fn test_compile_regex() {
573 compile_and_test(r#"payload matches "error_[0-9]+""#, "found error_123", true);
574 compile_and_test(r#"payload matches "error_[0-9]+""#, "no errors", false);
575 }
576
577 #[test]
578 fn test_compile_unknown_field() {
579 let config = test_config();
580 let result = compile(r#"UNKNOWN_FIELD == "x""#, &config);
581 assert!(matches!(result, Err(CompileError::UnknownField(_, _))));
582 }
583
584 #[test]
585 fn test_compile_invalid_regex() {
586 let config = test_config();
587 let result = compile(r#"payload matches "[invalid""#, &config);
588 assert!(matches!(result, Err(CompileError::InvalidRegex { .. })));
589 }
590
591 #[test]
592 fn test_bytecode_size() {
593 let config = test_config();
594
595 let filter = compile(r#"LEVEL == "error""#, &config).unwrap();
596 assert_eq!(filter.bytecode_len(), 5); let filter = compile(
599 r#"LEVEL == "error" AND CODE == "500""#,
600 &config,
601 )
602 .unwrap();
603 assert_eq!(filter.bytecode_len(), 12); }
605}