Skip to main content

marki_parse/
special_char.rs

1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2#[repr(u8)]
3pub enum SpecialChar {
4    Tab = b'\t',
5    Newline = b'\n',
6    CarriageReturn = b'\r',
7    Space = b' ',
8    ExclamationMark = b'!',
9    DoubleQuote = b'"',
10    Hash = b'#',
11    SingleQuote = b'\'',
12    OpenParen = b'(',
13    CloseParen = b')',
14    Asterisk = b'*',
15    Plus = b'+',
16    Dash = b'-',
17    Dot = b'.',
18    Zero = b'0',
19    GreaterThan = b'>',
20    OpenBracket = b'[',
21    Backslash = b'\\',
22    CloseBracket = b']',
23    Underscore = b'_',
24    Tilde = b'~',
25    Backtick = b'`',
26}
27
28/// Static lookup table for `from_byte`. Built at compile time.
29static FROM_BYTE: [Option<SpecialChar>; 256] = {
30    use SpecialChar as S;
31    let mut table: [Option<SpecialChar>; 256] = [None; 256];
32    table[b'\t' as usize] = Some(S::Tab);
33    table[b'\n' as usize] = Some(S::Newline);
34    table[b'\r' as usize] = Some(S::CarriageReturn);
35    table[b' ' as usize] = Some(S::Space);
36    table[b'!' as usize] = Some(S::ExclamationMark);
37    table[b'"' as usize] = Some(S::DoubleQuote);
38    table[b'#' as usize] = Some(S::Hash);
39    table[b'\'' as usize] = Some(S::SingleQuote);
40    table[b'(' as usize] = Some(S::OpenParen);
41    table[b')' as usize] = Some(S::CloseParen);
42    table[b'*' as usize] = Some(S::Asterisk);
43    table[b'+' as usize] = Some(S::Plus);
44    table[b'-' as usize] = Some(S::Dash);
45    table[b'.' as usize] = Some(S::Dot);
46    table[b'0' as usize] = Some(S::Zero);
47    table[b'>' as usize] = Some(S::GreaterThan);
48    table[b'[' as usize] = Some(S::OpenBracket);
49    table[b'\\' as usize] = Some(S::Backslash);
50    table[b']' as usize] = Some(S::CloseBracket);
51    table[b'_' as usize] = Some(S::Underscore);
52    table[b'~' as usize] = Some(S::Tilde);
53    table[b'`' as usize] = Some(S::Backtick);
54    table
55};
56
57impl SpecialChar {
58    /// Returns the `u8` value of this character.
59    #[inline]
60    #[must_use]
61    pub const fn byte(self) -> u8 {
62        self as u8
63    }
64
65    /// Look up a byte in the static table. O(1).
66    #[inline]
67    #[must_use]
68    pub fn from_byte(b: u8) -> Option<Self> {
69        FROM_BYTE[b as usize]
70    }
71
72    #[inline]
73    #[must_use]
74    pub const fn is_list_char(self) -> bool {
75        matches!(self, Self::Dash | Self::Asterisk | Self::Plus)
76    }
77
78    #[inline]
79    #[must_use]
80    pub fn count_leading_bytes(self, bytes: &[u8]) -> usize {
81        #[cfg(target_arch = "x86_64")]
82        {
83            // SAFETY: SSE2 is baseline on all x86_64 processors.
84            unsafe { self.count_leading_sse2(bytes) }
85        }
86        #[cfg(not(target_arch = "x86_64"))]
87        {
88            self.count_leading_scalar(bytes)
89        }
90    }
91
92    #[cfg(not(target_arch = "x86_64"))]
93    fn count_leading_scalar(self, bytes: &[u8]) -> usize {
94        let needle = self.byte();
95        let mut n = 0;
96        while n < bytes.len() && bytes[n] == needle {
97            n += 1;
98        }
99        n
100    }
101
102    #[cfg(target_arch = "x86_64")]
103    #[target_feature(enable = "sse2")]
104    #[allow(clippy::cast_ptr_alignment)]
105    unsafe fn count_leading_sse2(self, bytes: &[u8]) -> usize {
106        use std::arch::x86_64::{
107            _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_set1_epi8,
108        };
109        let needle = self.byte();
110        let len = bytes.len();
111        let ptr = bytes.as_ptr();
112        unsafe {
113            let n = _mm_set1_epi8(i8::from_ne_bytes([needle]));
114            let mut i = 0;
115            while i + 16 <= len {
116                let chunk = _mm_loadu_si128(ptr.add(i).cast());
117                let mask = u16::try_from(_mm_movemask_epi8(_mm_cmpeq_epi8(chunk, n)))
118                    .expect("movemask produces a 16-bit value");
119                if mask == 0xFFFF {
120                    i += 16;
121                } else {
122                    let matched: usize = mask
123                        .trailing_ones()
124                        .try_into()
125                        .expect("trailing ones fit in usize");
126                    return i + matched;
127                }
128            }
129            while i < len && bytes[i] == needle {
130                i += 1;
131            }
132            i
133        }
134    }
135}
136
137impl PartialEq<u8> for SpecialChar {
138    #[inline]
139    fn eq(&self, other: &u8) -> bool {
140        self.byte() == *other
141    }
142}
143
144impl PartialEq<SpecialChar> for u8 {
145    #[inline]
146    fn eq(&self, other: &SpecialChar) -> bool {
147        *self == other.byte()
148    }
149}
150
151impl PartialEq<SpecialChar> for Option<&u8> {
152    #[inline]
153    fn eq(&self, other: &SpecialChar) -> bool {
154        matches!(self, Some(b) if **b == other.byte())
155    }
156}
157
158impl PartialEq<SpecialChar> for Option<u8> {
159    #[inline]
160    fn eq(&self, other: &SpecialChar) -> bool {
161        matches!(self, Some(b) if *b == other.byte())
162    }
163}
164
165impl std::fmt::Display for SpecialChar {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        write!(f, "{}", self.byte() as char)
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::SpecialChar;
174
175    #[test]
176    fn count_leading_bytes_empty() {
177        assert_eq!(SpecialChar::Asterisk.count_leading_bytes(b""), 0);
178    }
179
180    #[test]
181    fn count_leading_bytes_none() {
182        assert_eq!(SpecialChar::Asterisk.count_leading_bytes(b"abc"), 0);
183    }
184
185    #[test]
186    fn count_leading_bytes_short_run() {
187        assert_eq!(SpecialChar::Backtick.count_leading_bytes(b"``code"), 2);
188    }
189
190    #[test]
191    fn count_leading_bytes_exact_boundary() {
192        // Exactly 16 characters exercises the SSE2 fast-path boundary.
193        let input = b"****************";
194        assert_eq!(SpecialChar::Asterisk.count_leading_bytes(input), 16);
195    }
196
197    #[test]
198    fn count_leading_bytes_long_run() {
199        let input = vec![SpecialChar::Hash.byte(); 100];
200        assert_eq!(SpecialChar::Hash.count_leading_bytes(&input), 100);
201    }
202
203    #[test]
204    fn count_leading_bytes_partial_chunk() {
205        let input = b"*****x**********";
206        assert_eq!(SpecialChar::Asterisk.count_leading_bytes(input), 5);
207    }
208
209    #[test]
210    fn count_leading_bytes_tilde() {
211        assert_eq!(SpecialChar::Tilde.count_leading_bytes(b"~~~rust"), 3);
212    }
213}