1#![feature(portable_simd, iter_array_chunks)]
2
3use std::{
4 ops::BitAnd,
5 simd::{Mask, Simd, cmp::SimdPartialEq},
6};
7
8#[cfg(all(
12 not(target_feature = "sse2"),
13 not(target_feature = "avx2"),
14 not(target_feature = "avx512f"),
15 not(target_feature = "neon")
16))]
17compile_error!("you have not selected a proper SIMD instruction set (SSE2/AVX2/AVX512/NEON)");
18
19#[cfg(all(
20 any(target_feature = "sse2", target_feature = "neon"),
21 not(target_feature = "avx2"),
22 not(target_feature = "avx512f")
23))]
24const BYTES: usize = 16;
25
26#[cfg(all(target_feature = "avx2", not(target_feature = "avx512f")))]
27const BYTES: usize = 32;
28
29#[cfg(target_feature = "avx512f")]
30const BYTES: usize = 64;
31
32#[macro_export]
33macro_rules! pattern {
34 ($($elem:tt),+) => {
35 &[$(pattern!(@el $elem)),+]
36 };
37 (@el $v:literal) => {
38 Some($v as u8)
39 };
40 (@el $v:tt) => {
41 None
42 };
43}
44
45pub type OwnedPattern = Vec<Option<u8>>;
46pub type Pattern<'a> = &'a [Option<u8>];
47
48pub struct PatternChunk {
49 pub first_byte: Simd<u8, BYTES>,
50 pub mask: Mask<i8, BYTES>,
51 pub bytes: Simd<u8, BYTES>,
52}
53
54pub struct PreparedPattern {
55 pub chunks: Vec<PatternChunk>,
56 pub orig_pat: OwnedPattern,
57 pub size: usize,
58 pub padded_size: usize,
59 pub start_offset: usize,
60}
61
62impl<'a> From<Pattern<'a>> for PreparedPattern {
63 fn from(pat: Pattern) -> Self {
64 let pat = &pat[0..=pat
66 .iter()
67 .rposition(|chr| matches!(chr, Some(_)))
68 .expect("pattern should not be a wildcard!")];
69
70 let start_offset = pat
74 .iter()
75 .position(|byte| byte.is_some())
76 .expect("pattern should not be a wildcard!");
77
78 let pat = &pat[start_offset..pat.len()];
79
80 let size = if pat.len() % BYTES == 0 {
82 pat.len()
83 } else {
84 pat.len() + (BYTES - (pat.len() % BYTES))
85 };
86
87 let bytes: Vec<u8> = pat
88 .iter()
89 .map(|x| match x {
90 Some(x) => *x,
91 None => 0u8,
92 })
93 .collect();
94
95 let mask: Vec<bool> = pat.iter().map(|x| x.is_some()).collect();
96
97 let mut bytes_extended = vec![0u8; size];
98
99 bytes_extended[0..pat.len()].copy_from_slice(&bytes);
100
101 let mut mask_extended = vec![false; size];
102
103 mask_extended[0..pat.len()].copy_from_slice(&mask);
104
105 let chunks: Vec<PatternChunk> = bytes_extended
106 .into_iter()
107 .array_chunks::<BYTES>()
108 .zip(mask_extended.into_iter().array_chunks::<BYTES>())
109 .map(|(bytes, mask)| PatternChunk {
110 first_byte: Simd::from_array([bytes[0]; BYTES]),
111 mask: Mask::from_array(mask),
112 bytes: Simd::from_array(bytes),
113 })
114 .collect();
115
116 Self {
117 chunks,
118 orig_pat: pat.to_owned(),
119 size: pat.len(),
120 padded_size: size,
121 start_offset,
122 }
123 }
124}
125
126pub struct PatternSearcher<'data> {
130 data: &'data [u8],
131 remaining_data: &'data [u8],
132 pattern: PreparedPattern,
133}
134
135impl<'data> PatternSearcher<'data> {
136 pub fn new(data: &'data [u8], pattern: Pattern) -> Self {
137 Self {
138 data,
139 remaining_data: data,
140 pattern: pattern.into(),
141 }
142 }
143}
144
145impl<'data> Iterator for PatternSearcher<'data> {
146 type Item = usize;
147
148 fn next(&mut self) -> Option<Self::Item> {
149 'main: loop {
150 if self.remaining_data.len() < self.pattern.size {
151 break None;
153 }
154
155 if self.remaining_data.len() < self.pattern.padded_size {
156 #[cold]
160 fn find_pattern(region: &[u8], pattern: Pattern) -> Option<usize> {
161 region.windows(pattern.len()).position(|wnd| {
162 wnd.iter().zip(pattern).all(|(v, p)| match p {
163 Some(x) => *v == *x,
164 None => true,
165 })
166 })
167 }
168
169 let result = find_pattern(self.remaining_data, &self.pattern.orig_pat);
170
171 break match result {
172 Some(offset) => {
173 let result = offset - self.pattern.start_offset + self.data.len()
174 - self.remaining_data.len();
175 self.remaining_data = &self.remaining_data[offset + 1..];
176
177 Some(result)
178 }
179 None => None,
180 };
181 }
182
183 let mut current_search = self.remaining_data;
184 let mut current_offset = 0usize;
185 let mut first_chunk = true;
186
187 for chunk in &self.pattern.chunks {
188 let search = Simd::from_slice(¤t_search[..BYTES]);
189
190 let first_byte = search.simd_eq(chunk.first_byte).to_bitmask();
191
192 if first_byte == 0 {
193 if first_chunk {
194 self.remaining_data = &self.remaining_data[BYTES..];
197 } else {
198 self.remaining_data = &self.remaining_data[current_offset..];
201 }
202
203 continue 'main;
204 }
205
206 if first_chunk && first_byte.trailing_zeros() != 0 {
208 self.remaining_data =
209 &self.remaining_data[first_byte.trailing_zeros() as usize..];
210 continue 'main;
211 } else if first_byte.trailing_zeros() != 0 {
212 self.remaining_data = &self.remaining_data[current_offset..];
216 continue 'main;
217 }
218
219 let search = Simd::from_slice(current_search);
221
222 let result = search.simd_eq(chunk.bytes);
223
224 let filtered_result = result.bitand(chunk.mask);
226
227 if filtered_result != chunk.mask {
228 self.remaining_data = &self.remaining_data[1..];
232
233 continue 'main;
234 }
235
236 first_chunk = false;
239 current_search = ¤t_search[BYTES..];
240 current_offset += BYTES;
241 }
242
243 let result = self.data.len() - self.remaining_data.len() - self.pattern.start_offset;
244
245 self.remaining_data = &self.remaining_data[1..];
246
247 return Some(result);
248 }
249 }
250}
251
252#[test]
253fn test_scan_simple() {
254 let mut buf = vec![0u8; 500];
255
256 buf[6] = 0xDE;
257 buf[7] = 0xAD;
258 buf[8] = 0xBE;
259 buf[9] = 0xEF;
260
261 let pattern = pattern!(0xDE, 0xAD, 0xBE, 0xEF);
262 let mut scanner = PatternSearcher::new(&buf, pattern);
263
264 assert_eq!(scanner.next(), Some(6))
265}
266
267#[test]
268fn test_scan_offset() {
269 let mut buf = vec![0u8; 500];
270
271 buf[6] = 0xDE;
272 buf[7] = 0xAD;
273 buf[8] = 0xBE;
274 buf[9] = 0xEF;
275
276 let pattern = pattern!(_, 0xDE, 0xAD, 0xBE, 0xEF);
277 let mut scanner = PatternSearcher::new(&buf, pattern);
278
279 assert_eq!(scanner.next(), Some(5))
280}
281
282#[test]
283fn test_scan_simd_fallback() {
284 let mut buf = vec![0u8; 500];
285
286 buf[496] = 0xDE;
287 buf[497] = 0xAD;
288 buf[498] = 0xBE;
289 buf[499] = 0xEF;
290
291 let pattern = pattern!(0xDE, 0xAD, 0xBE, 0xEF);
292 let mut scanner = PatternSearcher::new(&buf, pattern);
293
294 assert_eq!(scanner.next(), Some(496))
295}
296
297#[test]
298fn test_scan_simd_fallback_offset() {
299 let mut buf = vec![0u8; 500];
300
301 buf[496] = 0xDE;
302 buf[497] = 0xAD;
303 buf[498] = 0xBE;
304 buf[499] = 0xEF;
305
306 let pattern = pattern!(_, 0xDE, 0xAD, 0xBE, 0xEF);
307 let mut scanner = PatternSearcher::new(&buf, pattern);
308
309 assert_eq!(scanner.next(), Some(495))
310}
311
312#[test]
313fn test_scan_wildcard() {
314 let mut buf = vec![0u8; 500];
315
316 buf[6] = 0xDE;
317 buf[7] = 0xAD;
318 buf[9] = 0xBE;
319 buf[10] = 0xEF;
320
321 let pattern = pattern!(0xDE, 0xAD, _, 0xBE, 0xEF);
322 let mut scanner = PatternSearcher::new(&buf, pattern);
323
324 assert_eq!(scanner.next(), Some(6))
325}
326
327#[test]
328fn test_scan_large_sig() {
329 let mut buf = vec![0u8; 500];
330
331 buf[5] = 0xDE;
332 buf[6] = 0xAD;
333 buf[8] = 0xBE;
334 buf[9] = 0xEF;
335
336 buf[10] = 0xDE;
337 buf[11] = 0xAD;
338 buf[13] = 0xBE;
339 buf[14] = 0xEF;
340
341 buf[15] = 0xDE;
342 buf[16] = 0xAD;
343 buf[18] = 0xBE;
344 buf[19] = 0xEF;
345
346 buf[20] = 0xDE;
347 buf[21] = 0xAD;
348 buf[23] = 0xBE;
349 buf[24] = 0xEF;
350
351 buf[25] = 0xDE;
352 buf[26] = 0xAD;
353 buf[28] = 0xBE;
354 buf[29] = 0xEF;
355
356 buf[30] = 0xDE;
357 buf[31] = 0xAD;
358 buf[33] = 0xBE;
359 buf[34] = 0xEF;
360
361 buf[35] = 0xDE;
362 buf[36] = 0xAD;
363 buf[38] = 0xBE;
364 buf[39] = 0xEF;
365
366 buf[40] = 0xDE;
367 buf[41] = 0xAD;
368 buf[43] = 0xBE;
369 buf[44] = 0xEF;
370
371 buf[45] = 0xDE;
372 buf[46] = 0xAD;
373 buf[48] = 0xBE;
374 buf[49] = 0xEF;
375
376 let pattern = pattern!(
377 0xDE, 0xAD, _, 0xBE, 0xEF, 0xDE, 0xAD, _, 0xBE, 0xEF, 0xDE, 0xAD, _, 0xBE, 0xEF, 0xDE,
378 0xAD, _, 0xBE, 0xEF, 0xDE, 0xAD, _, 0xBE, 0xEF, 0xDE, 0xAD, _, 0xBE, 0xEF, 0xDE, 0xAD, _,
379 0xBE, 0xEF, 0xDE, 0xAD, _, 0xBE, 0xEF, 0xDE, 0xAD, _, 0xBE, 0xEF
380 );
381
382 let mut scanner = PatternSearcher::new(&buf, pattern);
383
384 assert_eq!(scanner.next(), Some(5))
385}