patternscanner/
lib.rs

1//! This crate provides a simple API for searching for a pattern in an array of bytes as either single-threaded or multi-threaded. It supports matching on either a single pattern or all possible patterns.
2
3use core::num;
4use rayon::{
5    prelude::{IndexedParallelIterator, ParallelIterator},
6    slice::ParallelSlice,
7    ThreadPool, ThreadPoolBuilder,
8};
9use thiserror::Error;
10
11pub struct PatternScanner {
12    bytes: Vec<u8>,
13    threadpool: ThreadPool,
14}
15
16impl PatternScanner {
17    pub fn scan<T: AsRef<str>>(&self, pattern: T) -> Result<Option<usize>, PatternScannerError> {
18        self.scan_with_bytes(&self.bytes, pattern)
19    }
20
21    pub fn scan_with_bytes<T: AsRef<[u8]> + std::marker::Sync, U: AsRef<str>>(
22        &self,
23        bytes: T,
24        pattern: U,
25    ) -> Result<Option<usize>, PatternScannerError> {
26        let pattern_bytes = create_bytes_from_string(pattern)?;
27
28        // Scan the bytes for the unique pattern using the rayon crate
29        Ok(self.threadpool.install(|| {
30            bytes
31                .as_ref()
32                .par_windows(pattern_bytes.len())
33                .position_any(|window| {
34                    window
35                        .iter()
36                        .zip(pattern_bytes.iter())
37                        .all(|(byte, pattern_byte)| {
38                            pattern_byte.is_none() || Some(*byte) == *pattern_byte
39                        })
40                })
41        }))
42    }
43
44    pub fn scan_all<T: AsRef<str>>(&self, pattern: T) -> Result<Vec<usize>, PatternScannerError> {
45        self.scan_all_with_bytes(&self.bytes, pattern)
46    }
47
48    pub fn scan_all_with_bytes<T: AsRef<[u8]> + std::marker::Sync, U: AsRef<str>>(
49        &self,
50        bytes: T,
51        pattern: U,
52    ) -> Result<Vec<usize>, PatternScannerError> {
53        let pattern_bytes = create_bytes_from_string(pattern)?;
54
55        // Scan the bytes for all matches of the pattern using the rayon crate
56        Ok(self.threadpool.install(|| {
57            bytes
58                .as_ref()
59                .par_windows(pattern_bytes.len())
60                .enumerate()
61                .filter(|(_, window)| {
62                    window
63                        .iter()
64                        .zip(pattern_bytes.iter())
65                        .all(|(byte, pattern_byte)| {
66                            pattern_byte.is_none() || Some(*byte) == *pattern_byte
67                        })
68                })
69                .map(|(i, _)| i)
70                .collect()
71        }))
72    }
73}
74
75pub struct PatternScannerBuilder {
76    bytes: Vec<u8>,
77    threadpool_builder: ThreadPoolBuilder,
78}
79
80impl PatternScannerBuilder {
81    ///  Create a new pattern scanner builder
82    pub fn builder() -> Self {
83        Self {
84            bytes: Vec::new(),
85            threadpool_builder: ThreadPoolBuilder::new(),
86        }
87    }
88
89    /// Set the bytes to scan
90    pub fn with_bytes<T: AsRef<[u8]>>(mut self, bytes: T) -> Self {
91        self.bytes = bytes.as_ref().to_vec();
92        self
93    }
94
95    /// Set the number of threads to use
96    pub fn with_threads(mut self, threads: usize) -> Self {
97        self.threadpool_builder = self.threadpool_builder.num_threads(threads);
98        self
99    }
100
101    /// Build the pattern scanner
102    pub fn build(self) -> PatternScanner {
103        PatternScanner {
104            bytes: self.bytes,
105            threadpool: self
106                .threadpool_builder
107                .build()
108                .expect("failed to build threadpool"),
109        }
110    }
111}
112
113#[derive(Error, Debug, PartialEq)]
114// The error types for the pattern scanner
115pub enum PatternScannerError {
116    #[error("failed to parse the pattern byte {0} as a u8")]
117    InvalidByte(#[from] num::ParseIntError),
118    #[error("the pattern byte {0} is invalid (must be 2 characters long)")]
119    ByteLength(String),
120    //#[error("invalid header (expected {expected:?}, found {found:?})")]
121    //InvalidHeader { expected: String, found: String },
122    #[error("unknown pattern scanner error")]
123    Unknown,
124}
125
126/// Create a vector of bytes from a pattern string
127///
128/// # Arguments
129/// * `pattern` - The pattern string
130///
131/// # Returns
132/// * A vector of bytes
133fn create_bytes_from_string<T: AsRef<str>>(
134    pattern: T,
135) -> Result<Vec<Option<u8>>, PatternScannerError> {
136    let split_pattern = pattern.as_ref().split_whitespace();
137
138    // Create a Vec of Option<u8> where None represents a ? character in the pattern string
139    let mut v = Vec::new();
140    for x in split_pattern {
141        if x == "?" || x == "??" {
142            v.push(None);
143        } else {
144            // Check that the pattern byte string is 2 characters long
145            if x.len() != 2 {
146                return Err(PatternScannerError::ByteLength(x.to_owned()));
147            }
148
149            v.push(Some(match u8::from_str_radix(x, 16) {
150                Ok(b) => b,
151                Err(e) => return Err(PatternScannerError::InvalidByte(e)),
152            }));
153        }
154    }
155
156    Ok(v)
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    // Test the create_bytes_from_string function with a valid string
165    fn test_create_bytes_from_string_1() {
166        assert_eq!(
167            create_bytes_from_string("AA BB CC").unwrap(),
168            vec![Some(0xAA), Some(0xBB), Some(0xCC)]
169        );
170    }
171
172    #[test]
173    // Test the create_bytes_from_string function with a valid string
174    fn test_create_bytes_from_string_2() {
175        assert_eq!(
176            create_bytes_from_string("AA BB CC AA BB FF").unwrap(),
177            vec![
178                Some(0xAA),
179                Some(0xBB),
180                Some(0xCC),
181                Some(0xAA),
182                Some(0xBB),
183                Some(0xFF)
184            ]
185        );
186    }
187
188    #[test]
189    // Test the create_bytes_from_string function with a wildcard "?"
190    fn test_create_bytes_from_string_wildcard_1() {
191        assert_eq!(
192            create_bytes_from_string("AA BB ? ? CC").unwrap(),
193            vec![Some(0xAA), Some(0xBB), None, None, Some(0xCC)]
194        );
195    }
196
197    #[test]
198    // Test the create_bytes_from_string function with a wildcard "?"
199    fn test_create_bytes_from_string_wildcard_2() {
200        assert_eq!(
201            create_bytes_from_string("? AA BB ? ? CC ? ? ? FF").unwrap(),
202            vec![
203                None,
204                Some(0xAA),
205                Some(0xBB),
206                None,
207                None,
208                Some(0xCC),
209                None,
210                None,
211                None,
212                Some(0xFF)
213            ]
214        );
215    }
216
217    #[test]
218    // Test the create_bytes_from_string function with an invalid byte "GG"
219    fn test_create_bytes_from_string_error_invalid_byte() {
220        // There is currently no way to construct a ParseIntError so we can't test this yet, reference: https://stackoverflow.com/questions/55572098/how-to-construct-a-parseinterror-in-my-own-code
221        assert!(create_bytes_from_string("AA GG").is_err());
222    }
223
224    #[test]
225    // Test the create_bytes_from_string function with a string that contains a space between the bytes
226    fn test_create_bytes_from_string_error_space() {
227        assert_eq!(
228            create_bytes_from_string("A A BB"),
229            Err(PatternScannerError::ByteLength("A".to_owned()))
230        );
231    }
232
233    #[test]
234    fn test_pattern_scan() {
235        let result = PatternScannerBuilder::builder()
236            .with_bytes(&[0x00, 0x01, 0x02, 0x33, 0x35, 0x33, 0x35, 0x07, 0x08, 0x09])
237            .with_threads(1)
238            .build()
239            .scan("33 35")
240            .unwrap();
241
242        assert_eq!(result, Some(3));
243    }
244
245    #[test]
246    fn test_pattern_scan_all() {
247        let result = PatternScannerBuilder::builder()
248            .with_bytes(&[0x00, 0x01, 0x02, 0x33, 0x35, 0x33, 0x35, 0x07, 0x08, 0x09])
249            .build()
250            .scan_all("33 35")
251            .unwrap();
252
253        assert_eq!(result, vec![3, 5]);
254    }
255
256    // This test measures the execution time of the scan_all function with 1 million bytes and 1 thread
257    #[test]
258    fn test_pattern_scan_all_1_million_bytes() {
259        // Create an array of 1 million bytes
260        let mut bytes = [0u8; 1_000_000];
261        bytes[600_000] = 0x33;
262        bytes[600_001] = 0x35;
263
264        // Create the pattern scanner
265        let scanner = PatternScannerBuilder::builder()
266            .with_bytes(bytes)
267            .with_threads(1)
268            .build();
269
270        // Start measuring the execution time
271        let start = std::time::Instant::now();
272
273        // Scan the bytes
274        let result = scanner.scan_all("33 35").unwrap();
275
276        // Stop measuring the execution time
277        let duration = start.elapsed();
278
279        // Print the execution time
280        println!("Execution time: {:?}", duration);
281
282        assert_eq!(result, vec![600_000]);
283    }
284}