1#[cfg(target_arch = "x86_64")]
7use core::arch::x86_64::*;
8
9use crate::{DelimiterResult, classify_byte};
10
11#[target_feature(enable = "avx2")]
21#[cfg(target_arch = "x86_64")]
22pub unsafe fn find_delimiters(haystack: &[u8]) -> DelimiterResult {
23 let len = haystack.len();
24 let ptr = haystack.as_ptr();
25 let mut offset = 0;
26
27 unsafe {
29 let lt = _mm256_set1_epi8(b'<' as i8);
31 let gt = _mm256_set1_epi8(b'>' as i8);
32 let amp = _mm256_set1_epi8(b'&' as i8);
33 let quot = _mm256_set1_epi8(b'"' as i8);
34 let apos = _mm256_set1_epi8(b'\'' as i8);
35 let eq = _mm256_set1_epi8(b'=' as i8);
36 let slash = _mm256_set1_epi8(b'/' as i8);
37
38 while offset + 32 <= len {
39 let chunk = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
41
42 let cmp_lt = _mm256_cmpeq_epi8(chunk, lt);
44 let cmp_gt = _mm256_cmpeq_epi8(chunk, gt);
45 let cmp_amp = _mm256_cmpeq_epi8(chunk, amp);
46 let cmp_quot = _mm256_cmpeq_epi8(chunk, quot);
47 let cmp_apos = _mm256_cmpeq_epi8(chunk, apos);
48 let cmp_eq = _mm256_cmpeq_epi8(chunk, eq);
49 let cmp_slash = _mm256_cmpeq_epi8(chunk, slash);
50
51 let combined = _mm256_or_si256(
53 _mm256_or_si256(
54 _mm256_or_si256(cmp_lt, cmp_gt),
55 _mm256_or_si256(cmp_amp, cmp_quot),
56 ),
57 _mm256_or_si256(_mm256_or_si256(cmp_apos, cmp_eq), cmp_slash),
58 );
59
60 let mask = _mm256_movemask_epi8(combined) as u32;
62 if mask != 0 {
63 let pos = offset + mask.trailing_zeros() as usize;
65 return DelimiterResult::Found {
66 pos,
67 byte: *ptr.add(pos),
68 };
69 }
70 offset += 32;
71 }
72 }
73
74 crate::scalar::find_delimiters_safe(&haystack[offset..]).offset_by(offset)
76}
77
78#[target_feature(enable = "avx2")]
87#[cfg(target_arch = "x86_64")]
88pub unsafe fn classify_bytes(input: &[u8]) -> Vec<u8> {
89 let len = input.len();
90 let mut result = Vec::with_capacity(len);
91 let ptr = input.as_ptr();
92 let out_ptr: *mut u8 = result.as_mut_ptr();
93 let mut offset = 0;
94
95 unsafe {
98 while offset + 32 <= len {
99 let chunk = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
100
101 let ws_mask = _mm256_or_si256(
103 _mm256_or_si256(
104 _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b' ' as i8)),
105 _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'\t' as i8)),
106 ),
107 _mm256_or_si256(
108 _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'\n' as i8)),
109 _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'\r' as i8)),
110 ),
111 );
112
113 let lower = _mm256_or_si256(chunk, _mm256_set1_epi8(0x20));
115 let sub = _mm256_sub_epi8(lower, _mm256_set1_epi8(b'a' as i8));
116 let clamped = _mm256_min_epu8(sub, _mm256_set1_epi8(25));
117 let alpha_mask = _mm256_cmpeq_epi8(sub, clamped);
118
119 let sub_d = _mm256_sub_epi8(chunk, _mm256_set1_epi8(b'0' as i8));
121 let dclamped = _mm256_min_epu8(sub_d, _mm256_set1_epi8(9));
122 let digit_mask = _mm256_cmpeq_epi8(sub_d, dclamped);
123
124 let delim_mask = _mm256_or_si256(
126 _mm256_or_si256(
127 _mm256_or_si256(
128 _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'<' as i8)),
129 _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'>' as i8)),
130 ),
131 _mm256_or_si256(
132 _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'&' as i8)),
133 _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'"' as i8)),
134 ),
135 ),
136 _mm256_or_si256(
137 _mm256_or_si256(
138 _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'\'' as i8)),
139 _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'=' as i8)),
140 ),
141 _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'/' as i8)),
142 ),
143 );
144
145 let combined = _mm256_or_si256(
147 _mm256_or_si256(
148 _mm256_and_si256(ws_mask, _mm256_set1_epi8(crate::class::WHITESPACE as i8)),
149 _mm256_and_si256(alpha_mask, _mm256_set1_epi8(crate::class::ALPHA as i8)),
150 ),
151 _mm256_or_si256(
152 _mm256_and_si256(digit_mask, _mm256_set1_epi8(crate::class::DIGIT as i8)),
153 _mm256_and_si256(delim_mask, _mm256_set1_epi8(crate::class::DELIMITER as i8)),
154 ),
155 );
156
157 _mm256_storeu_si256(out_ptr.add(offset) as *mut __m256i, combined);
158 offset += 32;
159 }
160
161 while offset < len {
163 *out_ptr.add(offset) = classify_byte(*ptr.add(offset));
164 offset += 1;
165 }
166
167 result.set_len(len);
168 }
169
170 result
171}
172
173#[target_feature(enable = "avx2")]
179#[cfg(target_arch = "x86_64")]
180pub unsafe fn skip_whitespace(input: &[u8]) -> usize {
181 let len = input.len();
182 let ptr = input.as_ptr();
183 let mut offset = 0;
184
185 unsafe {
187 while offset + 32 <= len {
188 let chunk = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
189
190 let ws_mask = _mm256_or_si256(
192 _mm256_or_si256(
193 _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b' ' as i8)),
194 _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'\t' as i8)),
195 ),
196 _mm256_or_si256(
197 _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'\n' as i8)),
198 _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(b'\r' as i8)),
199 ),
200 );
201
202 let mask = _mm256_movemask_epi8(ws_mask) as u32;
204
205 if mask != 0xFFFF_FFFF {
206 let non_ws = !mask;
208 return offset + non_ws.trailing_zeros() as usize;
209 }
210 offset += 32;
211 }
212 }
213
214 offset + crate::scalar::skip_whitespace_safe(&input[offset..])
216}
217
218#[target_feature(enable = "avx2")]
227#[cfg(target_arch = "x86_64")]
228pub unsafe fn compute_byte_mask(block: &[u8], byte: u8) -> u64 {
229 let len = block.len();
230 let ptr = block.as_ptr();
231 let mut result: u64 = 0;
232 let mut offset = 0;
233
234 unsafe {
236 let target = _mm256_set1_epi8(byte as i8);
237
238 while offset + 32 <= len {
239 let chunk = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
240 let cmp = _mm256_cmpeq_epi8(chunk, target);
241 let mask = _mm256_movemask_epi8(cmp) as u32;
242 result |= (mask as u64) << offset;
243 offset += 32;
244 }
245 }
246
247 while offset < len {
249 if unsafe { *ptr.add(offset) } == byte {
251 result |= 1u64 << offset;
252 }
253 offset += 1;
254 }
255
256 result
257}
258
259#[cfg(all(test, target_arch = "x86_64"))]
260mod tests {
261 use super::*;
262 use crate::class;
263
264 fn has_avx2() -> bool {
265 is_x86_feature_detected!("avx2")
266 }
267
268 #[test]
269 fn find_delimiters_basic() {
270 if !has_avx2() {
271 return;
272 }
273 let input = b"abcdefghijklmnopqrstuvwxyz12345<div>";
275 let result = unsafe { find_delimiters(input) };
276 assert_eq!(
277 result,
278 DelimiterResult::Found {
279 pos: 31,
280 byte: b'<'
281 },
282 );
283 }
284
285 #[test]
286 fn find_delimiters_not_found() {
287 if !has_avx2() {
288 return;
289 }
290 let input = b"abcdefghijklmnopqrstuvwxyz0123456789";
291 let result = unsafe { find_delimiters(input) };
292 assert_eq!(result, DelimiterResult::NotFound);
293 }
294
295 #[test]
296 fn classify_bytes_basic() {
297 if !has_avx2() {
298 return;
299 }
300 let input = b"abcd 1234\t<>&&\"'/=\nABCDwxyz09......";
302 let result = unsafe { classify_bytes(input) };
303 assert_eq!(result[0], class::ALPHA);
304 assert_eq!(result[4], class::WHITESPACE);
305 assert_eq!(result[5], class::DIGIT);
306 assert_eq!(result[10], class::DELIMITER); }
308
309 #[test]
310 fn skip_whitespace_basic() {
311 if !has_avx2() {
312 return;
313 }
314 let input = b" X";
316 let result = unsafe { skip_whitespace(input) };
317 assert_eq!(result, 40);
318 }
319}