Skip to main content

fhp_simd/
sse42.rs

1//! SSE4.2 accelerated operations (128-bit, x86_64).
2//!
3//! Each function is marked `#[target_feature(enable = "sse4.2")]` so the
4//! compiler generates SSE4.2 instructions without requiring the entire
5//! crate to be compiled with `-C target-feature=+sse4.2`.
6
7#[cfg(target_arch = "x86_64")]
8use core::arch::x86_64::*;
9
10use crate::{DelimiterResult, classify_byte};
11
12/// Scan `haystack` for the first HTML delimiter using SSE4.2 (128-bit).
13///
14/// Processes 16 bytes at a time. Falls back to scalar for the tail.
15///
16/// # Safety
17///
18/// Caller must ensure the CPU supports SSE4.2 (`is_x86_feature_detected!("sse4.2")`).
19#[target_feature(enable = "sse4.2")]
20#[cfg(target_arch = "x86_64")]
21pub unsafe fn find_delimiters(haystack: &[u8]) -> DelimiterResult {
22    let len = haystack.len();
23    let ptr = haystack.as_ptr();
24    let mut offset = 0;
25
26    // SAFETY: all intrinsics below require SSE4.2, guaranteed by #[target_feature].
27    unsafe {
28        // Load all 7 delimiter bytes into a single 128-bit register.
29        // _mm_cmpistri can compare against up to 16 needle bytes.
30        let delims = _mm_set_epi8(
31            0,
32            0,
33            0,
34            0,
35            0,
36            0,
37            0,
38            0,
39            0, // padding (unused slots)
40            b'/' as i8,
41            b'=' as i8,
42            b'\'' as i8,
43            b'"' as i8,
44            b'&' as i8,
45            b'>' as i8,
46            b'<' as i8,
47        );
48
49        while offset + 16 <= len {
50            // Load 16 bytes from haystack (unaligned).
51            let chunk = _mm_loadu_si128(ptr.add(offset) as *const __m128i);
52
53            // _mm_cmpistri: compare each byte in `chunk` against the set of
54            // bytes in `delims`. Returns the index of the *first* matching
55            // byte (0..15), or 16 if no match.
56            // Mode 0x00 = SIDD_UBYTE_OPS | SIDD_CMP_EQUAL_ANY | SIDD_LEAST_SIGNIFICANT
57            let idx = _mm_cmpistri(delims, chunk, 0x00);
58            if idx < 16 {
59                let pos = offset + idx as usize;
60                return DelimiterResult::Found {
61                    pos,
62                    byte: *ptr.add(pos),
63                };
64            }
65            offset += 16;
66        }
67    }
68
69    // Scalar tail for remaining < 16 bytes.
70    crate::scalar::find_delimiters_safe(&haystack[offset..]).offset_by(offset)
71}
72
73/// Classify each byte using SSE4.2 SIMD — 16 bytes at a time.
74///
75/// # Safety
76///
77/// Caller must ensure SSE4.2 support.
78#[target_feature(enable = "sse4.2")]
79#[cfg(target_arch = "x86_64")]
80pub unsafe fn classify_bytes(input: &[u8]) -> Vec<u8> {
81    let len = input.len();
82    let mut result = Vec::with_capacity(len);
83    let ptr = input.as_ptr();
84    let out_ptr: *mut u8 = result.as_mut_ptr();
85    let mut offset = 0;
86
87    // SAFETY: all intrinsics below require SSE4.2, guaranteed by #[target_feature].
88    // Pointer arithmetic is valid because offset < len and result has capacity >= len.
89    unsafe {
90        while offset + 16 <= len {
91            let chunk = _mm_loadu_si128(ptr.add(offset) as *const __m128i);
92
93            // Whitespace check: compare against space, tab, newline, CR.
94            let space = _mm_set1_epi8(b' ' as i8);
95            let tab = _mm_set1_epi8(b'\t' as i8);
96            let nl = _mm_set1_epi8(b'\n' as i8);
97            let cr = _mm_set1_epi8(b'\r' as i8);
98
99            // _mm_cmpeq_epi8: 0xFF where equal, 0x00 where not.
100            let ws_mask = _mm_or_si128(
101                _mm_or_si128(_mm_cmpeq_epi8(chunk, space), _mm_cmpeq_epi8(chunk, tab)),
102                _mm_or_si128(_mm_cmpeq_epi8(chunk, nl), _mm_cmpeq_epi8(chunk, cr)),
103            );
104
105            // Alpha check: (b | 0x20) - 'a' <= 25 (unsigned)
106            let or_mask = _mm_set1_epi8(0x20);
107            let lower = _mm_or_si128(chunk, or_mask); // force lowercase
108            let a_val = _mm_set1_epi8(b'a' as i8);
109            let sub = _mm_sub_epi8(lower, a_val); // lower - 'a'
110            let bound = _mm_set1_epi8(25); // 'z' - 'a'
111            // Unsigned compare: alpha if sub == min(sub, 25)
112            let clamped = _mm_min_epu8(sub, bound);
113            let alpha_mask = _mm_cmpeq_epi8(sub, clamped);
114
115            // Digit check: b - '0' <= 9 (unsigned)
116            let zero = _mm_set1_epi8(b'0' as i8);
117            let sub_d = _mm_sub_epi8(chunk, zero);
118            let dbound = _mm_set1_epi8(9);
119            let dclamped = _mm_min_epu8(sub_d, dbound);
120            let digit_mask = _mm_cmpeq_epi8(sub_d, dclamped);
121
122            // Delimiter check: compare against each of the 7 delimiters.
123            let lt = _mm_set1_epi8(b'<' as i8);
124            let gt = _mm_set1_epi8(b'>' as i8);
125            let amp = _mm_set1_epi8(b'&' as i8);
126            let quot = _mm_set1_epi8(b'"' as i8);
127            let apos = _mm_set1_epi8(b'\'' as i8);
128            let eq = _mm_set1_epi8(b'=' as i8);
129            let slash = _mm_set1_epi8(b'/' as i8);
130
131            let delim_mask = _mm_or_si128(
132                _mm_or_si128(
133                    _mm_or_si128(_mm_cmpeq_epi8(chunk, lt), _mm_cmpeq_epi8(chunk, gt)),
134                    _mm_or_si128(_mm_cmpeq_epi8(chunk, amp), _mm_cmpeq_epi8(chunk, quot)),
135                ),
136                _mm_or_si128(
137                    _mm_or_si128(_mm_cmpeq_epi8(chunk, apos), _mm_cmpeq_epi8(chunk, eq)),
138                    _mm_cmpeq_epi8(chunk, slash),
139                ),
140            );
141
142            // Map to class constants and combine.
143            let ws_class = _mm_and_si128(ws_mask, _mm_set1_epi8(crate::class::WHITESPACE as i8));
144            let al_class = _mm_and_si128(alpha_mask, _mm_set1_epi8(crate::class::ALPHA as i8));
145            let di_class = _mm_and_si128(digit_mask, _mm_set1_epi8(crate::class::DIGIT as i8));
146            let de_class = _mm_and_si128(delim_mask, _mm_set1_epi8(crate::class::DELIMITER as i8));
147
148            let combined = _mm_or_si128(
149                _mm_or_si128(ws_class, al_class),
150                _mm_or_si128(di_class, de_class),
151            );
152
153            // Store 16 bytes of classification results.
154            _mm_storeu_si128(out_ptr.add(offset) as *mut __m128i, combined);
155            offset += 16;
156        }
157
158        // Scalar tail.
159        while offset < len {
160            *out_ptr.add(offset) = classify_byte(*ptr.add(offset));
161            offset += 1;
162        }
163
164        result.set_len(len);
165    }
166
167    result
168}
169
170/// Skip leading whitespace using SSE4.2 — 16 bytes at a time.
171///
172/// # Safety
173///
174/// Caller must ensure SSE4.2 support.
175#[target_feature(enable = "sse4.2")]
176#[cfg(target_arch = "x86_64")]
177pub unsafe fn skip_whitespace(input: &[u8]) -> usize {
178    let len = input.len();
179    let ptr = input.as_ptr();
180    let mut offset = 0;
181
182    // SAFETY: all intrinsics below require SSE4.2, guaranteed by #[target_feature].
183    unsafe {
184        // Whitespace needle: space, tab, newline, CR.
185        let ws = _mm_set_epi8(
186            0,
187            0,
188            0,
189            0,
190            0,
191            0,
192            0,
193            0,
194            0,
195            0,
196            0,
197            0,
198            b'\r' as i8,
199            b'\n' as i8,
200            b'\t' as i8,
201            b' ' as i8,
202        );
203
204        while offset + 16 <= len {
205            let chunk = _mm_loadu_si128(ptr.add(offset) as *const __m128i);
206
207            // _mm_cmpistri with EQUAL_ANY + NEGATIVE_POLARITY:
208            // returns index of first byte NOT in the whitespace set.
209            // Mode 0x10 = SIDD_UBYTE_OPS | SIDD_CMP_EQUAL_ANY | SIDD_NEGATIVE_POLARITY
210            let idx = _mm_cmpistri(ws, chunk, 0x10);
211            if idx < 16 {
212                return offset + idx as usize;
213            }
214            offset += 16;
215        }
216    }
217
218    // Scalar tail.
219    offset + crate::scalar::skip_whitespace_safe(&input[offset..])
220}
221
222/// Produce a bitmask where bit `i` is set if `block[i] == byte`.
223///
224/// Processes 16 bytes at a time using `_mm_cmpeq_epi8` + `_mm_movemask_epi8`.
225/// Handles blocks up to 64 bytes.
226///
227/// # Safety
228///
229/// Caller must ensure the CPU supports SSE4.2.
230#[target_feature(enable = "sse4.2")]
231#[cfg(target_arch = "x86_64")]
232pub unsafe fn compute_byte_mask(block: &[u8], byte: u8) -> u64 {
233    let len = block.len();
234    let ptr = block.as_ptr();
235    let mut result: u64 = 0;
236    let mut offset = 0;
237
238    // SAFETY: all intrinsics below require SSE4.2, guaranteed by #[target_feature].
239    unsafe {
240        let target = _mm_set1_epi8(byte as i8);
241
242        while offset + 16 <= len {
243            let chunk = _mm_loadu_si128(ptr.add(offset) as *const __m128i);
244            let cmp = _mm_cmpeq_epi8(chunk, target);
245            let mask = _mm_movemask_epi8(cmp) as u16;
246            result |= (mask as u64) << offset;
247            offset += 16;
248        }
249    }
250
251    // Scalar tail.
252    while offset < len {
253        // SAFETY: offset < len, so ptr.add(offset) is valid.
254        if unsafe { *ptr.add(offset) } == byte {
255            result |= 1u64 << offset;
256        }
257        offset += 1;
258    }
259
260    result
261}
262
263#[cfg(all(test, target_arch = "x86_64"))]
264mod tests {
265    use super::*;
266    use crate::class;
267
268    fn has_sse42() -> bool {
269        is_x86_feature_detected!("sse4.2")
270    }
271
272    #[test]
273    fn find_delimiters_basic() {
274        if !has_sse42() {
275            return;
276        }
277        let input = b"hello world <div>";
278        let result = unsafe { find_delimiters(input) };
279        assert_eq!(
280            result,
281            DelimiterResult::Found {
282                pos: 12,
283                byte: b'<'
284            }
285        );
286    }
287
288    #[test]
289    fn find_delimiters_not_found() {
290        if !has_sse42() {
291            return;
292        }
293        let input = b"hello world no delimiters here";
294        let result = unsafe { find_delimiters(input) };
295        assert_eq!(result, DelimiterResult::NotFound);
296    }
297
298    #[test]
299    fn classify_bytes_basic() {
300        if !has_sse42() {
301            return;
302        }
303        let input = b"a1 <b2\t>Zz09&\"'/=\nhello world...";
304        let result = unsafe { classify_bytes(input) };
305        assert_eq!(result[0], class::ALPHA);
306        assert_eq!(result[1], class::DIGIT);
307        assert_eq!(result[2], class::WHITESPACE);
308        assert_eq!(result[3], class::DELIMITER);
309    }
310
311    #[test]
312    fn skip_whitespace_basic() {
313        if !has_sse42() {
314            return;
315        }
316        let result = unsafe { skip_whitespace(b"   \t\nhello") };
317        assert_eq!(result, 5);
318    }
319
320    #[test]
321    fn skip_whitespace_all_ws() {
322        if !has_sse42() {
323            return;
324        }
325        let result = unsafe { skip_whitespace(b"                    ") };
326        assert_eq!(result, 20);
327    }
328}