use anyhow::Result;
use nom::{
branch::alt,
bytes::complete::{tag, take_while_m_n},
combinator::{map, map_res},
multi::many1,
IResult,
};
#[derive(Copy, Clone, Hash, Eq, PartialEq)]
pub struct Symbol(pub u16);
pub const WILDCARD: Symbol = Symbol(0x100);
impl std::convert::From<u8> for Symbol {
fn from(v: u8) -> Self {
Symbol(v as u16)
}
}
impl std::fmt::Display for Symbol {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
if self.0 == WILDCARD.0 {
write!(f, "..")
} else {
write!(f, r"{:02X}", self.0)
}
}
}
#[derive(Hash, PartialEq, Eq, Clone)]
pub struct Pattern(pub Vec<Symbol>);
impl std::fmt::Display for Pattern {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let parts: Vec<String> = self.0.iter().map(|s| format!("{s}")).collect();
write!(f, "{}", parts.join(""))
}
}
impl std::fmt::Debug for Pattern {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{self}")
}
}
fn is_hex_digit(c: char) -> bool {
c.is_ascii_hexdigit()
}
fn from_hex(input: &str) -> Result<u8, std::num::ParseIntError> {
u8::from_str_radix(input, 16)
}
fn hex(input: &str) -> IResult<&str, u8> {
map_res(take_while_m_n(2, 2, is_hex_digit), from_hex)(input)
}
fn sig_element(input: &str) -> IResult<&str, Symbol> {
alt((map(hex, Symbol::from), map(tag(".."), |_| WILDCARD)))(input)
}
fn byte_signature(input: &str) -> IResult<&str, Pattern> {
let (input, elems) = many1(sig_element)(input)?;
Ok((input, Pattern(elems)))
}
impl std::convert::From<&str> for Pattern {
fn from(v: &str) -> Self {
byte_signature(v).expect("failed to parse pattern").1
}
}
pub struct PatternSet {
patterns: Vec<Pattern>,
dt: super::decision_tree::DecisionTree,
}
impl std::fmt::Debug for PatternSet {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
for pattern in self.patterns.iter() {
writeln!(f, " - {pattern}")?;
}
Ok(())
}
}
impl PatternSet {
pub fn r#match(&self, buf: &[u8]) -> Vec<&Pattern> {
self.dt
.matches(buf)
.into_iter()
.map(|i| &self.patterns[i as usize])
.collect()
}
pub fn builder() -> PatternSetBuilder {
PatternSetBuilder { patterns: vec![] }
}
pub fn from_patterns(patterns: Vec<Pattern>) -> PatternSet {
PatternSetBuilder { patterns }.build()
}
}
pub struct PatternSetBuilder {
patterns: Vec<Pattern>,
}
impl PatternSetBuilder {
pub fn add_pattern(&mut self, pattern: Pattern) {
self.patterns.push(pattern)
}
pub fn build(self) -> PatternSet {
let mut patterns = vec![];
for pattern in self.patterns.iter() {
patterns.push(format!("{pattern}"));
}
let dt = super::decision_tree::DecisionTree::new(&patterns);
PatternSet {
patterns: self.patterns,
dt,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_build() {
PatternSet::builder().build();
}
#[test]
fn test_add_one_pattern() {
let mut b = PatternSet::builder();
b.add_pattern(Pattern::from("AABBCCDD"));
println!("{:?}", b.build());
}
#[test]
fn test_add_two_patterns() {
let mut b = PatternSet::builder();
b.add_pattern(Pattern::from("AABBCCDD"));
b.add_pattern(Pattern::from("AABBCCCC"));
println!("{:?}", b.build());
}
#[test]
fn test_add_one_wildcard() {
let mut b = PatternSet::builder();
b.add_pattern(Pattern::from("AABBCCDD"));
b.add_pattern(Pattern::from("AABBCC.."));
println!("{:?}", b.build());
}
#[test]
fn test_match_empty() {
let pattern_set = PatternSet::builder().build();
assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 0);
}
#[test]
fn test_match_one() {
let mut b = PatternSet::builder();
b.add_pattern(Pattern::from("AABBCCDD"));
let pattern_set = b.build();
assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 1);
assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xEE").len(), 0);
}
#[test]
fn test_match_long() {
let mut b = PatternSet::builder();
b.add_pattern(Pattern::from("AABBCCDD"));
let pattern_set = b.build();
assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD\x00").len(), 1);
assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD\x11").len(), 1);
}
#[test]
fn test_match_one_tail_wildcard() {
let mut b = PatternSet::builder();
b.add_pattern(Pattern::from("AABBCC.."));
b.add_pattern(Pattern::from("AABBCCDD"));
let pattern_set = b.build();
assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 2);
assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xEE").len(), 1);
assert_eq!(pattern_set.r#match(b"\xAA\xBB\x00\x00").len(), 0);
let mut b = PatternSet::builder();
b.add_pattern(Pattern::from("AABBCCDD"));
b.add_pattern(Pattern::from("AABBCC.."));
let pattern_set = b.build();
assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 2);
assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xEE").len(), 1);
assert_eq!(pattern_set.r#match(b"\xAA\xBB\x00\x00").len(), 0);
}
#[test]
fn test_match_one_middle_wildcard() {
let pattern_set = PatternSet::from_patterns(vec![Pattern::from("AABB..DD"), Pattern::from("AABBCCDD")]);
assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 2);
assert_eq!(pattern_set.r#match(b"\xAA\xBB\xEE\xDD").len(), 1);
assert_eq!(pattern_set.r#match(b"\xAA\xBB\x00\x00").len(), 0);
let pattern_set = PatternSet::from_patterns(vec![Pattern::from("AABBCCDD"), Pattern::from("AABB..DD")]);
assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 2);
assert_eq!(pattern_set.r#match(b"\xAA\xBB\xEE\xDD").len(), 1);
assert_eq!(pattern_set.r#match(b"\xAA\xBB\x00\x00").len(), 0);
}
#[test]
fn test_match_many() {
let pattern_set = PatternSet::from_patterns(vec![
Pattern::from("AABB..DD"),
Pattern::from("AABBCCDD"),
Pattern::from("........"),
Pattern::from("....CCDD"),
]);
assert_eq!(pattern_set.r#match(b"\xAA\xBB\xCC\xDD").len(), 4);
assert_eq!(pattern_set.r#match(b"\xAA\xBB\x00\xDD").len(), 2);
assert_eq!(pattern_set.r#match(b"\xAA\xBB\x00\x00").len(), 1);
assert_eq!(pattern_set.r#match(b"\x00\x00\xCC\xDD").len(), 2);
assert_eq!(pattern_set.r#match(b"\x00\x00\x00\x00").len(), 1);
}
#[test]
fn test_match_pathological_case() {
let pattern_set = PatternSet::from_patterns(vec![
Pattern::from("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
Pattern::from("................................................................"),
]);
assert_eq!(pattern_set.r#match(b"\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA").len(), 2);
}
}