Skip to main content

bytecode_filter/
compiler.rs

1//! Compiler for filter expressions.
2//!
3//! Compiles an AST into bytecode that can be executed by the VM.
4
5use std::collections::HashMap;
6
7use regex::bytes::Regex;
8use thiserror::Error;
9
10use crate::opcode::Opcode;
11use crate::parser::{Expr, ParseError, ParserConfig};
12use crate::vm::CompiledFilter;
13
14/// Compilation error types.
15#[derive(Debug, Clone, Error)]
16#[allow(missing_docs)]
17pub enum CompileError {
18    #[error("Parse error: {0}")]
19    Parse(#[from] ParseError),
20
21    #[error("Unknown field '{0}'. Available fields: {1}")]
22    UnknownField(String, String),
23
24    #[error("Invalid regex pattern '{pattern}': {error}")]
25    InvalidRegex { pattern: String, error: String },
26
27    #[error("Too many strings (max 65535)")]
28    TooManyStrings,
29
30    #[error("Too many regexes (max 65535)")]
31    TooManyRegexes,
32
33    #[error("Too many string sets (max 65535)")]
34    TooManySets,
35}
36
37/// Compiler state during bytecode generation.
38struct Compiler<'a> {
39    config: &'a ParserConfig,
40    bytecode: Vec<u8>,
41    strings: Vec<Vec<u8>>,
42    string_map: HashMap<Vec<u8>, u16>,
43    regexes: Vec<Regex>,
44    regex_map: HashMap<String, u16>,
45    string_sets: Vec<Vec<u16>>,
46}
47
48impl<'a> Compiler<'a> {
49    fn new(config: &'a ParserConfig) -> Self {
50        Self {
51            config,
52            bytecode: Vec::new(),
53            strings: Vec::new(),
54            string_map: HashMap::new(),
55            regexes: Vec::new(),
56            regex_map: HashMap::new(),
57            string_sets: Vec::new(),
58        }
59    }
60
61    /// Intern a string and return its index.
62    fn intern_string(&mut self, s: &str) -> Result<u16, CompileError> {
63        let bytes = s.as_bytes().to_vec();
64        if let Some(&idx) = self.string_map.get(&bytes) {
65            return Ok(idx);
66        }
67
68        let idx = self.strings.len();
69        if idx > u16::MAX as usize {
70            return Err(CompileError::TooManyStrings);
71        }
72
73        self.string_map.insert(bytes.clone(), idx as u16);
74        self.strings.push(bytes);
75        Ok(idx as u16)
76    }
77
78    /// Intern a regex and return its index.
79    fn intern_regex(&mut self, pattern: &str) -> Result<u16, CompileError> {
80        if let Some(&idx) = self.regex_map.get(pattern) {
81            return Ok(idx);
82        }
83
84        let regex = Regex::new(pattern).map_err(|e| CompileError::InvalidRegex {
85            pattern: pattern.to_string(),
86            error: e.to_string(),
87        })?;
88
89        let idx = self.regexes.len();
90        if idx > u16::MAX as usize {
91            return Err(CompileError::TooManyRegexes);
92        }
93
94        self.regex_map.insert(pattern.to_string(), idx as u16);
95        self.regexes.push(regex);
96        Ok(idx as u16)
97    }
98
99    /// Add a string set and return its index.
100    fn add_string_set(&mut self, values: &[String]) -> Result<u16, CompileError> {
101        let indices: Vec<u16> = values
102            .iter()
103            .map(|v| self.intern_string(v))
104            .collect::<Result<_, _>>()?;
105
106        let idx = self.string_sets.len();
107        if idx > u16::MAX as usize {
108            return Err(CompileError::TooManySets);
109        }
110
111        self.string_sets.push(indices);
112        Ok(idx as u16)
113    }
114
115    /// Look up a field name and return its part index.
116    fn lookup_field(&self, name: &str) -> Result<u8, CompileError> {
117        // Try case-insensitive lookup
118        let upper = name.to_uppercase();
119        self.config
120            .fields
121            .get(&upper)
122            .or_else(|| self.config.fields.get(name))
123            .copied()
124            .ok_or_else(|| {
125                let available: Vec<_> = self.config.fields.keys().cloned().collect();
126                CompileError::UnknownField(name.to_string(), available.join(", "))
127            })
128    }
129
130    /// Emit a single byte.
131    fn emit(&mut self, byte: u8) {
132        self.bytecode.push(byte);
133    }
134
135    /// Emit a u16 in little-endian format.
136    fn emit_u16(&mut self, value: u16) {
137        self.bytecode.extend_from_slice(&value.to_le_bytes());
138    }
139
140    /// Emit an i16 in little-endian format.
141    fn emit_i16(&mut self, value: i16) {
142        self.bytecode.extend_from_slice(&value.to_le_bytes());
143    }
144
145    /// Current bytecode offset (for backpatching).
146    fn offset(&self) -> usize {
147        self.bytecode.len()
148    }
149
150    /// Backpatch an i16 at the given bytecode position.
151    fn patch_i16(&mut self, pos: usize, value: i16) {
152        let bytes = value.to_le_bytes();
153        self.bytecode[pos] = bytes[0];
154        self.bytecode[pos + 1] = bytes[1];
155    }
156
157    /// Compile an expression.
158    fn compile_expr(&mut self, expr: &Expr) -> Result<(), CompileError> {
159        match expr {
160            Expr::Bool(true) => {
161                self.emit(Opcode::PushTrue as u8);
162            }
163            Expr::Bool(false) => {
164                self.emit(Opcode::PushFalse as u8);
165            }
166            Expr::Rand(n) => {
167                self.emit(Opcode::Rand as u8);
168                self.emit_u16(*n);
169            }
170
171            // Payload-wide operations
172            Expr::Contains(s) => {
173                let idx = self.intern_string(s)?;
174                self.emit(Opcode::Contains as u8);
175                self.emit_u16(idx);
176            }
177            Expr::StartsWith(s) => {
178                let idx = self.intern_string(s)?;
179                self.emit(Opcode::StartsWith as u8);
180                self.emit_u16(idx);
181            }
182            Expr::EndsWith(s) => {
183                let idx = self.intern_string(s)?;
184                self.emit(Opcode::EndsWith as u8);
185                self.emit_u16(idx);
186            }
187            Expr::Equals(s) => {
188                let idx = self.intern_string(s)?;
189                self.emit(Opcode::Equals as u8);
190                self.emit_u16(idx);
191            }
192            Expr::Matches(pattern) => {
193                let idx = self.intern_regex(pattern)?;
194                self.emit(Opcode::Matches as u8);
195                self.emit_u16(idx);
196            }
197
198            // Part-specific operations
199            Expr::PartContains { part, value } => {
200                let part_idx = self.lookup_field(part)?;
201                let str_idx = self.intern_string(value)?;
202                self.emit(Opcode::PartContains as u8);
203                self.emit(part_idx);
204                self.emit_u16(str_idx);
205            }
206            Expr::PartIContains { part, value } => {
207                let part_idx = self.lookup_field(part)?;
208                let str_idx = self.intern_string(value)?;
209                self.emit(Opcode::PartIContains as u8);
210                self.emit(part_idx);
211                self.emit_u16(str_idx);
212            }
213            Expr::PartStartsWith { part, value } => {
214                let part_idx = self.lookup_field(part)?;
215                let str_idx = self.intern_string(value)?;
216                self.emit(Opcode::PartStartsWith as u8);
217                self.emit(part_idx);
218                self.emit_u16(str_idx);
219            }
220            Expr::PartEndsWith { part, value } => {
221                let part_idx = self.lookup_field(part)?;
222                let str_idx = self.intern_string(value)?;
223                self.emit(Opcode::PartEndsWith as u8);
224                self.emit(part_idx);
225                self.emit_u16(str_idx);
226            }
227            Expr::PartEquals { part, value } => {
228                let part_idx = self.lookup_field(part)?;
229                let str_idx = self.intern_string(value)?;
230                self.emit(Opcode::PartEquals as u8);
231                self.emit(part_idx);
232                self.emit_u16(str_idx);
233            }
234            Expr::PartIEquals { part, value } => {
235                let part_idx = self.lookup_field(part)?;
236                let str_idx = self.intern_string(value)?;
237                self.emit(Opcode::PartIEquals as u8);
238                self.emit(part_idx);
239                self.emit_u16(str_idx);
240            }
241            Expr::PartNotEquals { part, value } => {
242                // Compile as NOT (PartEquals)
243                let part_idx = self.lookup_field(part)?;
244                let str_idx = self.intern_string(value)?;
245                self.emit(Opcode::PartEquals as u8);
246                self.emit(part_idx);
247                self.emit_u16(str_idx);
248                self.emit(Opcode::Not as u8);
249            }
250            Expr::PartMatches { part, pattern } => {
251                let part_idx = self.lookup_field(part)?;
252                let regex_idx = self.intern_regex(pattern)?;
253                self.emit(Opcode::PartMatches as u8);
254                self.emit(part_idx);
255                self.emit_u16(regex_idx);
256            }
257            Expr::PartInSet { part, values } => {
258                let part_idx = self.lookup_field(part)?;
259                let set_idx = self.add_string_set(values)?;
260                self.emit(Opcode::PartInSet as u8);
261                self.emit(part_idx);
262                self.emit_u16(set_idx);
263            }
264            Expr::PartIsEmpty { part } => {
265                let part_idx = self.lookup_field(part)?;
266                self.emit(Opcode::PartIsEmpty as u8);
267                self.emit(part_idx);
268            }
269            Expr::PartNotEmpty { part } => {
270                let part_idx = self.lookup_field(part)?;
271                self.emit(Opcode::PartNotEmpty as u8);
272                self.emit(part_idx);
273            }
274
275            // Header operations
276            Expr::HeaderEquals {
277                part,
278                header,
279                value,
280            } => {
281                let part_idx = self.lookup_field(part)?;
282                let hdr_idx = self.intern_string(header)?;
283                let val_idx = self.intern_string(value)?;
284                self.emit(Opcode::HeaderEquals as u8);
285                self.emit(part_idx);
286                self.emit_u16(hdr_idx);
287                self.emit_u16(val_idx);
288            }
289            Expr::HeaderIEquals {
290                part,
291                header,
292                value,
293            } => {
294                let part_idx = self.lookup_field(part)?;
295                let hdr_idx = self.intern_string(header)?;
296                let val_idx = self.intern_string(value)?;
297                self.emit(Opcode::HeaderIEquals as u8);
298                self.emit(part_idx);
299                self.emit_u16(hdr_idx);
300                self.emit_u16(val_idx);
301            }
302            Expr::HeaderContains {
303                part,
304                header,
305                value,
306            } => {
307                let part_idx = self.lookup_field(part)?;
308                let hdr_idx = self.intern_string(header)?;
309                let val_idx = self.intern_string(value)?;
310                self.emit(Opcode::HeaderContains as u8);
311                self.emit(part_idx);
312                self.emit_u16(hdr_idx);
313                self.emit_u16(val_idx);
314            }
315            Expr::HeaderExists { part, header } => {
316                let part_idx = self.lookup_field(part)?;
317                let hdr_idx = self.intern_string(header)?;
318                self.emit(Opcode::HeaderExists as u8);
319                self.emit(part_idx);
320                self.emit_u16(hdr_idx);
321            }
322
323            // Boolean operations — short-circuit with jumps
324            Expr::And(left, right) => {
325                // Emit left operand
326                self.compile_expr(left)?;
327                // JumpIfFalse: if left is false, skip right (leave false on stack)
328                let opcode_pos = self.offset();
329                self.emit(Opcode::JumpIfFalse as u8);
330                let patch_pos = self.offset();
331                self.emit_i16(0); // placeholder
332                // Emit right operand (its result becomes the AND result)
333                self.compile_expr(right)?;
334                // Backpatch: offset is relative to opcode position (VM does pc += offset)
335                let jump_target = self.offset();
336                let offset = (jump_target as isize - opcode_pos as isize) as i16;
337                self.patch_i16(patch_pos, offset);
338            }
339            Expr::Or(left, right) => {
340                // Emit left operand
341                self.compile_expr(left)?;
342                // JumpIfTrue: if left is true, skip right (leave true on stack)
343                let opcode_pos = self.offset();
344                self.emit(Opcode::JumpIfTrue as u8);
345                let patch_pos = self.offset();
346                self.emit_i16(0); // placeholder
347                // Emit right operand (its result becomes the OR result)
348                self.compile_expr(right)?;
349                // Backpatch: offset is relative to opcode position (VM does pc += offset)
350                let jump_target = self.offset();
351                let offset = (jump_target as isize - opcode_pos as isize) as i16;
352                self.patch_i16(patch_pos, offset);
353            }
354            Expr::Not(inner) => {
355                self.compile_expr(inner)?;
356                self.emit(Opcode::Not as u8);
357            }
358        }
359
360        Ok(())
361    }
362
363    /// Finish compilation and return the compiled filter.
364    fn finish(mut self, source: String) -> CompiledFilter {
365        self.emit(Opcode::Return as u8);
366
367        CompiledFilter::new(
368            self.bytecode,
369            self.strings,
370            self.regexes,
371            self.string_sets,
372            self.config.delimiter.clone(),
373            source,
374        )
375    }
376}
377
378/// Compile a filter expression string into a CompiledFilter.
379///
380/// # Arguments
381/// * `source` - The filter expression string
382/// * `config` - Parser configuration with field mappings
383///
384/// # Returns
385/// A `CompiledFilter` ready for evaluation.
386///
387/// # Example
388/// ```
389/// use bytecode_filter::{compile, ParserConfig};
390/// use bytes::Bytes;
391///
392/// let mut config = ParserConfig::default();
393/// config.add_field("STATUS", 0);
394/// config.add_field("CODE", 1);
395/// let filter = compile(r#"STATUS == "ok""#, &config).unwrap();
396///
397/// let record = Bytes::from("ok;;;200");
398/// assert!(filter.evaluate(record));
399/// ```
400///
401/// # Errors
402/// Returns `CompileError` if parsing or compilation fails.
403pub fn compile(source: &str, config: &ParserConfig) -> Result<CompiledFilter, CompileError> {
404    let expr = crate::parser::parse(source, config)?;
405    compile_expr(&expr, config, source.to_string())
406}
407
408/// Compile a pre-parsed AST into a CompiledFilter.
409///
410/// # Errors
411/// Returns `CompileError` if the expression contains invalid operations.
412pub fn compile_expr(
413    expr: &Expr,
414    config: &ParserConfig,
415    source: String,
416) -> Result<CompiledFilter, CompileError> {
417    let mut compiler = Compiler::new(config);
418    compiler.compile_expr(expr)?;
419    Ok(compiler.finish(source))
420}
421
422#[cfg(test)]
423mod tests {
424    use bytes::Bytes;
425
426    use super::*;
427
428    fn test_config() -> ParserConfig {
429        let mut config = ParserConfig::default();
430        config.add_field("LEVEL", 0);
431        config.add_field("CODE", 1);
432        config.add_field("METHOD", 2);
433        config.add_field("PATH", 3);
434        config.add_field("HEADERS", 4);
435        config.add_field("BODY", 5);
436        config
437    }
438
439    fn compile_and_test(input: &str, payload: &str, expected: bool) {
440        let config = test_config();
441        let filter = compile(input, &config).expect("Failed to compile");
442        let result = filter.evaluate(Bytes::from(payload.to_string()));
443        assert_eq!(
444            result, expected,
445            "Filter '{}' on payload '{}' expected {} but got {}",
446            input, payload, expected, result
447        );
448    }
449
450    #[test]
451    fn test_compile_true() {
452        compile_and_test("true", "", true);
453    }
454
455    #[test]
456    fn test_compile_false() {
457        compile_and_test("false", "", false);
458    }
459
460    #[test]
461    fn test_compile_payload_contains() {
462        compile_and_test(r#"payload contains "error""#, "an error occurred", true);
463        compile_and_test(r#"payload contains "error""#, "all good", false);
464    }
465
466    #[test]
467    fn test_compile_field_equals() {
468        compile_and_test(r#"LEVEL == "error""#, "error;;;500;;;GET", true);
469        compile_and_test(r#"LEVEL == "error""#, "info;;;500;;;GET", false);
470    }
471
472    #[test]
473    fn test_compile_field_in_set() {
474        compile_and_test(r#"LEVEL in {"error", "warn", "fatal"}"#, "error;;;500;;;GET", true);
475        compile_and_test(r#"LEVEL in {"error", "warn", "fatal"}"#, "warn;;;500;;;GET", true);
476        compile_and_test(r#"LEVEL in {"error", "warn", "fatal"}"#, "info;;;500;;;GET", false);
477    }
478
479    #[test]
480    fn test_compile_and() {
481        compile_and_test(
482            r#"LEVEL == "error" AND CODE == "500""#,
483            "error;;;500;;;GET",
484            true,
485        );
486        compile_and_test(
487            r#"LEVEL == "error" AND CODE == "500""#,
488            "error;;;200;;;GET",
489            false,
490        );
491    }
492
493    #[test]
494    fn test_compile_or() {
495        compile_and_test(
496            r#"LEVEL == "error" OR LEVEL == "warn""#,
497            "error;;;500;;;GET",
498            true,
499        );
500        compile_and_test(
501            r#"LEVEL == "error" OR LEVEL == "warn""#,
502            "warn;;;500;;;GET",
503            true,
504        );
505        compile_and_test(
506            r#"LEVEL == "error" OR LEVEL == "warn""#,
507            "info;;;500;;;GET",
508            false,
509        );
510    }
511
512    #[test]
513    fn test_compile_not() {
514        compile_and_test(r#"NOT LEVEL == "debug""#, "error;;;500;;;GET", true);
515        compile_and_test(r#"NOT LEVEL == "debug""#, "debug;;;500;;;GET", false);
516    }
517
518    #[test]
519    fn test_compile_header_iequals() {
520        let mut parts = vec![""; 6];
521        parts[0] = "error";
522        parts[4] = "X-Custom: value\r\n";
523        let payload = parts.join(";;;");
524
525        let config = test_config();
526        let filter = compile(
527            r#"HEADERS.header("x-custom") iequals "value""#,
528            &config,
529        )
530        .unwrap();
531
532        assert!(filter.evaluate(Bytes::from(payload)));
533    }
534
535    #[test]
536    fn test_compile_complex_filter() {
537        let filter_str = r#"
538            LEVEL == "error"
539            AND CODE == "500"
540            AND HEADERS.header("Content-Type") iequals "application/json"
541        "#;
542
543        let config = test_config();
544        let filter = compile(filter_str, &config).unwrap();
545
546        let mut parts = vec![""; 6];
547        parts[0] = "error";
548        parts[1] = "500";
549        parts[4] = "Content-Type: application/json\r\n";
550        let payload = parts.join(";;;");
551        assert!(filter.evaluate(Bytes::from(payload)));
552
553        parts[0] = "info";
554        let payload = parts.join(";;;");
555        assert!(!filter.evaluate(Bytes::from(payload)));
556    }
557
558    #[test]
559    fn test_compile_rand() {
560        crate::vm::reset_rand_counter();
561
562        let config = test_config();
563        let filter = compile("rand(2)", &config).unwrap();
564
565        assert!(filter.evaluate(Bytes::new()));
566        assert!(!filter.evaluate(Bytes::new()));
567        assert!(filter.evaluate(Bytes::new()));
568        assert!(!filter.evaluate(Bytes::new()));
569    }
570
571    #[test]
572    fn test_compile_regex() {
573        compile_and_test(r#"payload matches "error_[0-9]+""#, "found error_123", true);
574        compile_and_test(r#"payload matches "error_[0-9]+""#, "no errors", false);
575    }
576
577    #[test]
578    fn test_compile_unknown_field() {
579        let config = test_config();
580        let result = compile(r#"UNKNOWN_FIELD == "x""#, &config);
581        assert!(matches!(result, Err(CompileError::UnknownField(_, _))));
582    }
583
584    #[test]
585    fn test_compile_invalid_regex() {
586        let config = test_config();
587        let result = compile(r#"payload matches "[invalid""#, &config);
588        assert!(matches!(result, Err(CompileError::InvalidRegex { .. })));
589    }
590
591    #[test]
592    fn test_bytecode_size() {
593        let config = test_config();
594
595        let filter = compile(r#"LEVEL == "error""#, &config).unwrap();
596        assert_eq!(filter.bytecode_len(), 5); // PartEquals(1 + 1 + 2) + Return(1)
597
598        let filter = compile(
599            r#"LEVEL == "error" AND CODE == "500""#,
600            &config,
601        )
602        .unwrap();
603        assert_eq!(filter.bytecode_len(), 12); // 2x PartEquals(4) + JumpIfFalse(3) + Return(1)
604    }
605}