polykit_core/
simd_utils.rs

1//! SIMD-optimized utilities for performance-critical operations.
2//!
3//! This module provides SIMD-accelerated functions for common operations:
4//! - String comparison
5//! - ASCII validation
6//! - Byte searching
7//! - Byte counting
8//!
9//! SIMD implementations are architecture-specific and automatically selected at compile time.
10//! Falls back to scalar implementations when SIMD is not available.
11
12#[cfg(target_arch = "aarch64")]
13use std::arch::aarch64::*;
14
15#[cfg(target_arch = "x86_64")]
16use std::arch::x86_64::*;
17
18#[cfg(target_arch = "x86")]
19use std::arch::x86::*;
20
21/// Fast string comparison using SIMD when available.
22///
23/// Uses architecture-specific SIMD instructions on:
24/// - ARM64/aarch64 (Apple Silicon, ARM servers)
25/// - x86_64 (Intel/AMD with SSE2)
26/// - x86 (32-bit Intel/AMD with SSE2)
27///
28/// Falls back to standard comparison on other architectures or short strings.
29#[inline]
30pub fn fast_str_eq(a: &str, b: &str) -> bool {
31    if a.len() != b.len() {
32        return false;
33    }
34
35    if a.len() < 16 {
36        return a == b;
37    }
38
39    #[cfg(target_arch = "aarch64")]
40    {
41        fast_str_eq_simd_aarch64(a.as_bytes(), b.as_bytes())
42    }
43
44    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
45    {
46        if is_x86_feature_detected!("sse2") {
47            unsafe { fast_str_eq_simd_x86(a.as_bytes(), b.as_bytes()) }
48        } else {
49            a == b
50        }
51    }
52
53    #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64", target_arch = "x86")))]
54    {
55        a == b
56    }
57}
58
59#[cfg(target_arch = "aarch64")]
60#[inline]
61fn fast_str_eq_simd_aarch64(a: &[u8], b: &[u8]) -> bool {
62    let len = a.len();
63    let mut offset = 0;
64
65    unsafe {
66        while offset + 16 <= len {
67            let a_chunk = vld1q_u8(a.as_ptr().add(offset));
68            let b_chunk = vld1q_u8(b.as_ptr().add(offset));
69            let cmp = vceqq_u8(a_chunk, b_chunk);
70            let mask = vminvq_u8(cmp);
71
72            if mask != 255 {
73                return false;
74            }
75
76            offset += 16;
77        }
78
79        #[allow(clippy::needless_range_loop)]
80        for i in offset..len {
81            if a[i] != b[i] {
82                return false;
83            }
84        }
85
86        true
87    }
88}
89
90#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
91#[target_feature(enable = "sse2")]
92#[inline]
93unsafe fn fast_str_eq_simd_x86(a: &[u8], b: &[u8]) -> bool {
94    let len = a.len();
95    let mut offset = 0;
96
97    while offset + 16 <= len {
98        let a_chunk = _mm_loadu_si128(a.as_ptr().add(offset) as *const __m128i);
99        let b_chunk = _mm_loadu_si128(b.as_ptr().add(offset) as *const __m128i);
100        let cmp = _mm_cmpeq_epi8(a_chunk, b_chunk);
101        let mask = _mm_movemask_epi8(cmp);
102
103        if mask != 0xFFFF {
104            return false;
105        }
106
107        offset += 16;
108    }
109
110    #[allow(clippy::needless_range_loop)]
111    for i in offset..len {
112        if a[i] != b[i] {
113            return false;
114        }
115    }
116
117    true
118}
119
120/// Fast check if string contains only ASCII characters using SIMD.
121#[inline]
122pub fn is_ascii_fast(s: &[u8]) -> bool {
123    if s.is_empty() {
124        return true;
125    }
126
127    if s.len() < 16 {
128        return s.iter().all(|&b| b < 128);
129    }
130
131    #[cfg(target_arch = "aarch64")]
132    {
133        is_ascii_simd_aarch64(s)
134    }
135
136    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
137    {
138        if is_x86_feature_detected!("sse2") {
139            unsafe { is_ascii_simd_x86(s) }
140        } else {
141            s.iter().all(|&b| b < 128)
142        }
143    }
144
145    #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64", target_arch = "x86")))]
146    {
147        s.iter().all(|&b| b < 128)
148    }
149}
150
151#[cfg(target_arch = "aarch64")]
152#[inline]
153fn is_ascii_simd_aarch64(s: &[u8]) -> bool {
154    let len = s.len();
155    let mut offset = 0;
156
157    unsafe {
158        let ascii_mask = vdupq_n_u8(0x80);
159
160        while offset + 16 <= len {
161            let chunk = vld1q_u8(s.as_ptr().add(offset));
162            let test = vtstq_u8(chunk, ascii_mask);
163            let any_high = vmaxvq_u8(test);
164
165            if any_high != 0 {
166                return false;
167            }
168
169            offset += 16;
170        }
171
172        #[allow(clippy::needless_range_loop)]
173        for i in offset..len {
174            if s[i] >= 128 {
175                return false;
176            }
177        }
178
179        true
180    }
181}
182
183#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
184#[target_feature(enable = "sse2")]
185#[inline]
186unsafe fn is_ascii_simd_x86(s: &[u8]) -> bool {
187    let len = s.len();
188    let mut offset = 0;
189
190    let ascii_mask = _mm_set1_epi8(0x80u8 as i8);
191
192    while offset + 16 <= len {
193        let chunk = _mm_loadu_si128(s.as_ptr().add(offset) as *const __m128i);
194        let test = _mm_and_si128(chunk, ascii_mask);
195        let mask = _mm_movemask_epi8(test);
196
197        if mask != 0 {
198            return false;
199        }
200
201        offset += 16;
202    }
203
204    #[allow(clippy::needless_range_loop)]
205    for i in offset..len {
206        if s[i] >= 128 {
207            return false;
208        }
209    }
210
211    true
212}
213
214/// Fast byte search using SIMD.
215#[inline]
216pub fn find_byte_fast(haystack: &[u8], needle: u8) -> Option<usize> {
217    if haystack.is_empty() {
218        return None;
219    }
220
221    if haystack.len() < 16 {
222        return haystack.iter().position(|&b| b == needle);
223    }
224
225    #[cfg(target_arch = "aarch64")]
226    {
227        find_byte_simd_aarch64(haystack, needle)
228    }
229
230    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
231    {
232        if is_x86_feature_detected!("sse2") {
233            unsafe { find_byte_simd_x86(haystack, needle) }
234        } else {
235            haystack.iter().position(|&b| b == needle)
236        }
237    }
238
239    #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64", target_arch = "x86")))]
240    {
241        haystack.iter().position(|&b| b == needle)
242    }
243}
244
245#[cfg(target_arch = "aarch64")]
246#[inline]
247fn find_byte_simd_aarch64(haystack: &[u8], needle: u8) -> Option<usize> {
248    let len = haystack.len();
249    let mut offset = 0;
250
251    unsafe {
252        let needle_vec = vdupq_n_u8(needle);
253
254        while offset + 16 <= len {
255            let chunk = vld1q_u8(haystack.as_ptr().add(offset));
256            let cmp = vceqq_u8(chunk, needle_vec);
257            let mask = vmaxvq_u8(cmp);
258
259            if mask != 0 {
260                #[allow(clippy::needless_range_loop)]
261                for i in 0..16 {
262                    if haystack[offset + i] == needle {
263                        return Some(offset + i);
264                    }
265                }
266            }
267
268            offset += 16;
269        }
270
271        #[allow(clippy::needless_range_loop)]
272        for i in offset..len {
273            if haystack[i] == needle {
274                return Some(i);
275            }
276        }
277
278        None
279    }
280}
281
282#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
283#[target_feature(enable = "sse2")]
284#[inline]
285unsafe fn find_byte_simd_x86(haystack: &[u8], needle: u8) -> Option<usize> {
286    let len = haystack.len();
287    let mut offset = 0;
288
289    let needle_vec = _mm_set1_epi8(needle as i8);
290
291    while offset + 16 <= len {
292        let chunk = _mm_loadu_si128(haystack.as_ptr().add(offset) as *const __m128i);
293        let cmp = _mm_cmpeq_epi8(chunk, needle_vec);
294        let mask = _mm_movemask_epi8(cmp);
295
296        if mask != 0 {
297            #[allow(clippy::needless_range_loop)]
298            for i in 0..16 {
299                if haystack[offset + i] == needle {
300                    return Some(offset + i);
301                }
302            }
303        }
304
305        offset += 16;
306    }
307
308    #[allow(clippy::needless_range_loop)]
309    for i in offset..len {
310        if haystack[i] == needle {
311            return Some(i);
312        }
313    }
314
315    None
316}
317
318/// Fast count of specific byte in slice using SIMD.
319#[inline]
320pub fn count_byte_fast(haystack: &[u8], needle: u8) -> usize {
321    if haystack.is_empty() {
322        return 0;
323    }
324
325    if haystack.len() < 16 {
326        return haystack.iter().filter(|&&b| b == needle).count();
327    }
328
329    #[cfg(target_arch = "aarch64")]
330    {
331        count_byte_simd_aarch64(haystack, needle)
332    }
333
334    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
335    {
336        if is_x86_feature_detected!("sse2") {
337            unsafe { count_byte_simd_x86(haystack, needle) }
338        } else {
339            haystack.iter().filter(|&&b| b == needle).count()
340        }
341    }
342
343    #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64", target_arch = "x86")))]
344    {
345        haystack.iter().filter(|&&b| b == needle).count()
346    }
347}
348
349#[cfg(target_arch = "aarch64")]
350#[inline]
351fn count_byte_simd_aarch64(haystack: &[u8], needle: u8) -> usize {
352    let len = haystack.len();
353    let mut offset = 0;
354    let mut count = 0;
355
356    unsafe {
357        let needle_vec = vdupq_n_u8(needle);
358        let ones = vdupq_n_u8(1);
359
360        while offset + 16 <= len {
361            let chunk = vld1q_u8(haystack.as_ptr().add(offset));
362            let cmp = vceqq_u8(chunk, needle_vec);
363            let masked = vandq_u8(cmp, ones);
364            count += vaddvq_u8(masked) as usize;
365            offset += 16;
366        }
367
368        #[allow(clippy::needless_range_loop)]
369        for i in offset..len {
370            if haystack[i] == needle {
371                count += 1;
372            }
373        }
374
375        count
376    }
377}
378
379#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
380#[target_feature(enable = "sse2")]
381#[inline]
382unsafe fn count_byte_simd_x86(haystack: &[u8], needle: u8) -> usize {
383    let len = haystack.len();
384    let mut offset = 0;
385    let mut count = 0;
386
387    let needle_vec = _mm_set1_epi8(needle as i8);
388
389    while offset + 16 <= len {
390        let chunk = _mm_loadu_si128(haystack.as_ptr().add(offset) as *const __m128i);
391        let cmp = _mm_cmpeq_epi8(chunk, needle_vec);
392        let mask = _mm_movemask_epi8(cmp);
393        count += mask.count_ones() as usize;
394        offset += 16;
395    }
396
397    #[allow(clippy::needless_range_loop)]
398    for i in offset..len {
399        if haystack[i] == needle {
400            count += 1;
401        }
402    }
403
404    count
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410
411    #[test]
412    fn test_fast_str_eq() {
413        assert!(fast_str_eq("hello", "hello"));
414        assert!(!fast_str_eq("hello", "world"));
415        assert!(!fast_str_eq("hello", "hello!"));
416
417        let long_str = "a".repeat(100);
418        assert!(fast_str_eq(&long_str, &long_str));
419        assert!(!fast_str_eq(&long_str, &"b".repeat(100)));
420    }
421
422    #[test]
423    fn test_is_ascii_fast() {
424        assert!(is_ascii_fast(b"hello world"));
425        assert!(is_ascii_fast(b"0123456789abcdefghijklmnop"));
426        assert!(!is_ascii_fast("hello 世界".as_bytes()));
427    }
428
429    #[test]
430    fn test_find_byte_fast() {
431        assert_eq!(find_byte_fast(b"hello", b'e'), Some(1));
432        assert_eq!(find_byte_fast(b"hello world!", b'w'), Some(6));
433        assert_eq!(find_byte_fast(b"hello", b'x'), None);
434
435        let mut long_bytes = b"a".repeat(100);
436        long_bytes.push(b'b');
437        assert_eq!(find_byte_fast(&long_bytes, b'b'), Some(100));
438    }
439
440    #[test]
441    fn test_count_byte_fast() {
442        assert_eq!(count_byte_fast(b"hello", b'l'), 2);
443        assert_eq!(count_byte_fast(b"aaabbbccc", b'b'), 3);
444        assert_eq!(count_byte_fast(b"hello", b'x'), 0);
445
446        let long_bytes = vec![b'a'; 100];
447        assert_eq!(count_byte_fast(&long_bytes, b'a'), 100);
448    }
449}