1use std::sync::OnceLock;
9
10use crate::{AllMasks, DelimiterResult};
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
14pub enum SimdLevel {
15 Scalar,
17 Sse42,
19 Avx2,
21 Neon,
23}
24
25pub fn detect() -> SimdLevel {
30 #[cfg(target_arch = "x86_64")]
31 {
32 if is_x86_feature_detected!("avx2") {
33 return SimdLevel::Avx2;
34 }
35 if is_x86_feature_detected!("sse4.2") {
36 return SimdLevel::Sse42;
37 }
38 }
39
40 #[cfg(target_arch = "aarch64")]
41 {
42 return SimdLevel::Neon;
43 }
44
45 #[allow(unreachable_code)]
46 SimdLevel::Scalar
47}
48
49pub struct SimdOps {
56 pub find_delimiters: unsafe fn(&[u8]) -> DelimiterResult,
58
59 pub classify_bytes: unsafe fn(&[u8]) -> Vec<u8>,
61
62 pub skip_whitespace: unsafe fn(&[u8]) -> usize,
64
65 pub compute_byte_mask: unsafe fn(&[u8], u8) -> u64,
68
69 pub compute_all_masks: unsafe fn(&[u8]) -> AllMasks,
72
73 pub level: SimdLevel,
75}
76
77static SIMD_OPS: OnceLock<SimdOps> = OnceLock::new();
79
80pub fn ops() -> &'static SimdOps {
85 SIMD_OPS.get_or_init(|| {
86 let level = detect();
87 match level {
88 #[cfg(target_arch = "x86_64")]
89 SimdLevel::Avx2 => SimdOps {
90 find_delimiters: crate::avx2::find_delimiters,
91 classify_bytes: crate::avx2::classify_bytes,
92 skip_whitespace: crate::avx2::skip_whitespace,
93 compute_byte_mask: crate::avx2::compute_byte_mask,
94 compute_all_masks: crate::scalar::compute_all_masks,
95 level,
96 },
97 #[cfg(target_arch = "x86_64")]
98 SimdLevel::Sse42 => SimdOps {
99 find_delimiters: crate::sse42::find_delimiters,
100 classify_bytes: crate::sse42::classify_bytes,
101 skip_whitespace: crate::sse42::skip_whitespace,
102 compute_byte_mask: crate::sse42::compute_byte_mask,
103 compute_all_masks: crate::scalar::compute_all_masks,
104 level,
105 },
106 #[cfg(target_arch = "aarch64")]
107 SimdLevel::Neon => SimdOps {
108 find_delimiters: crate::neon::find_delimiters,
109 classify_bytes: crate::neon::classify_bytes,
110 skip_whitespace: crate::neon::skip_whitespace,
111 compute_byte_mask: crate::neon::compute_byte_mask,
112 compute_all_masks: crate::neon::compute_all_masks,
113 level,
114 },
115 _ => SimdOps {
117 find_delimiters: crate::scalar::find_delimiters,
118 classify_bytes: crate::scalar::classify_bytes,
119 skip_whitespace: crate::scalar::skip_whitespace,
120 compute_byte_mask: crate::scalar::compute_byte_mask,
121 compute_all_masks: crate::scalar::compute_all_masks,
122 level: SimdLevel::Scalar,
123 },
124 }
125 })
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 #[test]
133 fn detect_returns_valid_level() {
134 let level = detect();
135 assert!(level >= SimdLevel::Scalar);
137 }
138
139 #[test]
140 fn ops_singleton_is_consistent() {
141 let a = ops();
142 let b = ops();
143 assert!(std::ptr::eq(a, b));
145 }
146
147 #[test]
148 fn ops_level_matches_detect() {
149 let o = ops();
150 let detected = detect();
151 assert_eq!(o.level, detected);
152 }
153
154 #[test]
155 fn dispatch_find_delimiters() {
156 let o = ops();
157 let input = b"hello <world>";
158 let result = unsafe { (o.find_delimiters)(input) };
159 assert_eq!(result, crate::DelimiterResult::Found { pos: 6, byte: b'<' });
160 }
161
162 #[test]
163 fn dispatch_classify_bytes() {
164 let o = ops();
165 let input = b"a1 <";
166 let result = unsafe { (o.classify_bytes)(input) };
167 assert_eq!(result[0], crate::class::ALPHA);
168 assert_eq!(result[1], crate::class::DIGIT);
169 assert_eq!(result[2], crate::class::WHITESPACE);
170 assert_eq!(result[3], crate::class::DELIMITER);
171 }
172
173 #[test]
174 fn dispatch_skip_whitespace() {
175 let o = ops();
176 let input = b" hello";
177 let result = unsafe { (o.skip_whitespace)(input) };
178 assert_eq!(result, 3);
179 }
180
181 #[test]
182 fn dispatch_compute_byte_mask() {
183 let o = ops();
184 let input = b"hello <world> & end";
185 let mask = unsafe { (o.compute_byte_mask)(input, b'<') };
186 assert_eq!(mask, 1 << 6);
187 }
188
189 #[test]
190 fn dispatch_compute_byte_mask_matches_scalar() {
191 let o = ops();
192 let input = b"Hello <World> & \"test\" = 'value' / 123\n\r\t end!!";
193 for &byte in b"<>&\"'=/" {
194 let dispatched = unsafe { (o.compute_byte_mask)(input, byte) };
195 let scalar = unsafe { crate::scalar::compute_byte_mask(input, byte) };
196 assert_eq!(dispatched, scalar, "mismatch for byte 0x{byte:02X}");
197 }
198 }
199
200 #[test]
201 fn dispatch_matches_scalar() {
202 let o = ops();
203 let input = b"Hello <World> & \"test\" = 'value' / 123\n\r\t end";
204 let dispatched = unsafe { (o.classify_bytes)(input) };
205 let scalar = unsafe { crate::scalar::classify_bytes(input) };
206 assert_eq!(dispatched, scalar);
207 }
208
209 #[test]
210 fn dispatch_compute_all_masks() {
211 let o = ops();
212 let input = b"Hello <World> & \"test\" = 'value' / 123\n\r\t end!!";
213 let dispatched = unsafe { (o.compute_all_masks)(input) };
214 let scalar = crate::scalar::compute_all_masks_safe(input);
215 assert_eq!(dispatched.lt, scalar.lt, "lt mismatch");
216 assert_eq!(dispatched.gt, scalar.gt, "gt mismatch");
217 assert_eq!(dispatched.quot, scalar.quot, "quot mismatch");
218 assert_eq!(dispatched.apos, scalar.apos, "apos mismatch");
219 }
220}