Skip to main content

pandrs/optimized/jit/
simd_string.rs

1//! SIMD-accelerated string operations
2//!
3//! Provides high-performance string operations using SIMD instructions (AVX2/SSE2)
4//! for common string operations like case conversion, character classification,
5//! and pattern matching.
6
7use std::sync::atomic::{AtomicBool, Ordering};
8
9// Feature detection cache
10static AVX2_SUPPORTED: AtomicBool = AtomicBool::new(false);
11static SSE2_SUPPORTED: AtomicBool = AtomicBool::new(false);
12static FEATURES_DETECTED: AtomicBool = AtomicBool::new(false);
13
14/// Initialize SIMD feature detection
15fn detect_features() {
16    if FEATURES_DETECTED.load(Ordering::Relaxed) {
17        return;
18    }
19
20    #[cfg(target_arch = "x86_64")]
21    {
22        AVX2_SUPPORTED.store(is_x86_feature_detected!("avx2"), Ordering::Relaxed);
23        SSE2_SUPPORTED.store(is_x86_feature_detected!("sse2"), Ordering::Relaxed);
24    }
25
26    #[cfg(target_arch = "aarch64")]
27    {
28        // ARM NEON is always available on AArch64
29        SSE2_SUPPORTED.store(true, Ordering::Relaxed);
30    }
31
32    FEATURES_DETECTED.store(true, Ordering::Relaxed);
33}
34
35/// Check if AVX2 is available
36#[inline]
37pub fn has_avx2() -> bool {
38    detect_features();
39    AVX2_SUPPORTED.load(Ordering::Relaxed)
40}
41
42/// Check if SSE2 is available
43#[inline]
44pub fn has_sse2() -> bool {
45    detect_features();
46    SSE2_SUPPORTED.load(Ordering::Relaxed)
47}
48
49// ============================================================================
50// ASCII Detection
51// ============================================================================
52
53/// Check if a string is ASCII-only using SIMD
54///
55/// Returns true if all bytes are in the ASCII range (0-127)
56pub fn is_ascii_simd(s: &str) -> bool {
57    let bytes = s.as_bytes();
58
59    #[cfg(target_arch = "x86_64")]
60    {
61        if has_avx2() && bytes.len() >= 32 {
62            return unsafe { is_ascii_avx2(bytes) };
63        }
64        if has_sse2() && bytes.len() >= 16 {
65            return unsafe { is_ascii_sse2(bytes) };
66        }
67    }
68
69    // Scalar fallback
70    bytes.iter().all(|&b| b < 128)
71}
72
73#[cfg(target_arch = "x86_64")]
74#[target_feature(enable = "avx2")]
75unsafe fn is_ascii_avx2(bytes: &[u8]) -> bool {
76    use std::arch::x86_64::*;
77
78    let len = bytes.len();
79    let mut i = 0;
80
81    // Process 32 bytes at a time
82    while i + 32 <= len {
83        let chunk = _mm256_loadu_si256(bytes.as_ptr().add(i) as *const __m256i);
84        let high_bits = _mm256_movemask_epi8(chunk);
85        if high_bits != 0 {
86            return false;
87        }
88        i += 32;
89    }
90
91    // Process remaining bytes
92    while i < len {
93        if bytes[i] >= 128 {
94            return false;
95        }
96        i += 1;
97    }
98
99    true
100}
101
102#[cfg(target_arch = "x86_64")]
103#[target_feature(enable = "sse2")]
104unsafe fn is_ascii_sse2(bytes: &[u8]) -> bool {
105    use std::arch::x86_64::*;
106
107    let len = bytes.len();
108    let mut i = 0;
109
110    // Process 16 bytes at a time
111    while i + 16 <= len {
112        let chunk = _mm_loadu_si128(bytes.as_ptr().add(i) as *const __m128i);
113        let high_bits = _mm_movemask_epi8(chunk);
114        if high_bits != 0 {
115            return false;
116        }
117        i += 16;
118    }
119
120    // Process remaining bytes
121    while i < len {
122        if bytes[i] >= 128 {
123            return false;
124        }
125        i += 1;
126    }
127
128    true
129}
130
131// ============================================================================
132// Case Conversion (ASCII fast path)
133// ============================================================================
134
135/// Convert ASCII string to uppercase using SIMD
136///
137/// For non-ASCII strings, falls back to standard conversion
138pub fn to_uppercase_simd(s: &str) -> String {
139    if is_ascii_simd(s) {
140        let bytes = s.as_bytes();
141
142        #[cfg(target_arch = "x86_64")]
143        {
144            if has_avx2() && bytes.len() >= 32 {
145                return unsafe { to_uppercase_avx2(bytes) };
146            }
147            if has_sse2() && bytes.len() >= 16 {
148                return unsafe { to_uppercase_sse2(bytes) };
149            }
150        }
151
152        // Scalar ASCII fast path
153        return to_uppercase_ascii_scalar(bytes);
154    }
155
156    // Non-ASCII fallback
157    s.to_uppercase()
158}
159
160/// Convert ASCII string to lowercase using SIMD
161pub fn to_lowercase_simd(s: &str) -> String {
162    if is_ascii_simd(s) {
163        let bytes = s.as_bytes();
164
165        #[cfg(target_arch = "x86_64")]
166        {
167            if has_avx2() && bytes.len() >= 32 {
168                return unsafe { to_lowercase_avx2(bytes) };
169            }
170            if has_sse2() && bytes.len() >= 16 {
171                return unsafe { to_lowercase_sse2(bytes) };
172            }
173        }
174
175        // Scalar ASCII fast path
176        return to_lowercase_ascii_scalar(bytes);
177    }
178
179    // Non-ASCII fallback
180    s.to_lowercase()
181}
182
183fn to_uppercase_ascii_scalar(bytes: &[u8]) -> String {
184    let mut result = Vec::with_capacity(bytes.len());
185    for &b in bytes {
186        if b >= b'a' && b <= b'z' {
187            result.push(b - 32);
188        } else {
189            result.push(b);
190        }
191    }
192    unsafe { String::from_utf8_unchecked(result) }
193}
194
195fn to_lowercase_ascii_scalar(bytes: &[u8]) -> String {
196    let mut result = Vec::with_capacity(bytes.len());
197    for &b in bytes {
198        if b >= b'A' && b <= b'Z' {
199            result.push(b + 32);
200        } else {
201            result.push(b);
202        }
203    }
204    unsafe { String::from_utf8_unchecked(result) }
205}
206
207#[cfg(target_arch = "x86_64")]
208#[target_feature(enable = "avx2")]
209unsafe fn to_uppercase_avx2(bytes: &[u8]) -> String {
210    use std::arch::x86_64::*;
211
212    let len = bytes.len();
213    let mut result: Vec<u8> = Vec::with_capacity(len);
214    result.set_len(len);
215
216    let lower_a = _mm256_set1_epi8(b'a' as i8);
217    let lower_z = _mm256_set1_epi8(b'z' as i8);
218    let diff = _mm256_set1_epi8(32);
219
220    let mut i = 0;
221
222    // Process 32 bytes at a time
223    while i + 32 <= len {
224        let chunk = _mm256_loadu_si256(bytes.as_ptr().add(i) as *const __m256i);
225
226        // Check if char is in range [a-z]
227        let ge_a = _mm256_cmpgt_epi8(chunk, _mm256_sub_epi8(lower_a, _mm256_set1_epi8(1)));
228        let le_z = _mm256_cmpgt_epi8(_mm256_add_epi8(lower_z, _mm256_set1_epi8(1)), chunk);
229        let is_lower = _mm256_and_si256(ge_a, le_z);
230
231        // Subtract 32 from lowercase letters
232        let to_sub = _mm256_and_si256(is_lower, diff);
233        let converted = _mm256_sub_epi8(chunk, to_sub);
234
235        _mm256_storeu_si256(result.as_mut_ptr().add(i) as *mut __m256i, converted);
236        i += 32;
237    }
238
239    // Process remaining bytes
240    while i < len {
241        let b = bytes[i];
242        result[i] = if b >= b'a' && b <= b'z' { b - 32 } else { b };
243        i += 1;
244    }
245
246    String::from_utf8_unchecked(result)
247}
248
249#[cfg(target_arch = "x86_64")]
250#[target_feature(enable = "sse2")]
251unsafe fn to_uppercase_sse2(bytes: &[u8]) -> String {
252    use std::arch::x86_64::*;
253
254    let len = bytes.len();
255    let mut result: Vec<u8> = Vec::with_capacity(len);
256    result.set_len(len);
257
258    let lower_a = _mm_set1_epi8(b'a' as i8);
259    let lower_z = _mm_set1_epi8(b'z' as i8);
260    let diff = _mm_set1_epi8(32);
261
262    let mut i = 0;
263
264    // Process 16 bytes at a time
265    while i + 16 <= len {
266        let chunk = _mm_loadu_si128(bytes.as_ptr().add(i) as *const __m128i);
267
268        // Check if char is in range [a-z]
269        let ge_a = _mm_cmpgt_epi8(chunk, _mm_sub_epi8(lower_a, _mm_set1_epi8(1)));
270        let le_z = _mm_cmpgt_epi8(_mm_add_epi8(lower_z, _mm_set1_epi8(1)), chunk);
271        let is_lower = _mm_and_si128(ge_a, le_z);
272
273        // Subtract 32 from lowercase letters
274        let to_sub = _mm_and_si128(is_lower, diff);
275        let converted = _mm_sub_epi8(chunk, to_sub);
276
277        _mm_storeu_si128(result.as_mut_ptr().add(i) as *mut __m128i, converted);
278        i += 16;
279    }
280
281    // Process remaining bytes
282    while i < len {
283        let b = bytes[i];
284        result[i] = if b >= b'a' && b <= b'z' { b - 32 } else { b };
285        i += 1;
286    }
287
288    String::from_utf8_unchecked(result)
289}
290
291#[cfg(target_arch = "x86_64")]
292#[target_feature(enable = "avx2")]
293unsafe fn to_lowercase_avx2(bytes: &[u8]) -> String {
294    use std::arch::x86_64::*;
295
296    let len = bytes.len();
297    let mut result: Vec<u8> = Vec::with_capacity(len);
298    result.set_len(len);
299
300    let upper_a = _mm256_set1_epi8(b'A' as i8);
301    let upper_z = _mm256_set1_epi8(b'Z' as i8);
302    let diff = _mm256_set1_epi8(32);
303
304    let mut i = 0;
305
306    // Process 32 bytes at a time
307    while i + 32 <= len {
308        let chunk = _mm256_loadu_si256(bytes.as_ptr().add(i) as *const __m256i);
309
310        // Check if char is in range [A-Z]
311        let ge_a = _mm256_cmpgt_epi8(chunk, _mm256_sub_epi8(upper_a, _mm256_set1_epi8(1)));
312        let le_z = _mm256_cmpgt_epi8(_mm256_add_epi8(upper_z, _mm256_set1_epi8(1)), chunk);
313        let is_upper = _mm256_and_si256(ge_a, le_z);
314
315        // Add 32 to uppercase letters
316        let to_add = _mm256_and_si256(is_upper, diff);
317        let converted = _mm256_add_epi8(chunk, to_add);
318
319        _mm256_storeu_si256(result.as_mut_ptr().add(i) as *mut __m256i, converted);
320        i += 32;
321    }
322
323    // Process remaining bytes
324    while i < len {
325        let b = bytes[i];
326        result[i] = if b >= b'A' && b <= b'Z' { b + 32 } else { b };
327        i += 1;
328    }
329
330    String::from_utf8_unchecked(result)
331}
332
333#[cfg(target_arch = "x86_64")]
334#[target_feature(enable = "sse2")]
335unsafe fn to_lowercase_sse2(bytes: &[u8]) -> String {
336    use std::arch::x86_64::*;
337
338    let len = bytes.len();
339    let mut result: Vec<u8> = Vec::with_capacity(len);
340    result.set_len(len);
341
342    let upper_a = _mm_set1_epi8(b'A' as i8);
343    let upper_z = _mm_set1_epi8(b'Z' as i8);
344    let diff = _mm_set1_epi8(32);
345
346    let mut i = 0;
347
348    // Process 16 bytes at a time
349    while i + 16 <= len {
350        let chunk = _mm_loadu_si128(bytes.as_ptr().add(i) as *const __m128i);
351
352        // Check if char is in range [A-Z]
353        let ge_a = _mm_cmpgt_epi8(chunk, _mm_sub_epi8(upper_a, _mm_set1_epi8(1)));
354        let le_z = _mm_cmpgt_epi8(_mm_add_epi8(upper_z, _mm_set1_epi8(1)), chunk);
355        let is_upper = _mm_and_si128(ge_a, le_z);
356
357        // Add 32 to uppercase letters
358        let to_add = _mm_and_si128(is_upper, diff);
359        let converted = _mm_add_epi8(chunk, to_add);
360
361        _mm_storeu_si128(result.as_mut_ptr().add(i) as *mut __m128i, converted);
362        i += 16;
363    }
364
365    // Process remaining bytes
366    while i < len {
367        let b = bytes[i];
368        result[i] = if b >= b'A' && b <= b'Z' { b + 32 } else { b };
369        i += 1;
370    }
371
372    String::from_utf8_unchecked(result)
373}
374
375// ============================================================================
376// Character Classification (Batch Operations)
377// ============================================================================
378
379/// Count ASCII digits in a string using SIMD
380pub fn count_digits_simd(s: &str) -> usize {
381    let bytes = s.as_bytes();
382
383    #[cfg(target_arch = "x86_64")]
384    {
385        if has_avx2() && bytes.len() >= 32 {
386            return unsafe { count_digits_avx2(bytes) };
387        }
388        if has_sse2() && bytes.len() >= 16 {
389            return unsafe { count_digits_sse2(bytes) };
390        }
391    }
392
393    // Scalar fallback
394    bytes.iter().filter(|&&b| b >= b'0' && b <= b'9').count()
395}
396
397/// Count ASCII alphabetic characters using SIMD
398pub fn count_alpha_simd(s: &str) -> usize {
399    let bytes = s.as_bytes();
400
401    #[cfg(target_arch = "x86_64")]
402    {
403        if has_avx2() && bytes.len() >= 32 {
404            return unsafe { count_alpha_avx2(bytes) };
405        }
406        if has_sse2() && bytes.len() >= 16 {
407            return unsafe { count_alpha_sse2(bytes) };
408        }
409    }
410
411    // Scalar fallback
412    bytes
413        .iter()
414        .filter(|&&b| (b >= b'a' && b <= b'z') || (b >= b'A' && b <= b'Z'))
415        .count()
416}
417
418/// Count whitespace characters using SIMD
419pub fn count_whitespace_simd(s: &str) -> usize {
420    let bytes = s.as_bytes();
421
422    #[cfg(target_arch = "x86_64")]
423    {
424        if has_avx2() && bytes.len() >= 32 {
425            return unsafe { count_whitespace_avx2(bytes) };
426        }
427        if has_sse2() && bytes.len() >= 16 {
428            return unsafe { count_whitespace_sse2(bytes) };
429        }
430    }
431
432    // Scalar fallback
433    bytes
434        .iter()
435        .filter(|&&b| b == b' ' || b == b'\t' || b == b'\n' || b == b'\r')
436        .count()
437}
438
439#[cfg(target_arch = "x86_64")]
440#[target_feature(enable = "avx2")]
441unsafe fn count_digits_avx2(bytes: &[u8]) -> usize {
442    use std::arch::x86_64::*;
443
444    let len = bytes.len();
445    let mut count = 0usize;
446    let mut i = 0;
447
448    let digit_0 = _mm256_set1_epi8(b'0' as i8);
449    let digit_9 = _mm256_set1_epi8(b'9' as i8);
450
451    // Process 32 bytes at a time
452    while i + 32 <= len {
453        let chunk = _mm256_loadu_si256(bytes.as_ptr().add(i) as *const __m256i);
454
455        // Check range [0-9]
456        let ge_0 = _mm256_cmpgt_epi8(chunk, _mm256_sub_epi8(digit_0, _mm256_set1_epi8(1)));
457        let le_9 = _mm256_cmpgt_epi8(_mm256_add_epi8(digit_9, _mm256_set1_epi8(1)), chunk);
458        let is_digit = _mm256_and_si256(ge_0, le_9);
459
460        // Count set bits
461        let mask = _mm256_movemask_epi8(is_digit) as u32;
462        count += mask.count_ones() as usize;
463        i += 32;
464    }
465
466    // Process remaining bytes
467    while i < len {
468        if bytes[i] >= b'0' && bytes[i] <= b'9' {
469            count += 1;
470        }
471        i += 1;
472    }
473
474    count
475}
476
477#[cfg(target_arch = "x86_64")]
478#[target_feature(enable = "sse2")]
479unsafe fn count_digits_sse2(bytes: &[u8]) -> usize {
480    use std::arch::x86_64::*;
481
482    let len = bytes.len();
483    let mut count = 0usize;
484    let mut i = 0;
485
486    let digit_0 = _mm_set1_epi8(b'0' as i8);
487    let digit_9 = _mm_set1_epi8(b'9' as i8);
488
489    // Process 16 bytes at a time
490    while i + 16 <= len {
491        let chunk = _mm_loadu_si128(bytes.as_ptr().add(i) as *const __m128i);
492
493        // Check range [0-9]
494        let ge_0 = _mm_cmpgt_epi8(chunk, _mm_sub_epi8(digit_0, _mm_set1_epi8(1)));
495        let le_9 = _mm_cmpgt_epi8(_mm_add_epi8(digit_9, _mm_set1_epi8(1)), chunk);
496        let is_digit = _mm_and_si128(ge_0, le_9);
497
498        // Count set bits
499        let mask = _mm_movemask_epi8(is_digit) as u32;
500        count += mask.count_ones() as usize;
501        i += 16;
502    }
503
504    // Process remaining bytes
505    while i < len {
506        if bytes[i] >= b'0' && bytes[i] <= b'9' {
507            count += 1;
508        }
509        i += 1;
510    }
511
512    count
513}
514
515#[cfg(target_arch = "x86_64")]
516#[target_feature(enable = "avx2")]
517unsafe fn count_alpha_avx2(bytes: &[u8]) -> usize {
518    use std::arch::x86_64::*;
519
520    let len = bytes.len();
521    let mut count = 0usize;
522    let mut i = 0;
523
524    let lower_a = _mm256_set1_epi8(b'a' as i8);
525    let lower_z = _mm256_set1_epi8(b'z' as i8);
526    let upper_a = _mm256_set1_epi8(b'A' as i8);
527    let upper_z = _mm256_set1_epi8(b'Z' as i8);
528    let one = _mm256_set1_epi8(1);
529
530    while i + 32 <= len {
531        let chunk = _mm256_loadu_si256(bytes.as_ptr().add(i) as *const __m256i);
532
533        // Check lowercase [a-z]
534        let ge_a = _mm256_cmpgt_epi8(chunk, _mm256_sub_epi8(lower_a, one));
535        let le_z = _mm256_cmpgt_epi8(_mm256_add_epi8(lower_z, one), chunk);
536        let is_lower = _mm256_and_si256(ge_a, le_z);
537
538        // Check uppercase [A-Z]
539        let ge_upper = _mm256_cmpgt_epi8(chunk, _mm256_sub_epi8(upper_a, one));
540        let le_upper = _mm256_cmpgt_epi8(_mm256_add_epi8(upper_z, one), chunk);
541        let is_upper = _mm256_and_si256(ge_upper, le_upper);
542
543        // Combine
544        let is_alpha = _mm256_or_si256(is_lower, is_upper);
545        let mask = _mm256_movemask_epi8(is_alpha) as u32;
546        count += mask.count_ones() as usize;
547        i += 32;
548    }
549
550    // Scalar remainder
551    while i < len {
552        let b = bytes[i];
553        if (b >= b'a' && b <= b'z') || (b >= b'A' && b <= b'Z') {
554            count += 1;
555        }
556        i += 1;
557    }
558
559    count
560}
561
562#[cfg(target_arch = "x86_64")]
563#[target_feature(enable = "sse2")]
564unsafe fn count_alpha_sse2(bytes: &[u8]) -> usize {
565    use std::arch::x86_64::*;
566
567    let len = bytes.len();
568    let mut count = 0usize;
569    let mut i = 0;
570
571    let lower_a = _mm_set1_epi8(b'a' as i8);
572    let lower_z = _mm_set1_epi8(b'z' as i8);
573    let upper_a = _mm_set1_epi8(b'A' as i8);
574    let upper_z = _mm_set1_epi8(b'Z' as i8);
575    let one = _mm_set1_epi8(1);
576
577    while i + 16 <= len {
578        let chunk = _mm_loadu_si128(bytes.as_ptr().add(i) as *const __m128i);
579
580        // Check lowercase [a-z]
581        let ge_a = _mm_cmpgt_epi8(chunk, _mm_sub_epi8(lower_a, one));
582        let le_z = _mm_cmpgt_epi8(_mm_add_epi8(lower_z, one), chunk);
583        let is_lower = _mm_and_si128(ge_a, le_z);
584
585        // Check uppercase [A-Z]
586        let ge_upper = _mm_cmpgt_epi8(chunk, _mm_sub_epi8(upper_a, one));
587        let le_upper = _mm_cmpgt_epi8(_mm_add_epi8(upper_z, one), chunk);
588        let is_upper = _mm_and_si128(ge_upper, le_upper);
589
590        // Combine
591        let is_alpha = _mm_or_si128(is_lower, is_upper);
592        let mask = _mm_movemask_epi8(is_alpha) as u32;
593        count += mask.count_ones() as usize;
594        i += 16;
595    }
596
597    // Scalar remainder
598    while i < len {
599        let b = bytes[i];
600        if (b >= b'a' && b <= b'z') || (b >= b'A' && b <= b'Z') {
601            count += 1;
602        }
603        i += 1;
604    }
605
606    count
607}
608
609#[cfg(target_arch = "x86_64")]
610#[target_feature(enable = "avx2")]
611unsafe fn count_whitespace_avx2(bytes: &[u8]) -> usize {
612    use std::arch::x86_64::*;
613
614    let len = bytes.len();
615    let mut count = 0usize;
616    let mut i = 0;
617
618    let space = _mm256_set1_epi8(b' ' as i8);
619    let tab = _mm256_set1_epi8(b'\t' as i8);
620    let newline = _mm256_set1_epi8(b'\n' as i8);
621    let cr = _mm256_set1_epi8(b'\r' as i8);
622
623    while i + 32 <= len {
624        let chunk = _mm256_loadu_si256(bytes.as_ptr().add(i) as *const __m256i);
625
626        let is_space = _mm256_cmpeq_epi8(chunk, space);
627        let is_tab = _mm256_cmpeq_epi8(chunk, tab);
628        let is_newline = _mm256_cmpeq_epi8(chunk, newline);
629        let is_cr = _mm256_cmpeq_epi8(chunk, cr);
630
631        let is_ws = _mm256_or_si256(
632            _mm256_or_si256(is_space, is_tab),
633            _mm256_or_si256(is_newline, is_cr),
634        );
635
636        let mask = _mm256_movemask_epi8(is_ws) as u32;
637        count += mask.count_ones() as usize;
638        i += 32;
639    }
640
641    // Scalar remainder
642    while i < len {
643        let b = bytes[i];
644        if b == b' ' || b == b'\t' || b == b'\n' || b == b'\r' {
645            count += 1;
646        }
647        i += 1;
648    }
649
650    count
651}
652
653#[cfg(target_arch = "x86_64")]
654#[target_feature(enable = "sse2")]
655unsafe fn count_whitespace_sse2(bytes: &[u8]) -> usize {
656    use std::arch::x86_64::*;
657
658    let len = bytes.len();
659    let mut count = 0usize;
660    let mut i = 0;
661
662    let space = _mm_set1_epi8(b' ' as i8);
663    let tab = _mm_set1_epi8(b'\t' as i8);
664    let newline = _mm_set1_epi8(b'\n' as i8);
665    let cr = _mm_set1_epi8(b'\r' as i8);
666
667    while i + 16 <= len {
668        let chunk = _mm_loadu_si128(bytes.as_ptr().add(i) as *const __m128i);
669
670        let is_space = _mm_cmpeq_epi8(chunk, space);
671        let is_tab = _mm_cmpeq_epi8(chunk, tab);
672        let is_newline = _mm_cmpeq_epi8(chunk, newline);
673        let is_cr = _mm_cmpeq_epi8(chunk, cr);
674
675        let is_ws = _mm_or_si128(
676            _mm_or_si128(is_space, is_tab),
677            _mm_or_si128(is_newline, is_cr),
678        );
679
680        let mask = _mm_movemask_epi8(is_ws) as u32;
681        count += mask.count_ones() as usize;
682        i += 16;
683    }
684
685    // Scalar remainder
686    while i < len {
687        let b = bytes[i];
688        if b == b' ' || b == b'\t' || b == b'\n' || b == b'\r' {
689            count += 1;
690        }
691        i += 1;
692    }
693
694    count
695}
696
697// ============================================================================
698// Pattern Matching
699// ============================================================================
700
701/// Find first occurrence of a single byte using SIMD
702pub fn find_byte_simd(haystack: &[u8], needle: u8) -> Option<usize> {
703    #[cfg(target_arch = "x86_64")]
704    {
705        if has_avx2() && haystack.len() >= 32 {
706            return unsafe { find_byte_avx2(haystack, needle) };
707        }
708        if has_sse2() && haystack.len() >= 16 {
709            return unsafe { find_byte_sse2(haystack, needle) };
710        }
711    }
712
713    // Scalar fallback
714    haystack.iter().position(|&b| b == needle)
715}
716
717/// Count occurrences of a single byte using SIMD
718pub fn count_byte_simd(haystack: &[u8], needle: u8) -> usize {
719    #[cfg(target_arch = "x86_64")]
720    {
721        if has_avx2() && haystack.len() >= 32 {
722            return unsafe { count_byte_avx2(haystack, needle) };
723        }
724        if has_sse2() && haystack.len() >= 16 {
725            return unsafe { count_byte_sse2(haystack, needle) };
726        }
727    }
728
729    // Scalar fallback
730    haystack.iter().filter(|&&b| b == needle).count()
731}
732
733#[cfg(target_arch = "x86_64")]
734#[target_feature(enable = "avx2")]
735unsafe fn find_byte_avx2(haystack: &[u8], needle: u8) -> Option<usize> {
736    use std::arch::x86_64::*;
737
738    let len = haystack.len();
739    let mut i = 0;
740    let needle_vec = _mm256_set1_epi8(needle as i8);
741
742    while i + 32 <= len {
743        let chunk = _mm256_loadu_si256(haystack.as_ptr().add(i) as *const __m256i);
744        let cmp = _mm256_cmpeq_epi8(chunk, needle_vec);
745        let mask = _mm256_movemask_epi8(cmp) as u32;
746
747        if mask != 0 {
748            return Some(i + mask.trailing_zeros() as usize);
749        }
750        i += 32;
751    }
752
753    // Scalar remainder
754    while i < len {
755        if haystack[i] == needle {
756            return Some(i);
757        }
758        i += 1;
759    }
760
761    None
762}
763
764#[cfg(target_arch = "x86_64")]
765#[target_feature(enable = "sse2")]
766unsafe fn find_byte_sse2(haystack: &[u8], needle: u8) -> Option<usize> {
767    use std::arch::x86_64::*;
768
769    let len = haystack.len();
770    let mut i = 0;
771    let needle_vec = _mm_set1_epi8(needle as i8);
772
773    while i + 16 <= len {
774        let chunk = _mm_loadu_si128(haystack.as_ptr().add(i) as *const __m128i);
775        let cmp = _mm_cmpeq_epi8(chunk, needle_vec);
776        let mask = _mm_movemask_epi8(cmp) as u32;
777
778        if mask != 0 {
779            return Some(i + mask.trailing_zeros() as usize);
780        }
781        i += 16;
782    }
783
784    // Scalar remainder
785    while i < len {
786        if haystack[i] == needle {
787            return Some(i);
788        }
789        i += 1;
790    }
791
792    None
793}
794
795#[cfg(target_arch = "x86_64")]
796#[target_feature(enable = "avx2")]
797unsafe fn count_byte_avx2(haystack: &[u8], needle: u8) -> usize {
798    use std::arch::x86_64::*;
799
800    let len = haystack.len();
801    let mut count = 0usize;
802    let mut i = 0;
803    let needle_vec = _mm256_set1_epi8(needle as i8);
804
805    while i + 32 <= len {
806        let chunk = _mm256_loadu_si256(haystack.as_ptr().add(i) as *const __m256i);
807        let cmp = _mm256_cmpeq_epi8(chunk, needle_vec);
808        let mask = _mm256_movemask_epi8(cmp) as u32;
809        count += mask.count_ones() as usize;
810        i += 32;
811    }
812
813    // Scalar remainder
814    while i < len {
815        if haystack[i] == needle {
816            count += 1;
817        }
818        i += 1;
819    }
820
821    count
822}
823
824#[cfg(target_arch = "x86_64")]
825#[target_feature(enable = "sse2")]
826unsafe fn count_byte_sse2(haystack: &[u8], needle: u8) -> usize {
827    use std::arch::x86_64::*;
828
829    let len = haystack.len();
830    let mut count = 0usize;
831    let mut i = 0;
832    let needle_vec = _mm_set1_epi8(needle as i8);
833
834    while i + 16 <= len {
835        let chunk = _mm_loadu_si128(haystack.as_ptr().add(i) as *const __m128i);
836        let cmp = _mm_cmpeq_epi8(chunk, needle_vec);
837        let mask = _mm_movemask_epi8(cmp) as u32;
838        count += mask.count_ones() as usize;
839        i += 16;
840    }
841
842    // Scalar remainder
843    while i < len {
844        if haystack[i] == needle {
845            count += 1;
846        }
847        i += 1;
848    }
849
850    count
851}
852
853// ============================================================================
854// Batch Operations for Series
855// ============================================================================
856
857/// Batch uppercase conversion for a vector of strings
858pub fn batch_uppercase(strings: &[String]) -> Vec<String> {
859    strings.iter().map(|s| to_uppercase_simd(s)).collect()
860}
861
862/// Batch lowercase conversion for a vector of strings
863pub fn batch_lowercase(strings: &[String]) -> Vec<String> {
864    strings.iter().map(|s| to_lowercase_simd(s)).collect()
865}
866
867/// Batch ASCII check for a vector of strings
868pub fn batch_is_ascii(strings: &[String]) -> Vec<bool> {
869    strings.iter().map(|s| is_ascii_simd(s)).collect()
870}
871
872/// Batch digit count for a vector of strings
873pub fn batch_count_digits(strings: &[String]) -> Vec<usize> {
874    strings.iter().map(|s| count_digits_simd(s)).collect()
875}
876
877/// Batch alphabetic count for a vector of strings
878pub fn batch_count_alpha(strings: &[String]) -> Vec<usize> {
879    strings.iter().map(|s| count_alpha_simd(s)).collect()
880}
881
882/// Batch whitespace count for a vector of strings
883pub fn batch_count_whitespace(strings: &[String]) -> Vec<usize> {
884    strings.iter().map(|s| count_whitespace_simd(s)).collect()
885}
886
887// ============================================================================
888// Parallel Batch Operations (using Rayon)
889// ============================================================================
890
891/// Parallel batch uppercase conversion
892pub fn parallel_batch_uppercase(strings: &[String]) -> Vec<String> {
893    use rayon::prelude::*;
894    strings.par_iter().map(|s| to_uppercase_simd(s)).collect()
895}
896
897/// Parallel batch lowercase conversion
898pub fn parallel_batch_lowercase(strings: &[String]) -> Vec<String> {
899    use rayon::prelude::*;
900    strings.par_iter().map(|s| to_lowercase_simd(s)).collect()
901}
902
903/// Parallel batch ASCII check
904pub fn parallel_batch_is_ascii(strings: &[String]) -> Vec<bool> {
905    use rayon::prelude::*;
906    strings.par_iter().map(|s| is_ascii_simd(s)).collect()
907}
908
909// ============================================================================
910// Statistics
911// ============================================================================
912
913/// SIMD string operation statistics
914#[derive(Debug, Clone, Default)]
915pub struct SimdStringStats {
916    /// Whether AVX2 is available
917    pub avx2_available: bool,
918    /// Whether SSE2 is available
919    pub sse2_available: bool,
920    /// Number of strings processed with SIMD
921    pub simd_operations: u64,
922    /// Number of strings processed with scalar fallback
923    pub scalar_operations: u64,
924}
925
926impl SimdStringStats {
927    /// Create new stats with current feature detection
928    pub fn new() -> Self {
929        detect_features();
930        Self {
931            avx2_available: AVX2_SUPPORTED.load(Ordering::Relaxed),
932            sse2_available: SSE2_SUPPORTED.load(Ordering::Relaxed),
933            simd_operations: 0,
934            scalar_operations: 0,
935        }
936    }
937
938    /// Get the best available SIMD level
939    pub fn simd_level(&self) -> &'static str {
940        if self.avx2_available {
941            "AVX2 (256-bit)"
942        } else if self.sse2_available {
943            "SSE2 (128-bit)"
944        } else {
945            "Scalar (no SIMD)"
946        }
947    }
948}
949
950// ============================================================================
951// Tests
952// ============================================================================
953
954#[cfg(test)]
955mod tests {
956    use super::*;
957
958    #[test]
959    fn test_is_ascii_simd() {
960        assert!(is_ascii_simd("Hello, World!"));
961        assert!(is_ascii_simd("12345"));
962        assert!(is_ascii_simd(""));
963        assert!(!is_ascii_simd("Hello, 世界!"));
964        assert!(!is_ascii_simd("café"));
965
966        // Long string test
967        let long_ascii = "a".repeat(1000);
968        assert!(is_ascii_simd(&long_ascii));
969
970        let long_mixed = format!("{}世界", "a".repeat(100));
971        assert!(!is_ascii_simd(&long_mixed));
972    }
973
974    #[test]
975    fn test_to_uppercase_simd() {
976        assert_eq!(to_uppercase_simd("hello"), "HELLO");
977        assert_eq!(to_uppercase_simd("Hello World"), "HELLO WORLD");
978        assert_eq!(to_uppercase_simd("123abc"), "123ABC");
979        assert_eq!(to_uppercase_simd(""), "");
980
981        // Long string test
982        let long = "hello ".repeat(100);
983        let expected = "HELLO ".repeat(100);
984        assert_eq!(to_uppercase_simd(&long), expected);
985    }
986
987    #[test]
988    fn test_to_lowercase_simd() {
989        assert_eq!(to_lowercase_simd("HELLO"), "hello");
990        assert_eq!(to_lowercase_simd("Hello World"), "hello world");
991        assert_eq!(to_lowercase_simd("123ABC"), "123abc");
992        assert_eq!(to_lowercase_simd(""), "");
993
994        // Long string test
995        let long = "HELLO ".repeat(100);
996        let expected = "hello ".repeat(100);
997        assert_eq!(to_lowercase_simd(&long), expected);
998    }
999
1000    #[test]
1001    fn test_count_digits_simd() {
1002        assert_eq!(count_digits_simd("abc123def456"), 6);
1003        assert_eq!(count_digits_simd("no digits"), 0);
1004        assert_eq!(count_digits_simd("12345678901234567890"), 20);
1005        assert_eq!(count_digits_simd(""), 0);
1006
1007        // Long string test
1008        let long = "a1b2c3d4e5".repeat(50);
1009        assert_eq!(count_digits_simd(&long), 250);
1010    }
1011
1012    #[test]
1013    fn test_count_alpha_simd() {
1014        assert_eq!(count_alpha_simd("abc123DEF"), 6);
1015        assert_eq!(count_alpha_simd("12345"), 0);
1016        assert_eq!(count_alpha_simd("AbCdEfGh"), 8);
1017        assert_eq!(count_alpha_simd(""), 0);
1018
1019        // Long string test
1020        let long = "a1b2c3".repeat(100);
1021        assert_eq!(count_alpha_simd(&long), 300);
1022    }
1023
1024    #[test]
1025    fn test_count_whitespace_simd() {
1026        assert_eq!(count_whitespace_simd("hello world"), 1);
1027        assert_eq!(count_whitespace_simd("a\tb\nc\rd"), 3);
1028        assert_eq!(count_whitespace_simd("no_whitespace"), 0);
1029        // "   \t\n\r   " = 3 spaces + tab + newline + cr + 3 spaces = 9
1030        assert_eq!(count_whitespace_simd("   \t\n\r   "), 9);
1031
1032        // Long string test
1033        let long = "a b ".repeat(100);
1034        assert_eq!(count_whitespace_simd(&long), 200);
1035    }
1036
1037    #[test]
1038    fn test_find_byte_simd() {
1039        assert_eq!(find_byte_simd(b"hello world", b'o'), Some(4));
1040        assert_eq!(find_byte_simd(b"hello world", b'x'), None);
1041        assert_eq!(find_byte_simd(b"", b'a'), None);
1042
1043        // Long string test
1044        let long = "a".repeat(100) + "b";
1045        assert_eq!(find_byte_simd(long.as_bytes(), b'b'), Some(100));
1046    }
1047
1048    #[test]
1049    fn test_count_byte_simd() {
1050        assert_eq!(count_byte_simd(b"hello world", b'o'), 2);
1051        assert_eq!(count_byte_simd(b"hello world", b'l'), 3);
1052        assert_eq!(count_byte_simd(b"hello world", b'x'), 0);
1053
1054        // Long string test
1055        let long = "aba".repeat(100);
1056        assert_eq!(count_byte_simd(long.as_bytes(), b'a'), 200);
1057    }
1058
1059    #[test]
1060    fn test_batch_operations() {
1061        let strings = vec![
1062            "hello".to_string(),
1063            "WORLD".to_string(),
1064            "Test123".to_string(),
1065        ];
1066
1067        let upper = batch_uppercase(&strings);
1068        assert_eq!(upper, vec!["HELLO", "WORLD", "TEST123"]);
1069
1070        let lower = batch_lowercase(&strings);
1071        assert_eq!(lower, vec!["hello", "world", "test123"]);
1072
1073        let ascii = batch_is_ascii(&strings);
1074        assert_eq!(ascii, vec![true, true, true]);
1075
1076        let digits = batch_count_digits(&strings);
1077        assert_eq!(digits, vec![0, 0, 3]);
1078    }
1079
1080    #[test]
1081    fn test_simd_stats() {
1082        let stats = SimdStringStats::new();
1083        println!("SIMD Level: {}", stats.simd_level());
1084        println!("AVX2: {}", stats.avx2_available);
1085        println!("SSE2: {}", stats.sse2_available);
1086
1087        // Should at least have scalar
1088        assert!(
1089            stats.avx2_available
1090                || stats.sse2_available
1091                || stats.simd_level() == "Scalar (no SIMD)"
1092        );
1093    }
1094
1095    #[test]
1096    fn test_non_ascii_fallback() {
1097        // Test that non-ASCII strings fall back gracefully
1098        assert_eq!(to_uppercase_simd("café"), "CAFÉ");
1099        assert_eq!(to_lowercase_simd("CAFÉ"), "café");
1100        assert_eq!(to_uppercase_simd("日本語"), "日本語");
1101    }
1102}