use std::fmt::{self, Display};
use std::io::Read;
use std::str::FromStr;
pub const CHUNK_SIZE: usize = 0x800;
pub fn scan(reader: impl Read, pattern: &str) -> Result<Vec<usize>, Error> {
let matches = Matches::from_pattern_str(reader, pattern)?;
matches.collect()
}
pub fn scan_first_match(reader: impl Read, pattern: &str) -> Result<Option<usize>, Error> {
let mut matches = Matches::from_pattern_str(reader, pattern)?;
matches.next().transpose()
}
pub fn pattern_matches(bytes: &[u8], pattern: &Pattern) -> bool {
if bytes.len() < pattern.len() {
false
} else {
pattern == bytes
}
}
#[derive(Debug)]
pub struct Error {
e: String,
}
impl Error {
pub fn new(e: String) -> Self {
Self { e }
}
}
impl Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "Pattern scan error: {}", self.e)
}
}
impl std::error::Error for Error {}
#[derive(PartialEq, Eq)]
pub enum PatternByte {
Byte(u8),
Any,
}
impl FromStr for PatternByte {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Error> {
if s == "?" {
Ok(Self::Any)
} else {
let n = match u8::from_str_radix(s, 16) {
Ok(n) => Ok(n),
Err(e) => Err(Error::new(format!("from_str_radix failed: {}", e))),
}?;
Ok(Self::Byte(n))
}
}
}
impl PartialEq<u8> for PatternByte {
fn eq(&self, other: &u8) -> bool {
match self {
PatternByte::Any => true,
PatternByte::Byte(b) => b == other,
}
}
}
#[derive(PartialEq, Eq)]
pub struct Pattern {
bytes: Vec<PatternByte>,
}
impl Pattern {
fn new(bytes: Vec<PatternByte>) -> Self {
Self { bytes }
}
fn len(&self) -> usize {
self.bytes.len()
}
}
impl FromStr for Pattern {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Error> {
let mut bytes = Vec::new();
for segment in s.split_ascii_whitespace() {
bytes.push(PatternByte::from_str(segment)?);
}
Ok(Self::new(bytes))
}
}
impl PartialEq<[u8]> for Pattern {
fn eq(&self, other: &[u8]) -> bool {
Iterator::zip(self.bytes.iter(), other.iter()).all(|(pb, b)| pb == b)
}
}
pub struct Matches<R: Read> {
pub reader: R,
pub pattern: Pattern,
bytes_buf: [u8; CHUNK_SIZE],
last_bytes_read: usize,
abs_position: usize,
rel_position: usize,
}
impl<R: Read> Matches<R> {
pub fn from_pattern(mut reader: R, pattern: Pattern) -> Result<Self, Error> {
if 2 * pattern.len() > CHUNK_SIZE {
return Err(Error::new(format!(
"Pattern too long: It can be at most {} bytes",
CHUNK_SIZE / 2
)));
}
let mut bytes_buf = [0; CHUNK_SIZE];
let bytes_read = reader
.read(&mut bytes_buf)
.map_err(|e| Error::new(format!("Failed to read bytes: {}", e)))?;
Ok(Self {
reader,
pattern,
bytes_buf,
last_bytes_read: bytes_read,
abs_position: 0,
rel_position: 0,
})
}
pub fn from_pattern_str(reader: R, pattern: &str) -> Result<Self, Error> {
let pattern = Pattern::from_str(pattern)?;
Self::from_pattern(reader, pattern)
}
}
impl<R: Read> Iterator for Matches<R> {
type Item = Result<usize, Error>;
fn next(&mut self) -> Option<Self::Item> {
loop {
if self.rel_position == CHUNK_SIZE - self.pattern.len() {
let len = self.pattern.len();
let boundary_bytes = &self.bytes_buf[CHUNK_SIZE - len..].to_owned();
self.bytes_buf[..len].copy_from_slice(&boundary_bytes);
self.last_bytes_read = match self.reader.read(&mut self.bytes_buf[len..]) {
Ok(b) => b,
Err(e) => return Some(Err(Error::new(format!("Failed to read bytes: {}", e)))),
};
self.rel_position = 0;
}
if self.rel_position == self.last_bytes_read + self.pattern.len() {
break;
}
for i in self.rel_position..self.last_bytes_read + self.pattern.len() {
if i == CHUNK_SIZE - self.pattern.len() {
break;
}
self.abs_position += 1;
self.rel_position += 1;
if pattern_matches(&self.bytes_buf[i..], &self.pattern) {
return Some(Ok(self.abs_position - 1));
}
}
if self.last_bytes_read == 0 {
break;
}
}
None
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
#[test]
fn simple_scan_start() {
let bytes = [0x10, 0x20, 0x30, 0x40, 0x50];
let pattern = "10 20 30";
assert_eq!(crate::scan(Cursor::new(bytes), &pattern).unwrap(), vec![0]);
}
#[test]
fn simple_scan_middle() {
let bytes = [0x10, 0x20, 0x30, 0x40, 0x50];
let pattern = "20 30 40";
assert_eq!(crate::scan(Cursor::new(bytes), &pattern).unwrap(), vec![1]);
}
#[test]
fn scan_bad_exceeds() {
let bytes = [0x10, 0x20, 0x30, 0x40, 0x50];
let pattern = "40 50 60";
assert_eq!(crate::scan(Cursor::new(bytes), &pattern).unwrap(), vec![]);
}
#[test]
fn scan_exists() {
let bytes = [0xff, 0xfe, 0x7c, 0x88, 0xfd, 0x90, 0x00];
let pattern = "fe 7c 88 fd 90 0";
assert_eq!(crate::scan(Cursor::new(bytes), &pattern).unwrap(), vec![1]);
}
#[test]
fn scan_exists_multiple_q() {
let bytes = [0xff, 0xfe, 0x7c, 0x88, 0xfd, 0x90, 0x00];
let pattern = "fe ? ? ? 90";
assert_eq!(crate::scan(Cursor::new(bytes), &pattern).unwrap(), vec![1]);
}
#[test]
fn scan_exists_multiple_q_starts() {
let bytes = [0xff, 0xfe, 0x7c, 0x88, 0xfd, 0x90, 0x00];
let pattern = "? ? ? ? fd";
assert_eq!(crate::scan(Cursor::new(bytes), &pattern).unwrap(), vec![0]);
}
#[test]
fn scan_nexists_1() {
let bytes = [0xff, 0xfe, 0x7c, 0x88, 0xfd, 0x90, 0x00];
let pattern = "78 90 cc dd fe";
assert_eq!(crate::scan(Cursor::new(bytes), &pattern).unwrap(), vec![]);
}
#[test]
fn scan_nexists_2() {
let bytes = [0xff, 0xfe, 0x7c, 0x88, 0xfd, 0x90, 0x00];
let pattern = "fe 7c 88 fd 90 1";
assert_eq!(crate::scan(Cursor::new(bytes), &pattern).unwrap(), vec![]);
}
#[test]
fn scan_pattern_larger_than_bytes() {
let bytes = [0xff, 0xfe, 0x7c, 0x88, 0xfd, 0x90, 0x00];
let pattern = "fe 7c 88 fd 90 0 1";
assert_eq!(crate::scan(Cursor::new(bytes), &pattern).unwrap(), vec![]);
}
#[test]
fn scan_multiple_instances_of_pattern() {
let bytes = [0x10, 0x20, 0x30, 0x10, 0x20, 0x30];
let pattern = "10 20 30";
assert_eq!(
crate::scan(Cursor::new(bytes), &pattern).unwrap(),
vec![0, 3]
);
}
#[test]
fn scan_multiple_instances_q() {
let bytes = [0x10, 0x20, 0x30, 0x10, 0x40, 0x30];
let pattern = "10 ? 30";
assert_eq!(
crate::scan(Cursor::new(bytes), &pattern).unwrap(),
vec![0, 3]
);
}
#[test]
fn scan_rejects_invalid_pattern() {
let bytes = [0x10, 0x20, 0x30];
let pattern = "10 fff 20";
assert!(crate::scan(Cursor::new(bytes), &pattern).is_err());
}
#[test]
fn scan_first_match_simple_start() {
let bytes = [0x10, 0x20, 0x30, 0x40, 0x50];
let pattern = "10 20 30";
assert_eq!(
crate::scan_first_match(Cursor::new(bytes), &pattern)
.unwrap()
.unwrap(),
0
);
}
#[test]
fn scan_first_match_simple_middle() {
let bytes = [0x10, 0x20, 0x30, 0x40, 0x50];
let pattern = "20 30 40";
assert_eq!(
crate::scan_first_match(Cursor::new(bytes), &pattern)
.unwrap()
.unwrap(),
1
);
}
#[test]
fn scan_first_match_no_match() {
let bytes = [0x10, 0x20, 0x30, 0x40, 0x50];
let pattern = "10 11 12";
assert!(crate::scan_first_match(Cursor::new(bytes), &pattern)
.unwrap()
.is_none());
}
#[test]
fn find_across_chunk_boundary() {
let mut bytes = vec![0; super::CHUNK_SIZE - 2];
bytes.push(0xaa);
bytes.push(0xbb);
bytes.push(0xcc);
bytes.push(0xdd);
let pattern = "aa bb cc dd";
assert!(crate::scan_first_match(Cursor::new(bytes), &pattern)
.unwrap()
.is_some())
}
#[test]
fn correct_index_noninitial_chunk() {
let mut bytes = vec![0; super::CHUNK_SIZE];
bytes.push(0xaa);
bytes.push(0xbb);
bytes.push(0xcc);
bytes.push(0xdd);
let pattern = "aa bb cc dd";
assert_eq!(
crate::scan_first_match(Cursor::new(bytes), &pattern)
.unwrap()
.unwrap(),
super::CHUNK_SIZE
);
}
}