use core::num;
use rayon::{
prelude::{IndexedParallelIterator, ParallelIterator},
slice::ParallelSlice,
ThreadPool, ThreadPoolBuilder,
};
use thiserror::Error;
pub struct PatternScanner {
bytes: Vec<u8>,
threadpool: ThreadPool,
}
impl PatternScanner {
pub fn scan<T: AsRef<str>>(&self, pattern: T) -> Result<Option<usize>, PatternScannerError> {
self.scan_with_bytes(&self.bytes, pattern)
}
pub fn scan_with_bytes<T: AsRef<[u8]> + std::marker::Sync, U: AsRef<str>>(
&self,
bytes: T,
pattern: U,
) -> Result<Option<usize>, PatternScannerError> {
let pattern_bytes = create_bytes_from_string(pattern)?;
Ok(self.threadpool.install(|| {
bytes
.as_ref()
.par_windows(pattern_bytes.len())
.position_any(|window| {
window
.iter()
.zip(pattern_bytes.iter())
.all(|(byte, pattern_byte)| {
pattern_byte.is_none() || Some(*byte) == *pattern_byte
})
})
}))
}
pub fn scan_all<T: AsRef<str>>(&self, pattern: T) -> Result<Vec<usize>, PatternScannerError> {
self.scan_all_with_bytes(&self.bytes, pattern)
}
pub fn scan_all_with_bytes<T: AsRef<[u8]> + std::marker::Sync, U: AsRef<str>>(
&self,
bytes: T,
pattern: U,
) -> Result<Vec<usize>, PatternScannerError> {
let pattern_bytes = create_bytes_from_string(pattern)?;
Ok(self.threadpool.install(|| {
bytes
.as_ref()
.par_windows(pattern_bytes.len())
.enumerate()
.filter(|(_, window)| {
window
.iter()
.zip(pattern_bytes.iter())
.all(|(byte, pattern_byte)| {
pattern_byte.is_none() || Some(*byte) == *pattern_byte
})
})
.map(|(i, _)| i)
.collect()
}))
}
}
pub struct PatternScannerBuilder {
bytes: Vec<u8>,
threadpool_builder: ThreadPoolBuilder,
}
impl PatternScannerBuilder {
pub fn builder() -> Self {
Self {
bytes: Vec::new(),
threadpool_builder: ThreadPoolBuilder::new(),
}
}
pub fn with_bytes<T: AsRef<[u8]>>(mut self, bytes: T) -> Self {
self.bytes = bytes.as_ref().to_vec();
self
}
pub fn with_threads(mut self, threads: usize) -> Self {
self.threadpool_builder = self.threadpool_builder.num_threads(threads);
self
}
pub fn build(self) -> PatternScanner {
PatternScanner {
bytes: self.bytes,
threadpool: self
.threadpool_builder
.build()
.expect("failed to build threadpool"),
}
}
}
#[derive(Error, Debug, PartialEq)]
pub enum PatternScannerError {
#[error("failed to parse the pattern byte {0} as a u8")]
InvalidByte(#[from] num::ParseIntError),
#[error("the pattern byte {0} is invalid (must be 2 characters long)")]
ByteLength(String),
#[error("unknown pattern scanner error")]
Unknown,
}
fn create_bytes_from_string<T: AsRef<str>>(
pattern: T,
) -> Result<Vec<Option<u8>>, PatternScannerError> {
let split_pattern = pattern.as_ref().split_whitespace();
let mut v = Vec::new();
for x in split_pattern {
if x == "?" || x == "??" {
v.push(None);
} else {
if x.len() != 2 {
return Err(PatternScannerError::ByteLength(x.to_owned()));
}
v.push(Some(match u8::from_str_radix(x, 16) {
Ok(b) => b,
Err(e) => return Err(PatternScannerError::InvalidByte(e)),
}));
}
}
Ok(v)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_bytes_from_string_1() {
assert_eq!(
create_bytes_from_string("AA BB CC").unwrap(),
vec![Some(0xAA), Some(0xBB), Some(0xCC)]
);
}
#[test]
fn test_create_bytes_from_string_2() {
assert_eq!(
create_bytes_from_string("AA BB CC AA BB FF").unwrap(),
vec![
Some(0xAA),
Some(0xBB),
Some(0xCC),
Some(0xAA),
Some(0xBB),
Some(0xFF)
]
);
}
#[test]
fn test_create_bytes_from_string_wildcard_1() {
assert_eq!(
create_bytes_from_string("AA BB ? ? CC").unwrap(),
vec![Some(0xAA), Some(0xBB), None, None, Some(0xCC)]
);
}
#[test]
fn test_create_bytes_from_string_wildcard_2() {
assert_eq!(
create_bytes_from_string("? AA BB ? ? CC ? ? ? FF").unwrap(),
vec![
None,
Some(0xAA),
Some(0xBB),
None,
None,
Some(0xCC),
None,
None,
None,
Some(0xFF)
]
);
}
#[test]
fn test_create_bytes_from_string_error_invalid_byte() {
assert!(create_bytes_from_string("AA GG").is_err());
}
#[test]
fn test_create_bytes_from_string_error_space() {
assert_eq!(
create_bytes_from_string("A A BB"),
Err(PatternScannerError::ByteLength("A".to_owned()))
);
}
#[test]
fn test_pattern_scan() {
let result = PatternScannerBuilder::builder()
.with_bytes(&[0x00, 0x01, 0x02, 0x33, 0x35, 0x33, 0x35, 0x07, 0x08, 0x09])
.with_threads(1)
.build()
.scan("33 35")
.unwrap();
assert_eq!(result, Some(3));
}
#[test]
fn test_pattern_scan_all() {
let result = PatternScannerBuilder::builder()
.with_bytes(&[0x00, 0x01, 0x02, 0x33, 0x35, 0x33, 0x35, 0x07, 0x08, 0x09])
.build()
.scan_all("33 35")
.unwrap();
assert_eq!(result, vec![3, 5]);
}
#[test]
fn test_pattern_scan_all_1_million_bytes() {
let mut bytes = [0u8; 1_000_000];
bytes[600_000] = 0x33;
bytes[600_001] = 0x35;
let scanner = PatternScannerBuilder::builder()
.with_bytes(bytes)
.with_threads(1)
.build();
let start = std::time::Instant::now();
let result = scanner.scan_all("33 35").unwrap();
let duration = start.elapsed();
println!("Execution time: {:?}", duration);
assert_eq!(result, vec![600_000]);
}
}