simd_bmh/
lib.rs

1#![no_std]
2#![feature(portable_simd)]
3#![allow(dead_code)]
4#![allow(unused_imports)]
5#![feature(maybe_uninit_uninit_array)]
6
7extern crate alloc;
8use crate::alloc::vec::Vec;
9
10use core::simd::{Simd, cmp::SimdPartialEq, LaneCount, SupportedLaneCount};
11use core::ops::{BitAnd, BitAndAssign};
12use core::arch::x86_64::*;
13use log::debug;
14use simd_bmh_macro::parse_pattern;
15
16#[derive(Clone, Debug)]
17#[repr(align(32))]
18pub struct Pattern<const N: usize> {
19    pub bytes: [u8; N],
20    pub masks: [u8; N],
21    pub best_skip_value: u8,
22    pub best_skip_mask: u8,
23    pub max_skip: usize,
24    pub best_skip_offset: usize,
25    pub shift_table: [usize; 256],
26}
27
28impl<const N: usize> Pattern<N> {
29    #[inline(always)]
30    pub fn find_all_matches(&self, text: &[u8]) -> Vec<usize> {
31        find_all_matches_sse::<N>(text, self)
32    }
33}
34
35pub fn find_all_matches_sse<const PATTERN_LEN: usize>(text: &[u8], pattern: &Pattern<PATTERN_LEN>) -> Vec<usize> {
36    if PATTERN_LEN > text.len() {
37        return Vec::new();
38    }
39
40    let mut matches = Vec::new();
41    let mut i = 0;
42
43    let best_skip = pattern.best_skip_value as i32;
44    let best_mask = pattern.best_skip_mask as i32;
45    let best_skip_offset = pattern.best_skip_offset as i32;
46
47    unsafe {
48        let skip_vector = _mm_set1_epi8(best_skip as i8);
49        let mask_vector = _mm_set1_epi8(best_mask as i8);
50
51        while i + 16 <= text.len() {
52            let mut match_masks = _mm_setzero_si128();
53            let chunk = _mm_loadu_si128(text.as_ptr().add(i) as *const __m128i);
54            let masked_chunk = _mm_and_si128(chunk, mask_vector);
55            let cmp_result = _mm_cmpeq_epi8(masked_chunk, skip_vector);
56            match_masks = _mm_or_si128(match_masks, cmp_result);
57
58            let match_positions = _mm_movemask_epi8(match_masks);
59            if match_positions != 0 {
60                for pos in 0..16 {
61                    if (match_positions & (1 << pos)) != 0 {
62                        let match_pos = i + pos;
63                        let start_pos = match_pos - best_skip_offset as usize;
64                        
65                        let mut valid = true;
66                        for k in 0..PATTERN_LEN {
67                            let pattern_byte = pattern.bytes[k];
68                            let pattern_mask = pattern.masks[k];
69                            let text_index = start_pos + k;
70
71                            let masked_pattern_byte = pattern_byte & pattern_mask;
72                            let masked_text_byte = text[text_index] & pattern_mask;
73                            if masked_text_byte != masked_pattern_byte {
74                                valid = false;
75                                break;
76                            }
77                        }
78
79                        if valid {
80                            matches.push(start_pos);
81                        }
82                    }
83                }
84            }
85
86            i += 16;
87        }
88    }
89
90    while i + PATTERN_LEN <= text.len() {
91        let start_pos = i;
92        let mut match_found = true;
93
94        for k in 0..PATTERN_LEN {
95            let pattern_byte = pattern.bytes[k];
96            let pattern_mask = pattern.masks[k];
97            let text_index = start_pos + k;
98
99            let masked_pattern_byte = pattern_byte & pattern_mask;
100            let masked_text_byte = text[text_index] & pattern_mask;
101
102            if masked_text_byte != masked_pattern_byte {
103                match_found = false;
104                break;
105            }
106        }
107
108        if match_found {
109            matches.push(start_pos);
110            i += PATTERN_LEN;
111        } else {
112            i += pattern.shift_table[text[start_pos + PATTERN_LEN - 1] as usize];
113        }
114    }
115
116    matches
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use alloc::vec;
123    use rand::Rng;
124
125    #[test]
126    fn test_parse_pattern() {
127        let pattern = parse_pattern!("A?C?FF");
128        assert_eq!(&pattern.bytes[..], &[0xA0, 0xC0, 0xFF]);
129        assert_eq!(&pattern.masks[..], &[0xF0, 0xF0, 0xFF]);
130    }
131
132    #[test]
133    fn test_match() {
134        let pattern = parse_pattern!("A?C?FF");
135        let text = b"\xA0\xC0\xFF\x00\xA0\xC0\xFF";
136
137        let matches = pattern.find_all_matches(text);
138        assert_eq!(matches, [0, 4]);
139    }
140    
141    #[test]
142    fn test_random_pool_with_fixed_pattern() {
143        let buffer_size = 2_000;
144        let mut random_buffer: Vec<u8> = (0..buffer_size).map(|_| rand::rng().random()).collect();
145        random_buffer[1337..1342].copy_from_slice(b"\xAA\xCC\xFF\xFF\xFF");
146
147        let pattern = parse_pattern!("A?C?FF");
148        let matches = find_all_matches_sse(&random_buffer, &pattern);
149        assert!(!matches.is_empty(), "Pattern matches should not be empty!");
150    }
151}