uwuifier/
bitap.rs

1#[cfg(target_arch = "x86")]
2use std::arch::x86::*;
3#[cfg(target_arch = "x86_64")]
4use std::arch::x86_64::*;
5
6use super::{A, str_to_bytes, bytes_len};
7
8pub struct Bitap8x16 {
9    v: __m128i,
10    start_mask: __m128i
11}
12
13const fn get_masks(patterns: &[&str]) -> [A; 256] {
14    // preprecessing step to associate each character with a mask of locations
15    // in each of the 8 pattern strings
16
17    // must use const to init this array
18    const TEMP_A: A = A([0u8; 16]);
19    let mut res = [TEMP_A; 256];
20    let mut i = 0;
21    let bit5 = 0b0010_0000u8;
22
23    while i < patterns.len() {
24        let bytes = patterns[i].as_bytes();
25        // offset masks so the last character maps to the last bit of each 16-bit lane
26        // this is useful for movemask later
27        let offset = 16 - bytes.len();
28        let mut j = 0;
29
30        while j < bytes.len() {
31            let idx = i * 16 + j + offset;
32            res[bytes[j] as usize].0[idx / 8] |= 1u8 << (idx % 8);
33
34            // make sure to be case insensitive
35            if bytes[j].is_ascii_alphabetic() {
36                res[(bytes[j] ^ bit5) as usize].0[idx / 8] |= 1u8 << (idx % 8);
37            }
38
39            j += 1;
40        }
41
42        i += 1;
43    }
44
45    res
46}
47
48const fn get_start_mask(patterns: &[&str]) -> A {
49    // get a mask that indicates the first character for each pattern
50    let mut res = A([0u8; 16]);
51    let mut i = 0;
52
53    while i < patterns.len() {
54        let j = 16 - patterns[i].as_bytes().len();
55        let idx = i * 16 + j;
56        res.0[idx / 8] |= 1u8 << (idx % 8);
57        i += 1;
58    }
59
60    res
61}
62
63static PATTERNS: [&str; 8] = [
64    "small",
65    "cute",
66    "fluff",
67    "love",
68    "stupid",
69    "what",
70    "meow",
71    "meow"
72];
73
74static MASKS: [A; 256] = get_masks(&PATTERNS);
75static START_MASK: A = get_start_mask(&PATTERNS);
76
77// important note: replacement cannot be more than 2 times longer than the corresponding pattern!
78// this is to prevent increasing the size of the output too much in certain cases
79// another note: this table has a fixed size of 8 and expanding it will require changing the
80// algorithm a little
81static REPLACE: [A; 8] = [
82    str_to_bytes("smol"),
83    str_to_bytes("kawaii~"),
84    str_to_bytes("floof"),
85    str_to_bytes("luv"),
86    str_to_bytes("baka"),
87    str_to_bytes("nani"),
88    str_to_bytes("nya~"),
89    str_to_bytes("nya~")
90];
91
92const fn get_len(a: &[A]) -> [usize; 8] {
93    let mut res = [0usize; 8];
94    let mut i = 0;
95
96    while i < a.len() {
97        res[i] = bytes_len(&a[i].0);
98        i += 1;
99    }
100
101    res
102}
103
104static REPLACE_LEN: [usize; 8] = get_len(&REPLACE);
105
106#[derive(Debug, PartialEq)]
107pub struct Match {
108    pub match_len: usize,
109    pub replace_ptr: *const __m128i,
110    pub replace_len: usize
111}
112
113impl Bitap8x16 {
114    #[inline]
115    #[target_feature(enable = "sse4.1")]
116    pub unsafe fn new() -> Self {
117        Self {
118            v: _mm_setzero_si128(),
119            start_mask: _mm_load_si128(START_MASK.0.as_ptr() as *const __m128i)
120        }
121    }
122
123    #[inline]
124    #[target_feature(enable = "sse4.1")]
125    pub unsafe fn next(&mut self, c: u8) -> Option<Match> {
126        self.v = _mm_slli_epi16(self.v, 1);
127        self.v = _mm_or_si128(self.v, self.start_mask);
128        let mask = _mm_load_si128(MASKS.get_unchecked(c as usize).0.as_ptr() as *const __m128i);
129        self.v = _mm_and_si128(self.v, mask);
130
131        let match_mask = (_mm_movemask_epi8(self.v) as u32) & 0xAAAAAAAAu32;
132
133        if match_mask != 0 {
134            let match_idx = (match_mask.trailing_zeros() as usize) / 2;
135
136            return Some(Match {
137                match_len: PATTERNS.get_unchecked(match_idx).len(),
138                replace_ptr: REPLACE.get_unchecked(match_idx).0.as_ptr() as *const __m128i,
139                replace_len: *REPLACE_LEN.get_unchecked(match_idx)
140            });
141        }
142
143        None
144    }
145
146    #[inline]
147    #[target_feature(enable = "sse4.1")]
148    pub unsafe fn reset(&mut self) {
149        self.v = _mm_setzero_si128();
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    #[test]
158    fn test_bitap() {
159        if !is_x86_feature_detected!("sse4.1") {
160            panic!("sse4.1 feature not detected!");
161        }
162
163        unsafe {
164            let mut b = Bitap8x16::new();
165            assert_eq!(b.next(b'c'), None);
166            assert_eq!(b.next(b'u'), None);
167            assert_eq!(b.next(b't'), None);
168            let next = b.next(b'e').unwrap();
169            assert_eq!(next.match_len, 4);
170            assert_eq!(next.replace_len, 7);
171
172            b.reset();
173            assert_eq!(b.next(b'w'), None);
174            assert_eq!(b.next(b'h'), None);
175            assert_eq!(b.next(b'a'), None);
176            let next = b.next(b't').unwrap();
177            assert_eq!(next.match_len, 4);
178            assert_eq!(next.replace_len, 4);
179
180            assert_eq!(b.next(b'w'), None);
181            assert_eq!(b.next(b'h'), None);
182            assert_eq!(b.next(b'a'), None);
183            assert_eq!(b.next(b'a'), None);
184
185            assert_eq!(b.next(b'W'), None);
186            assert_eq!(b.next(b'h'), None);
187            assert_eq!(b.next(b'A'), None);
188            let next = b.next(b't').unwrap();
189            assert_eq!(next.match_len, 4);
190            assert_eq!(next.replace_len, 4);
191        }
192    }
193}