coolfindpattern/
lib.rs

1#![feature(portable_simd, iter_array_chunks)]
2
3use std::{
4    ops::BitAnd,
5    simd::{Mask, Simd, cmp::SimdPartialEq},
6};
7
8// we attempt to detect which instruction set rustc will make use of,
9// as *for some reason* rust does not allow us to have the width
10// automatically inferred. but it is what it is ¯\_(ツ)_/¯
11#[cfg(all(
12    not(target_feature = "sse2"),
13    not(target_feature = "avx2"),
14    not(target_feature = "avx512f"),
15    not(target_feature = "neon")
16))]
17compile_error!("you have not selected a proper SIMD instruction set (SSE2/AVX2/AVX512/NEON)");
18
19#[cfg(all(
20    any(target_feature = "sse2", target_feature = "neon"),
21    not(target_feature = "avx2"),
22    not(target_feature = "avx512f")
23))]
24const BYTES: usize = 16;
25
26#[cfg(all(target_feature = "avx2", not(target_feature = "avx512f")))]
27const BYTES: usize = 32;
28
29#[cfg(target_feature = "avx512f")]
30const BYTES: usize = 64;
31
32#[macro_export]
33macro_rules! pattern {
34    ($($elem:tt),+) => {
35        &[$(pattern!(@el $elem)),+]
36    };
37    (@el $v:literal) => {
38        Some($v as u8)
39    };
40    (@el $v:tt) => {
41        None
42    };
43}
44
45pub type OwnedPattern = Vec<Option<u8>>;
46pub type Pattern<'a> = &'a [Option<u8>];
47
48pub struct PatternChunk {
49    pub first_byte: Simd<u8, BYTES>,
50    pub mask: Mask<i8, BYTES>,
51    pub bytes: Simd<u8, BYTES>,
52}
53
54pub struct PreparedPattern {
55    pub chunks: Vec<PatternChunk>,
56    pub orig_pat: OwnedPattern,
57    pub size: usize,
58    pub padded_size: usize,
59    pub start_offset: usize,
60}
61
62impl<'a> From<Pattern<'a>> for PreparedPattern {
63    fn from(pat: Pattern) -> Self {
64        // remove trailing wildcard bytes
65        let pat = &pat[0..=pat
66            .iter()
67            .rposition(|chr| matches!(chr, Some(_)))
68            .expect("pattern should not be a wildcard!")];
69
70        // don't include the first n wildcard bytes in the actual search pattern, saving valuable space
71        // doing this naively would cause an unexpected shift in the returned matches, therefore
72        // we simply re-apply the offset when returning pattern matches to the user.
73        let start_offset = pat
74            .iter()
75            .position(|byte| byte.is_some())
76            .expect("pattern should not be a wildcard!");
77
78        let pat = &pat[start_offset..pat.len()];
79
80        // get size extended to next chunk
81        let size = if pat.len() % BYTES == 0 {
82            pat.len()
83        } else {
84            pat.len() + (BYTES - (pat.len() % BYTES))
85        };
86
87        let bytes: Vec<u8> = pat
88            .iter()
89            .map(|x| match x {
90                Some(x) => *x,
91                None => 0u8,
92            })
93            .collect();
94
95        let mask: Vec<bool> = pat.iter().map(|x| x.is_some()).collect();
96
97        let mut bytes_extended = vec![0u8; size];
98
99        bytes_extended[0..pat.len()].copy_from_slice(&bytes);
100
101        let mut mask_extended = vec![false; size];
102
103        mask_extended[0..pat.len()].copy_from_slice(&mask);
104
105        let chunks: Vec<PatternChunk> = bytes_extended
106            .into_iter()
107            .array_chunks::<BYTES>()
108            .zip(mask_extended.into_iter().array_chunks::<BYTES>())
109            .map(|(bytes, mask)| PatternChunk {
110                first_byte: Simd::from_array([bytes[0]; BYTES]),
111                mask: Mask::from_array(mask),
112                bytes: Simd::from_array(bytes),
113            })
114            .collect();
115
116        Self {
117            chunks,
118            orig_pat: pat.to_owned(),
119            size: pat.len(),
120            padded_size: size,
121            start_offset,
122        }
123    }
124}
125
126// precompute data for pattern in SIMD chunks.
127// SIMD search binary
128
129pub struct PatternSearcher<'data> {
130    data: &'data [u8],
131    remaining_data: &'data [u8],
132    pattern: PreparedPattern,
133}
134
135impl<'data> PatternSearcher<'data> {
136    pub fn new(data: &'data [u8], pattern: Pattern) -> Self {
137        Self {
138            data,
139            remaining_data: data,
140            pattern: pattern.into(),
141        }
142    }
143}
144
145impl<'data> Iterator for PatternSearcher<'data> {
146    type Item = usize;
147
148    fn next(&mut self) -> Option<Self::Item> {
149        'main: loop {
150            if self.remaining_data.len() < self.pattern.size {
151                // pattern is not findable anymore.
152                break None;
153            }
154
155            if self.remaining_data.len() < self.pattern.padded_size {
156                // pattern is no longer SIMD-findable. manually find.
157
158                // this is a very cold path.
159                #[cold]
160                fn find_pattern(region: &[u8], pattern: Pattern) -> Option<usize> {
161                    region.windows(pattern.len()).position(|wnd| {
162                        wnd.iter().zip(pattern).all(|(v, p)| match p {
163                            Some(x) => *v == *x,
164                            None => true,
165                        })
166                    })
167                }
168
169                let result = find_pattern(self.remaining_data, &self.pattern.orig_pat);
170
171                break match result {
172                    Some(offset) => {
173                        let result = offset - self.pattern.start_offset + self.data.len()
174                            - self.remaining_data.len();
175                        self.remaining_data = &self.remaining_data[offset + 1..];
176
177                        Some(result)
178                    }
179                    None => None,
180                };
181            }
182
183            let mut current_search = self.remaining_data;
184            let mut current_offset = 0usize;
185            let mut first_chunk = true;
186
187            for chunk in &self.pattern.chunks {
188                let search = Simd::from_slice(&current_search[..BYTES]);
189
190                let first_byte = search.simd_eq(chunk.first_byte).to_bitmask();
191
192                if first_byte == 0 {
193                    if first_chunk {
194                        // this is the first block. the next block may contain the first again
195                        // advance current cursor to the next block and restart pattern verification
196                        self.remaining_data = &self.remaining_data[BYTES..];
197                    } else {
198                        // this is a continuation block. the first pattern chunk might still be in this data chunk
199                        // only this chunk has failed, we need to restart pattern verification in this same block, just this time with the first chunk
200                        self.remaining_data = &self.remaining_data[current_offset..];
201                    }
202
203                    continue 'main;
204                }
205
206                // if this is the first chunk, allow advancing to the next occurrence of the first byte and restart check
207                if first_chunk && first_byte.trailing_zeros() != 0 {
208                    self.remaining_data =
209                        &self.remaining_data[first_byte.trailing_zeros() as usize..];
210                    continue 'main;
211                } else if first_byte.trailing_zeros() != 0 {
212                    // not the first chunk, but we are not aligned to the first byte.
213                    // this means we did not match.
214                    // restart pattern verification from the current data chunk.
215                    self.remaining_data = &self.remaining_data[current_offset..];
216                    continue 'main;
217                }
218
219                // we are now aligned to the first byte of the chunk
220                let search = Simd::from_slice(current_search);
221
222                let result = search.simd_eq(chunk.bytes);
223
224                // filtered result is smaller than the mask
225                let filtered_result = result.bitand(chunk.mask);
226
227                if filtered_result != chunk.mask {
228                    // we did not match. restart pattern scan in one byte
229
230                    // increase index by one to avoid scanning the same chunk again
231                    self.remaining_data = &self.remaining_data[1..];
232
233                    continue 'main;
234                }
235
236                // we matched. go on to next chunk. if the remaining chunks also match, we gracefully leave the loop and return a match.
237
238                first_chunk = false;
239                current_search = &current_search[BYTES..];
240                current_offset += BYTES;
241            }
242
243            let result = self.data.len() - self.remaining_data.len() - self.pattern.start_offset;
244
245            self.remaining_data = &self.remaining_data[1..];
246
247            return Some(result);
248        }
249    }
250}
251
252#[test]
253fn test_scan_simple() {
254    let mut buf = vec![0u8; 500];
255
256    buf[6] = 0xDE;
257    buf[7] = 0xAD;
258    buf[8] = 0xBE;
259    buf[9] = 0xEF;
260
261    let pattern = pattern!(0xDE, 0xAD, 0xBE, 0xEF);
262    let mut scanner = PatternSearcher::new(&buf, pattern);
263
264    assert_eq!(scanner.next(), Some(6))
265}
266
267#[test]
268fn test_scan_offset() {
269    let mut buf = vec![0u8; 500];
270
271    buf[6] = 0xDE;
272    buf[7] = 0xAD;
273    buf[8] = 0xBE;
274    buf[9] = 0xEF;
275
276    let pattern = pattern!(_, 0xDE, 0xAD, 0xBE, 0xEF);
277    let mut scanner = PatternSearcher::new(&buf, pattern);
278
279    assert_eq!(scanner.next(), Some(5))
280}
281
282#[test]
283fn test_scan_simd_fallback() {
284    let mut buf = vec![0u8; 500];
285
286    buf[496] = 0xDE;
287    buf[497] = 0xAD;
288    buf[498] = 0xBE;
289    buf[499] = 0xEF;
290
291    let pattern = pattern!(0xDE, 0xAD, 0xBE, 0xEF);
292    let mut scanner = PatternSearcher::new(&buf, pattern);
293
294    assert_eq!(scanner.next(), Some(496))
295}
296
297#[test]
298fn test_scan_simd_fallback_offset() {
299    let mut buf = vec![0u8; 500];
300
301    buf[496] = 0xDE;
302    buf[497] = 0xAD;
303    buf[498] = 0xBE;
304    buf[499] = 0xEF;
305
306    let pattern = pattern!(_, 0xDE, 0xAD, 0xBE, 0xEF);
307    let mut scanner = PatternSearcher::new(&buf, pattern);
308
309    assert_eq!(scanner.next(), Some(495))
310}
311
312#[test]
313fn test_scan_wildcard() {
314    let mut buf = vec![0u8; 500];
315
316    buf[6] = 0xDE;
317    buf[7] = 0xAD;
318    buf[9] = 0xBE;
319    buf[10] = 0xEF;
320
321    let pattern = pattern!(0xDE, 0xAD, _, 0xBE, 0xEF);
322    let mut scanner = PatternSearcher::new(&buf, pattern);
323
324    assert_eq!(scanner.next(), Some(6))
325}
326
327#[test]
328fn test_scan_large_sig() {
329    let mut buf = vec![0u8; 500];
330
331    buf[5] = 0xDE;
332    buf[6] = 0xAD;
333    buf[8] = 0xBE;
334    buf[9] = 0xEF;
335
336    buf[10] = 0xDE;
337    buf[11] = 0xAD;
338    buf[13] = 0xBE;
339    buf[14] = 0xEF;
340
341    buf[15] = 0xDE;
342    buf[16] = 0xAD;
343    buf[18] = 0xBE;
344    buf[19] = 0xEF;
345
346    buf[20] = 0xDE;
347    buf[21] = 0xAD;
348    buf[23] = 0xBE;
349    buf[24] = 0xEF;
350
351    buf[25] = 0xDE;
352    buf[26] = 0xAD;
353    buf[28] = 0xBE;
354    buf[29] = 0xEF;
355
356    buf[30] = 0xDE;
357    buf[31] = 0xAD;
358    buf[33] = 0xBE;
359    buf[34] = 0xEF;
360
361    buf[35] = 0xDE;
362    buf[36] = 0xAD;
363    buf[38] = 0xBE;
364    buf[39] = 0xEF;
365
366    buf[40] = 0xDE;
367    buf[41] = 0xAD;
368    buf[43] = 0xBE;
369    buf[44] = 0xEF;
370
371    buf[45] = 0xDE;
372    buf[46] = 0xAD;
373    buf[48] = 0xBE;
374    buf[49] = 0xEF;
375
376    let pattern = pattern!(
377        0xDE, 0xAD, _, 0xBE, 0xEF, 0xDE, 0xAD, _, 0xBE, 0xEF, 0xDE, 0xAD, _, 0xBE, 0xEF, 0xDE,
378        0xAD, _, 0xBE, 0xEF, 0xDE, 0xAD, _, 0xBE, 0xEF, 0xDE, 0xAD, _, 0xBE, 0xEF, 0xDE, 0xAD, _,
379        0xBE, 0xEF, 0xDE, 0xAD, _, 0xBE, 0xEF, 0xDE, 0xAD, _, 0xBE, 0xEF
380    );
381
382    let mut scanner = PatternSearcher::new(&buf, pattern);
383
384    assert_eq!(scanner.next(), Some(5))
385}