wraith/util/
simd.rs

1//! SIMD-accelerated pattern matching
2//!
3//! Uses SSE2/AVX2 for fast pattern scanning when available.
4//! Falls back to scalar implementation on unsupported platforms.
5//!
6//! In `no_std` mode, runtime SIMD detection is disabled and defaults to
7//! scalar implementation. Use target features to enable SIMD at compile time.
8
9#[cfg(all(not(feature = "std"), feature = "alloc"))]
10use alloc::vec::Vec;
11
12#[cfg(feature = "std")]
13use std::vec::Vec;
14
15#[cfg(target_arch = "x86_64")]
16use core::arch::x86_64::*;
17
18#[cfg(target_arch = "x86")]
19use core::arch::x86::*;
20
21/// SIMD implementation selector
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum SimdLevel {
24    /// No SIMD available, use scalar
25    None,
26    /// SSE2 available (128-bit)
27    Sse2,
28    /// AVX2 available (256-bit)
29    Avx2,
30}
31
32impl SimdLevel {
33    /// detect the best available SIMD level at runtime
34    ///
35    /// in `no_std` mode, this uses compile-time target features instead of runtime detection.
36    /// enable `target-feature=+avx2` or `target-feature=+sse2` at compile time for SIMD acceleration.
37    #[inline]
38    pub fn detect() -> Self {
39        #[cfg(all(feature = "std", any(target_arch = "x86_64", target_arch = "x86")))]
40        {
41            if is_x86_feature_detected!("avx2") {
42                return SimdLevel::Avx2;
43            }
44            if is_x86_feature_detected!("sse2") {
45                return SimdLevel::Sse2;
46            }
47        }
48
49        // in no_std mode, use compile-time feature detection
50        #[cfg(all(not(feature = "std"), any(target_arch = "x86_64", target_arch = "x86")))]
51        {
52            #[cfg(target_feature = "avx2")]
53            {
54                return SimdLevel::Avx2;
55            }
56            #[cfg(all(not(target_feature = "avx2"), target_feature = "sse2"))]
57            {
58                return SimdLevel::Sse2;
59            }
60        }
61
62        SimdLevel::None
63    }
64}
65
66/// SIMD-accelerated pattern scanner
67pub struct SimdScanner {
68    /// pattern bytes
69    pattern: Vec<u8>,
70    /// mask: true = wildcard (match any)
71    mask: Vec<bool>,
72    /// first non-wildcard byte index (for SIMD skip)
73    first_concrete_idx: Option<usize>,
74    /// first non-wildcard byte value
75    first_concrete_byte: u8,
76    /// detected SIMD level
77    simd_level: SimdLevel,
78}
79
80impl SimdScanner {
81    /// create a new SIMD scanner from pattern bytes and mask
82    pub fn new(pattern: Vec<u8>, mask: Vec<bool>) -> Self {
83        // find first non-wildcard byte for SIMD acceleration
84        let (first_concrete_idx, first_concrete_byte) = mask
85            .iter()
86            .enumerate()
87            .find(|(_, &is_wildcard)| !is_wildcard)
88            .map(|(i, _)| (Some(i), pattern[i]))
89            .unwrap_or((None, 0));
90
91        Self {
92            pattern,
93            mask,
94            first_concrete_idx,
95            first_concrete_byte,
96            simd_level: SimdLevel::detect(),
97        }
98    }
99
100    /// get pattern length
101    #[inline]
102    pub fn len(&self) -> usize {
103        self.pattern.len()
104    }
105
106    /// check if pattern is empty
107    #[inline]
108    pub fn is_empty(&self) -> bool {
109        self.pattern.is_empty()
110    }
111
112    /// scan data for pattern, returns offsets of all matches
113    pub fn scan(&self, data: &[u8]) -> Vec<usize> {
114        if self.pattern.is_empty() || data.len() < self.pattern.len() {
115            return Vec::new();
116        }
117
118        // if pattern is all wildcards, match everything
119        if self.first_concrete_idx.is_none() {
120            return (0..=data.len() - self.pattern.len()).collect();
121        }
122
123        match self.simd_level {
124            #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
125            SimdLevel::Avx2 => unsafe { self.scan_avx2(data) },
126            #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
127            SimdLevel::Sse2 => unsafe { self.scan_sse2(data) },
128            _ => self.scan_scalar(data),
129        }
130    }
131
132    /// scan data for first match only
133    pub fn scan_first(&self, data: &[u8]) -> Option<usize> {
134        if self.pattern.is_empty() || data.len() < self.pattern.len() {
135            return None;
136        }
137
138        if self.first_concrete_idx.is_none() {
139            return Some(0);
140        }
141
142        match self.simd_level {
143            #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
144            SimdLevel::Avx2 => unsafe { self.scan_first_avx2(data) },
145            #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
146            SimdLevel::Sse2 => unsafe { self.scan_first_sse2(data) },
147            _ => self.scan_first_scalar(data),
148        }
149    }
150
151    /// scalar fallback implementation
152    fn scan_scalar(&self, data: &[u8]) -> Vec<usize> {
153        let mut results = Vec::new();
154        let max_offset = data.len() - self.pattern.len();
155
156        for offset in 0..=max_offset {
157            if self.matches_at(data, offset) {
158                results.push(offset);
159            }
160        }
161
162        results
163    }
164
165    /// scalar first match
166    fn scan_first_scalar(&self, data: &[u8]) -> Option<usize> {
167        let max_offset = data.len() - self.pattern.len();
168
169        for offset in 0..=max_offset {
170            if self.matches_at(data, offset) {
171                return Some(offset);
172            }
173        }
174
175        None
176    }
177
178    /// check if pattern matches at offset
179    #[inline]
180    fn matches_at(&self, data: &[u8], offset: usize) -> bool {
181        self.pattern
182            .iter()
183            .zip(self.mask.iter())
184            .enumerate()
185            .all(|(i, (&pattern_byte, &is_wildcard))| {
186                is_wildcard || data[offset + i] == pattern_byte
187            })
188    }
189
190    /// AVX2 accelerated scan (256-bit, 32 bytes at a time)
191    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
192    #[target_feature(enable = "avx2")]
193    unsafe fn scan_avx2(&self, data: &[u8]) -> Vec<usize> {
194        let mut results = Vec::new();
195        let pattern_len = self.pattern.len();
196        let first_idx = self.first_concrete_idx.unwrap();
197
198        if data.len() < pattern_len {
199            return results;
200        }
201
202        let max_offset = data.len() - pattern_len;
203
204        // SAFETY: avx2 is guaranteed available by target_feature
205        // broadcast the first concrete byte to all lanes
206        let needle = unsafe { _mm256_set1_epi8(self.first_concrete_byte as i8) };
207
208        // we search for first_concrete_byte, accounting for its position in pattern
209        // the first concrete byte can appear at positions first_idx through max_offset + first_idx
210        let search_start = first_idx;
211        let search_end = max_offset + first_idx + 1; // +1 for exclusive end
212
213        if search_end <= search_start {
214            return self.scan_scalar(data);
215        }
216
217        let mut pos = search_start;
218
219        // process 32 bytes at a time
220        while pos + 32 <= search_end {
221            // SAFETY: bounds checked, avx2 available
222            let chunk = unsafe { _mm256_loadu_si256(data.as_ptr().add(pos) as *const __m256i) };
223            let cmp = unsafe { _mm256_cmpeq_epi8(chunk, needle) };
224            let mut mask = unsafe { _mm256_movemask_epi8(cmp) } as u32;
225
226            while mask != 0 {
227                let bit_pos = mask.trailing_zeros() as usize;
228                let candidate_offset = pos + bit_pos - first_idx;
229
230                if candidate_offset <= max_offset && self.matches_at(data, candidate_offset) {
231                    results.push(candidate_offset);
232                }
233
234                mask &= mask - 1; // clear lowest set bit
235            }
236
237            pos += 32;
238        }
239
240        // handle remaining bytes with scalar
241        while pos < search_end {
242            if data[pos] == self.first_concrete_byte {
243                let candidate_offset = pos - first_idx;
244                if candidate_offset <= max_offset && self.matches_at(data, candidate_offset) {
245                    results.push(candidate_offset);
246                }
247            }
248            pos += 1;
249        }
250
251        results
252    }
253
254    /// AVX2 first match
255    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
256    #[target_feature(enable = "avx2")]
257    unsafe fn scan_first_avx2(&self, data: &[u8]) -> Option<usize> {
258        let pattern_len = self.pattern.len();
259        let first_idx = self.first_concrete_idx.unwrap();
260
261        if data.len() < pattern_len {
262            return None;
263        }
264
265        let max_offset = data.len() - pattern_len;
266        // SAFETY: avx2 guaranteed by target_feature
267        let needle = unsafe { _mm256_set1_epi8(self.first_concrete_byte as i8) };
268
269        let search_start = first_idx;
270        let search_end = max_offset + first_idx + 1; // +1 for exclusive end
271
272        if search_end <= search_start {
273            return self.scan_first_scalar(data);
274        }
275
276        let mut pos = search_start;
277
278        while pos + 32 <= search_end {
279            // SAFETY: bounds checked, avx2 available
280            let chunk = unsafe { _mm256_loadu_si256(data.as_ptr().add(pos) as *const __m256i) };
281            let cmp = unsafe { _mm256_cmpeq_epi8(chunk, needle) };
282            let mut mask = unsafe { _mm256_movemask_epi8(cmp) } as u32;
283
284            while mask != 0 {
285                let bit_pos = mask.trailing_zeros() as usize;
286                let candidate_offset = pos + bit_pos - first_idx;
287
288                if candidate_offset <= max_offset && self.matches_at(data, candidate_offset) {
289                    return Some(candidate_offset);
290                }
291
292                mask &= mask - 1;
293            }
294
295            pos += 32;
296        }
297
298        // scalar remainder
299        while pos < search_end {
300            if data[pos] == self.first_concrete_byte {
301                let candidate_offset = pos - first_idx;
302                if candidate_offset <= max_offset && self.matches_at(data, candidate_offset) {
303                    return Some(candidate_offset);
304                }
305            }
306            pos += 1;
307        }
308
309        None
310    }
311
312    /// SSE2 accelerated scan (128-bit, 16 bytes at a time)
313    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
314    #[target_feature(enable = "sse2")]
315    unsafe fn scan_sse2(&self, data: &[u8]) -> Vec<usize> {
316        let mut results = Vec::new();
317        let pattern_len = self.pattern.len();
318        let first_idx = self.first_concrete_idx.unwrap();
319
320        if data.len() < pattern_len {
321            return results;
322        }
323
324        let max_offset = data.len() - pattern_len;
325        // SAFETY: sse2 guaranteed by target_feature
326        let needle = unsafe { _mm_set1_epi8(self.first_concrete_byte as i8) };
327
328        let search_start = first_idx;
329        let search_end = max_offset + first_idx + 1; // +1 for exclusive end
330
331        if search_end <= search_start {
332            return self.scan_scalar(data);
333        }
334
335        let mut pos = search_start;
336
337        // process 16 bytes at a time
338        while pos + 16 <= search_end {
339            // SAFETY: bounds checked, sse2 available
340            let chunk = unsafe { _mm_loadu_si128(data.as_ptr().add(pos) as *const __m128i) };
341            let cmp = unsafe { _mm_cmpeq_epi8(chunk, needle) };
342            let mut mask = unsafe { _mm_movemask_epi8(cmp) } as u16;
343
344            while mask != 0 {
345                let bit_pos = mask.trailing_zeros() as usize;
346                let candidate_offset = pos + bit_pos - first_idx;
347
348                if candidate_offset <= max_offset && self.matches_at(data, candidate_offset) {
349                    results.push(candidate_offset);
350                }
351
352                mask &= mask - 1;
353            }
354
355            pos += 16;
356        }
357
358        // scalar remainder
359        while pos < search_end {
360            if data[pos] == self.first_concrete_byte {
361                let candidate_offset = pos - first_idx;
362                if candidate_offset <= max_offset && self.matches_at(data, candidate_offset) {
363                    results.push(candidate_offset);
364                }
365            }
366            pos += 1;
367        }
368
369        results
370    }
371
372    /// SSE2 first match
373    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
374    #[target_feature(enable = "sse2")]
375    unsafe fn scan_first_sse2(&self, data: &[u8]) -> Option<usize> {
376        let pattern_len = self.pattern.len();
377        let first_idx = self.first_concrete_idx.unwrap();
378
379        if data.len() < pattern_len {
380            return None;
381        }
382
383        let max_offset = data.len() - pattern_len;
384        // SAFETY: sse2 guaranteed by target_feature
385        let needle = unsafe { _mm_set1_epi8(self.first_concrete_byte as i8) };
386
387        let search_start = first_idx;
388        let search_end = max_offset + first_idx + 1; // +1 for exclusive end
389
390        if search_end <= search_start {
391            return self.scan_first_scalar(data);
392        }
393
394        let mut pos = search_start;
395
396        while pos + 16 <= search_end {
397            // SAFETY: bounds checked, sse2 available
398            let chunk = unsafe { _mm_loadu_si128(data.as_ptr().add(pos) as *const __m128i) };
399            let cmp = unsafe { _mm_cmpeq_epi8(chunk, needle) };
400            let mut mask = unsafe { _mm_movemask_epi8(cmp) } as u16;
401
402            while mask != 0 {
403                let bit_pos = mask.trailing_zeros() as usize;
404                let candidate_offset = pos + bit_pos - first_idx;
405
406                if candidate_offset <= max_offset && self.matches_at(data, candidate_offset) {
407                    return Some(candidate_offset);
408                }
409
410                mask &= mask - 1;
411            }
412
413            pos += 16;
414        }
415
416        // scalar remainder
417        while pos < search_end {
418            if data[pos] == self.first_concrete_byte {
419                let candidate_offset = pos - first_idx;
420                if candidate_offset <= max_offset && self.matches_at(data, candidate_offset) {
421                    return Some(candidate_offset);
422                }
423            }
424            pos += 1;
425        }
426
427        None
428    }
429}
430
431/// scan data for pattern using SIMD when available
432///
433/// pattern format: space-separated hex bytes, `?` or `??` for wildcards
434pub fn simd_scan(data: &[u8], pattern: &str) -> Vec<usize> {
435    let (bytes, mask) = match parse_pattern(pattern) {
436        Some(p) => p,
437        None => return Vec::new(),
438    };
439
440    let scanner = SimdScanner::new(bytes, mask);
441    scanner.scan(data)
442}
443
444/// scan data for first pattern match using SIMD
445pub fn simd_scan_first(data: &[u8], pattern: &str) -> Option<usize> {
446    let (bytes, mask) = parse_pattern(pattern)?;
447    let scanner = SimdScanner::new(bytes, mask);
448    scanner.scan_first(data)
449}
450
451/// parse IDA-style pattern into bytes and mask
452fn parse_pattern(pattern: &str) -> Option<(Vec<u8>, Vec<bool>)> {
453    let trimmed = pattern.trim();
454    if trimmed.is_empty() {
455        return None;
456    }
457
458    let parts: Vec<&str> = trimmed.split_whitespace().collect();
459    if parts.is_empty() {
460        return None;
461    }
462
463    let mut bytes = Vec::with_capacity(parts.len());
464    let mut mask = Vec::with_capacity(parts.len());
465
466    for part in parts {
467        if part == "?" || part == "??" || part == "*" || part == "**" {
468            bytes.push(0);
469            mask.push(true);
470        } else {
471            let byte = u8::from_str_radix(part, 16).ok()?;
472            bytes.push(byte);
473            mask.push(false);
474        }
475    }
476
477    Some((bytes, mask))
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483
484    #[test]
485    fn test_simd_level_detect() {
486        let level = SimdLevel::detect();
487        // should at least detect something on modern CPUs
488        println!("Detected SIMD level: {:?}", level);
489    }
490
491    #[test]
492    fn test_simd_scan_simple() {
493        let data = [0x48, 0x8B, 0x05, 0x12, 0x34, 0x56, 0x78, 0x90];
494        let results = simd_scan(&data, "48 8B 05");
495        assert_eq!(results, vec![0]);
496    }
497
498    #[test]
499    fn test_simd_scan_wildcards() {
500        let data = [0x48, 0x8B, 0x05, 0x12, 0x34, 0x56, 0x78, 0x90];
501        let results = simd_scan(&data, "48 8B ?? ?? 34");
502        assert_eq!(results, vec![0]);
503    }
504
505    #[test]
506    fn test_simd_scan_no_match() {
507        let data = [0x48, 0x8B, 0x05, 0x12, 0x34, 0x56, 0x78, 0x90];
508        let results = simd_scan(&data, "FF FF FF");
509        assert!(results.is_empty());
510    }
511
512    #[test]
513    fn test_simd_scan_multiple_matches() {
514        let data = [0x48, 0x8B, 0x48, 0x8B, 0x48, 0x8B];
515        let results = simd_scan(&data, "48 8B");
516        assert_eq!(results, vec![0, 2, 4]);
517    }
518
519    #[test]
520    fn test_simd_scan_first() {
521        let data = [0x48, 0x8B, 0x48, 0x8B, 0x48, 0x8B];
522        let result = simd_scan_first(&data, "48 8B");
523        assert_eq!(result, Some(0));
524    }
525
526    #[test]
527    fn test_simd_scan_large_data() {
528        // test with data larger than SIMD register width
529        let mut data = vec![0u8; 1024];
530        data[500] = 0xDE;
531        data[501] = 0xAD;
532        data[502] = 0xBE;
533        data[503] = 0xEF;
534
535        let results = simd_scan(&data, "DE AD BE EF");
536        assert_eq!(results, vec![500]);
537    }
538
539    #[test]
540    fn test_simd_scan_pattern_at_end() {
541        let data = [0x00, 0x00, 0x00, 0x48, 0x8B];
542        let results = simd_scan(&data, "48 8B");
543        assert_eq!(results, vec![3]);
544    }
545
546    #[test]
547    fn test_simd_scan_wildcard_first() {
548        // pattern starts with wildcard - this exercises the edge case
549        let data = [0x48, 0x8B, 0x05, 0x12, 0x34];
550        let results = simd_scan(&data, "?? 8B 05");
551        assert_eq!(results, vec![0]);
552    }
553}