wraith/util/
simd.rs

1//! SIMD-accelerated pattern matching
2//!
3//! Uses SSE2/AVX2 for fast pattern scanning when available.
4//! Falls back to scalar implementation on unsupported platforms.
5
6#[cfg(target_arch = "x86_64")]
7use core::arch::x86_64::*;
8
9#[cfg(target_arch = "x86")]
10use core::arch::x86::*;
11
12/// SIMD implementation selector
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum SimdLevel {
15    /// No SIMD available, use scalar
16    None,
17    /// SSE2 available (128-bit)
18    Sse2,
19    /// AVX2 available (256-bit)
20    Avx2,
21}
22
23impl SimdLevel {
24    /// detect the best available SIMD level at runtime
25    #[inline]
26    pub fn detect() -> Self {
27        #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
28        {
29            if is_x86_feature_detected!("avx2") {
30                return SimdLevel::Avx2;
31            }
32            if is_x86_feature_detected!("sse2") {
33                return SimdLevel::Sse2;
34            }
35        }
36        SimdLevel::None
37    }
38}
39
40/// SIMD-accelerated pattern scanner
41pub struct SimdScanner {
42    /// pattern bytes
43    pattern: Vec<u8>,
44    /// mask: true = wildcard (match any)
45    mask: Vec<bool>,
46    /// first non-wildcard byte index (for SIMD skip)
47    first_concrete_idx: Option<usize>,
48    /// first non-wildcard byte value
49    first_concrete_byte: u8,
50    /// detected SIMD level
51    simd_level: SimdLevel,
52}
53
54impl SimdScanner {
55    /// create a new SIMD scanner from pattern bytes and mask
56    pub fn new(pattern: Vec<u8>, mask: Vec<bool>) -> Self {
57        // find first non-wildcard byte for SIMD acceleration
58        let (first_concrete_idx, first_concrete_byte) = mask
59            .iter()
60            .enumerate()
61            .find(|(_, &is_wildcard)| !is_wildcard)
62            .map(|(i, _)| (Some(i), pattern[i]))
63            .unwrap_or((None, 0));
64
65        Self {
66            pattern,
67            mask,
68            first_concrete_idx,
69            first_concrete_byte,
70            simd_level: SimdLevel::detect(),
71        }
72    }
73
74    /// get pattern length
75    #[inline]
76    pub fn len(&self) -> usize {
77        self.pattern.len()
78    }
79
80    /// check if pattern is empty
81    #[inline]
82    pub fn is_empty(&self) -> bool {
83        self.pattern.is_empty()
84    }
85
86    /// scan data for pattern, returns offsets of all matches
87    pub fn scan(&self, data: &[u8]) -> Vec<usize> {
88        if self.pattern.is_empty() || data.len() < self.pattern.len() {
89            return Vec::new();
90        }
91
92        // if pattern is all wildcards, match everything
93        if self.first_concrete_idx.is_none() {
94            return (0..=data.len() - self.pattern.len()).collect();
95        }
96
97        match self.simd_level {
98            #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
99            SimdLevel::Avx2 => unsafe { self.scan_avx2(data) },
100            #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
101            SimdLevel::Sse2 => unsafe { self.scan_sse2(data) },
102            _ => self.scan_scalar(data),
103        }
104    }
105
106    /// scan data for first match only
107    pub fn scan_first(&self, data: &[u8]) -> Option<usize> {
108        if self.pattern.is_empty() || data.len() < self.pattern.len() {
109            return None;
110        }
111
112        if self.first_concrete_idx.is_none() {
113            return Some(0);
114        }
115
116        match self.simd_level {
117            #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
118            SimdLevel::Avx2 => unsafe { self.scan_first_avx2(data) },
119            #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
120            SimdLevel::Sse2 => unsafe { self.scan_first_sse2(data) },
121            _ => self.scan_first_scalar(data),
122        }
123    }
124
125    /// scalar fallback implementation
126    fn scan_scalar(&self, data: &[u8]) -> Vec<usize> {
127        let mut results = Vec::new();
128        let max_offset = data.len() - self.pattern.len();
129
130        for offset in 0..=max_offset {
131            if self.matches_at(data, offset) {
132                results.push(offset);
133            }
134        }
135
136        results
137    }
138
139    /// scalar first match
140    fn scan_first_scalar(&self, data: &[u8]) -> Option<usize> {
141        let max_offset = data.len() - self.pattern.len();
142
143        for offset in 0..=max_offset {
144            if self.matches_at(data, offset) {
145                return Some(offset);
146            }
147        }
148
149        None
150    }
151
152    /// check if pattern matches at offset
153    #[inline]
154    fn matches_at(&self, data: &[u8], offset: usize) -> bool {
155        self.pattern
156            .iter()
157            .zip(self.mask.iter())
158            .enumerate()
159            .all(|(i, (&pattern_byte, &is_wildcard))| {
160                is_wildcard || data[offset + i] == pattern_byte
161            })
162    }
163
164    /// AVX2 accelerated scan (256-bit, 32 bytes at a time)
165    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
166    #[target_feature(enable = "avx2")]
167    unsafe fn scan_avx2(&self, data: &[u8]) -> Vec<usize> {
168        let mut results = Vec::new();
169        let pattern_len = self.pattern.len();
170        let first_idx = self.first_concrete_idx.unwrap();
171
172        if data.len() < pattern_len {
173            return results;
174        }
175
176        let max_offset = data.len() - pattern_len;
177
178        // SAFETY: avx2 is guaranteed available by target_feature
179        // broadcast the first concrete byte to all lanes
180        let needle = unsafe { _mm256_set1_epi8(self.first_concrete_byte as i8) };
181
182        // we search for first_concrete_byte, accounting for its position in pattern
183        // the first concrete byte can appear at positions first_idx through max_offset + first_idx
184        let search_start = first_idx;
185        let search_end = max_offset + first_idx + 1; // +1 for exclusive end
186
187        if search_end <= search_start {
188            return self.scan_scalar(data);
189        }
190
191        let mut pos = search_start;
192
193        // process 32 bytes at a time
194        while pos + 32 <= search_end {
195            // SAFETY: bounds checked, avx2 available
196            let chunk = unsafe { _mm256_loadu_si256(data.as_ptr().add(pos) as *const __m256i) };
197            let cmp = unsafe { _mm256_cmpeq_epi8(chunk, needle) };
198            let mut mask = unsafe { _mm256_movemask_epi8(cmp) } as u32;
199
200            while mask != 0 {
201                let bit_pos = mask.trailing_zeros() as usize;
202                let candidate_offset = pos + bit_pos - first_idx;
203
204                if candidate_offset <= max_offset && self.matches_at(data, candidate_offset) {
205                    results.push(candidate_offset);
206                }
207
208                mask &= mask - 1; // clear lowest set bit
209            }
210
211            pos += 32;
212        }
213
214        // handle remaining bytes with scalar
215        while pos < search_end {
216            if data[pos] == self.first_concrete_byte {
217                let candidate_offset = pos - first_idx;
218                if candidate_offset <= max_offset && self.matches_at(data, candidate_offset) {
219                    results.push(candidate_offset);
220                }
221            }
222            pos += 1;
223        }
224
225        results
226    }
227
228    /// AVX2 first match
229    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
230    #[target_feature(enable = "avx2")]
231    unsafe fn scan_first_avx2(&self, data: &[u8]) -> Option<usize> {
232        let pattern_len = self.pattern.len();
233        let first_idx = self.first_concrete_idx.unwrap();
234
235        if data.len() < pattern_len {
236            return None;
237        }
238
239        let max_offset = data.len() - pattern_len;
240        // SAFETY: avx2 guaranteed by target_feature
241        let needle = unsafe { _mm256_set1_epi8(self.first_concrete_byte as i8) };
242
243        let search_start = first_idx;
244        let search_end = max_offset + first_idx + 1; // +1 for exclusive end
245
246        if search_end <= search_start {
247            return self.scan_first_scalar(data);
248        }
249
250        let mut pos = search_start;
251
252        while pos + 32 <= search_end {
253            // SAFETY: bounds checked, avx2 available
254            let chunk = unsafe { _mm256_loadu_si256(data.as_ptr().add(pos) as *const __m256i) };
255            let cmp = unsafe { _mm256_cmpeq_epi8(chunk, needle) };
256            let mut mask = unsafe { _mm256_movemask_epi8(cmp) } as u32;
257
258            while mask != 0 {
259                let bit_pos = mask.trailing_zeros() as usize;
260                let candidate_offset = pos + bit_pos - first_idx;
261
262                if candidate_offset <= max_offset && self.matches_at(data, candidate_offset) {
263                    return Some(candidate_offset);
264                }
265
266                mask &= mask - 1;
267            }
268
269            pos += 32;
270        }
271
272        // scalar remainder
273        while pos < search_end {
274            if data[pos] == self.first_concrete_byte {
275                let candidate_offset = pos - first_idx;
276                if candidate_offset <= max_offset && self.matches_at(data, candidate_offset) {
277                    return Some(candidate_offset);
278                }
279            }
280            pos += 1;
281        }
282
283        None
284    }
285
286    /// SSE2 accelerated scan (128-bit, 16 bytes at a time)
287    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
288    #[target_feature(enable = "sse2")]
289    unsafe fn scan_sse2(&self, data: &[u8]) -> Vec<usize> {
290        let mut results = Vec::new();
291        let pattern_len = self.pattern.len();
292        let first_idx = self.first_concrete_idx.unwrap();
293
294        if data.len() < pattern_len {
295            return results;
296        }
297
298        let max_offset = data.len() - pattern_len;
299        // SAFETY: sse2 guaranteed by target_feature
300        let needle = unsafe { _mm_set1_epi8(self.first_concrete_byte as i8) };
301
302        let search_start = first_idx;
303        let search_end = max_offset + first_idx + 1; // +1 for exclusive end
304
305        if search_end <= search_start {
306            return self.scan_scalar(data);
307        }
308
309        let mut pos = search_start;
310
311        // process 16 bytes at a time
312        while pos + 16 <= search_end {
313            // SAFETY: bounds checked, sse2 available
314            let chunk = unsafe { _mm_loadu_si128(data.as_ptr().add(pos) as *const __m128i) };
315            let cmp = unsafe { _mm_cmpeq_epi8(chunk, needle) };
316            let mut mask = unsafe { _mm_movemask_epi8(cmp) } as u16;
317
318            while mask != 0 {
319                let bit_pos = mask.trailing_zeros() as usize;
320                let candidate_offset = pos + bit_pos - first_idx;
321
322                if candidate_offset <= max_offset && self.matches_at(data, candidate_offset) {
323                    results.push(candidate_offset);
324                }
325
326                mask &= mask - 1;
327            }
328
329            pos += 16;
330        }
331
332        // scalar remainder
333        while pos < search_end {
334            if data[pos] == self.first_concrete_byte {
335                let candidate_offset = pos - first_idx;
336                if candidate_offset <= max_offset && self.matches_at(data, candidate_offset) {
337                    results.push(candidate_offset);
338                }
339            }
340            pos += 1;
341        }
342
343        results
344    }
345
346    /// SSE2 first match
347    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
348    #[target_feature(enable = "sse2")]
349    unsafe fn scan_first_sse2(&self, data: &[u8]) -> Option<usize> {
350        let pattern_len = self.pattern.len();
351        let first_idx = self.first_concrete_idx.unwrap();
352
353        if data.len() < pattern_len {
354            return None;
355        }
356
357        let max_offset = data.len() - pattern_len;
358        // SAFETY: sse2 guaranteed by target_feature
359        let needle = unsafe { _mm_set1_epi8(self.first_concrete_byte as i8) };
360
361        let search_start = first_idx;
362        let search_end = max_offset + first_idx + 1; // +1 for exclusive end
363
364        if search_end <= search_start {
365            return self.scan_first_scalar(data);
366        }
367
368        let mut pos = search_start;
369
370        while pos + 16 <= search_end {
371            // SAFETY: bounds checked, sse2 available
372            let chunk = unsafe { _mm_loadu_si128(data.as_ptr().add(pos) as *const __m128i) };
373            let cmp = unsafe { _mm_cmpeq_epi8(chunk, needle) };
374            let mut mask = unsafe { _mm_movemask_epi8(cmp) } as u16;
375
376            while mask != 0 {
377                let bit_pos = mask.trailing_zeros() as usize;
378                let candidate_offset = pos + bit_pos - first_idx;
379
380                if candidate_offset <= max_offset && self.matches_at(data, candidate_offset) {
381                    return Some(candidate_offset);
382                }
383
384                mask &= mask - 1;
385            }
386
387            pos += 16;
388        }
389
390        // scalar remainder
391        while pos < search_end {
392            if data[pos] == self.first_concrete_byte {
393                let candidate_offset = pos - first_idx;
394                if candidate_offset <= max_offset && self.matches_at(data, candidate_offset) {
395                    return Some(candidate_offset);
396                }
397            }
398            pos += 1;
399        }
400
401        None
402    }
403}
404
405/// scan data for pattern using SIMD when available
406///
407/// pattern format: space-separated hex bytes, `?` or `??` for wildcards
408pub fn simd_scan(data: &[u8], pattern: &str) -> Vec<usize> {
409    let (bytes, mask) = match parse_pattern(pattern) {
410        Some(p) => p,
411        None => return Vec::new(),
412    };
413
414    let scanner = SimdScanner::new(bytes, mask);
415    scanner.scan(data)
416}
417
418/// scan data for first pattern match using SIMD
419pub fn simd_scan_first(data: &[u8], pattern: &str) -> Option<usize> {
420    let (bytes, mask) = parse_pattern(pattern)?;
421    let scanner = SimdScanner::new(bytes, mask);
422    scanner.scan_first(data)
423}
424
425/// parse IDA-style pattern into bytes and mask
426fn parse_pattern(pattern: &str) -> Option<(Vec<u8>, Vec<bool>)> {
427    let trimmed = pattern.trim();
428    if trimmed.is_empty() {
429        return None;
430    }
431
432    let parts: Vec<&str> = trimmed.split_whitespace().collect();
433    if parts.is_empty() {
434        return None;
435    }
436
437    let mut bytes = Vec::with_capacity(parts.len());
438    let mut mask = Vec::with_capacity(parts.len());
439
440    for part in parts {
441        if part == "?" || part == "??" || part == "*" || part == "**" {
442            bytes.push(0);
443            mask.push(true);
444        } else {
445            let byte = u8::from_str_radix(part, 16).ok()?;
446            bytes.push(byte);
447            mask.push(false);
448        }
449    }
450
451    Some((bytes, mask))
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457
458    #[test]
459    fn test_simd_level_detect() {
460        let level = SimdLevel::detect();
461        // should at least detect something on modern CPUs
462        println!("Detected SIMD level: {:?}", level);
463    }
464
465    #[test]
466    fn test_simd_scan_simple() {
467        let data = [0x48, 0x8B, 0x05, 0x12, 0x34, 0x56, 0x78, 0x90];
468        let results = simd_scan(&data, "48 8B 05");
469        assert_eq!(results, vec![0]);
470    }
471
472    #[test]
473    fn test_simd_scan_wildcards() {
474        let data = [0x48, 0x8B, 0x05, 0x12, 0x34, 0x56, 0x78, 0x90];
475        let results = simd_scan(&data, "48 8B ?? ?? 34");
476        assert_eq!(results, vec![0]);
477    }
478
479    #[test]
480    fn test_simd_scan_no_match() {
481        let data = [0x48, 0x8B, 0x05, 0x12, 0x34, 0x56, 0x78, 0x90];
482        let results = simd_scan(&data, "FF FF FF");
483        assert!(results.is_empty());
484    }
485
486    #[test]
487    fn test_simd_scan_multiple_matches() {
488        let data = [0x48, 0x8B, 0x48, 0x8B, 0x48, 0x8B];
489        let results = simd_scan(&data, "48 8B");
490        assert_eq!(results, vec![0, 2, 4]);
491    }
492
493    #[test]
494    fn test_simd_scan_first() {
495        let data = [0x48, 0x8B, 0x48, 0x8B, 0x48, 0x8B];
496        let result = simd_scan_first(&data, "48 8B");
497        assert_eq!(result, Some(0));
498    }
499
500    #[test]
501    fn test_simd_scan_large_data() {
502        // test with data larger than SIMD register width
503        let mut data = vec![0u8; 1024];
504        data[500] = 0xDE;
505        data[501] = 0xAD;
506        data[502] = 0xBE;
507        data[503] = 0xEF;
508
509        let results = simd_scan(&data, "DE AD BE EF");
510        assert_eq!(results, vec![500]);
511    }
512
513    #[test]
514    fn test_simd_scan_pattern_at_end() {
515        let data = [0x00, 0x00, 0x00, 0x48, 0x8B];
516        let results = simd_scan(&data, "48 8B");
517        assert_eq!(results, vec![3]);
518    }
519
520    #[test]
521    fn test_simd_scan_wildcard_first() {
522        // pattern starts with wildcard - this exercises the edge case
523        let data = [0x48, 0x8B, 0x05, 0x12, 0x34];
524        let results = simd_scan(&data, "?? 8B 05");
525        assert_eq!(results, vec![0]);
526    }
527}