Skip to main content

fuzzy_regex/engine/
simd_class.rs

1// Suppress pedantic lints for SIMD code
2#![allow(clippy::wildcard_imports)]
3
4//! SIMD-accelerated character class matching.
5//!
6//! This module provides fast character class membership testing using:
7//! 1. 128-bit ASCII bitmap for O(1) single-character lookups
8//! 2. SIMD vectorized scanning for finding matches in byte slices
9
10use crate::ir::HirClass;
11use crate::parser::ast::{CharClass, CharClassItem, NamedClass};
12
13/// A 128-bit bitmap for fast ASCII character class membership testing.
14/// Each bit represents whether the corresponding ASCII byte (0-127) is in the class.
15#[derive(Clone, Copy, Debug)]
16pub struct AsciiClassBitmap {
17    /// Lower 64 bits (bytes 0-63).
18    lo: u64,
19    /// Upper 64 bits (bytes 64-127).
20    hi: u64,
21    /// Whether this is a negated class.
22    negated: bool,
23    /// Whether this class matches non-ASCII characters.
24    matches_non_ascii: bool,
25}
26
27impl AsciiClassBitmap {
28    /// Create an empty bitmap (matches nothing).
29    #[must_use]
30    pub fn empty() -> Self {
31        AsciiClassBitmap {
32            lo: 0,
33            hi: 0,
34            negated: false,
35            matches_non_ascii: false,
36        }
37    }
38
39    /// Create a bitmap that matches all ASCII characters.
40    #[must_use]
41    pub fn all_ascii() -> Self {
42        AsciiClassBitmap {
43            lo: u64::MAX,
44            hi: u64::MAX,
45            negated: false,
46            matches_non_ascii: false,
47        }
48    }
49
50    /// Create a bitmap from an AST `CharClass`.
51    #[must_use]
52    pub fn from_char_class(class: &CharClass) -> Self {
53        let mut bitmap = AsciiClassBitmap::empty();
54        bitmap.negated = class.negated;
55
56        for item in &class.items {
57            match item {
58                CharClassItem::Single(ch) => {
59                    if ch.is_ascii() {
60                        bitmap.set(*ch as u8);
61                    } else {
62                        bitmap.matches_non_ascii = true;
63                    }
64                }
65                CharClassItem::Range(start, end) => {
66                    let start_byte = if start.is_ascii() { *start as u8 } else { 128 };
67                    let end_byte = if end.is_ascii() { *end as u8 } else { 127 };
68
69                    for b in start_byte..=end_byte.min(127) {
70                        bitmap.set(b);
71                    }
72                    // Check if range extends into non-ASCII
73                    if *end as u32 > 127 {
74                        bitmap.matches_non_ascii = true;
75                    }
76                }
77                CharClassItem::Named(named) => {
78                    bitmap.add_named_class(*named);
79                }
80            }
81        }
82
83        bitmap
84    }
85
86    /// Create a bitmap from an IR `HirClass`.
87    #[must_use]
88    pub fn from_hir_class(class: &HirClass) -> Self {
89        let mut bitmap = AsciiClassBitmap::empty();
90        bitmap.negated = class.negated;
91
92        // Add single characters
93        for &ch in &class.chars {
94            if ch.is_ascii() {
95                bitmap.set(ch as u8);
96            } else {
97                bitmap.matches_non_ascii = true;
98            }
99        }
100
101        // Add ranges
102        for &(start, end) in &class.ranges {
103            let start_byte = if start.is_ascii() { start as u8 } else { 128 };
104            let end_byte = if end.is_ascii() { end as u8 } else { 127 };
105
106            for b in start_byte..=end_byte.min(127) {
107                bitmap.set(b);
108            }
109            // Check if range extends into non-ASCII
110            if end as u32 > 127 {
111                bitmap.matches_non_ascii = true;
112            }
113        }
114
115        // Add named classes
116        for &named in &class.named {
117            bitmap.add_named_class(named);
118        }
119
120        bitmap
121    }
122
123    /// Add a named class to the bitmap.
124    fn add_named_class(&mut self, class: NamedClass) {
125        match class {
126            NamedClass::Digit => {
127                for b in b'0'..=b'9' {
128                    self.set(b);
129                }
130            }
131            NamedClass::NotDigit => {
132                // Set all except digits
133                for b in 0u8..=127 {
134                    if !b.is_ascii_digit() {
135                        self.set(b);
136                    }
137                }
138                self.matches_non_ascii = true;
139            }
140            NamedClass::Word => {
141                for b in b'a'..=b'z' {
142                    self.set(b);
143                }
144                for b in b'A'..=b'Z' {
145                    self.set(b);
146                }
147                for b in b'0'..=b'9' {
148                    self.set(b);
149                }
150                self.set(b'_');
151            }
152            NamedClass::NotWord => {
153                for b in 0u8..=127 {
154                    let is_word = b.is_ascii_lowercase()
155                        || b.is_ascii_uppercase()
156                        || b.is_ascii_digit()
157                        || b == b'_';
158                    if !is_word {
159                        self.set(b);
160                    }
161                }
162                self.matches_non_ascii = true;
163            }
164            NamedClass::Whitespace => {
165                self.set(b' ');
166                self.set(b'\t');
167                self.set(b'\n');
168                self.set(b'\r');
169                self.set(0x0C); // form feed
170                self.set(0x0B); // vertical tab
171            }
172            NamedClass::NotWhitespace => {
173                for b in 0u8..=127 {
174                    if !matches!(b, b' ' | b'\t' | b'\n' | b'\r' | 0x0C | 0x0B) {
175                        self.set(b);
176                    }
177                }
178                self.matches_non_ascii = true;
179            }
180            NamedClass::Any | NamedClass::AnyExceptNewline => {
181                // Set all ASCII
182                self.lo = u64::MAX;
183                self.hi = u64::MAX;
184                if matches!(class, NamedClass::AnyExceptNewline) {
185                    self.clear(b'\n');
186                    self.clear(b'\r');
187                }
188                self.matches_non_ascii = true;
189            }
190        }
191    }
192
193    /// Set a bit for the given ASCII byte.
194    #[inline]
195    fn set(&mut self, byte: u8) {
196        if byte < 64 {
197            self.lo |= 1u64 << byte;
198        } else if byte < 128 {
199            self.hi |= 1u64 << (byte - 64);
200        }
201    }
202
203    /// Clear a bit for the given ASCII byte.
204    #[inline]
205    fn clear(&mut self, byte: u8) {
206        if byte < 64 {
207            self.lo &= !(1u64 << byte);
208        } else if byte < 128 {
209            self.hi &= !(1u64 << (byte - 64));
210        }
211    }
212
213    /// Check if a byte is in the class.
214    #[inline]
215    #[must_use]
216    pub fn contains(&self, byte: u8) -> bool {
217        let in_bitmap = if byte < 64 {
218            (self.lo & (1u64 << byte)) != 0
219        } else if byte < 128 {
220            (self.hi & (1u64 << (byte - 64))) != 0
221        } else {
222            self.matches_non_ascii
223        };
224
225        if self.negated { !in_bitmap } else { in_bitmap }
226    }
227
228    /// Check if a character is in the class.
229    #[inline]
230    #[must_use]
231    pub fn contains_char(&self, ch: char) -> bool {
232        if ch.is_ascii() {
233            self.contains(ch as u8)
234        } else {
235            let in_class = self.matches_non_ascii;
236            if self.negated { !in_class } else { in_class }
237        }
238    }
239
240    /// Find the first position in the slice where any byte matches the class.
241    /// Returns None if no match is found.
242    #[must_use]
243    pub fn find_first(&self, haystack: &[u8]) -> Option<usize> {
244        // Use SIMD-optimized path for longer slices
245        #[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
246        {
247            if haystack.len() >= 16 {
248                return self.find_first_simd(haystack);
249            }
250        }
251
252        #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
253        {
254            if haystack.len() >= 16 {
255                return self.find_first_simd(haystack);
256            }
257        }
258
259        // Scalar fallback
260        self.find_first_scalar(haystack)
261    }
262
263    /// Scalar implementation of `find_first`.
264    #[inline]
265    fn find_first_scalar(&self, haystack: &[u8]) -> Option<usize> {
266        for (i, &byte) in haystack.iter().enumerate() {
267            if self.contains(byte) {
268                return Some(i);
269            }
270        }
271        None
272    }
273
274    /// SIMD implementation for `x86_64` with SSE2.
275    #[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
276    fn find_first_simd(&self, haystack: &[u8]) -> Option<usize> {
277        use std::arch::x86_64::*;
278
279        // For negated classes or classes matching non-ASCII, fall back to scalar
280        // (SIMD path handles simpler cases more efficiently)
281        if self.negated || self.matches_non_ascii {
282            return self.find_first_scalar(haystack);
283        }
284
285        let len = haystack.len();
286        let mut i = 0;
287
288        // Process 16 bytes at a time
289        unsafe {
290            while i + 16 <= len {
291                let chunk = _mm_loadu_si128(haystack.as_ptr().add(i).cast::<__m128i>());
292
293                // Check each byte against the bitmap using lookup
294                // We use a different strategy: check if any byte is in our set
295                // by building a mask of matching positions
296                let mut mask = 0u16;
297
298                // Extract bytes and check individually (SSE2 doesn't have good gather)
299                let bytes: [u8; 16] = std::mem::transmute(chunk);
300                for (j, &b) in bytes.iter().enumerate() {
301                    if self.contains(b) {
302                        mask |= 1 << j;
303                    }
304                }
305
306                if mask != 0 {
307                    return Some(i + mask.trailing_zeros() as usize);
308                }
309
310                i += 16;
311            }
312        }
313
314        // Handle remaining bytes
315        (i..len).find(|&j| self.contains(haystack[j]))
316    }
317
318    /// SIMD implementation for aarch64 with NEON.
319    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
320    fn find_first_simd(&self, haystack: &[u8]) -> Option<usize> {
321        // For negated classes or classes matching non-ASCII, fall back to scalar
322        if self.negated || self.matches_non_ascii {
323            return self.find_first_scalar(haystack);
324        }
325
326        let len = haystack.len();
327        let mut i = 0;
328
329        // Process 16 bytes at a time
330        unsafe {
331            use std::arch::aarch64::*;
332
333            while i + 16 <= len {
334                let chunk = vld1q_u8(haystack.as_ptr().add(i));
335
336                // Check each byte against the bitmap
337                let bytes: [u8; 16] = std::mem::transmute(chunk);
338                for (j, &b) in bytes.iter().enumerate() {
339                    if self.contains(b) {
340                        return Some(i + j);
341                    }
342                }
343
344                i += 16;
345            }
346        }
347
348        // Handle remaining bytes
349        (i..len).find(|&j| self.contains(haystack[j]))
350    }
351
352    /// Find all positions where bytes match the class.
353    /// Returns a vector of indices.
354    #[must_use]
355    pub fn find_all(&self, haystack: &[u8]) -> Vec<usize> {
356        let mut results = Vec::new();
357        let mut pos = 0;
358
359        while pos < haystack.len() {
360            if let Some(offset) = self.find_first(&haystack[pos..]) {
361                results.push(pos + offset);
362                pos += offset + 1;
363            } else {
364                break;
365            }
366        }
367
368        results
369    }
370
371    /// Count how many bytes in the slice match the class.
372    #[must_use]
373    pub fn count_matches(&self, haystack: &[u8]) -> usize {
374        haystack.iter().filter(|&&b| self.contains(b)).count()
375    }
376
377    /// Check if the bitmap matches any byte in the slice.
378    #[inline]
379    #[must_use]
380    pub fn matches_any(&self, haystack: &[u8]) -> bool {
381        self.find_first(haystack).is_some()
382    }
383}
384
385impl Default for AsciiClassBitmap {
386    fn default() -> Self {
387        Self::empty()
388    }
389}
390
391/// A precompiled character class for fast matching.
392/// Combines bitmap for ASCII and handles non-ASCII via the original `CharClass`.
393#[derive(Clone, Debug)]
394pub struct CompiledCharClass {
395    /// Fast ASCII bitmap.
396    pub bitmap: AsciiClassBitmap,
397    /// Original char class for non-ASCII and complex cases.
398    pub original: CharClass,
399    /// Unicode mode - enable Unicode character classes.
400    pub unicode: bool,
401}
402
403impl CompiledCharClass {
404    /// Create a compiled character class.
405    #[must_use]
406    pub fn new(class: &CharClass) -> Self {
407        CompiledCharClass {
408            bitmap: AsciiClassBitmap::from_char_class(class),
409            original: class.clone(),
410            unicode: false,
411        }
412    }
413
414    /// Create a compiled character class with unicode mode.
415    #[must_use]
416    pub fn new_with_unicode(class: &CharClass, unicode: bool) -> Self {
417        CompiledCharClass {
418            bitmap: AsciiClassBitmap::from_char_class(class),
419            original: class.clone(),
420            unicode,
421        }
422    }
423
424    /// Check if a character matches this class.
425    #[inline]
426    #[must_use]
427    pub fn matches(&self, ch: char) -> bool {
428        if ch.is_ascii() {
429            self.bitmap.contains(ch as u8)
430        } else if self.unicode {
431            // In unicode mode, check if non-ASCII chars match via the original class
432            // which now has NamedClass with unicode-aware matching
433            self.original.matches_unicode(ch)
434        } else {
435            self.original.matches(ch)
436        }
437    }
438
439    /// Find the first position where any byte matches.
440    #[inline]
441    #[must_use]
442    pub fn find_first(&self, haystack: &[u8]) -> Option<usize> {
443        self.bitmap.find_first(haystack)
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450
451    #[test]
452    fn test_ascii_bitmap_single() {
453        let class = CharClass::new(false, vec![CharClassItem::Single('a')]);
454        let bitmap = AsciiClassBitmap::from_char_class(&class);
455
456        assert!(bitmap.contains(b'a'));
457        assert!(!bitmap.contains(b'b'));
458        assert!(!bitmap.contains(b'A'));
459    }
460
461    #[test]
462    fn test_ascii_bitmap_range() {
463        let class = CharClass::new(false, vec![CharClassItem::Range('a', 'z')]);
464        let bitmap = AsciiClassBitmap::from_char_class(&class);
465
466        assert!(bitmap.contains(b'a'));
467        assert!(bitmap.contains(b'm'));
468        assert!(bitmap.contains(b'z'));
469        assert!(!bitmap.contains(b'A'));
470        assert!(!bitmap.contains(b'0'));
471    }
472
473    #[test]
474    fn test_ascii_bitmap_negated() {
475        let class = CharClass::new(true, vec![CharClassItem::Range('a', 'z')]);
476        let bitmap = AsciiClassBitmap::from_char_class(&class);
477
478        assert!(!bitmap.contains(b'a'));
479        assert!(!bitmap.contains(b'z'));
480        assert!(bitmap.contains(b'A'));
481        assert!(bitmap.contains(b'0'));
482        assert!(bitmap.contains(b' '));
483    }
484
485    #[test]
486    fn test_ascii_bitmap_digit() {
487        let class = CharClass::digit();
488        let bitmap = AsciiClassBitmap::from_char_class(&class);
489
490        for b in b'0'..=b'9' {
491            assert!(bitmap.contains(b), "Should contain digit {}", b as char);
492        }
493        assert!(!bitmap.contains(b'a'));
494        assert!(!bitmap.contains(b' '));
495    }
496
497    #[test]
498    fn test_ascii_bitmap_word() {
499        let class = CharClass::word();
500        let bitmap = AsciiClassBitmap::from_char_class(&class);
501
502        assert!(bitmap.contains(b'a'));
503        assert!(bitmap.contains(b'Z'));
504        assert!(bitmap.contains(b'5'));
505        assert!(bitmap.contains(b'_'));
506        assert!(!bitmap.contains(b' '));
507        assert!(!bitmap.contains(b'-'));
508    }
509
510    #[test]
511    fn test_find_first() {
512        let class = CharClass::new(false, vec![CharClassItem::Range('a', 'z')]);
513        let bitmap = AsciiClassBitmap::from_char_class(&class);
514
515        assert_eq!(bitmap.find_first(b"123abc"), Some(3));
516        assert_eq!(bitmap.find_first(b"ABC"), None);
517        assert_eq!(bitmap.find_first(b"hello"), Some(0));
518        assert_eq!(bitmap.find_first(b""), None);
519    }
520
521    #[test]
522    fn test_find_first_long() {
523        let class = CharClass::new(false, vec![CharClassItem::Single('x')]);
524        let bitmap = AsciiClassBitmap::from_char_class(&class);
525
526        // Test with text longer than 16 bytes to exercise SIMD path
527        let text = b"0123456789abcdefxyz";
528        assert_eq!(bitmap.find_first(text), Some(16));
529
530        let text2 = b"01234567890123456789x";
531        assert_eq!(bitmap.find_first(text2), Some(20));
532    }
533
534    #[test]
535    fn test_find_all() {
536        let class = CharClass::new(false, vec![CharClassItem::Range('a', 'z')]);
537        let bitmap = AsciiClassBitmap::from_char_class(&class);
538
539        let positions = bitmap.find_all(b"a1b2c3");
540        assert_eq!(positions, vec![0, 2, 4]);
541    }
542
543    #[test]
544    fn test_count_matches() {
545        let class = CharClass::digit();
546        let bitmap = AsciiClassBitmap::from_char_class(&class);
547
548        assert_eq!(bitmap.count_matches(b"abc123def456"), 6);
549        assert_eq!(bitmap.count_matches(b"no digits"), 0);
550    }
551
552    #[test]
553    fn test_compiled_char_class() {
554        let class = CharClass::word();
555        let compiled = CompiledCharClass::new(&class);
556
557        assert!(compiled.matches('a'));
558        assert!(compiled.matches('Z'));
559        assert!(compiled.matches('5'));
560        assert!(!compiled.matches(' '));
561    }
562}