lancelot_flirt/
pattern_set.rs

1/// ## purpose
2/// to match multiple byte patterns against a byte slice in parallel.
3/// we should get all valid matches at the end.
4/// does not have to support scanning across the byte slice, only anchored at
5/// the start. need support for single character wild cards (`.`).
6///
7/// implemented via [RegexSet](https://docs.rs/regex/1.3.9/regex/struct.RegexSet.html)
8///
9/// > Match multiple (possibly overlapping) regular expressions in a single
10/// > scan.
11/// > A regex set corresponds to the union of two or more regular expressions.
12/// > That is, a regex set will match text where at least one of its constituent
13/// > regular expressions matches. A regex set as its formulated here provides a
14/// > touch more power:  it will also report which regular expressions in the
15/// > set match. Indeed, this is the key difference between regex sets and a
16/// > single Regex with many alternates, since only one alternate can match at a
17/// > time.
18use anyhow::Result;
19use nom::{
20    branch::alt,
21    bytes::complete::{tag, take_while_m_n},
22    combinator::{map, map_res},
23    multi::many1,
24    IResult,
25};
26
27// u16 because we need 257 possible values, all unsigned.
28#[derive(Copy, Clone, Hash, Eq, PartialEq)]
29pub struct Symbol(pub u16);
30
31// impl note: value 256 is WILDCARD.
32pub const WILDCARD: Symbol = Symbol(0x100);
33
34// byte values map directly into their Symbol indices.
35impl std::convert::From<u8> for Symbol {
36    fn from(v: u8) -> Self {
37        Symbol(v as u16)
38    }
39}
40
41impl std::fmt::Display for Symbol {
42    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
43        if self.0 == WILDCARD.0 {
44            write!(f, "..")
45        } else {
46            write!(f, r"{:02X}", self.0)
47        }
48    }
49}
50
51// a pattern is just a sequence of symbols.
52#[derive(Hash, PartialEq, Eq, Clone)]
53pub struct Pattern(pub Vec<Symbol>);
54
55impl std::fmt::Display for Pattern {
56    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
57        let parts: Vec<String> = self.0.iter().map(|s| format!("{s}")).collect();
58        write!(f, "{}", parts.join(""))
59    }
60}
61
62impl std::fmt::Debug for Pattern {
63    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
64        write!(f, "{self}")
65    }
66}
67
68fn is_hex_digit(c: char) -> bool {
69    c.is_ascii_hexdigit()
70}
71
72fn from_hex(input: &str) -> Result<u8, std::num::ParseIntError> {
73    u8::from_str_radix(input, 16)
74}
75
76/// parse a single hex byte, like `AB`
77fn hex(input: &str) -> IResult<&str, u8> {
78    map_res(take_while_m_n(2, 2, is_hex_digit), from_hex)(input)
79}
80
81/// parse a single byte signature element, which is either a hex byte or a
82/// wildcard.
83fn sig_element(input: &str) -> IResult<&str, Symbol> {
84    alt((map(hex, Symbol::from), map(tag(".."), |_| WILDCARD)))(input)
85}
86
87/// parse byte signature elements, hex or wildcard.
88fn byte_signature(input: &str) -> IResult<&str, Pattern> {
89    let (input, elems) = many1(sig_element)(input)?;
90    Ok((input, Pattern(elems)))
91}
92
93/// parse a pattern from a string like `AABB..DD`.
94impl std::convert::From<&str> for Pattern {
95    fn from(v: &str) -> Self {
96        byte_signature(v).expect("failed to parse pattern").1
97    }
98}
99
100pub struct PatternSet {
101    patterns: Vec<Pattern>,
102    dt:       super::decision_tree::DecisionTree,
103}
104
105impl std::fmt::Debug for PatternSet {
106    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
107        for pattern in self.patterns.iter() {
108            writeln!(f, "  - {pattern}")?;
109        }
110        Ok(())
111    }
112}
113
114impl PatternSet {
115    pub fn r#match(&self, buf: &[u8]) -> Vec<&Pattern> {
116        self.dt
117            .matches(buf)
118            .into_iter()
119            .map(|i| &self.patterns[i as usize])
120            .collect()
121    }
122
123    pub fn builder() -> PatternSetBuilder {
124        PatternSetBuilder { patterns: vec![] }
125    }
126
127    pub fn from_patterns(patterns: Vec<Pattern>) -> PatternSet {
128        PatternSetBuilder { patterns }.build()
129    }
130}
131
132pub struct PatternSetBuilder {
133    patterns: Vec<Pattern>,
134}
135
136impl PatternSetBuilder {
137    pub fn add_pattern(&mut self, pattern: Pattern) {
138        self.patterns.push(pattern)
139    }
140
141    pub fn build(self) -> PatternSet {
142        // should not be possible to generate invalid regex from a pattern
143        // otherwise, programming error.
144        // must reject invalid patterns when deserializing from pat/sig.
145
146        let mut patterns = vec![];
147        for pattern in self.patterns.iter() {
148            patterns.push(format!("{pattern}"));
149        }
150
151        let dt = super::decision_tree::DecisionTree::new(&patterns);
152
153        PatternSet {
154            patterns: self.patterns,
155            dt,
156        }
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn test_empty_build() {
166        PatternSet::builder().build();
167    }
168
169    // patterns:
170    //   - pat0: aabbccdd
171    #[test]
172    fn test_add_one_pattern() {
173        let mut b = PatternSet::builder();
174        b.add_pattern(Pattern::from("AABBCCDD"));
175
176        println!("{:?}", b.build());
177    }
178
179    // patterns:
180    //   - pat0: aabbccdd
181    //   - pat1: aabbcccc
182    #[test]
183    fn test_add_two_patterns() {
184        let mut b = PatternSet::builder();
185        b.add_pattern(Pattern::from("AABBCCDD"));
186        b.add_pattern(Pattern::from("AABBCCCC"));
187
188        println!("{:?}", b.build());
189    }
190
191    // patterns:
192    //   - pat0: aabbccdd
193    //   - pat1: aabbcc..
194    #[test]
195    fn test_add_one_wildcard() {
196        let mut b = PatternSet::builder();
197        b.add_pattern(Pattern::from("AABBCCDD"));
198        b.add_pattern(Pattern::from("AABBCC.."));
199
200        println!("{:?}", b.build());
201    }
202
203    // we don't match when we don't have any patterns.
204    #[test]
205    fn test_match_empty() {
206        let pattern_set = PatternSet::builder().build();
207        assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 0);
208    }
209
210    // we match things we want to, and don't match other data.
211    #[test]
212    fn test_match_one() {
213        let mut b = PatternSet::builder();
214        b.add_pattern(Pattern::from("AABBCCDD"));
215        let pattern_set = b.build();
216
217        // true positive
218        assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 1);
219        // true negative
220        assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xEE").len(), 0);
221    }
222
223    // we match from the beginning of the buffer onwards,
224    // ignoring trailing bytes beyond the length of the pattern.
225    #[test]
226    fn test_match_long() {
227        let mut b = PatternSet::builder();
228        b.add_pattern(Pattern::from("AABBCCDD"));
229        let pattern_set = b.build();
230
231        assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD\x00").len(), 1);
232        assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD\x11").len(), 1);
233    }
234
235    // we can match when there are single character wildcards present,
236    // and order of the pattern declarations should not matter.
237    #[test]
238    fn test_match_one_tail_wildcard() {
239        let mut b = PatternSet::builder();
240        b.add_pattern(Pattern::from("AABBCC.."));
241        b.add_pattern(Pattern::from("AABBCCDD"));
242        let pattern_set = b.build();
243
244        assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 2);
245        assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xEE").len(), 1);
246        assert_eq!(pattern_set.r#match(b"\xAA\xBB\x00\x00").len(), 0);
247
248        // order of patterns should not matter
249        let mut b = PatternSet::builder();
250        b.add_pattern(Pattern::from("AABBCCDD"));
251        b.add_pattern(Pattern::from("AABBCC.."));
252        let pattern_set = b.build();
253
254        assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 2);
255        assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xEE").len(), 1);
256        assert_eq!(pattern_set.r#match(b"\xAA\xBB\x00\x00").len(), 0);
257    }
258
259    // wildcards can be found in the middle of patterns, too.
260    #[test]
261    fn test_match_one_middle_wildcard() {
262        let pattern_set = PatternSet::from_patterns(vec![Pattern::from("AABB..DD"), Pattern::from("AABBCCDD")]);
263
264        assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 2);
265        assert_eq!(pattern_set.r#match(b"\xAA\xBB\xEE\xDD").len(), 1);
266        assert_eq!(pattern_set.r#match(b"\xAA\xBB\x00\x00").len(), 0);
267
268        // order of patterns should not matter
269        let pattern_set = PatternSet::from_patterns(vec![Pattern::from("AABBCCDD"), Pattern::from("AABB..DD")]);
270
271        assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 2);
272        assert_eq!(pattern_set.r#match(b"\xAA\xBB\xEE\xDD").len(), 1);
273        assert_eq!(pattern_set.r#match(b"\xAA\xBB\x00\x00").len(), 0);
274    }
275
276    // we can have an arbitrary mix of wildcards and literals.
277    #[test]
278    fn test_match_many() {
279        let pattern_set = PatternSet::from_patterns(vec![
280            Pattern::from("AABB..DD"),
281            Pattern::from("AABBCCDD"),
282            Pattern::from("........"),
283            Pattern::from("....CCDD"),
284        ]);
285        assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 4);
286        assert_eq!(pattern_set.r#match(b"\xAA\xBB\x00\xDD").len(), 2);
287        assert_eq!(pattern_set.r#match(b"\xAA\xBB\x00\x00").len(), 1);
288        assert_eq!(pattern_set.r#match(b"\x00\x00\xCC\xDD").len(), 2);
289        assert_eq!(pattern_set.r#match(b"\x00\x00\x00\x00").len(), 1);
290    }
291
292    #[test]
293    fn test_match_pathological_case() {
294        let pattern_set = PatternSet::from_patterns(vec![
295            Pattern::from("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
296            Pattern::from("................................................................"),
297        ]);
298        assert_eq!(pattern_set.r#match(b"\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA").len(), 2);
299    }
300}