Skip to main content

fhp_simd/
neon.rs

1//! ARM NEON accelerated operations (128-bit, aarch64).
2//!
3//! NEON is always available on aarch64 — no runtime detection needed.
4//! Each function processes 16 bytes at a time, matching SSE4.2 throughput.
5
6#[cfg(target_arch = "aarch64")]
7use core::arch::aarch64::*;
8
9use crate::{DelimiterResult, classify_byte};
10
11/// Scan `haystack` for the first HTML delimiter using NEON (128-bit).
12///
13/// Processes 16 bytes at a time using `vceqq_u8` for each delimiter,
14/// then OR's the results and extracts a bitmask.
15///
16/// # Safety
17///
18/// Caller must ensure the CPU supports NEON (always true on aarch64).
19#[target_feature(enable = "neon")]
20#[cfg(target_arch = "aarch64")]
21pub unsafe fn find_delimiters(haystack: &[u8]) -> DelimiterResult {
22    let len = haystack.len();
23    let ptr = haystack.as_ptr();
24    let mut offset = 0;
25
26    // SAFETY: all intrinsics below require NEON, guaranteed by #[target_feature].
27    unsafe {
28        // Broadcast each delimiter into a 128-bit register (16 copies).
29        let lt = vdupq_n_u8(b'<');
30        let gt = vdupq_n_u8(b'>');
31        let amp = vdupq_n_u8(b'&');
32        let quot = vdupq_n_u8(b'"');
33        let apos = vdupq_n_u8(b'\'');
34        let eq = vdupq_n_u8(b'=');
35        let slash = vdupq_n_u8(b'/');
36
37        while offset + 16 <= len {
38            // Load 16 bytes (unaligned load is free on aarch64).
39            let chunk = vld1q_u8(ptr.add(offset));
40
41            // vceqq_u8: 0xFF where equal, 0x00 where not.
42            let cmp_lt = vceqq_u8(chunk, lt);
43            let cmp_gt = vceqq_u8(chunk, gt);
44            let cmp_amp = vceqq_u8(chunk, amp);
45            let cmp_quot = vceqq_u8(chunk, quot);
46            let cmp_apos = vceqq_u8(chunk, apos);
47            let cmp_eq = vceqq_u8(chunk, eq);
48            let cmp_slash = vceqq_u8(chunk, slash);
49
50            // OR all comparisons.
51            let combined = vorrq_u8(
52                vorrq_u8(vorrq_u8(cmp_lt, cmp_gt), vorrq_u8(cmp_amp, cmp_quot)),
53                vorrq_u8(vorrq_u8(cmp_apos, cmp_eq), cmp_slash),
54            );
55
56            // Extract bitmask: NEON doesn't have movemask, use pairwise reduction.
57            let mask = neon_movemask(combined);
58            if mask != 0 {
59                let bit_pos = mask.trailing_zeros() as usize;
60                let pos = offset + bit_pos;
61                return DelimiterResult::Found {
62                    pos,
63                    byte: *ptr.add(pos),
64                };
65            }
66            offset += 16;
67        }
68    }
69
70    // Scalar tail.
71    crate::scalar::find_delimiters_safe(&haystack[offset..]).offset_by(offset)
72}
73
74/// Classify each byte using NEON — 16 bytes at a time.
75///
76/// # Safety
77///
78/// Caller must ensure NEON support (always true on aarch64).
79#[target_feature(enable = "neon")]
80#[cfg(target_arch = "aarch64")]
81pub unsafe fn classify_bytes(input: &[u8]) -> Vec<u8> {
82    let len = input.len();
83    let mut result = Vec::with_capacity(len);
84    let ptr = input.as_ptr();
85    let out_ptr: *mut u8 = result.as_mut_ptr();
86    let mut offset = 0;
87
88    // SAFETY: all intrinsics below require NEON, guaranteed by #[target_feature].
89    // Pointer arithmetic is valid because offset < len and result has capacity >= len.
90    unsafe {
91        while offset + 16 <= len {
92            let chunk = vld1q_u8(ptr.add(offset));
93
94            // Whitespace: space, tab, newline, CR.
95            let ws_mask = vorrq_u8(
96                vorrq_u8(
97                    vceqq_u8(chunk, vdupq_n_u8(b' ')),
98                    vceqq_u8(chunk, vdupq_n_u8(b'\t')),
99                ),
100                vorrq_u8(
101                    vceqq_u8(chunk, vdupq_n_u8(b'\n')),
102                    vceqq_u8(chunk, vdupq_n_u8(b'\r')),
103                ),
104            );
105
106            // Alpha: (b | 0x20) - 'a' <= 25
107            let lower = vorrq_u8(chunk, vdupq_n_u8(0x20));
108            let sub = vsubq_u8(lower, vdupq_n_u8(b'a'));
109            // vcleq_u8: 0xFF where sub[i] <= 25
110            let alpha_mask = vcleq_u8(sub, vdupq_n_u8(25));
111
112            // Digit: b - '0' <= 9
113            let sub_d = vsubq_u8(chunk, vdupq_n_u8(b'0'));
114            let digit_mask = vcleq_u8(sub_d, vdupq_n_u8(9));
115
116            // Delimiters.
117            let delim_mask = vorrq_u8(
118                vorrq_u8(
119                    vorrq_u8(
120                        vceqq_u8(chunk, vdupq_n_u8(b'<')),
121                        vceqq_u8(chunk, vdupq_n_u8(b'>')),
122                    ),
123                    vorrq_u8(
124                        vceqq_u8(chunk, vdupq_n_u8(b'&')),
125                        vceqq_u8(chunk, vdupq_n_u8(b'"')),
126                    ),
127                ),
128                vorrq_u8(
129                    vorrq_u8(
130                        vceqq_u8(chunk, vdupq_n_u8(b'\'')),
131                        vceqq_u8(chunk, vdupq_n_u8(b'=')),
132                    ),
133                    vceqq_u8(chunk, vdupq_n_u8(b'/')),
134                ),
135            );
136
137            // Map to class constants: AND mask with the class value, then OR all.
138            let ws_class = vandq_u8(ws_mask, vdupq_n_u8(crate::class::WHITESPACE));
139            let al_class = vandq_u8(alpha_mask, vdupq_n_u8(crate::class::ALPHA));
140            let di_class = vandq_u8(digit_mask, vdupq_n_u8(crate::class::DIGIT));
141            let de_class = vandq_u8(delim_mask, vdupq_n_u8(crate::class::DELIMITER));
142
143            let combined = vorrq_u8(vorrq_u8(ws_class, al_class), vorrq_u8(di_class, de_class));
144
145            // Store 16 classified bytes.
146            vst1q_u8(out_ptr.add(offset), combined);
147            offset += 16;
148        }
149
150        // Scalar tail.
151        while offset < len {
152            *out_ptr.add(offset) = classify_byte(*ptr.add(offset));
153            offset += 1;
154        }
155
156        result.set_len(len);
157    }
158
159    result
160}
161
162/// Skip leading whitespace using NEON — 16 bytes at a time.
163///
164/// # Safety
165///
166/// Caller must ensure NEON support (always true on aarch64).
167#[target_feature(enable = "neon")]
168#[cfg(target_arch = "aarch64")]
169pub unsafe fn skip_whitespace(input: &[u8]) -> usize {
170    let len = input.len();
171    let ptr = input.as_ptr();
172    let mut offset = 0;
173
174    // SAFETY: all intrinsics below require NEON, guaranteed by #[target_feature].
175    unsafe {
176        while offset + 16 <= len {
177            let chunk = vld1q_u8(ptr.add(offset));
178
179            let ws_mask = vorrq_u8(
180                vorrq_u8(
181                    vceqq_u8(chunk, vdupq_n_u8(b' ')),
182                    vceqq_u8(chunk, vdupq_n_u8(b'\t')),
183                ),
184                vorrq_u8(
185                    vceqq_u8(chunk, vdupq_n_u8(b'\n')),
186                    vceqq_u8(chunk, vdupq_n_u8(b'\r')),
187                ),
188            );
189
190            let mask = neon_movemask(ws_mask);
191            if mask != 0xFFFF {
192                // Not all whitespace — find first non-WS.
193                let non_ws = !mask;
194                return offset + non_ws.trailing_zeros() as usize;
195            }
196            offset += 16;
197        }
198    }
199
200    // Scalar tail.
201    offset + crate::scalar::skip_whitespace_safe(&input[offset..])
202}
203
204/// Produce a bitmask where bit `i` is set if `block[i] == byte`.
205///
206/// Processes 16 bytes at a time using NEON `vceqq_u8`, then extracts
207/// a bitmask via an internal movemask helper. Handles blocks up to 64 bytes.
208///
209/// # Safety
210///
211/// Caller must ensure the CPU supports NEON (always true on aarch64).
212#[target_feature(enable = "neon")]
213#[cfg(target_arch = "aarch64")]
214pub unsafe fn compute_byte_mask(block: &[u8], byte: u8) -> u64 {
215    let len = block.len();
216    let ptr = block.as_ptr();
217    let mut result: u64 = 0;
218    let mut offset = 0;
219
220    // SAFETY: all intrinsics below require NEON, guaranteed by #[target_feature].
221    unsafe {
222        let target = vdupq_n_u8(byte);
223
224        while offset + 16 <= len {
225            let chunk = vld1q_u8(ptr.add(offset));
226            let cmp = vceqq_u8(chunk, target);
227            let mask = neon_movemask(cmp);
228            result |= (mask as u64) << offset;
229            offset += 16;
230        }
231    }
232
233    // Scalar tail.
234    while offset < len {
235        // SAFETY: offset < len, so ptr.add(offset) is valid.
236        if unsafe { *ptr.add(offset) } == byte {
237            result |= 1u64 << offset;
238        }
239        offset += 1;
240    }
241
242    result
243}
244
245/// Compute four delimiter bitmasks in a single pass over the block.
246///
247/// Loads each 16-byte chunk once and produces all 4 masks simultaneously.
248/// Only `<`, `>`, `"`, `'` are needed by the fused tokenizer pipeline.
249///
250/// # Safety
251///
252/// Caller must ensure the CPU supports NEON (always true on aarch64).
253#[target_feature(enable = "neon")]
254#[cfg(target_arch = "aarch64")]
255pub unsafe fn compute_all_masks(block: &[u8]) -> crate::AllMasks {
256    let len = block.len();
257    let ptr = block.as_ptr();
258    let mut masks = crate::AllMasks::default();
259    let mut offset = 0;
260
261    // SAFETY: all intrinsics below require NEON, guaranteed by #[target_feature].
262    unsafe {
263        // Broadcast each delimiter into a 128-bit register.
264        let v_lt = vdupq_n_u8(b'<');
265        let v_gt = vdupq_n_u8(b'>');
266        let v_quot = vdupq_n_u8(b'"');
267        let v_apos = vdupq_n_u8(b'\'');
268
269        while offset + 16 <= len {
270            // Single load per 16-byte chunk.
271            let chunk = vld1q_u8(ptr.add(offset));
272
273            // 4 comparisons on the same loaded chunk.
274            let m_lt = neon_movemask(vceqq_u8(chunk, v_lt)) as u64;
275            let m_gt = neon_movemask(vceqq_u8(chunk, v_gt)) as u64;
276            let m_quot = neon_movemask(vceqq_u8(chunk, v_quot)) as u64;
277            let m_apos = neon_movemask(vceqq_u8(chunk, v_apos)) as u64;
278
279            masks.lt |= m_lt << offset;
280            masks.gt |= m_gt << offset;
281            masks.quot |= m_quot << offset;
282            masks.apos |= m_apos << offset;
283
284            offset += 16;
285        }
286    }
287
288    // Scalar tail for remaining bytes.
289    while offset < len {
290        let b = block[offset];
291        let bit = 1u64 << offset;
292        match b {
293            b'<' => masks.lt |= bit,
294            b'>' => masks.gt |= bit,
295            b'"' => masks.quot |= bit,
296            b'\'' => masks.apos |= bit,
297            _ => {}
298        }
299        offset += 1;
300    }
301
302    masks
303}
304
305/// Emulate x86 `_mm_movemask_epi8` on NEON.
306///
307/// Takes a 128-bit vector where each byte is either 0x00 or 0xFF.
308/// Returns a `u16` where bit `i` corresponds to byte `i`'s high bit.
309///
310/// # Safety
311///
312/// Requires NEON support.
313#[target_feature(enable = "neon")]
314#[cfg(target_arch = "aarch64")]
315#[inline]
316unsafe fn neon_movemask(v: uint8x16_t) -> u16 {
317    // SAFETY: all intrinsics below require NEON, guaranteed by #[target_feature].
318    unsafe {
319        // Bit-select table: each byte contributes its corresponding bit.
320        // AND with power-of-2 per lane, then pairwise-add reduce to form bitmask.
321        // Static avoids stack allocation + load-use latency on every call.
322        static BIT_MASK: [u8; 16] = [1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128];
323        let bitmask = vld1q_u8(BIT_MASK.as_ptr());
324
325        // AND: each lane is either 0 or its bit value.
326        let masked = vandq_u8(v, bitmask);
327
328        // Split into low and high 64-bit halves.
329        let lo = vget_low_u8(masked);
330        let hi = vget_high_u8(masked);
331
332        // Horizontal add: 8 bytes -> 4 u16 -> 2 u32 -> 1 u64 per half.
333        let lo_pairs = vpaddl_u8(lo);
334        let lo_quads = vpaddl_u16(lo_pairs);
335        let lo_single = vpaddl_u32(lo_quads);
336        let lo_byte = vget_lane_u64(lo_single, 0) as u8;
337
338        let hi_pairs = vpaddl_u8(hi);
339        let hi_quads = vpaddl_u16(hi_pairs);
340        let hi_single = vpaddl_u32(hi_quads);
341        let hi_byte = vget_lane_u64(hi_single, 0) as u8;
342
343        (lo_byte as u16) | ((hi_byte as u16) << 8)
344    }
345}
346
347#[cfg(all(test, target_arch = "aarch64"))]
348mod tests {
349    use super::*;
350    use crate::class;
351
352    #[test]
353    fn find_delimiters_basic() {
354        let input = b"hello world <div>";
355        let result = unsafe { find_delimiters(input) };
356        assert_eq!(
357            result,
358            DelimiterResult::Found {
359                pos: 12,
360                byte: b'<'
361            }
362        );
363    }
364
365    #[test]
366    fn find_delimiters_not_found() {
367        let input = b"hello world no delimiters here at all okay";
368        let result = unsafe { find_delimiters(input) };
369        assert_eq!(result, DelimiterResult::NotFound);
370    }
371
372    #[test]
373    fn find_delimiters_all_types() {
374        for &delim in b"<>&\"'=/" {
375            let mut input = vec![b'x'; 20];
376            input[15] = delim;
377            let result = unsafe { find_delimiters(&input) };
378            assert_eq!(
379                result,
380                DelimiterResult::Found {
381                    pos: 15,
382                    byte: delim
383                },
384                "failed for delimiter 0x{delim:02X}"
385            );
386        }
387    }
388
389    #[test]
390    fn find_delimiters_in_tail() {
391        let mut input = vec![b'x'; 25];
392        input[20] = b'<';
393        let result = unsafe { find_delimiters(&input) };
394        assert_eq!(
395            result,
396            DelimiterResult::Found {
397                pos: 20,
398                byte: b'<'
399            }
400        );
401    }
402
403    #[test]
404    fn find_delimiters_empty() {
405        let result = unsafe { find_delimiters(b"") };
406        assert_eq!(result, DelimiterResult::NotFound);
407    }
408
409    #[test]
410    fn classify_bytes_basic() {
411        let input = b"a1 <b2\t>Zz09&\"'/=\nhello world...";
412        let result = unsafe { classify_bytes(input) };
413        assert_eq!(result[0], class::ALPHA); // 'a'
414        assert_eq!(result[1], class::DIGIT); // '1'
415        assert_eq!(result[2], class::WHITESPACE); // ' '
416        assert_eq!(result[3], class::DELIMITER); // '<'
417        assert_eq!(result[4], class::ALPHA); // 'b'
418        assert_eq!(result[5], class::DIGIT); // '2'
419        assert_eq!(result[6], class::WHITESPACE); // '\t'
420        assert_eq!(result[7], class::DELIMITER); // '>'
421    }
422
423    #[test]
424    fn classify_bytes_matches_scalar() {
425        let input = b"Hello <World> & \"test\" = 'value' / 123\n\r\t end";
426        let neon_result = unsafe { classify_bytes(input) };
427        let scalar_result = unsafe { crate::scalar::classify_bytes(input) };
428        assert_eq!(neon_result, scalar_result);
429    }
430
431    #[test]
432    fn classify_bytes_empty() {
433        let result = unsafe { classify_bytes(b"") };
434        assert!(result.is_empty());
435    }
436
437    #[test]
438    fn skip_whitespace_basic() {
439        let result = unsafe { skip_whitespace(b"   \t\nhello") };
440        assert_eq!(result, 5);
441    }
442
443    #[test]
444    fn skip_whitespace_all_ws() {
445        let result = unsafe { skip_whitespace(b"                    ") };
446        assert_eq!(result, 20);
447    }
448
449    #[test]
450    fn skip_whitespace_none() {
451        let result = unsafe { skip_whitespace(b"hello") };
452        assert_eq!(result, 0);
453    }
454
455    #[test]
456    fn skip_whitespace_empty() {
457        let result = unsafe { skip_whitespace(b"") };
458        assert_eq!(result, 0);
459    }
460
461    #[test]
462    fn skip_whitespace_matches_scalar() {
463        let inputs: &[&[u8]] = &[
464            b"   hello",
465            b"\t\n\r world",
466            b"no_leading_ws",
467            b"                                extra",
468            b"",
469            b"    ",
470        ];
471        for &input in inputs {
472            let neon_result = unsafe { skip_whitespace(input) };
473            let scalar_result = unsafe { crate::scalar::skip_whitespace(input) };
474            assert_eq!(
475                neon_result,
476                scalar_result,
477                "mismatch for input {:?}",
478                std::str::from_utf8(input)
479            );
480        }
481    }
482
483    #[test]
484    fn compute_byte_mask_basic() {
485        let input = b"hello world <div>";
486        let mask = unsafe { compute_byte_mask(input, b'<') };
487        assert_eq!(mask, 1 << 12);
488    }
489
490    #[test]
491    fn compute_byte_mask_multiple_hits() {
492        // 20 bytes — crosses 16-byte boundary.
493        let mut input = vec![b'x'; 20];
494        input[3] = b'<';
495        input[17] = b'<';
496        let mask = unsafe { compute_byte_mask(&input, b'<') };
497        assert_eq!(mask, (1 << 3) | (1 << 17));
498    }
499
500    #[test]
501    fn compute_byte_mask_matches_scalar() {
502        let input = b"Hello <World> & \"test\" = 'value' / 123\n\r\t end!!";
503        for &byte in b"<>&\"'=/" {
504            let neon_result = unsafe { compute_byte_mask(input, byte) };
505            let scalar_result = unsafe { crate::scalar::compute_byte_mask(input, byte) };
506            assert_eq!(neon_result, scalar_result, "mismatch for byte 0x{byte:02X}");
507        }
508    }
509
510    #[test]
511    fn compute_byte_mask_64_bytes() {
512        let mut input = vec![b'a'; 64];
513        input[0] = b'<';
514        input[15] = b'<';
515        input[16] = b'<';
516        input[31] = b'<';
517        input[48] = b'<';
518        input[63] = b'<';
519        let mask = unsafe { compute_byte_mask(&input, b'<') };
520        assert_eq!(
521            mask,
522            (1u64 << 0) | (1u64 << 15) | (1u64 << 16) | (1u64 << 31) | (1u64 << 48) | (1u64 << 63)
523        );
524    }
525
526    #[test]
527    fn neon_movemask_all_zero() {
528        unsafe {
529            let v = vdupq_n_u8(0);
530            assert_eq!(neon_movemask(v), 0);
531        }
532    }
533
534    #[test]
535    fn neon_movemask_all_ones() {
536        unsafe {
537            let v = vdupq_n_u8(0xFF);
538            assert_eq!(neon_movemask(v), 0xFFFF);
539        }
540    }
541
542    #[test]
543    fn neon_movemask_specific_bits() {
544        unsafe {
545            let mut bytes = [0u8; 16];
546            bytes[0] = 0xFF;
547            bytes[8] = 0xFF;
548            let v = vld1q_u8(bytes.as_ptr());
549            let mask = neon_movemask(v);
550            assert_eq!(mask, (1 << 0) | (1 << 8));
551        }
552    }
553
554    #[test]
555    fn compute_all_masks_matches_scalar() {
556        let input = b"Hello <World> & \"test\" = 'value' / 123\n\r\t end!!";
557        let neon_masks = unsafe { compute_all_masks(input) };
558        let scalar_masks = crate::scalar::compute_all_masks_safe(input);
559        assert_eq!(neon_masks.lt, scalar_masks.lt, "lt mismatch");
560        assert_eq!(neon_masks.gt, scalar_masks.gt, "gt mismatch");
561        assert_eq!(neon_masks.quot, scalar_masks.quot, "quot mismatch");
562        assert_eq!(neon_masks.apos, scalar_masks.apos, "apos mismatch");
563    }
564
565    #[test]
566    fn compute_all_masks_64_bytes() {
567        let mut input = vec![b'x'; 64];
568        input[0] = b'<';
569        input[15] = b'>';
570        input[31] = b'"';
571        input[48] = b'\'';
572        let masks = unsafe { compute_all_masks(&input) };
573        assert_eq!(masks.lt, 1u64 << 0);
574        assert_eq!(masks.gt, 1u64 << 15);
575        assert_eq!(masks.quot, 1u64 << 31);
576        assert_eq!(masks.apos, 1u64 << 48);
577    }
578}