1#[cfg(target_arch = "aarch64")]
7use core::arch::aarch64::*;
8
9use crate::{DelimiterResult, classify_byte};
10
11#[target_feature(enable = "neon")]
20#[cfg(target_arch = "aarch64")]
21pub unsafe fn find_delimiters(haystack: &[u8]) -> DelimiterResult {
22 let len = haystack.len();
23 let ptr = haystack.as_ptr();
24 let mut offset = 0;
25
26 unsafe {
28 let lt = vdupq_n_u8(b'<');
30 let gt = vdupq_n_u8(b'>');
31 let amp = vdupq_n_u8(b'&');
32 let quot = vdupq_n_u8(b'"');
33 let apos = vdupq_n_u8(b'\'');
34 let eq = vdupq_n_u8(b'=');
35 let slash = vdupq_n_u8(b'/');
36
37 while offset + 16 <= len {
38 let chunk = vld1q_u8(ptr.add(offset));
40
41 let cmp_lt = vceqq_u8(chunk, lt);
43 let cmp_gt = vceqq_u8(chunk, gt);
44 let cmp_amp = vceqq_u8(chunk, amp);
45 let cmp_quot = vceqq_u8(chunk, quot);
46 let cmp_apos = vceqq_u8(chunk, apos);
47 let cmp_eq = vceqq_u8(chunk, eq);
48 let cmp_slash = vceqq_u8(chunk, slash);
49
50 let combined = vorrq_u8(
52 vorrq_u8(vorrq_u8(cmp_lt, cmp_gt), vorrq_u8(cmp_amp, cmp_quot)),
53 vorrq_u8(vorrq_u8(cmp_apos, cmp_eq), cmp_slash),
54 );
55
56 let mask = neon_movemask(combined);
58 if mask != 0 {
59 let bit_pos = mask.trailing_zeros() as usize;
60 let pos = offset + bit_pos;
61 return DelimiterResult::Found {
62 pos,
63 byte: *ptr.add(pos),
64 };
65 }
66 offset += 16;
67 }
68 }
69
70 crate::scalar::find_delimiters_safe(&haystack[offset..]).offset_by(offset)
72}
73
74#[target_feature(enable = "neon")]
80#[cfg(target_arch = "aarch64")]
81pub unsafe fn classify_bytes(input: &[u8]) -> Vec<u8> {
82 let len = input.len();
83 let mut result = Vec::with_capacity(len);
84 let ptr = input.as_ptr();
85 let out_ptr: *mut u8 = result.as_mut_ptr();
86 let mut offset = 0;
87
88 unsafe {
91 while offset + 16 <= len {
92 let chunk = vld1q_u8(ptr.add(offset));
93
94 let ws_mask = vorrq_u8(
96 vorrq_u8(
97 vceqq_u8(chunk, vdupq_n_u8(b' ')),
98 vceqq_u8(chunk, vdupq_n_u8(b'\t')),
99 ),
100 vorrq_u8(
101 vceqq_u8(chunk, vdupq_n_u8(b'\n')),
102 vceqq_u8(chunk, vdupq_n_u8(b'\r')),
103 ),
104 );
105
106 let lower = vorrq_u8(chunk, vdupq_n_u8(0x20));
108 let sub = vsubq_u8(lower, vdupq_n_u8(b'a'));
109 let alpha_mask = vcleq_u8(sub, vdupq_n_u8(25));
111
112 let sub_d = vsubq_u8(chunk, vdupq_n_u8(b'0'));
114 let digit_mask = vcleq_u8(sub_d, vdupq_n_u8(9));
115
116 let delim_mask = vorrq_u8(
118 vorrq_u8(
119 vorrq_u8(
120 vceqq_u8(chunk, vdupq_n_u8(b'<')),
121 vceqq_u8(chunk, vdupq_n_u8(b'>')),
122 ),
123 vorrq_u8(
124 vceqq_u8(chunk, vdupq_n_u8(b'&')),
125 vceqq_u8(chunk, vdupq_n_u8(b'"')),
126 ),
127 ),
128 vorrq_u8(
129 vorrq_u8(
130 vceqq_u8(chunk, vdupq_n_u8(b'\'')),
131 vceqq_u8(chunk, vdupq_n_u8(b'=')),
132 ),
133 vceqq_u8(chunk, vdupq_n_u8(b'/')),
134 ),
135 );
136
137 let ws_class = vandq_u8(ws_mask, vdupq_n_u8(crate::class::WHITESPACE));
139 let al_class = vandq_u8(alpha_mask, vdupq_n_u8(crate::class::ALPHA));
140 let di_class = vandq_u8(digit_mask, vdupq_n_u8(crate::class::DIGIT));
141 let de_class = vandq_u8(delim_mask, vdupq_n_u8(crate::class::DELIMITER));
142
143 let combined = vorrq_u8(vorrq_u8(ws_class, al_class), vorrq_u8(di_class, de_class));
144
145 vst1q_u8(out_ptr.add(offset), combined);
147 offset += 16;
148 }
149
150 while offset < len {
152 *out_ptr.add(offset) = classify_byte(*ptr.add(offset));
153 offset += 1;
154 }
155
156 result.set_len(len);
157 }
158
159 result
160}
161
162#[target_feature(enable = "neon")]
168#[cfg(target_arch = "aarch64")]
169pub unsafe fn skip_whitespace(input: &[u8]) -> usize {
170 let len = input.len();
171 let ptr = input.as_ptr();
172 let mut offset = 0;
173
174 unsafe {
176 while offset + 16 <= len {
177 let chunk = vld1q_u8(ptr.add(offset));
178
179 let ws_mask = vorrq_u8(
180 vorrq_u8(
181 vceqq_u8(chunk, vdupq_n_u8(b' ')),
182 vceqq_u8(chunk, vdupq_n_u8(b'\t')),
183 ),
184 vorrq_u8(
185 vceqq_u8(chunk, vdupq_n_u8(b'\n')),
186 vceqq_u8(chunk, vdupq_n_u8(b'\r')),
187 ),
188 );
189
190 let mask = neon_movemask(ws_mask);
191 if mask != 0xFFFF {
192 let non_ws = !mask;
194 return offset + non_ws.trailing_zeros() as usize;
195 }
196 offset += 16;
197 }
198 }
199
200 offset + crate::scalar::skip_whitespace_safe(&input[offset..])
202}
203
204#[target_feature(enable = "neon")]
213#[cfg(target_arch = "aarch64")]
214pub unsafe fn compute_byte_mask(block: &[u8], byte: u8) -> u64 {
215 let len = block.len();
216 let ptr = block.as_ptr();
217 let mut result: u64 = 0;
218 let mut offset = 0;
219
220 unsafe {
222 let target = vdupq_n_u8(byte);
223
224 while offset + 16 <= len {
225 let chunk = vld1q_u8(ptr.add(offset));
226 let cmp = vceqq_u8(chunk, target);
227 let mask = neon_movemask(cmp);
228 result |= (mask as u64) << offset;
229 offset += 16;
230 }
231 }
232
233 while offset < len {
235 if unsafe { *ptr.add(offset) } == byte {
237 result |= 1u64 << offset;
238 }
239 offset += 1;
240 }
241
242 result
243}
244
245#[target_feature(enable = "neon")]
254#[cfg(target_arch = "aarch64")]
255pub unsafe fn compute_all_masks(block: &[u8]) -> crate::AllMasks {
256 let len = block.len();
257 let ptr = block.as_ptr();
258 let mut masks = crate::AllMasks::default();
259 let mut offset = 0;
260
261 unsafe {
263 let v_lt = vdupq_n_u8(b'<');
265 let v_gt = vdupq_n_u8(b'>');
266 let v_quot = vdupq_n_u8(b'"');
267 let v_apos = vdupq_n_u8(b'\'');
268
269 while offset + 16 <= len {
270 let chunk = vld1q_u8(ptr.add(offset));
272
273 let m_lt = neon_movemask(vceqq_u8(chunk, v_lt)) as u64;
275 let m_gt = neon_movemask(vceqq_u8(chunk, v_gt)) as u64;
276 let m_quot = neon_movemask(vceqq_u8(chunk, v_quot)) as u64;
277 let m_apos = neon_movemask(vceqq_u8(chunk, v_apos)) as u64;
278
279 masks.lt |= m_lt << offset;
280 masks.gt |= m_gt << offset;
281 masks.quot |= m_quot << offset;
282 masks.apos |= m_apos << offset;
283
284 offset += 16;
285 }
286 }
287
288 while offset < len {
290 let b = block[offset];
291 let bit = 1u64 << offset;
292 match b {
293 b'<' => masks.lt |= bit,
294 b'>' => masks.gt |= bit,
295 b'"' => masks.quot |= bit,
296 b'\'' => masks.apos |= bit,
297 _ => {}
298 }
299 offset += 1;
300 }
301
302 masks
303}
304
305#[target_feature(enable = "neon")]
314#[cfg(target_arch = "aarch64")]
315#[inline]
316unsafe fn neon_movemask(v: uint8x16_t) -> u16 {
317 unsafe {
319 static BIT_MASK: [u8; 16] = [1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128];
323 let bitmask = vld1q_u8(BIT_MASK.as_ptr());
324
325 let masked = vandq_u8(v, bitmask);
327
328 let lo = vget_low_u8(masked);
330 let hi = vget_high_u8(masked);
331
332 let lo_pairs = vpaddl_u8(lo);
334 let lo_quads = vpaddl_u16(lo_pairs);
335 let lo_single = vpaddl_u32(lo_quads);
336 let lo_byte = vget_lane_u64(lo_single, 0) as u8;
337
338 let hi_pairs = vpaddl_u8(hi);
339 let hi_quads = vpaddl_u16(hi_pairs);
340 let hi_single = vpaddl_u32(hi_quads);
341 let hi_byte = vget_lane_u64(hi_single, 0) as u8;
342
343 (lo_byte as u16) | ((hi_byte as u16) << 8)
344 }
345}
346
347#[cfg(all(test, target_arch = "aarch64"))]
348mod tests {
349 use super::*;
350 use crate::class;
351
352 #[test]
353 fn find_delimiters_basic() {
354 let input = b"hello world <div>";
355 let result = unsafe { find_delimiters(input) };
356 assert_eq!(
357 result,
358 DelimiterResult::Found {
359 pos: 12,
360 byte: b'<'
361 }
362 );
363 }
364
365 #[test]
366 fn find_delimiters_not_found() {
367 let input = b"hello world no delimiters here at all okay";
368 let result = unsafe { find_delimiters(input) };
369 assert_eq!(result, DelimiterResult::NotFound);
370 }
371
372 #[test]
373 fn find_delimiters_all_types() {
374 for &delim in b"<>&\"'=/" {
375 let mut input = vec![b'x'; 20];
376 input[15] = delim;
377 let result = unsafe { find_delimiters(&input) };
378 assert_eq!(
379 result,
380 DelimiterResult::Found {
381 pos: 15,
382 byte: delim
383 },
384 "failed for delimiter 0x{delim:02X}"
385 );
386 }
387 }
388
389 #[test]
390 fn find_delimiters_in_tail() {
391 let mut input = vec![b'x'; 25];
392 input[20] = b'<';
393 let result = unsafe { find_delimiters(&input) };
394 assert_eq!(
395 result,
396 DelimiterResult::Found {
397 pos: 20,
398 byte: b'<'
399 }
400 );
401 }
402
403 #[test]
404 fn find_delimiters_empty() {
405 let result = unsafe { find_delimiters(b"") };
406 assert_eq!(result, DelimiterResult::NotFound);
407 }
408
409 #[test]
410 fn classify_bytes_basic() {
411 let input = b"a1 <b2\t>Zz09&\"'/=\nhello world...";
412 let result = unsafe { classify_bytes(input) };
413 assert_eq!(result[0], class::ALPHA); assert_eq!(result[1], class::DIGIT); assert_eq!(result[2], class::WHITESPACE); assert_eq!(result[3], class::DELIMITER); assert_eq!(result[4], class::ALPHA); assert_eq!(result[5], class::DIGIT); assert_eq!(result[6], class::WHITESPACE); assert_eq!(result[7], class::DELIMITER); }
422
423 #[test]
424 fn classify_bytes_matches_scalar() {
425 let input = b"Hello <World> & \"test\" = 'value' / 123\n\r\t end";
426 let neon_result = unsafe { classify_bytes(input) };
427 let scalar_result = unsafe { crate::scalar::classify_bytes(input) };
428 assert_eq!(neon_result, scalar_result);
429 }
430
431 #[test]
432 fn classify_bytes_empty() {
433 let result = unsafe { classify_bytes(b"") };
434 assert!(result.is_empty());
435 }
436
437 #[test]
438 fn skip_whitespace_basic() {
439 let result = unsafe { skip_whitespace(b" \t\nhello") };
440 assert_eq!(result, 5);
441 }
442
443 #[test]
444 fn skip_whitespace_all_ws() {
445 let result = unsafe { skip_whitespace(b" ") };
446 assert_eq!(result, 20);
447 }
448
449 #[test]
450 fn skip_whitespace_none() {
451 let result = unsafe { skip_whitespace(b"hello") };
452 assert_eq!(result, 0);
453 }
454
455 #[test]
456 fn skip_whitespace_empty() {
457 let result = unsafe { skip_whitespace(b"") };
458 assert_eq!(result, 0);
459 }
460
461 #[test]
462 fn skip_whitespace_matches_scalar() {
463 let inputs: &[&[u8]] = &[
464 b" hello",
465 b"\t\n\r world",
466 b"no_leading_ws",
467 b" extra",
468 b"",
469 b" ",
470 ];
471 for &input in inputs {
472 let neon_result = unsafe { skip_whitespace(input) };
473 let scalar_result = unsafe { crate::scalar::skip_whitespace(input) };
474 assert_eq!(
475 neon_result,
476 scalar_result,
477 "mismatch for input {:?}",
478 std::str::from_utf8(input)
479 );
480 }
481 }
482
483 #[test]
484 fn compute_byte_mask_basic() {
485 let input = b"hello world <div>";
486 let mask = unsafe { compute_byte_mask(input, b'<') };
487 assert_eq!(mask, 1 << 12);
488 }
489
490 #[test]
491 fn compute_byte_mask_multiple_hits() {
492 let mut input = vec![b'x'; 20];
494 input[3] = b'<';
495 input[17] = b'<';
496 let mask = unsafe { compute_byte_mask(&input, b'<') };
497 assert_eq!(mask, (1 << 3) | (1 << 17));
498 }
499
500 #[test]
501 fn compute_byte_mask_matches_scalar() {
502 let input = b"Hello <World> & \"test\" = 'value' / 123\n\r\t end!!";
503 for &byte in b"<>&\"'=/" {
504 let neon_result = unsafe { compute_byte_mask(input, byte) };
505 let scalar_result = unsafe { crate::scalar::compute_byte_mask(input, byte) };
506 assert_eq!(neon_result, scalar_result, "mismatch for byte 0x{byte:02X}");
507 }
508 }
509
510 #[test]
511 fn compute_byte_mask_64_bytes() {
512 let mut input = vec![b'a'; 64];
513 input[0] = b'<';
514 input[15] = b'<';
515 input[16] = b'<';
516 input[31] = b'<';
517 input[48] = b'<';
518 input[63] = b'<';
519 let mask = unsafe { compute_byte_mask(&input, b'<') };
520 assert_eq!(
521 mask,
522 (1u64 << 0) | (1u64 << 15) | (1u64 << 16) | (1u64 << 31) | (1u64 << 48) | (1u64 << 63)
523 );
524 }
525
526 #[test]
527 fn neon_movemask_all_zero() {
528 unsafe {
529 let v = vdupq_n_u8(0);
530 assert_eq!(neon_movemask(v), 0);
531 }
532 }
533
534 #[test]
535 fn neon_movemask_all_ones() {
536 unsafe {
537 let v = vdupq_n_u8(0xFF);
538 assert_eq!(neon_movemask(v), 0xFFFF);
539 }
540 }
541
542 #[test]
543 fn neon_movemask_specific_bits() {
544 unsafe {
545 let mut bytes = [0u8; 16];
546 bytes[0] = 0xFF;
547 bytes[8] = 0xFF;
548 let v = vld1q_u8(bytes.as_ptr());
549 let mask = neon_movemask(v);
550 assert_eq!(mask, (1 << 0) | (1 << 8));
551 }
552 }
553
554 #[test]
555 fn compute_all_masks_matches_scalar() {
556 let input = b"Hello <World> & \"test\" = 'value' / 123\n\r\t end!!";
557 let neon_masks = unsafe { compute_all_masks(input) };
558 let scalar_masks = crate::scalar::compute_all_masks_safe(input);
559 assert_eq!(neon_masks.lt, scalar_masks.lt, "lt mismatch");
560 assert_eq!(neon_masks.gt, scalar_masks.gt, "gt mismatch");
561 assert_eq!(neon_masks.quot, scalar_masks.quot, "quot mismatch");
562 assert_eq!(neon_masks.apos, scalar_masks.apos, "apos mismatch");
563 }
564
565 #[test]
566 fn compute_all_masks_64_bytes() {
567 let mut input = vec![b'x'; 64];
568 input[0] = b'<';
569 input[15] = b'>';
570 input[31] = b'"';
571 input[48] = b'\'';
572 let masks = unsafe { compute_all_masks(&input) };
573 assert_eq!(masks.lt, 1u64 << 0);
574 assert_eq!(masks.gt, 1u64 << 15);
575 assert_eq!(masks.quot, 1u64 << 31);
576 assert_eq!(masks.apos, 1u64 << 48);
577 }
578}