Skip to main content

fhp_simd/
avx2.rs

1//! AVX2 accelerated operations (256-bit, x86_64).
2//!
3//! Processes 32 bytes at a time — double the throughput of SSE4.2.
4//! Falls back to scalar for the tail bytes.
5
6#[cfg(target_arch = "x86_64")]
7use core::arch::x86_64::*;
8
9use crate::{DelimiterResult, classify_byte};
10
11/// Scan `haystack` for the first HTML delimiter using AVX2 (256-bit).
12///
13/// Processes 32 bytes per iteration. Uses `_mm256_cmpeq_epi8` for each
14/// delimiter then combines with OR, extracting a 32-bit bitmask via
15/// `_mm256_movemask_epi8`.
16///
17/// # Safety
18///
19/// Caller must ensure the CPU supports AVX2 (`is_x86_feature_detected!("avx2")`).
20#[target_feature(enable = "avx2")]
21#[cfg(target_arch = "x86_64")]
22pub unsafe fn find_delimiters(haystack: &[u8]) -> DelimiterResult {
23    let len = haystack.len();
24    let ptr = haystack.as_ptr();
25    let mut offset = 0;
26
27    // SAFETY: all intrinsics below require AVX2, guaranteed by #[target_feature].
28    unsafe {
29        // Broadcast each delimiter byte into a 256-bit register.
30        let lt = _mm256_set1_epi8(b'<' as i8);
31        let gt = _mm256_set1_epi8(b'>' as i8);
32        let amp = _mm256_set1_epi8(b'&' as i8);
33        let quot = _mm256_set1_epi8(b'"' as i8);
34        let apos = _mm256_set1_epi8(b'\'' as i8);
35        let eq = _mm256_set1_epi8(b'=' as i8);
36        let slash = _mm256_set1_epi8(b'/' as i8);
37
38        while offset + 32 <= len {
39            // Load 32 bytes (unaligned).
40            let chunk = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
41
42            // Compare 32 bytes against each delimiter simultaneously.
43            let cmp_lt = _mm256_cmpeq_epi8(chunk, lt);
44            let cmp_gt = _mm256_cmpeq_epi8(chunk, gt);
45            let cmp_amp = _mm256_cmpeq_epi8(chunk, amp);
46            let cmp_quot = _mm256_cmpeq_epi8(chunk, quot);
47            let cmp_apos = _mm256_cmpeq_epi8(chunk, apos);
48            let cmp_eq = _mm256_cmpeq_epi8(chunk, eq);
49            let cmp_slash = _mm256_cmpeq_epi8(chunk, slash);
50
51            // OR all comparisons together.
52            let combined = _mm256_or_si256(
53                _mm256_or_si256(
54                    _mm256_or_si256(cmp_lt, cmp_gt),
55                    _mm256_or_si256(cmp_amp, cmp_quot),
56                ),
57                _mm256_or_si256(_mm256_or_si256(cmp_apos, cmp_eq), cmp_slash),
58            );
59
60            // Extract a 32-bit mask: bit i is set if byte i matched any delimiter.
61            let mask = _mm256_movemask_epi8(combined) as u32;
62            if mask != 0 {
63                // trailing_zeros gives the index of the first set bit.
64                let pos = offset + mask.trailing_zeros() as usize;
65                return DelimiterResult::Found {
66                    pos,
67                    byte: *ptr.add(pos),
68                };
69            }
70            offset += 32;
71        }
72    }
73
74    // Scalar tail for remaining < 32 bytes.
75    crate::scalar::find_delimiters_safe(&haystack[offset..]).offset_by(offset)
76}
77
78/// Classify each byte using AVX2 — 32 bytes at a time.
79///
80/// Uses the same algorithm as SSE4.2 but with 256-bit registers for
81/// double throughput.
82///
83/// # Safety
84///
85/// Caller must ensure AVX2 support.
86#[target_feature(enable = "avx2")]
87#[cfg(target_arch = "x86_64")]
88pub unsafe fn classify_bytes(input: &[u8]) -> Vec<u8> {
89    let len = input.len();
90    let mut result = Vec::with_capacity(len);
91    let ptr = input.as_ptr();
92    let out_ptr: *mut u8 = result.as_mut_ptr();
93    let mut offset = 0;
94
95    // SAFETY: all intrinsics below require AVX2, guaranteed by #[target_feature].
96    // Pointer arithmetic is valid because offset < len and result has capacity >= len.
97    unsafe {
98        while offset + 32 <= len {
99            let chunk = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
100
101            // Whitespace: space, tab, newline, CR.
102            let ws_mask = _mm256_or_si256(
103                _mm256_or_si256(
104                    _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b' ' as i8)),
105                    _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'\t' as i8)),
106                ),
107                _mm256_or_si256(
108                    _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'\n' as i8)),
109                    _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'\r' as i8)),
110                ),
111            );
112
113            // Alpha: (b | 0x20) - 'a' <= 25 (unsigned)
114            let lower = _mm256_or_si256(chunk, _mm256_set1_epi8(0x20));
115            let sub = _mm256_sub_epi8(lower, _mm256_set1_epi8(b'a' as i8));
116            let clamped = _mm256_min_epu8(sub, _mm256_set1_epi8(25));
117            let alpha_mask = _mm256_cmpeq_epi8(sub, clamped);
118
119            // Digit: b - '0' <= 9 (unsigned)
120            let sub_d = _mm256_sub_epi8(chunk, _mm256_set1_epi8(b'0' as i8));
121            let dclamped = _mm256_min_epu8(sub_d, _mm256_set1_epi8(9));
122            let digit_mask = _mm256_cmpeq_epi8(sub_d, dclamped);
123
124            // Delimiters: 7 equality comparisons OR'd together.
125            let delim_mask = _mm256_or_si256(
126                _mm256_or_si256(
127                    _mm256_or_si256(
128                        _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'<' as i8)),
129                        _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'>' as i8)),
130                    ),
131                    _mm256_or_si256(
132                        _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'&' as i8)),
133                        _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'"' as i8)),
134                    ),
135                ),
136                _mm256_or_si256(
137                    _mm256_or_si256(
138                        _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'\'' as i8)),
139                        _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'=' as i8)),
140                    ),
141                    _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'/' as i8)),
142                ),
143            );
144
145            // Map masks to class constants and combine.
146            let combined = _mm256_or_si256(
147                _mm256_or_si256(
148                    _mm256_and_si256(ws_mask, _mm256_set1_epi8(crate::class::WHITESPACE as i8)),
149                    _mm256_and_si256(alpha_mask, _mm256_set1_epi8(crate::class::ALPHA as i8)),
150                ),
151                _mm256_or_si256(
152                    _mm256_and_si256(digit_mask, _mm256_set1_epi8(crate::class::DIGIT as i8)),
153                    _mm256_and_si256(delim_mask, _mm256_set1_epi8(crate::class::DELIMITER as i8)),
154                ),
155            );
156
157            _mm256_storeu_si256(out_ptr.add(offset) as *mut __m256i, combined);
158            offset += 32;
159        }
160
161        // Scalar tail.
162        while offset < len {
163            *out_ptr.add(offset) = classify_byte(*ptr.add(offset));
164            offset += 1;
165        }
166
167        result.set_len(len);
168    }
169
170    result
171}
172
173/// Skip leading whitespace using AVX2 — 32 bytes at a time.
174///
175/// # Safety
176///
177/// Caller must ensure AVX2 support.
178#[target_feature(enable = "avx2")]
179#[cfg(target_arch = "x86_64")]
180pub unsafe fn skip_whitespace(input: &[u8]) -> usize {
181    let len = input.len();
182    let ptr = input.as_ptr();
183    let mut offset = 0;
184
185    // SAFETY: all intrinsics below require AVX2, guaranteed by #[target_feature].
186    unsafe {
187        while offset + 32 <= len {
188            let chunk = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
189
190            // Check all 4 whitespace bytes.
191            let ws_mask = _mm256_or_si256(
192                _mm256_or_si256(
193                    _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b' ' as i8)),
194                    _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'\t' as i8)),
195                ),
196                _mm256_or_si256(
197                    _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'\n' as i8)),
198                    _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'\r' as i8)),
199                ),
200            );
201
202            // movemask: bit i = 1 if byte i is whitespace.
203            let mask = _mm256_movemask_epi8(ws_mask) as u32;
204
205            if mask != 0xFFFF_FFFF {
206                // Not all whitespace — find first non-WS byte.
207                let non_ws = !mask;
208                return offset + non_ws.trailing_zeros() as usize;
209            }
210            offset += 32;
211        }
212    }
213
214    // Scalar tail.
215    offset + crate::scalar::skip_whitespace_safe(&input[offset..])
216}
217
218/// Produce a bitmask where bit `i` is set if `block[i] == byte`.
219///
220/// Processes 32 bytes at a time using `_mm256_cmpeq_epi8` +
221/// `_mm256_movemask_epi8`. Handles blocks up to 64 bytes.
222///
223/// # Safety
224///
225/// Caller must ensure the CPU supports AVX2.
226#[target_feature(enable = "avx2")]
227#[cfg(target_arch = "x86_64")]
228pub unsafe fn compute_byte_mask(block: &[u8], byte: u8) -> u64 {
229    let len = block.len();
230    let ptr = block.as_ptr();
231    let mut result: u64 = 0;
232    let mut offset = 0;
233
234    // SAFETY: all intrinsics below require AVX2, guaranteed by #[target_feature].
235    unsafe {
236        let target = _mm256_set1_epi8(byte as i8);
237
238        while offset + 32 <= len {
239            let chunk = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
240            let cmp = _mm256_cmpeq_epi8(chunk, target);
241            let mask = _mm256_movemask_epi8(cmp) as u32;
242            result |= (mask as u64) << offset;
243            offset += 32;
244        }
245    }
246
247    // Scalar tail.
248    while offset < len {
249        // SAFETY: offset < len, so ptr.add(offset) is valid.
250        if unsafe { *ptr.add(offset) } == byte {
251            result |= 1u64 << offset;
252        }
253        offset += 1;
254    }
255
256    result
257}
258
259#[cfg(all(test, target_arch = "x86_64"))]
260mod tests {
261    use super::*;
262    use crate::class;
263
264    fn has_avx2() -> bool {
265        is_x86_feature_detected!("avx2")
266    }
267
268    #[test]
269    fn find_delimiters_basic() {
270        if !has_avx2() {
271            return;
272        }
273        // 35 bytes — crosses the 32-byte boundary.
274        let input = b"abcdefghijklmnopqrstuvwxyz12345<div>";
275        let result = unsafe { find_delimiters(input) };
276        assert_eq!(
277            result,
278            DelimiterResult::Found {
279                pos: 31,
280                byte: b'<'
281            },
282        );
283    }
284
285    #[test]
286    fn find_delimiters_not_found() {
287        if !has_avx2() {
288            return;
289        }
290        let input = b"abcdefghijklmnopqrstuvwxyz0123456789";
291        let result = unsafe { find_delimiters(input) };
292        assert_eq!(result, DelimiterResult::NotFound);
293    }
294
295    #[test]
296    fn classify_bytes_basic() {
297        if !has_avx2() {
298            return;
299        }
300        // 36 bytes to cross 32-byte boundary.
301        let input = b"abcd 1234\t<>&&\"'/=\nABCDwxyz09......";
302        let result = unsafe { classify_bytes(input) };
303        assert_eq!(result[0], class::ALPHA);
304        assert_eq!(result[4], class::WHITESPACE);
305        assert_eq!(result[5], class::DIGIT);
306        assert_eq!(result[10], class::DELIMITER); // '<'
307    }
308
309    #[test]
310    fn skip_whitespace_basic() {
311        if !has_avx2() {
312            return;
313        }
314        // 40 spaces then a letter — crosses 32-byte boundary.
315        let input = b"                                        X";
316        let result = unsafe { skip_whitespace(input) };
317        assert_eq!(result, 40);
318    }
319}