1use core::num;
4use rayon::{
5 prelude::{IndexedParallelIterator, ParallelIterator},
6 slice::ParallelSlice,
7 ThreadPool, ThreadPoolBuilder,
8};
9use thiserror::Error;
10
11pub struct PatternScanner {
12 bytes: Vec<u8>,
13 threadpool: ThreadPool,
14}
15
16impl PatternScanner {
17 pub fn scan<T: AsRef<str>>(&self, pattern: T) -> Result<Option<usize>, PatternScannerError> {
18 self.scan_with_bytes(&self.bytes, pattern)
19 }
20
21 pub fn scan_with_bytes<T: AsRef<[u8]> + std::marker::Sync, U: AsRef<str>>(
22 &self,
23 bytes: T,
24 pattern: U,
25 ) -> Result<Option<usize>, PatternScannerError> {
26 let pattern_bytes = create_bytes_from_string(pattern)?;
27
28 Ok(self.threadpool.install(|| {
30 bytes
31 .as_ref()
32 .par_windows(pattern_bytes.len())
33 .position_any(|window| {
34 window
35 .iter()
36 .zip(pattern_bytes.iter())
37 .all(|(byte, pattern_byte)| {
38 pattern_byte.is_none() || Some(*byte) == *pattern_byte
39 })
40 })
41 }))
42 }
43
44 pub fn scan_all<T: AsRef<str>>(&self, pattern: T) -> Result<Vec<usize>, PatternScannerError> {
45 self.scan_all_with_bytes(&self.bytes, pattern)
46 }
47
48 pub fn scan_all_with_bytes<T: AsRef<[u8]> + std::marker::Sync, U: AsRef<str>>(
49 &self,
50 bytes: T,
51 pattern: U,
52 ) -> Result<Vec<usize>, PatternScannerError> {
53 let pattern_bytes = create_bytes_from_string(pattern)?;
54
55 Ok(self.threadpool.install(|| {
57 bytes
58 .as_ref()
59 .par_windows(pattern_bytes.len())
60 .enumerate()
61 .filter(|(_, window)| {
62 window
63 .iter()
64 .zip(pattern_bytes.iter())
65 .all(|(byte, pattern_byte)| {
66 pattern_byte.is_none() || Some(*byte) == *pattern_byte
67 })
68 })
69 .map(|(i, _)| i)
70 .collect()
71 }))
72 }
73}
74
75pub struct PatternScannerBuilder {
76 bytes: Vec<u8>,
77 threadpool_builder: ThreadPoolBuilder,
78}
79
80impl PatternScannerBuilder {
81 pub fn builder() -> Self {
83 Self {
84 bytes: Vec::new(),
85 threadpool_builder: ThreadPoolBuilder::new(),
86 }
87 }
88
89 pub fn with_bytes<T: AsRef<[u8]>>(mut self, bytes: T) -> Self {
91 self.bytes = bytes.as_ref().to_vec();
92 self
93 }
94
95 pub fn with_threads(mut self, threads: usize) -> Self {
97 self.threadpool_builder = self.threadpool_builder.num_threads(threads);
98 self
99 }
100
101 pub fn build(self) -> PatternScanner {
103 PatternScanner {
104 bytes: self.bytes,
105 threadpool: self
106 .threadpool_builder
107 .build()
108 .expect("failed to build threadpool"),
109 }
110 }
111}
112
113#[derive(Error, Debug, PartialEq)]
114pub enum PatternScannerError {
116 #[error("failed to parse the pattern byte {0} as a u8")]
117 InvalidByte(#[from] num::ParseIntError),
118 #[error("the pattern byte {0} is invalid (must be 2 characters long)")]
119 ByteLength(String),
120 #[error("unknown pattern scanner error")]
123 Unknown,
124}
125
126fn create_bytes_from_string<T: AsRef<str>>(
134 pattern: T,
135) -> Result<Vec<Option<u8>>, PatternScannerError> {
136 let split_pattern = pattern.as_ref().split_whitespace();
137
138 let mut v = Vec::new();
140 for x in split_pattern {
141 if x == "?" || x == "??" {
142 v.push(None);
143 } else {
144 if x.len() != 2 {
146 return Err(PatternScannerError::ByteLength(x.to_owned()));
147 }
148
149 v.push(Some(match u8::from_str_radix(x, 16) {
150 Ok(b) => b,
151 Err(e) => return Err(PatternScannerError::InvalidByte(e)),
152 }));
153 }
154 }
155
156 Ok(v)
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn test_create_bytes_from_string_1() {
166 assert_eq!(
167 create_bytes_from_string("AA BB CC").unwrap(),
168 vec![Some(0xAA), Some(0xBB), Some(0xCC)]
169 );
170 }
171
172 #[test]
173 fn test_create_bytes_from_string_2() {
175 assert_eq!(
176 create_bytes_from_string("AA BB CC AA BB FF").unwrap(),
177 vec![
178 Some(0xAA),
179 Some(0xBB),
180 Some(0xCC),
181 Some(0xAA),
182 Some(0xBB),
183 Some(0xFF)
184 ]
185 );
186 }
187
188 #[test]
189 fn test_create_bytes_from_string_wildcard_1() {
191 assert_eq!(
192 create_bytes_from_string("AA BB ? ? CC").unwrap(),
193 vec![Some(0xAA), Some(0xBB), None, None, Some(0xCC)]
194 );
195 }
196
197 #[test]
198 fn test_create_bytes_from_string_wildcard_2() {
200 assert_eq!(
201 create_bytes_from_string("? AA BB ? ? CC ? ? ? FF").unwrap(),
202 vec![
203 None,
204 Some(0xAA),
205 Some(0xBB),
206 None,
207 None,
208 Some(0xCC),
209 None,
210 None,
211 None,
212 Some(0xFF)
213 ]
214 );
215 }
216
217 #[test]
218 fn test_create_bytes_from_string_error_invalid_byte() {
220 assert!(create_bytes_from_string("AA GG").is_err());
222 }
223
224 #[test]
225 fn test_create_bytes_from_string_error_space() {
227 assert_eq!(
228 create_bytes_from_string("A A BB"),
229 Err(PatternScannerError::ByteLength("A".to_owned()))
230 );
231 }
232
233 #[test]
234 fn test_pattern_scan() {
235 let result = PatternScannerBuilder::builder()
236 .with_bytes(&[0x00, 0x01, 0x02, 0x33, 0x35, 0x33, 0x35, 0x07, 0x08, 0x09])
237 .with_threads(1)
238 .build()
239 .scan("33 35")
240 .unwrap();
241
242 assert_eq!(result, Some(3));
243 }
244
245 #[test]
246 fn test_pattern_scan_all() {
247 let result = PatternScannerBuilder::builder()
248 .with_bytes(&[0x00, 0x01, 0x02, 0x33, 0x35, 0x33, 0x35, 0x07, 0x08, 0x09])
249 .build()
250 .scan_all("33 35")
251 .unwrap();
252
253 assert_eq!(result, vec![3, 5]);
254 }
255
256 #[test]
258 fn test_pattern_scan_all_1_million_bytes() {
259 let mut bytes = [0u8; 1_000_000];
261 bytes[600_000] = 0x33;
262 bytes[600_001] = 0x35;
263
264 let scanner = PatternScannerBuilder::builder()
266 .with_bytes(bytes)
267 .with_threads(1)
268 .build();
269
270 let start = std::time::Instant::now();
272
273 let result = scanner.scan_all("33 35").unwrap();
275
276 let duration = start.elapsed();
278
279 println!("Execution time: {:?}", duration);
281
282 assert_eq!(result, vec![600_000]);
283 }
284}