fuzzy_regex/engine/
simd_class.rs1#![allow(clippy::wildcard_imports)]
3
4use crate::ir::HirClass;
11use crate::parser::ast::{CharClass, CharClassItem, NamedClass};
12
13#[derive(Clone, Copy, Debug)]
16pub struct AsciiClassBitmap {
17 lo: u64,
19 hi: u64,
21 negated: bool,
23 matches_non_ascii: bool,
25}
26
27impl AsciiClassBitmap {
28 #[must_use]
30 pub fn empty() -> Self {
31 AsciiClassBitmap {
32 lo: 0,
33 hi: 0,
34 negated: false,
35 matches_non_ascii: false,
36 }
37 }
38
39 #[must_use]
41 pub fn all_ascii() -> Self {
42 AsciiClassBitmap {
43 lo: u64::MAX,
44 hi: u64::MAX,
45 negated: false,
46 matches_non_ascii: false,
47 }
48 }
49
50 #[must_use]
52 pub fn from_char_class(class: &CharClass) -> Self {
53 let mut bitmap = AsciiClassBitmap::empty();
54 bitmap.negated = class.negated;
55
56 for item in &class.items {
57 match item {
58 CharClassItem::Single(ch) => {
59 if ch.is_ascii() {
60 bitmap.set(*ch as u8);
61 } else {
62 bitmap.matches_non_ascii = true;
63 }
64 }
65 CharClassItem::Range(start, end) => {
66 let start_byte = if start.is_ascii() { *start as u8 } else { 128 };
67 let end_byte = if end.is_ascii() { *end as u8 } else { 127 };
68
69 for b in start_byte..=end_byte.min(127) {
70 bitmap.set(b);
71 }
72 if *end as u32 > 127 {
74 bitmap.matches_non_ascii = true;
75 }
76 }
77 CharClassItem::Named(named) => {
78 bitmap.add_named_class(*named);
79 }
80 }
81 }
82
83 bitmap
84 }
85
86 #[must_use]
88 pub fn from_hir_class(class: &HirClass) -> Self {
89 let mut bitmap = AsciiClassBitmap::empty();
90 bitmap.negated = class.negated;
91
92 for &ch in &class.chars {
94 if ch.is_ascii() {
95 bitmap.set(ch as u8);
96 } else {
97 bitmap.matches_non_ascii = true;
98 }
99 }
100
101 for &(start, end) in &class.ranges {
103 let start_byte = if start.is_ascii() { start as u8 } else { 128 };
104 let end_byte = if end.is_ascii() { end as u8 } else { 127 };
105
106 for b in start_byte..=end_byte.min(127) {
107 bitmap.set(b);
108 }
109 if end as u32 > 127 {
111 bitmap.matches_non_ascii = true;
112 }
113 }
114
115 for &named in &class.named {
117 bitmap.add_named_class(named);
118 }
119
120 bitmap
121 }
122
123 fn add_named_class(&mut self, class: NamedClass) {
125 match class {
126 NamedClass::Digit => {
127 for b in b'0'..=b'9' {
128 self.set(b);
129 }
130 }
131 NamedClass::NotDigit => {
132 for b in 0u8..=127 {
134 if !b.is_ascii_digit() {
135 self.set(b);
136 }
137 }
138 self.matches_non_ascii = true;
139 }
140 NamedClass::Word => {
141 for b in b'a'..=b'z' {
142 self.set(b);
143 }
144 for b in b'A'..=b'Z' {
145 self.set(b);
146 }
147 for b in b'0'..=b'9' {
148 self.set(b);
149 }
150 self.set(b'_');
151 }
152 NamedClass::NotWord => {
153 for b in 0u8..=127 {
154 let is_word = b.is_ascii_lowercase()
155 || b.is_ascii_uppercase()
156 || b.is_ascii_digit()
157 || b == b'_';
158 if !is_word {
159 self.set(b);
160 }
161 }
162 self.matches_non_ascii = true;
163 }
164 NamedClass::Whitespace => {
165 self.set(b' ');
166 self.set(b'\t');
167 self.set(b'\n');
168 self.set(b'\r');
169 self.set(0x0C); self.set(0x0B); }
172 NamedClass::NotWhitespace => {
173 for b in 0u8..=127 {
174 if !matches!(b, b' ' | b'\t' | b'\n' | b'\r' | 0x0C | 0x0B) {
175 self.set(b);
176 }
177 }
178 self.matches_non_ascii = true;
179 }
180 NamedClass::Any | NamedClass::AnyExceptNewline => {
181 self.lo = u64::MAX;
183 self.hi = u64::MAX;
184 if matches!(class, NamedClass::AnyExceptNewline) {
185 self.clear(b'\n');
186 self.clear(b'\r');
187 }
188 self.matches_non_ascii = true;
189 }
190 }
191 }
192
193 #[inline]
195 fn set(&mut self, byte: u8) {
196 if byte < 64 {
197 self.lo |= 1u64 << byte;
198 } else if byte < 128 {
199 self.hi |= 1u64 << (byte - 64);
200 }
201 }
202
203 #[inline]
205 fn clear(&mut self, byte: u8) {
206 if byte < 64 {
207 self.lo &= !(1u64 << byte);
208 } else if byte < 128 {
209 self.hi &= !(1u64 << (byte - 64));
210 }
211 }
212
213 #[inline]
215 #[must_use]
216 pub fn contains(&self, byte: u8) -> bool {
217 let in_bitmap = if byte < 64 {
218 (self.lo & (1u64 << byte)) != 0
219 } else if byte < 128 {
220 (self.hi & (1u64 << (byte - 64))) != 0
221 } else {
222 self.matches_non_ascii
223 };
224
225 if self.negated { !in_bitmap } else { in_bitmap }
226 }
227
228 #[inline]
230 #[must_use]
231 pub fn contains_char(&self, ch: char) -> bool {
232 if ch.is_ascii() {
233 self.contains(ch as u8)
234 } else {
235 let in_class = self.matches_non_ascii;
236 if self.negated { !in_class } else { in_class }
237 }
238 }
239
240 #[must_use]
243 pub fn find_first(&self, haystack: &[u8]) -> Option<usize> {
244 #[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
246 {
247 if haystack.len() >= 16 {
248 return self.find_first_simd(haystack);
249 }
250 }
251
252 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
253 {
254 if haystack.len() >= 16 {
255 return self.find_first_simd(haystack);
256 }
257 }
258
259 self.find_first_scalar(haystack)
261 }
262
263 #[inline]
265 fn find_first_scalar(&self, haystack: &[u8]) -> Option<usize> {
266 for (i, &byte) in haystack.iter().enumerate() {
267 if self.contains(byte) {
268 return Some(i);
269 }
270 }
271 None
272 }
273
274 #[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
276 fn find_first_simd(&self, haystack: &[u8]) -> Option<usize> {
277 use std::arch::x86_64::*;
278
279 if self.negated || self.matches_non_ascii {
282 return self.find_first_scalar(haystack);
283 }
284
285 let len = haystack.len();
286 let mut i = 0;
287
288 unsafe {
290 while i + 16 <= len {
291 let chunk = _mm_loadu_si128(haystack.as_ptr().add(i).cast::<__m128i>());
292
293 let mut mask = 0u16;
297
298 let bytes: [u8; 16] = std::mem::transmute(chunk);
300 for (j, &b) in bytes.iter().enumerate() {
301 if self.contains(b) {
302 mask |= 1 << j;
303 }
304 }
305
306 if mask != 0 {
307 return Some(i + mask.trailing_zeros() as usize);
308 }
309
310 i += 16;
311 }
312 }
313
314 (i..len).find(|&j| self.contains(haystack[j]))
316 }
317
318 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
320 fn find_first_simd(&self, haystack: &[u8]) -> Option<usize> {
321 if self.negated || self.matches_non_ascii {
323 return self.find_first_scalar(haystack);
324 }
325
326 let len = haystack.len();
327 let mut i = 0;
328
329 unsafe {
331 use std::arch::aarch64::*;
332
333 while i + 16 <= len {
334 let chunk = vld1q_u8(haystack.as_ptr().add(i));
335
336 let bytes: [u8; 16] = std::mem::transmute(chunk);
338 for (j, &b) in bytes.iter().enumerate() {
339 if self.contains(b) {
340 return Some(i + j);
341 }
342 }
343
344 i += 16;
345 }
346 }
347
348 (i..len).find(|&j| self.contains(haystack[j]))
350 }
351
352 #[must_use]
355 pub fn find_all(&self, haystack: &[u8]) -> Vec<usize> {
356 let mut results = Vec::new();
357 let mut pos = 0;
358
359 while pos < haystack.len() {
360 if let Some(offset) = self.find_first(&haystack[pos..]) {
361 results.push(pos + offset);
362 pos += offset + 1;
363 } else {
364 break;
365 }
366 }
367
368 results
369 }
370
371 #[must_use]
373 pub fn count_matches(&self, haystack: &[u8]) -> usize {
374 haystack.iter().filter(|&&b| self.contains(b)).count()
375 }
376
377 #[inline]
379 #[must_use]
380 pub fn matches_any(&self, haystack: &[u8]) -> bool {
381 self.find_first(haystack).is_some()
382 }
383}
384
385impl Default for AsciiClassBitmap {
386 fn default() -> Self {
387 Self::empty()
388 }
389}
390
391#[derive(Clone, Debug)]
394pub struct CompiledCharClass {
395 pub bitmap: AsciiClassBitmap,
397 pub original: CharClass,
399 pub unicode: bool,
401}
402
403impl CompiledCharClass {
404 #[must_use]
406 pub fn new(class: &CharClass) -> Self {
407 CompiledCharClass {
408 bitmap: AsciiClassBitmap::from_char_class(class),
409 original: class.clone(),
410 unicode: false,
411 }
412 }
413
414 #[must_use]
416 pub fn new_with_unicode(class: &CharClass, unicode: bool) -> Self {
417 CompiledCharClass {
418 bitmap: AsciiClassBitmap::from_char_class(class),
419 original: class.clone(),
420 unicode,
421 }
422 }
423
424 #[inline]
426 #[must_use]
427 pub fn matches(&self, ch: char) -> bool {
428 if ch.is_ascii() {
429 self.bitmap.contains(ch as u8)
430 } else if self.unicode {
431 self.original.matches_unicode(ch)
434 } else {
435 self.original.matches(ch)
436 }
437 }
438
439 #[inline]
441 #[must_use]
442 pub fn find_first(&self, haystack: &[u8]) -> Option<usize> {
443 self.bitmap.find_first(haystack)
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450
451 #[test]
452 fn test_ascii_bitmap_single() {
453 let class = CharClass::new(false, vec![CharClassItem::Single('a')]);
454 let bitmap = AsciiClassBitmap::from_char_class(&class);
455
456 assert!(bitmap.contains(b'a'));
457 assert!(!bitmap.contains(b'b'));
458 assert!(!bitmap.contains(b'A'));
459 }
460
461 #[test]
462 fn test_ascii_bitmap_range() {
463 let class = CharClass::new(false, vec![CharClassItem::Range('a', 'z')]);
464 let bitmap = AsciiClassBitmap::from_char_class(&class);
465
466 assert!(bitmap.contains(b'a'));
467 assert!(bitmap.contains(b'm'));
468 assert!(bitmap.contains(b'z'));
469 assert!(!bitmap.contains(b'A'));
470 assert!(!bitmap.contains(b'0'));
471 }
472
473 #[test]
474 fn test_ascii_bitmap_negated() {
475 let class = CharClass::new(true, vec![CharClassItem::Range('a', 'z')]);
476 let bitmap = AsciiClassBitmap::from_char_class(&class);
477
478 assert!(!bitmap.contains(b'a'));
479 assert!(!bitmap.contains(b'z'));
480 assert!(bitmap.contains(b'A'));
481 assert!(bitmap.contains(b'0'));
482 assert!(bitmap.contains(b' '));
483 }
484
485 #[test]
486 fn test_ascii_bitmap_digit() {
487 let class = CharClass::digit();
488 let bitmap = AsciiClassBitmap::from_char_class(&class);
489
490 for b in b'0'..=b'9' {
491 assert!(bitmap.contains(b), "Should contain digit {}", b as char);
492 }
493 assert!(!bitmap.contains(b'a'));
494 assert!(!bitmap.contains(b' '));
495 }
496
497 #[test]
498 fn test_ascii_bitmap_word() {
499 let class = CharClass::word();
500 let bitmap = AsciiClassBitmap::from_char_class(&class);
501
502 assert!(bitmap.contains(b'a'));
503 assert!(bitmap.contains(b'Z'));
504 assert!(bitmap.contains(b'5'));
505 assert!(bitmap.contains(b'_'));
506 assert!(!bitmap.contains(b' '));
507 assert!(!bitmap.contains(b'-'));
508 }
509
510 #[test]
511 fn test_find_first() {
512 let class = CharClass::new(false, vec![CharClassItem::Range('a', 'z')]);
513 let bitmap = AsciiClassBitmap::from_char_class(&class);
514
515 assert_eq!(bitmap.find_first(b"123abc"), Some(3));
516 assert_eq!(bitmap.find_first(b"ABC"), None);
517 assert_eq!(bitmap.find_first(b"hello"), Some(0));
518 assert_eq!(bitmap.find_first(b""), None);
519 }
520
521 #[test]
522 fn test_find_first_long() {
523 let class = CharClass::new(false, vec![CharClassItem::Single('x')]);
524 let bitmap = AsciiClassBitmap::from_char_class(&class);
525
526 let text = b"0123456789abcdefxyz";
528 assert_eq!(bitmap.find_first(text), Some(16));
529
530 let text2 = b"01234567890123456789x";
531 assert_eq!(bitmap.find_first(text2), Some(20));
532 }
533
534 #[test]
535 fn test_find_all() {
536 let class = CharClass::new(false, vec![CharClassItem::Range('a', 'z')]);
537 let bitmap = AsciiClassBitmap::from_char_class(&class);
538
539 let positions = bitmap.find_all(b"a1b2c3");
540 assert_eq!(positions, vec![0, 2, 4]);
541 }
542
543 #[test]
544 fn test_count_matches() {
545 let class = CharClass::digit();
546 let bitmap = AsciiClassBitmap::from_char_class(&class);
547
548 assert_eq!(bitmap.count_matches(b"abc123def456"), 6);
549 assert_eq!(bitmap.count_matches(b"no digits"), 0);
550 }
551
552 #[test]
553 fn test_compiled_char_class() {
554 let class = CharClass::word();
555 let compiled = CompiledCharClass::new(&class);
556
557 assert!(compiled.matches('a'));
558 assert!(compiled.matches('Z'));
559 assert!(compiled.matches('5'));
560 assert!(!compiled.matches(' '));
561 }
562}