Skip to main content

fhp_simd/
dispatch.rs

1//! Runtime SIMD feature detection and function-pointer dispatch.
2//!
3//! On first access the best available SIMD level is detected via CPUID
4//! (x86_64) or compile-time knowledge (aarch64, where NEON is mandatory).
5//! The result is cached in a `OnceLock` so subsequent calls are a single
6//! pointer load.
7
8use std::sync::OnceLock;
9
10use crate::{AllMasks, DelimiterResult};
11
12/// Supported SIMD instruction set level, ordered from least to most capable.
13#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
14pub enum SimdLevel {
15    /// Portable scalar fallback — always available.
16    Scalar,
17    /// x86_64 SSE4.2 — 128-bit registers.
18    Sse42,
19    /// x86_64 AVX2 — 256-bit registers.
20    Avx2,
21    /// ARM aarch64 NEON — 128-bit registers (always available on aarch64).
22    Neon,
23}
24
25/// Detect the best SIMD level supported by the current CPU at runtime.
26///
27/// On x86_64 this queries CPUID; on aarch64 NEON is always present.
28/// All other architectures fall back to `Scalar`.
29pub fn detect() -> SimdLevel {
30    #[cfg(target_arch = "x86_64")]
31    {
32        if is_x86_feature_detected!("avx2") {
33            return SimdLevel::Avx2;
34        }
35        if is_x86_feature_detected!("sse4.2") {
36            return SimdLevel::Sse42;
37        }
38    }
39
40    #[cfg(target_arch = "aarch64")]
41    {
42        return SimdLevel::Neon;
43    }
44
45    #[allow(unreachable_code)]
46    SimdLevel::Scalar
47}
48
49/// Function-pointer table for the three core SIMD operations.
50///
51/// Each slot holds an `unsafe fn` because the SIMD variants require
52/// `#[target_feature]` and therefore cannot be called without `unsafe`.
53/// The scalar implementations are inherently safe but share the same
54/// signature for uniformity.
55pub struct SimdOps {
56    /// Scan for the first HTML delimiter (`<`, `>`, `&`, `"`, `'`, `=`, `/`).
57    pub find_delimiters: unsafe fn(&[u8]) -> DelimiterResult,
58
59    /// Classify every byte into a category bitmask (see [`crate::class`]).
60    pub classify_bytes: unsafe fn(&[u8]) -> Vec<u8>,
61
62    /// Return the offset of the first non-whitespace byte.
63    pub skip_whitespace: unsafe fn(&[u8]) -> usize,
64
65    /// Produce a u64 bitmask where bit `i` is set if `block[i] == byte`.
66    /// The block must be at most 64 bytes.
67    pub compute_byte_mask: unsafe fn(&[u8], u8) -> u64,
68
69    /// Compute all seven delimiter bitmasks in a single pass.
70    /// Loads each 16-byte chunk once, producing all masks simultaneously.
71    pub compute_all_masks: unsafe fn(&[u8]) -> AllMasks,
72
73    /// The SIMD level backing these function pointers.
74    pub level: SimdLevel,
75}
76
77/// Global singleton — initialised once on first access.
78static SIMD_OPS: OnceLock<SimdOps> = OnceLock::new();
79
80/// Return a reference to the auto-detected [`SimdOps`] dispatch table.
81///
82/// The detection runs exactly once (via `OnceLock`); subsequent calls
83/// are a single atomic load.
84pub fn ops() -> &'static SimdOps {
85    SIMD_OPS.get_or_init(|| {
86        let level = detect();
87        match level {
88            #[cfg(target_arch = "x86_64")]
89            SimdLevel::Avx2 => SimdOps {
90                find_delimiters: crate::avx2::find_delimiters,
91                classify_bytes: crate::avx2::classify_bytes,
92                skip_whitespace: crate::avx2::skip_whitespace,
93                compute_byte_mask: crate::avx2::compute_byte_mask,
94                compute_all_masks: crate::scalar::compute_all_masks,
95                level,
96            },
97            #[cfg(target_arch = "x86_64")]
98            SimdLevel::Sse42 => SimdOps {
99                find_delimiters: crate::sse42::find_delimiters,
100                classify_bytes: crate::sse42::classify_bytes,
101                skip_whitespace: crate::sse42::skip_whitespace,
102                compute_byte_mask: crate::sse42::compute_byte_mask,
103                compute_all_masks: crate::scalar::compute_all_masks,
104                level,
105            },
106            #[cfg(target_arch = "aarch64")]
107            SimdLevel::Neon => SimdOps {
108                find_delimiters: crate::neon::find_delimiters,
109                classify_bytes: crate::neon::classify_bytes,
110                skip_whitespace: crate::neon::skip_whitespace,
111                compute_byte_mask: crate::neon::compute_byte_mask,
112                compute_all_masks: crate::neon::compute_all_masks,
113                level,
114            },
115            // Scalar fallback for any architecture or if no SIMD detected.
116            _ => SimdOps {
117                find_delimiters: crate::scalar::find_delimiters,
118                classify_bytes: crate::scalar::classify_bytes,
119                skip_whitespace: crate::scalar::skip_whitespace,
120                compute_byte_mask: crate::scalar::compute_byte_mask,
121                compute_all_masks: crate::scalar::compute_all_masks,
122                level: SimdLevel::Scalar,
123            },
124        }
125    })
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[test]
133    fn detect_returns_valid_level() {
134        let level = detect();
135        // On this platform at least scalar should be available.
136        assert!(level >= SimdLevel::Scalar);
137    }
138
139    #[test]
140    fn ops_singleton_is_consistent() {
141        let a = ops();
142        let b = ops();
143        // Same pointer — OnceLock guarantees singleton.
144        assert!(std::ptr::eq(a, b));
145    }
146
147    #[test]
148    fn ops_level_matches_detect() {
149        let o = ops();
150        let detected = detect();
151        assert_eq!(o.level, detected);
152    }
153
154    #[test]
155    fn dispatch_find_delimiters() {
156        let o = ops();
157        let input = b"hello <world>";
158        let result = unsafe { (o.find_delimiters)(input) };
159        assert_eq!(result, crate::DelimiterResult::Found { pos: 6, byte: b'<' });
160    }
161
162    #[test]
163    fn dispatch_classify_bytes() {
164        let o = ops();
165        let input = b"a1 <";
166        let result = unsafe { (o.classify_bytes)(input) };
167        assert_eq!(result[0], crate::class::ALPHA);
168        assert_eq!(result[1], crate::class::DIGIT);
169        assert_eq!(result[2], crate::class::WHITESPACE);
170        assert_eq!(result[3], crate::class::DELIMITER);
171    }
172
173    #[test]
174    fn dispatch_skip_whitespace() {
175        let o = ops();
176        let input = b"   hello";
177        let result = unsafe { (o.skip_whitespace)(input) };
178        assert_eq!(result, 3);
179    }
180
181    #[test]
182    fn dispatch_compute_byte_mask() {
183        let o = ops();
184        let input = b"hello <world> & end";
185        let mask = unsafe { (o.compute_byte_mask)(input, b'<') };
186        assert_eq!(mask, 1 << 6);
187    }
188
189    #[test]
190    fn dispatch_compute_byte_mask_matches_scalar() {
191        let o = ops();
192        let input = b"Hello <World> & \"test\" = 'value' / 123\n\r\t end!!";
193        for &byte in b"<>&\"'=/" {
194            let dispatched = unsafe { (o.compute_byte_mask)(input, byte) };
195            let scalar = unsafe { crate::scalar::compute_byte_mask(input, byte) };
196            assert_eq!(dispatched, scalar, "mismatch for byte 0x{byte:02X}");
197        }
198    }
199
200    #[test]
201    fn dispatch_matches_scalar() {
202        let o = ops();
203        let input = b"Hello <World> & \"test\" = 'value' / 123\n\r\t end";
204        let dispatched = unsafe { (o.classify_bytes)(input) };
205        let scalar = unsafe { crate::scalar::classify_bytes(input) };
206        assert_eq!(dispatched, scalar);
207    }
208
209    #[test]
210    fn dispatch_compute_all_masks() {
211        let o = ops();
212        let input = b"Hello <World> & \"test\" = 'value' / 123\n\r\t end!!";
213        let dispatched = unsafe { (o.compute_all_masks)(input) };
214        let scalar = crate::scalar::compute_all_masks_safe(input);
215        assert_eq!(dispatched.lt, scalar.lt, "lt mismatch");
216        assert_eq!(dispatched.gt, scalar.gt, "gt mismatch");
217        assert_eq!(dispatched.quot, scalar.quot, "quot mismatch");
218        assert_eq!(dispatched.apos, scalar.apos, "apos mismatch");
219    }
220}