Skip to main content

fhp_simd/
scalar.rs

1//! Portable scalar fallback for all SIMD operations.
2//!
3//! Every function here processes bytes one at a time. It is the baseline
4//! that SIMD backends must match semantically (and beat in throughput).
5
6use crate::{DelimiterResult, classify_byte};
7
8/// Scan `haystack` for the first occurrence of any HTML delimiter.
9///
10/// Delimiters: `<`, `>`, `&`, `"`, `'`, `=`, `/`.
11///
12/// # Safety
13///
14/// This function is safe. The `unsafe fn` signature exists so that it
15/// can be stored in the same function-pointer slot as the SIMD variants
16/// (which require `unsafe` due to target-feature intrinsics).
17pub unsafe fn find_delimiters(haystack: &[u8]) -> DelimiterResult {
18    find_delimiters_safe(haystack)
19}
20
21/// Safe inner implementation of [`find_delimiters`].
22#[inline]
23pub fn find_delimiters_safe(haystack: &[u8]) -> DelimiterResult {
24    for (i, &b) in haystack.iter().enumerate() {
25        if is_delimiter(b) {
26            return DelimiterResult::Found { pos: i, byte: b };
27        }
28    }
29    DelimiterResult::NotFound
30}
31
32/// Classify each byte in `input` into a category bitmask.
33///
34/// Returns a `Vec<u8>` of the same length as `input`, where each element
35/// is one of the constants from [`crate::class`].
36///
37/// # Safety
38///
39/// This function is safe. The `unsafe fn` signature matches the SIMD
40/// dispatch slot.
41pub unsafe fn classify_bytes(input: &[u8]) -> Vec<u8> {
42    classify_bytes_safe(input)
43}
44
45/// Safe inner implementation of [`classify_bytes`].
46#[inline]
47pub fn classify_bytes_safe(input: &[u8]) -> Vec<u8> {
48    input.iter().map(|&b| classify_byte(b)).collect()
49}
50
51/// Skip leading whitespace bytes and return the byte offset of the first
52/// non-whitespace byte (or `input.len()` if the entire slice is whitespace).
53///
54/// # Safety
55///
56/// This function is safe. The `unsafe fn` signature matches the SIMD
57/// dispatch slot.
58pub unsafe fn skip_whitespace(input: &[u8]) -> usize {
59    skip_whitespace_safe(input)
60}
61
62/// Safe inner implementation of [`skip_whitespace`].
63#[inline]
64pub fn skip_whitespace_safe(input: &[u8]) -> usize {
65    input
66        .iter()
67        .position(|&b| !b.is_ascii_whitespace())
68        .unwrap_or(input.len())
69}
70
71/// Produce a bitmask where bit `i` is set if `block[i] == byte`.
72///
73/// Processes up to 64 bytes. Bits beyond `block.len()` are always 0.
74///
75/// # Safety
76///
77/// This function is safe. The `unsafe fn` signature matches the SIMD
78/// dispatch slot.
79pub unsafe fn compute_byte_mask(block: &[u8], byte: u8) -> u64 {
80    compute_byte_mask_safe(block, byte)
81}
82
83/// Safe inner implementation of [`compute_byte_mask`].
84#[inline]
85pub fn compute_byte_mask_safe(block: &[u8], byte: u8) -> u64 {
86    let mut mask = 0u64;
87    for (i, &b) in block.iter().enumerate() {
88        if b == byte {
89            mask |= 1u64 << i;
90        }
91    }
92    mask
93}
94
95/// Compute all seven delimiter bitmasks in a single pass over the block.
96///
97/// # Safety
98///
99/// This function is safe. The `unsafe fn` signature matches the SIMD
100/// dispatch slot.
101pub unsafe fn compute_all_masks(block: &[u8]) -> crate::AllMasks {
102    compute_all_masks_safe(block)
103}
104
105/// Safe inner implementation of [`compute_all_masks`].
106#[inline]
107pub fn compute_all_masks_safe(block: &[u8]) -> crate::AllMasks {
108    let mut masks = crate::AllMasks::default();
109    for (i, &b) in block.iter().enumerate() {
110        let bit = 1u64 << i;
111        match b {
112            b'<' => masks.lt |= bit,
113            b'>' => masks.gt |= bit,
114            b'"' => masks.quot |= bit,
115            b'\'' => masks.apos |= bit,
116            _ => {}
117        }
118    }
119    masks
120}
121
122/// Returns `true` if `b` is one of the 7 HTML delimiters.
123#[inline(always)]
124fn is_delimiter(b: u8) -> bool {
125    matches!(b, b'<' | b'>' | b'&' | b'"' | b'\'' | b'=' | b'/')
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use crate::class;
132
133    #[test]
134    fn find_delimiters_lt() {
135        let input = b"hello <world>";
136        let result = unsafe { find_delimiters(input) };
137        assert_eq!(result, DelimiterResult::Found { pos: 6, byte: b'<' });
138    }
139
140    #[test]
141    fn find_delimiters_amp() {
142        let input = b"a &amp; b";
143        let result = unsafe { find_delimiters(input) };
144        assert_eq!(result, DelimiterResult::Found { pos: 2, byte: b'&' });
145    }
146
147    #[test]
148    fn find_delimiters_none() {
149        let input = b"hello world";
150        let result = unsafe { find_delimiters(input) };
151        assert_eq!(result, DelimiterResult::NotFound);
152    }
153
154    #[test]
155    fn find_delimiters_empty() {
156        let result = unsafe { find_delimiters(b"") };
157        assert_eq!(result, DelimiterResult::NotFound);
158    }
159
160    #[test]
161    fn find_delimiters_first_byte() {
162        let result = unsafe { find_delimiters(b"<html>") };
163        assert_eq!(result, DelimiterResult::Found { pos: 0, byte: b'<' });
164    }
165
166    #[test]
167    fn find_delimiters_all_types() {
168        for &delim in b"<>&\"'=/" {
169            let input = [b'x', b'x', delim, b'x'];
170            let result = unsafe { find_delimiters(&input) };
171            assert_eq!(
172                result,
173                DelimiterResult::Found {
174                    pos: 2,
175                    byte: delim
176                },
177                "failed for delimiter 0x{delim:02X}"
178            );
179        }
180    }
181
182    #[test]
183    fn classify_bytes_mixed() {
184        let input = b"a1 <";
185        let result = unsafe { classify_bytes(input) };
186        assert_eq!(result[0], class::ALPHA); // 'a'
187        assert_eq!(result[1], class::DIGIT); // '1'
188        assert_eq!(result[2], class::WHITESPACE); // ' '
189        assert_eq!(result[3], class::DELIMITER); // '<'
190    }
191
192    #[test]
193    fn classify_bytes_empty() {
194        let result = unsafe { classify_bytes(b"") };
195        assert!(result.is_empty());
196    }
197
198    #[test]
199    fn skip_whitespace_leading() {
200        let result = unsafe { skip_whitespace(b"   hello") };
201        assert_eq!(result, 3);
202    }
203
204    #[test]
205    fn skip_whitespace_mixed() {
206        let result = unsafe { skip_whitespace(b" \t\n\rX") };
207        assert_eq!(result, 4);
208    }
209
210    #[test]
211    fn skip_whitespace_all() {
212        let result = unsafe { skip_whitespace(b"   ") };
213        assert_eq!(result, 3);
214    }
215
216    #[test]
217    fn skip_whitespace_none() {
218        let result = unsafe { skip_whitespace(b"hello") };
219        assert_eq!(result, 0);
220    }
221
222    #[test]
223    fn skip_whitespace_empty() {
224        let result = unsafe { skip_whitespace(b"") };
225        assert_eq!(result, 0);
226    }
227
228    #[test]
229    fn compute_byte_mask_basic() {
230        let input = b"hello <world>";
231        let mask = unsafe { compute_byte_mask(input, b'<') };
232        assert_eq!(mask, 1 << 6);
233    }
234
235    #[test]
236    fn compute_byte_mask_multiple() {
237        let input = b"a<b<c";
238        let mask = unsafe { compute_byte_mask(input, b'<') };
239        assert_eq!(mask, (1 << 1) | (1 << 3));
240    }
241
242    #[test]
243    fn compute_byte_mask_none() {
244        let input = b"hello world";
245        let mask = unsafe { compute_byte_mask(input, b'<') };
246        assert_eq!(mask, 0);
247    }
248
249    #[test]
250    fn compute_byte_mask_empty() {
251        let mask = unsafe { compute_byte_mask(b"", b'<') };
252        assert_eq!(mask, 0);
253    }
254
255    #[test]
256    fn compute_all_masks_basic() {
257        let input = b"<div class=\"foo\">";
258        let masks = unsafe { compute_all_masks(input) };
259        assert_eq!(masks.lt, 1 << 0); // '<' at 0
260        assert_eq!(masks.gt, 1 << 16); // '>' at 16
261        assert_eq!(masks.quot, (1 << 11) | (1 << 15)); // '"' at 11 and 15
262    }
263
264    #[test]
265    fn compute_all_masks_empty() {
266        let masks = unsafe { compute_all_masks(b"") };
267        assert_eq!(masks.lt, 0);
268        assert_eq!(masks.gt, 0);
269    }
270
271    #[test]
272    fn compute_all_masks_matches_individual() {
273        let input = b"Hello <World> & \"test\" = 'value' / 123\n\r\t end!!";
274        let masks = unsafe { compute_all_masks(input) };
275        assert_eq!(masks.lt, unsafe { compute_byte_mask(input, b'<') });
276        assert_eq!(masks.gt, unsafe { compute_byte_mask(input, b'>') });
277        assert_eq!(masks.quot, unsafe { compute_byte_mask(input, b'"') });
278        assert_eq!(masks.apos, unsafe { compute_byte_mask(input, b'\'') });
279    }
280}