lancelot_flirt/
pattern_set.rs1use anyhow::Result;
19use nom::{
20 branch::alt,
21 bytes::complete::{tag, take_while_m_n},
22 combinator::{map, map_res},
23 multi::many1,
24 IResult,
25};
26
27#[derive(Copy, Clone, Hash, Eq, PartialEq)]
29pub struct Symbol(pub u16);
30
31pub const WILDCARD: Symbol = Symbol(0x100);
33
34impl std::convert::From<u8> for Symbol {
36 fn from(v: u8) -> Self {
37 Symbol(v as u16)
38 }
39}
40
41impl std::fmt::Display for Symbol {
42 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
43 if self.0 == WILDCARD.0 {
44 write!(f, "..")
45 } else {
46 write!(f, r"{:02X}", self.0)
47 }
48 }
49}
50
51#[derive(Hash, PartialEq, Eq, Clone)]
53pub struct Pattern(pub Vec<Symbol>);
54
55impl std::fmt::Display for Pattern {
56 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
57 let parts: Vec<String> = self.0.iter().map(|s| format!("{s}")).collect();
58 write!(f, "{}", parts.join(""))
59 }
60}
61
62impl std::fmt::Debug for Pattern {
63 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
64 write!(f, "{self}")
65 }
66}
67
68fn is_hex_digit(c: char) -> bool {
69 c.is_ascii_hexdigit()
70}
71
72fn from_hex(input: &str) -> Result<u8, std::num::ParseIntError> {
73 u8::from_str_radix(input, 16)
74}
75
76fn hex(input: &str) -> IResult<&str, u8> {
78 map_res(take_while_m_n(2, 2, is_hex_digit), from_hex)(input)
79}
80
81fn sig_element(input: &str) -> IResult<&str, Symbol> {
84 alt((map(hex, Symbol::from), map(tag(".."), |_| WILDCARD)))(input)
85}
86
87fn byte_signature(input: &str) -> IResult<&str, Pattern> {
89 let (input, elems) = many1(sig_element)(input)?;
90 Ok((input, Pattern(elems)))
91}
92
93impl std::convert::From<&str> for Pattern {
95 fn from(v: &str) -> Self {
96 byte_signature(v).expect("failed to parse pattern").1
97 }
98}
99
100pub struct PatternSet {
101 patterns: Vec<Pattern>,
102 dt: super::decision_tree::DecisionTree,
103}
104
105impl std::fmt::Debug for PatternSet {
106 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
107 for pattern in self.patterns.iter() {
108 writeln!(f, " - {pattern}")?;
109 }
110 Ok(())
111 }
112}
113
114impl PatternSet {
115 pub fn r#match(&self, buf: &[u8]) -> Vec<&Pattern> {
116 self.dt
117 .matches(buf)
118 .into_iter()
119 .map(|i| &self.patterns[i as usize])
120 .collect()
121 }
122
123 pub fn builder() -> PatternSetBuilder {
124 PatternSetBuilder { patterns: vec![] }
125 }
126
127 pub fn from_patterns(patterns: Vec<Pattern>) -> PatternSet {
128 PatternSetBuilder { patterns }.build()
129 }
130}
131
132pub struct PatternSetBuilder {
133 patterns: Vec<Pattern>,
134}
135
136impl PatternSetBuilder {
137 pub fn add_pattern(&mut self, pattern: Pattern) {
138 self.patterns.push(pattern)
139 }
140
141 pub fn build(self) -> PatternSet {
142 let mut patterns = vec![];
147 for pattern in self.patterns.iter() {
148 patterns.push(format!("{pattern}"));
149 }
150
151 let dt = super::decision_tree::DecisionTree::new(&patterns);
152
153 PatternSet {
154 patterns: self.patterns,
155 dt,
156 }
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[test]
165 fn test_empty_build() {
166 PatternSet::builder().build();
167 }
168
169 #[test]
172 fn test_add_one_pattern() {
173 let mut b = PatternSet::builder();
174 b.add_pattern(Pattern::from("AABBCCDD"));
175
176 println!("{:?}", b.build());
177 }
178
179 #[test]
183 fn test_add_two_patterns() {
184 let mut b = PatternSet::builder();
185 b.add_pattern(Pattern::from("AABBCCDD"));
186 b.add_pattern(Pattern::from("AABBCCCC"));
187
188 println!("{:?}", b.build());
189 }
190
191 #[test]
195 fn test_add_one_wildcard() {
196 let mut b = PatternSet::builder();
197 b.add_pattern(Pattern::from("AABBCCDD"));
198 b.add_pattern(Pattern::from("AABBCC.."));
199
200 println!("{:?}", b.build());
201 }
202
203 #[test]
205 fn test_match_empty() {
206 let pattern_set = PatternSet::builder().build();
207 assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 0);
208 }
209
210 #[test]
212 fn test_match_one() {
213 let mut b = PatternSet::builder();
214 b.add_pattern(Pattern::from("AABBCCDD"));
215 let pattern_set = b.build();
216
217 assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 1);
219 assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xEE").len(), 0);
221 }
222
223 #[test]
226 fn test_match_long() {
227 let mut b = PatternSet::builder();
228 b.add_pattern(Pattern::from("AABBCCDD"));
229 let pattern_set = b.build();
230
231 assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD\x00").len(), 1);
232 assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD\x11").len(), 1);
233 }
234
235 #[test]
238 fn test_match_one_tail_wildcard() {
239 let mut b = PatternSet::builder();
240 b.add_pattern(Pattern::from("AABBCC.."));
241 b.add_pattern(Pattern::from("AABBCCDD"));
242 let pattern_set = b.build();
243
244 assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 2);
245 assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xEE").len(), 1);
246 assert_eq!(pattern_set.r#match(b"\xAA\xBB\x00\x00").len(), 0);
247
248 let mut b = PatternSet::builder();
250 b.add_pattern(Pattern::from("AABBCCDD"));
251 b.add_pattern(Pattern::from("AABBCC.."));
252 let pattern_set = b.build();
253
254 assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 2);
255 assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xEE").len(), 1);
256 assert_eq!(pattern_set.r#match(b"\xAA\xBB\x00\x00").len(), 0);
257 }
258
259 #[test]
261 fn test_match_one_middle_wildcard() {
262 let pattern_set = PatternSet::from_patterns(vec![Pattern::from("AABB..DD"), Pattern::from("AABBCCDD")]);
263
264 assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 2);
265 assert_eq!(pattern_set.r#match(b"\xAA\xBB\xEE\xDD").len(), 1);
266 assert_eq!(pattern_set.r#match(b"\xAA\xBB\x00\x00").len(), 0);
267
268 let pattern_set = PatternSet::from_patterns(vec![Pattern::from("AABBCCDD"), Pattern::from("AABB..DD")]);
270
271 assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 2);
272 assert_eq!(pattern_set.r#match(b"\xAA\xBB\xEE\xDD").len(), 1);
273 assert_eq!(pattern_set.r#match(b"\xAA\xBB\x00\x00").len(), 0);
274 }
275
276 #[test]
278 fn test_match_many() {
279 let pattern_set = PatternSet::from_patterns(vec![
280 Pattern::from("AABB..DD"),
281 Pattern::from("AABBCCDD"),
282 Pattern::from("........"),
283 Pattern::from("....CCDD"),
284 ]);
285 assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 4);
286 assert_eq!(pattern_set.r#match(b"\xAA\xBB\x00\xDD").len(), 2);
287 assert_eq!(pattern_set.r#match(b"\xAA\xBB\x00\x00").len(), 1);
288 assert_eq!(pattern_set.r#match(b"\x00\x00\xCC\xDD").len(), 2);
289 assert_eq!(pattern_set.r#match(b"\x00\x00\x00\x00").len(), 1);
290 }
291
292 #[test]
293 fn test_match_pathological_case() {
294 let pattern_set = PatternSet::from_patterns(vec![
295 Pattern::from("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
296 Pattern::from("................................................................"),
297 ]);
298 assert_eq!(pattern_set.r#match(b"\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA").len(), 2);
299 }
300}