simd_bmh_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, LitStr};
4
5#[proc_macro]
6pub fn parse_pattern(input: TokenStream) -> TokenStream {
7    let pattern = parse_macro_input!(input as LitStr).value().replace(" ", "");
8
9    if pattern.is_empty() {
10        panic!("Pattern cannot be empty.");
11    }
12
13    let mut bytes = Vec::new();
14    let mut masks = Vec::new();
15    let mut shift_table = vec![pattern.len(); 256]; // Initialize shift table with maximum skip value
16
17    let mut current_pos = 0;
18    let mut best_skip_value = 0u8;
19    let mut best_skip_mask = 0xFFu8;
20    let mut max_skip = 1usize;
21    let mut best_skip_offset = 0usize;
22
23    // Iterate over the pattern to generate byte values, masks, and the skip table
24    let mut chars = pattern.chars().peekable();
25    while let Some(ch) = chars.next() {
26        let next_ch = chars.peek().cloned();
27
28        match (ch, next_ch) {
29            // Case for `??` wildcard (matches any byte)
30            ('?', Some('?')) => {
31                bytes.push(0x00); // Wildcard byte
32                masks.push(0x00); // Full wildcard mask
33                chars.next(); // Consume the second `?`
34            }
35            // Case for `?A` (matches lower nibble, stores upper as wildcard)
36            ('?', Some(c)) if c.is_ascii_hexdigit() => {
37                let byte = u8::from_str_radix(&c.to_string(), 16).expect("Invalid nibble");
38                bytes.push(byte);
39                masks.push(0x0F); // Lower nibble match mask
40                chars.next(); // Consume the hex character
41            }
42            // Case for `A?` (matches upper nibble, stores lower as wildcard)
43            (c, Some('?')) if c.is_ascii_hexdigit() => {
44                let byte = u8::from_str_radix(&c.to_string(), 16).expect("Invalid nibble");
45                bytes.push(byte << 4); // Shift the byte to the upper nibble
46                masks.push(0xF0); // Upper nibble match mask
47                chars.next(); // Consume the `?`
48            }
49            // Case for exact two-byte hex match, e.g., `AA`, `BB`, etc.
50            (c1, Some(c2)) if c1.is_ascii_hexdigit() && c2.is_ascii_hexdigit() => {
51                let byte_str = format!("{}{}", c1, c2);
52                let byte = u8::from_str_radix(&byte_str, 16).expect("Invalid hex byte");
53                bytes.push(byte);
54                masks.push(0xFF); // Exact byte match mask
55                chars.next(); // Consume the second hex digit
56            }
57            _ => {
58                panic!("Invalid pattern token: {}", ch);
59            }
60        }
61
62        // Update the skip table for Boyer-Moore-Horspool
63        if masks.last() != Some(&0x00) {
64            if let Some(last_byte) = bytes.last() {
65                let skip_value = current_pos + 1;
66
67                // Check if the current mask is a full byte match (0xFF)
68                if *masks.last().unwrap() == 0xFF {
69                    if skip_value > max_skip {
70                        max_skip = skip_value;
71                        best_skip_offset = current_pos; // Track the position of the best skip byte
72                        best_skip_value = *last_byte;  // Use the full byte value
73                        best_skip_mask = *masks.last().unwrap(); // Use the full byte mask
74                    }
75                } else if *masks.last().unwrap() == 0xF0 || *masks.last().unwrap() == 0x0F {
76                    // If no full byte match, fallback to nibble match (0xF0 or 0x0F)
77                    if !best_skip_value == 0 && skip_value > max_skip {
78                        max_skip = skip_value;
79                        best_skip_offset = current_pos; // Track the position of the best skip byte
80                        best_skip_value = *last_byte; // Use the nibble value
81                        best_skip_mask = *masks.last().unwrap(); // Use the nibble mask
82                    }
83                }
84            }
85        }
86
87        current_pos += 1;
88    }
89
90    // Build the shift table: this is the BMH skip table
91    for i in 0..bytes.len() - 1 {
92        let byte = bytes[i] & masks[i];
93        shift_table[byte as usize] = bytes.len() - 1 - i;
94    }
95
96    // Return the parsed pattern data as a TokenStream for code generation
97    let expanded = quote! {
98        Pattern {
99            bytes: [#(#bytes),*], // Pattern bytes
100            masks: [#(#masks),*], // Masks for matching
101            best_skip_value: #best_skip_value, // Best skip byte
102            best_skip_mask: #best_skip_mask, // Best skip mask
103            max_skip: #max_skip, // Maximum skip distance
104            best_skip_offset: #best_skip_offset, // Best skip byte position
105            shift_table: [#(#shift_table),*], // Boyer-Moore-Horspool skip table
106        }
107    };
108
109    TokenStream::from(expanded)
110}