Skip to main content

mago_bytes/
lib.rs

1//! Shared byte-slice utilities for the Mago toolchain.
2//!
3//! PHP source code is binary-safe, so the toolchain handles identifiers, comments, and
4//! string literals as `&[u8]` end-to-end. Diagnostic messages and human-facing logs are
5//! UTF-8 strings, so a tiny adapter layer is needed at the display boundary. This crate
6//! is that adapter, plus a few SIMD-accelerated byte-trimming primitives.
7
8#![deny(unsafe_op_in_unsafe_fn)]
9
10use std::fmt;
11
12#[cfg(target_arch = "aarch64")]
13use std::arch::aarch64::vceqq_u8;
14#[cfg(target_arch = "aarch64")]
15use std::arch::aarch64::vdupq_n_u8;
16#[cfg(target_arch = "aarch64")]
17use std::arch::aarch64::vld1q_u8;
18#[cfg(target_arch = "aarch64")]
19use std::arch::aarch64::vminvq_u8;
20#[cfg(target_arch = "x86_64")]
21use std::arch::x86_64::__m128i;
22#[cfg(target_arch = "x86_64")]
23use std::arch::x86_64::_mm_cmpeq_epi8;
24#[cfg(target_arch = "x86_64")]
25use std::arch::x86_64::_mm_loadu_si128;
26#[cfg(target_arch = "x86_64")]
27use std::arch::x86_64::_mm_movemask_epi8;
28#[cfg(target_arch = "x86_64")]
29use std::arch::x86_64::_mm_set1_epi8;
30
31/// Writes `bytes` to `f`, rendering valid UTF-8 verbatim and every byte that is not part
32/// of a valid UTF-8 sequence as a `\xHH` hex escape.
33///
34/// PHP source is binary-safe, so identifiers, comments, and string literals can carry
35/// arbitrary bytes. When such content reaches a diagnostic message, lossy decoding would
36/// collapse every stray byte to `U+FFFD` — unreadable, and ambiguous (two distinct byte
37/// sequences look identical). Escaping instead keeps the output readable *and* lossless:
38/// `Caf` followed by the bytes `C9 E9 FF` renders as `Caf\xC9\xE9\xFF`.
39///
40/// # Errors
41///
42/// Returns any error produced by the underlying [`fmt::Formatter`] while writing.
43pub fn write_escaped(f: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result {
44    let mut rest = bytes;
45    while !rest.is_empty() {
46        match std::str::from_utf8(rest) {
47            Ok(valid) => {
48                f.write_str(valid)?;
49                break;
50            }
51            Err(error) => {
52                let valid_up_to = error.valid_up_to();
53                if valid_up_to > 0 {
54                    // SAFETY: `valid_up_to` is the length of the longest valid UTF-8 prefix,
55                    // per the `Utf8Error` contract.
56                    f.write_str(unsafe { std::str::from_utf8_unchecked(&rest[..valid_up_to]) })?;
57                }
58
59                // `error_len()` is `None` when the input ends mid-sequence; in that case the
60                // whole remaining tail is unconvertible.
61                let invalid_len = error.error_len().unwrap_or(rest.len() - valid_up_to);
62                for &byte in &rest[valid_up_to..valid_up_to + invalid_len] {
63                    write!(f, "\\x{byte:02X}")?;
64                }
65
66                rest = &rest[valid_up_to + invalid_len..];
67            }
68        }
69    }
70
71    Ok(())
72}
73
74/// Renders a byte slice as text for diagnostic messages.
75///
76/// Valid UTF-8 is shown verbatim; bytes that are not valid UTF-8 are escaped as `\xHH`.
77/// Use this in `format!`/`write!`/`println!` whenever a `&[u8]` needs to surface in
78/// user-facing output (issue messages, log lines, error reports).
79#[derive(Debug, Clone, Copy)]
80#[repr(transparent)]
81pub struct BytesDisplay<'src>(pub &'src [u8]);
82
83impl fmt::Display for BytesDisplay<'_> {
84    #[inline]
85    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86        write_escaped(f, self.0)
87    }
88}
89
90/// Strips all leading occurrences of `byte` from `s`.
91///
92/// SIMD-accelerated on x86_64 (SSE2, 16-byte chunks) and aarch64 (NEON, 16-byte chunks);
93/// scalar elsewhere. SSE2 and NEON are baseline on their respective targets, so no
94/// runtime feature detection is required.
95#[inline]
96#[must_use]
97pub fn trim_start_byte(s: &[u8], byte: u8) -> &[u8] {
98    let mut i = simd_skip_leading(s, byte);
99    while i < s.len() && s[i] == byte {
100        i += 1;
101    }
102    &s[i..]
103}
104
105/// Strips all trailing occurrences of `byte` from `s`.
106///
107/// SIMD-accelerated on x86_64 (SSE2, 16-byte chunks) and aarch64 (NEON, 16-byte chunks);
108/// scalar elsewhere.
109#[inline]
110#[must_use]
111pub fn trim_end_byte(s: &[u8], byte: u8) -> &[u8] {
112    let mut end = simd_skip_trailing(s, byte);
113    while end > 0 && s[end - 1] == byte {
114        end -= 1;
115    }
116    &s[..end]
117}
118
119/// Strips all leading and trailing occurrences of `byte` from `s`.
120#[inline]
121#[must_use]
122pub fn trim_byte(s: &[u8], byte: u8) -> &[u8] {
123    trim_end_byte(trim_start_byte(s, byte), byte)
124}
125
126/// Returns the byte index past the last all-matching SIMD chunk at the start of `s`, or
127/// the exact byte index of the first non-matching byte when the SIMD path can locate it.
128/// The scalar tail in the caller picks up from this position.
129#[cfg(target_arch = "x86_64")]
130#[inline]
131fn simd_skip_leading(s: &[u8], byte: u8) -> usize {
132    let mut i = 0;
133    // SAFETY: SSE2 is baseline on x86_64; the loop guard `i + 16 <= s.len()` keeps every
134    // unaligned 16-byte load inside the slice's allocation.
135    #[allow(clippy::multiple_unsafe_ops_per_block)]
136    unsafe {
137        let target = _mm_set1_epi8(byte as i8);
138        while i + 16 <= s.len() {
139            // `_mm_loadu_si128` is an unaligned load, so the pointer need not satisfy
140            // `__m128i`'s 16-byte alignment requirement.
141            #[allow(clippy::cast_ptr_alignment)]
142            let chunk = _mm_loadu_si128(s.as_ptr().add(i).cast::<__m128i>());
143            let eq = _mm_cmpeq_epi8(chunk, target);
144            let mask = _mm_movemask_epi8(eq) as u32;
145            if mask == 0xFFFF {
146                i += 16;
147                continue;
148            }
149            return i + mask.trailing_ones() as usize;
150        }
151    }
152    i
153}
154
155#[cfg(target_arch = "aarch64")]
156#[inline]
157fn simd_skip_leading(s: &[u8], byte: u8) -> usize {
158    let mut i = 0;
159    // SAFETY: NEON is baseline on aarch64; the loop guard `i + 16 <= s.len()` keeps every
160    // unaligned 16-byte load inside the slice's allocation.
161    #[allow(clippy::multiple_unsafe_ops_per_block)]
162    unsafe {
163        let target = vdupq_n_u8(byte);
164        while i + 16 <= s.len() {
165            let chunk = vld1q_u8(s.as_ptr().add(i));
166            let eq = vceqq_u8(chunk, target);
167            // Reduce-min across lanes: 0xFF iff every lane matched.
168            if vminvq_u8(eq) != 0xFF {
169                break;
170            }
171            i += 16;
172        }
173    }
174    i
175}
176
177#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
178#[inline]
179fn simd_skip_leading(_s: &[u8], _byte: u8) -> usize {
180    0
181}
182
183/// Returns the byte index of the first non-matching byte from the end, or the index past
184/// the last all-matching SIMD chunk near the end. The caller's scalar loop trims any
185/// remaining matching bytes.
186#[cfg(target_arch = "x86_64")]
187#[inline]
188fn simd_skip_trailing(s: &[u8], byte: u8) -> usize {
189    let mut end = s.len();
190    // SAFETY: SSE2 is baseline on x86_64; `end >= 16` keeps every load at `end - 16`
191    // inside the slice.
192    #[allow(clippy::multiple_unsafe_ops_per_block)]
193    unsafe {
194        let target = _mm_set1_epi8(byte as i8);
195        while end >= 16 {
196            let start = end - 16;
197            // `_mm_loadu_si128` is an unaligned load, so the pointer need not satisfy
198            // `__m128i`'s 16-byte alignment requirement.
199            #[allow(clippy::cast_ptr_alignment)]
200            let chunk = _mm_loadu_si128(s.as_ptr().add(start).cast::<__m128i>());
201            let eq = _mm_cmpeq_epi8(chunk, target);
202            let mask = _mm_movemask_epi8(eq) as u32;
203            if mask == 0xFFFF {
204                end = start;
205                continue;
206            }
207            // Shift the 16-bit mask into the high half of u32 so `leading_ones` measures
208            // from the highest lane; the count of leading 1s is the number of trailing
209            // matching bytes in this chunk.
210            let mask_hi = mask << 16;
211            return start + 16 - mask_hi.leading_ones() as usize;
212        }
213    }
214    end
215}
216
217#[cfg(target_arch = "aarch64")]
218#[inline]
219fn simd_skip_trailing(s: &[u8], byte: u8) -> usize {
220    let mut end = s.len();
221    // SAFETY: NEON is baseline on aarch64; `end >= 16` keeps the load at `end - 16` inside
222    // the slice.
223    #[allow(clippy::multiple_unsafe_ops_per_block)]
224    unsafe {
225        let target = vdupq_n_u8(byte);
226        while end >= 16 {
227            let start = end - 16;
228            let chunk = vld1q_u8(s.as_ptr().add(start));
229            let eq = vceqq_u8(chunk, target);
230            if vminvq_u8(eq) != 0xFF {
231                break;
232            }
233            end = start;
234        }
235    }
236    end
237}
238
239#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
240#[inline]
241fn simd_skip_trailing(s: &[u8], _byte: u8) -> usize {
242    s.len()
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    fn escaped(bytes: &[u8]) -> String {
250        BytesDisplay(bytes).to_string()
251    }
252
253    #[test]
254    fn escape_pure_ascii_is_verbatim() {
255        assert_eq!(escaped(b""), "");
256        assert_eq!(escaped(b"hello"), "hello");
257        assert_eq!(escaped(b"a b\tc"), "a b\tc");
258    }
259
260    #[test]
261    fn escape_valid_utf8_is_verbatim() {
262        // `café` and `日本語` are valid UTF-8 and must render unchanged, not escaped.
263        assert_eq!(escaped("café".as_bytes()), "café");
264        assert_eq!(escaped("日本語".as_bytes()), "日本語");
265    }
266
267    #[test]
268    fn escape_invalid_bytes() {
269        // 0xC9 is an incomplete 2-byte lead, 0xFF is never valid UTF-8.
270        assert_eq!(escaped(b"Caf\xC9\xE9\xFF"), "Caf\\xC9\\xE9\\xFF");
271        assert_eq!(escaped(b"\xFF\xFE"), "\\xFF\\xFE");
272        // Invalid byte between valid runs.
273        assert_eq!(escaped(b"a\xFFb"), "a\\xFFb");
274        // Valid multibyte char followed by an invalid byte.
275        assert_eq!(escaped(b"\xC3\xA9\xFF"), "é\\xFF");
276        // Trailing incomplete sequence (error_len() == None path).
277        assert_eq!(escaped(b"ab\xC9"), "ab\\xC9");
278    }
279
280    #[test]
281    fn trim_start_byte_basic() {
282        assert_eq!(trim_start_byte(b"", b'x'), b"");
283        assert_eq!(trim_start_byte(b"xxx", b'x'), b"");
284        assert_eq!(trim_start_byte(b"xxxa", b'x'), b"a");
285        assert_eq!(trim_start_byte(b"axxx", b'x'), b"axxx");
286        assert_eq!(trim_start_byte(b"abc", b'x'), b"abc");
287    }
288
289    #[test]
290    fn trim_start_byte_long() {
291        let s: Vec<u8> = b"xxxxxxxxxxxxxxxx".iter().chain(b"abc".iter()).copied().collect();
292        assert_eq!(trim_start_byte(&s, b'x'), b"abc");
293        let s: Vec<u8> = b"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx".iter().chain(b"yz".iter()).copied().collect();
294        assert_eq!(trim_start_byte(&s, b'x'), b"yz");
295        let s: Vec<u8> = b"xxxxxxxxxxxxxxxxx".iter().chain(b"q".iter()).copied().collect();
296        assert_eq!(trim_start_byte(&s, b'x'), b"q");
297        let s = vec![b'x'; 64];
298        assert_eq!(trim_start_byte(&s, b'x'), b"");
299        let mut s = vec![b'x'; 16];
300        s[15] = b'q';
301        assert_eq!(trim_start_byte(&s, b'x'), b"q");
302    }
303
304    #[test]
305    fn trim_end_byte_basic() {
306        assert_eq!(trim_end_byte(b"", b'x'), b"");
307        assert_eq!(trim_end_byte(b"xxx", b'x'), b"");
308        assert_eq!(trim_end_byte(b"axxx", b'x'), b"a");
309        assert_eq!(trim_end_byte(b"xxxa", b'x'), b"xxxa");
310        assert_eq!(trim_end_byte(b"abc", b'x'), b"abc");
311    }
312
313    #[test]
314    fn trim_end_byte_long() {
315        let s: Vec<u8> = b"abc".iter().chain(b"xxxxxxxxxxxxxxxx".iter()).copied().collect();
316        assert_eq!(trim_end_byte(&s, b'x'), b"abc");
317        let s: Vec<u8> = b"yz".iter().chain(b"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx".iter()).copied().collect();
318        assert_eq!(trim_end_byte(&s, b'x'), b"yz");
319        let s: Vec<u8> = b"q".iter().chain(b"xxxxxxxxxxxxxxxxx".iter()).copied().collect();
320        assert_eq!(trim_end_byte(&s, b'x'), b"q");
321        let s = vec![b'x'; 64];
322        assert_eq!(trim_end_byte(&s, b'x'), b"");
323        let mut s = vec![b'x'; 16];
324        s[0] = b'q';
325        assert_eq!(trim_end_byte(&s, b'x'), b"q");
326    }
327
328    #[test]
329    fn trim_byte_both_sides() {
330        assert_eq!(trim_byte(b"xxxabcxxx", b'x'), b"abc");
331        assert_eq!(trim_byte(b"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxabcxxxxxxxxxxxxxxxx", b'x'), b"abc");
332        assert_eq!(trim_byte(b"abc", b'x'), b"abc");
333        assert_eq!(trim_byte(b"", b'x'), b"");
334        assert_eq!(trim_byte(b"xxxx", b'x'), b"");
335    }
336}