polykit_core/
simd_utils.rs1#[cfg(target_arch = "aarch64")]
13use std::arch::aarch64::*;
14
15#[cfg(target_arch = "x86_64")]
16use std::arch::x86_64::*;
17
18#[cfg(target_arch = "x86")]
19use std::arch::x86::*;
20
21#[inline]
30pub fn fast_str_eq(a: &str, b: &str) -> bool {
31 if a.len() != b.len() {
32 return false;
33 }
34
35 if a.len() < 16 {
36 return a == b;
37 }
38
39 #[cfg(target_arch = "aarch64")]
40 {
41 fast_str_eq_simd_aarch64(a.as_bytes(), b.as_bytes())
42 }
43
44 #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
45 {
46 if is_x86_feature_detected!("sse2") {
47 unsafe { fast_str_eq_simd_x86(a.as_bytes(), b.as_bytes()) }
48 } else {
49 a == b
50 }
51 }
52
53 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64", target_arch = "x86")))]
54 {
55 a == b
56 }
57}
58
59#[cfg(target_arch = "aarch64")]
60#[inline]
61fn fast_str_eq_simd_aarch64(a: &[u8], b: &[u8]) -> bool {
62 let len = a.len();
63 let mut offset = 0;
64
65 unsafe {
66 while offset + 16 <= len {
67 let a_chunk = vld1q_u8(a.as_ptr().add(offset));
68 let b_chunk = vld1q_u8(b.as_ptr().add(offset));
69 let cmp = vceqq_u8(a_chunk, b_chunk);
70 let mask = vminvq_u8(cmp);
71
72 if mask != 255 {
73 return false;
74 }
75
76 offset += 16;
77 }
78
79 #[allow(clippy::needless_range_loop)]
80 for i in offset..len {
81 if a[i] != b[i] {
82 return false;
83 }
84 }
85
86 true
87 }
88}
89
90#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
91#[target_feature(enable = "sse2")]
92#[inline]
93unsafe fn fast_str_eq_simd_x86(a: &[u8], b: &[u8]) -> bool {
94 let len = a.len();
95 let mut offset = 0;
96
97 while offset + 16 <= len {
98 let a_chunk = _mm_loadu_si128(a.as_ptr().add(offset) as *const __m128i);
99 let b_chunk = _mm_loadu_si128(b.as_ptr().add(offset) as *const __m128i);
100 let cmp = _mm_cmpeq_epi8(a_chunk, b_chunk);
101 let mask = _mm_movemask_epi8(cmp);
102
103 if mask != 0xFFFF {
104 return false;
105 }
106
107 offset += 16;
108 }
109
110 #[allow(clippy::needless_range_loop)]
111 for i in offset..len {
112 if a[i] != b[i] {
113 return false;
114 }
115 }
116
117 true
118}
119
120#[inline]
122pub fn is_ascii_fast(s: &[u8]) -> bool {
123 if s.is_empty() {
124 return true;
125 }
126
127 if s.len() < 16 {
128 return s.iter().all(|&b| b < 128);
129 }
130
131 #[cfg(target_arch = "aarch64")]
132 {
133 is_ascii_simd_aarch64(s)
134 }
135
136 #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
137 {
138 if is_x86_feature_detected!("sse2") {
139 unsafe { is_ascii_simd_x86(s) }
140 } else {
141 s.iter().all(|&b| b < 128)
142 }
143 }
144
145 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64", target_arch = "x86")))]
146 {
147 s.iter().all(|&b| b < 128)
148 }
149}
150
151#[cfg(target_arch = "aarch64")]
152#[inline]
153fn is_ascii_simd_aarch64(s: &[u8]) -> bool {
154 let len = s.len();
155 let mut offset = 0;
156
157 unsafe {
158 let ascii_mask = vdupq_n_u8(0x80);
159
160 while offset + 16 <= len {
161 let chunk = vld1q_u8(s.as_ptr().add(offset));
162 let test = vtstq_u8(chunk, ascii_mask);
163 let any_high = vmaxvq_u8(test);
164
165 if any_high != 0 {
166 return false;
167 }
168
169 offset += 16;
170 }
171
172 #[allow(clippy::needless_range_loop)]
173 for i in offset..len {
174 if s[i] >= 128 {
175 return false;
176 }
177 }
178
179 true
180 }
181}
182
183#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
184#[target_feature(enable = "sse2")]
185#[inline]
186unsafe fn is_ascii_simd_x86(s: &[u8]) -> bool {
187 let len = s.len();
188 let mut offset = 0;
189
190 let ascii_mask = _mm_set1_epi8(0x80u8 as i8);
191
192 while offset + 16 <= len {
193 let chunk = _mm_loadu_si128(s.as_ptr().add(offset) as *const __m128i);
194 let test = _mm_and_si128(chunk, ascii_mask);
195 let mask = _mm_movemask_epi8(test);
196
197 if mask != 0 {
198 return false;
199 }
200
201 offset += 16;
202 }
203
204 #[allow(clippy::needless_range_loop)]
205 for i in offset..len {
206 if s[i] >= 128 {
207 return false;
208 }
209 }
210
211 true
212}
213
214#[inline]
216pub fn find_byte_fast(haystack: &[u8], needle: u8) -> Option<usize> {
217 if haystack.is_empty() {
218 return None;
219 }
220
221 if haystack.len() < 16 {
222 return haystack.iter().position(|&b| b == needle);
223 }
224
225 #[cfg(target_arch = "aarch64")]
226 {
227 find_byte_simd_aarch64(haystack, needle)
228 }
229
230 #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
231 {
232 if is_x86_feature_detected!("sse2") {
233 unsafe { find_byte_simd_x86(haystack, needle) }
234 } else {
235 haystack.iter().position(|&b| b == needle)
236 }
237 }
238
239 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64", target_arch = "x86")))]
240 {
241 haystack.iter().position(|&b| b == needle)
242 }
243}
244
245#[cfg(target_arch = "aarch64")]
246#[inline]
247fn find_byte_simd_aarch64(haystack: &[u8], needle: u8) -> Option<usize> {
248 let len = haystack.len();
249 let mut offset = 0;
250
251 unsafe {
252 let needle_vec = vdupq_n_u8(needle);
253
254 while offset + 16 <= len {
255 let chunk = vld1q_u8(haystack.as_ptr().add(offset));
256 let cmp = vceqq_u8(chunk, needle_vec);
257 let mask = vmaxvq_u8(cmp);
258
259 if mask != 0 {
260 #[allow(clippy::needless_range_loop)]
261 for i in 0..16 {
262 if haystack[offset + i] == needle {
263 return Some(offset + i);
264 }
265 }
266 }
267
268 offset += 16;
269 }
270
271 #[allow(clippy::needless_range_loop)]
272 for i in offset..len {
273 if haystack[i] == needle {
274 return Some(i);
275 }
276 }
277
278 None
279 }
280}
281
282#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
283#[target_feature(enable = "sse2")]
284#[inline]
285unsafe fn find_byte_simd_x86(haystack: &[u8], needle: u8) -> Option<usize> {
286 let len = haystack.len();
287 let mut offset = 0;
288
289 let needle_vec = _mm_set1_epi8(needle as i8);
290
291 while offset + 16 <= len {
292 let chunk = _mm_loadu_si128(haystack.as_ptr().add(offset) as *const __m128i);
293 let cmp = _mm_cmpeq_epi8(chunk, needle_vec);
294 let mask = _mm_movemask_epi8(cmp);
295
296 if mask != 0 {
297 #[allow(clippy::needless_range_loop)]
298 for i in 0..16 {
299 if haystack[offset + i] == needle {
300 return Some(offset + i);
301 }
302 }
303 }
304
305 offset += 16;
306 }
307
308 #[allow(clippy::needless_range_loop)]
309 for i in offset..len {
310 if haystack[i] == needle {
311 return Some(i);
312 }
313 }
314
315 None
316}
317
318#[inline]
320pub fn count_byte_fast(haystack: &[u8], needle: u8) -> usize {
321 if haystack.is_empty() {
322 return 0;
323 }
324
325 if haystack.len() < 16 {
326 return haystack.iter().filter(|&&b| b == needle).count();
327 }
328
329 #[cfg(target_arch = "aarch64")]
330 {
331 count_byte_simd_aarch64(haystack, needle)
332 }
333
334 #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
335 {
336 if is_x86_feature_detected!("sse2") {
337 unsafe { count_byte_simd_x86(haystack, needle) }
338 } else {
339 haystack.iter().filter(|&&b| b == needle).count()
340 }
341 }
342
343 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64", target_arch = "x86")))]
344 {
345 haystack.iter().filter(|&&b| b == needle).count()
346 }
347}
348
349#[cfg(target_arch = "aarch64")]
350#[inline]
351fn count_byte_simd_aarch64(haystack: &[u8], needle: u8) -> usize {
352 let len = haystack.len();
353 let mut offset = 0;
354 let mut count = 0;
355
356 unsafe {
357 let needle_vec = vdupq_n_u8(needle);
358 let ones = vdupq_n_u8(1);
359
360 while offset + 16 <= len {
361 let chunk = vld1q_u8(haystack.as_ptr().add(offset));
362 let cmp = vceqq_u8(chunk, needle_vec);
363 let masked = vandq_u8(cmp, ones);
364 count += vaddvq_u8(masked) as usize;
365 offset += 16;
366 }
367
368 #[allow(clippy::needless_range_loop)]
369 for i in offset..len {
370 if haystack[i] == needle {
371 count += 1;
372 }
373 }
374
375 count
376 }
377}
378
379#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
380#[target_feature(enable = "sse2")]
381#[inline]
382unsafe fn count_byte_simd_x86(haystack: &[u8], needle: u8) -> usize {
383 let len = haystack.len();
384 let mut offset = 0;
385 let mut count = 0;
386
387 let needle_vec = _mm_set1_epi8(needle as i8);
388
389 while offset + 16 <= len {
390 let chunk = _mm_loadu_si128(haystack.as_ptr().add(offset) as *const __m128i);
391 let cmp = _mm_cmpeq_epi8(chunk, needle_vec);
392 let mask = _mm_movemask_epi8(cmp);
393 count += mask.count_ones() as usize;
394 offset += 16;
395 }
396
397 #[allow(clippy::needless_range_loop)]
398 for i in offset..len {
399 if haystack[i] == needle {
400 count += 1;
401 }
402 }
403
404 count
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410
411 #[test]
412 fn test_fast_str_eq() {
413 assert!(fast_str_eq("hello", "hello"));
414 assert!(!fast_str_eq("hello", "world"));
415 assert!(!fast_str_eq("hello", "hello!"));
416
417 let long_str = "a".repeat(100);
418 assert!(fast_str_eq(&long_str, &long_str));
419 assert!(!fast_str_eq(&long_str, &"b".repeat(100)));
420 }
421
422 #[test]
423 fn test_is_ascii_fast() {
424 assert!(is_ascii_fast(b"hello world"));
425 assert!(is_ascii_fast(b"0123456789abcdefghijklmnop"));
426 assert!(!is_ascii_fast("hello 世界".as_bytes()));
427 }
428
429 #[test]
430 fn test_find_byte_fast() {
431 assert_eq!(find_byte_fast(b"hello", b'e'), Some(1));
432 assert_eq!(find_byte_fast(b"hello world!", b'w'), Some(6));
433 assert_eq!(find_byte_fast(b"hello", b'x'), None);
434
435 let mut long_bytes = b"a".repeat(100);
436 long_bytes.push(b'b');
437 assert_eq!(find_byte_fast(&long_bytes, b'b'), Some(100));
438 }
439
440 #[test]
441 fn test_count_byte_fast() {
442 assert_eq!(count_byte_fast(b"hello", b'l'), 2);
443 assert_eq!(count_byte_fast(b"aaabbbccc", b'b'), 3);
444 assert_eq!(count_byte_fast(b"hello", b'x'), 0);
445
446 let long_bytes = vec![b'a'; 100];
447 assert_eq!(count_byte_fast(&long_bytes, b'a'), 100);
448 }
449}