im_rope/
validations.rs

1//! Operations related to UTF-8 validation.
2//!
3//! Lightly modified from the standard library's implementation
4//! in order to operate on [`im::Vector`]s.
5
6use im::vector::Vector;
7use std::mem;
8
9/// Errors which can occur when attempting to interpret a sequence of [`u8`]
10/// as a string.
11///
12/// This structure exactly duplicates [`std::str::Utf8Error`] to work around
13/// the fact that its constructors are private.
14#[derive(Copy, Eq, PartialEq, Clone, Debug)]
15pub struct Utf8Error {
16    valid_up_to: usize,
17    error_len: Option<u8>,
18}
19
20impl Utf8Error {
21    /// Returns the index in the given string up to which valid UTF-8 was
22    /// verified.
23    #[must_use]
24    #[inline]
25    pub const fn valid_up_to(&self) -> usize {
26        self.valid_up_to
27    }
28
29    /// Provides more information about the failure:
30    ///
31    /// * `None`: the end of the input was reached unexpectedly.
32    ///   `self.valid_up_to()` is 1 to 3 bytes from the end of the input.
33    ///   If a byte stream (such as a file or a network socket) is being decoded incrementally,
34    ///   this could be a valid `char` whose UTF-8 byte sequence is spanning multiple chunks.
35    ///
36    /// * `Some(len)`: an unexpected byte was encountered.
37    ///   The length provided is that of the invalid byte sequence
38    ///   that starts at the index given by `valid_up_to()`.
39    #[must_use]
40    #[inline]
41    pub const fn error_len(&self) -> Option<usize> {
42        // FIXME: This should become `map` again, once it's `const`
43        match self.error_len {
44            Some(len) => Some(len as usize),
45            None => None,
46        }
47    }
48}
49
50impl std::fmt::Display for Utf8Error {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        if let Some(error_len) = self.error_len {
53            write!(
54                f,
55                "invalid utf-8 sequence of {} bytes from index {}",
56                error_len, self.valid_up_to
57            )
58        } else {
59            write!(
60                f,
61                "incomplete utf-8 byte sequence from index {}",
62                self.valid_up_to
63            )
64        }
65    }
66}
67
68impl std::error::Error for Utf8Error {}
69
70impl From<std::str::Utf8Error> for Utf8Error {
71    fn from(e: std::str::Utf8Error) -> Self {
72        Utf8Error {
73            valid_up_to: e.valid_up_to(),
74            error_len: e.error_len().map(|l| l.try_into().unwrap()),
75        }
76    }
77}
78
79/// Checks whether the byte is a UTF-8 initial byte
80#[allow(clippy::cast_possible_wrap)]
81pub(super) const fn utf8_is_first_byte(byte: u8) -> bool {
82    byte as i8 >= -0x40
83}
84
85/// Returns the initial codepoint accumulator for the first byte.
86/// The first byte is special, only want bottom 5 bits for width 2, 4 bits
87/// for width 3, and 3 bits for width 4.
88#[inline]
89const fn utf8_first_byte(byte: u8, width: u32) -> u32 {
90    (byte & (0x7F >> width)) as u32
91}
92
93/// Returns the value of `ch` updated with continuation byte `byte`.
94#[inline]
95const fn utf8_acc_cont_byte(ch: u32, byte: u8) -> u32 {
96    (ch << 6) | (byte & CONT_MASK) as u32
97}
98
99/// Checks whether the byte is a UTF-8 continuation byte (i.e., starts with the
100/// bits `10`).
101#[inline]
102#[allow(clippy::cast_possible_wrap)]
103pub(super) const fn utf8_is_cont_byte(byte: u8) -> bool {
104    (byte as i8) < -0x40
105}
106
107#[inline]
108pub(super) fn starts_on_utf8_boundary(v: &Vector<u8>) -> bool {
109    v.front().map_or(true, |&b| utf8_is_first_byte(b))
110}
111
112#[inline]
113pub(super) fn ends_on_utf8_boundary(v: &Vector<u8>) -> bool {
114    if v.back().map_or(true, |&b| b < 128) {
115        return true;
116    }
117
118    let mut w = v.clone();
119    w.pop_back(); // We've already ruled out the ASCII case
120
121    for expected_length in 2usize..=4 {
122        let Some(ch) = w.pop_back() else { return false };
123        
124        if utf8_is_first_byte(ch) {
125            return utf8_char_width(ch) == expected_length;
126        }
127    }
128
129    false
130}
131
132/// Reads the next code point out of a byte iterator (assuming a
133/// UTF-8-like encoding).
134///
135/// # Safety
136///
137/// `bytes` must produce a valid UTF-8-like (UTF-8 or WTF-8) string
138#[inline]
139pub(super) unsafe fn next_code_point<I: Iterator<Item = u8>>(bytes: &mut I) -> Option<u32> {
140    // Decode UTF-8
141    let x = bytes.next()?;
142    if x < 128 {
143        return Some(x.into());
144    }
145
146    // Multibyte case follows
147    // Decode from a byte combination out of: [[[x y] z] w]
148    // NOTE: Performance is sensitive to the exact formulation here
149    let init = utf8_first_byte(x, 2);
150    // SAFETY: `bytes` produces an UTF-8-like string,
151    // so the iterator must produce a value here.
152    let y = unsafe { unwrap_debug(bytes.next()) };
153    let mut ch = utf8_acc_cont_byte(init, y);
154    if x >= 0xE0 {
155        // [[x y z] w] case
156        // 5th bit in 0xE0 .. 0xEF is always clear, so `init` is still valid
157        // SAFETY: `bytes` produces an UTF-8-like string,
158        // so the iterator must produce a value here.
159        let z = unsafe { unwrap_debug(bytes.next()) };
160        let y_z = utf8_acc_cont_byte((y & CONT_MASK).into(), z);
161        ch = init << 12 | y_z;
162        if x >= 0xF0 {
163            // [x y z w] case
164            // use only the lower 3 bits of `init`
165            // SAFETY: `bytes` produces an UTF-8-like string,
166            // so the iterator must produce a value here.
167            let w = unsafe { unwrap_debug(bytes.next()) };
168            ch = (init & 7) << 18 | utf8_acc_cont_byte(y_z, w);
169        }
170    }
171
172    Some(ch)
173}
174
175/// Reads the last code point out of a byte iterator (assuming a
176/// UTF-8-like encoding).
177///
178/// # Safety
179///
180/// `bytes` must produce a valid UTF-8-like (UTF-8 or WTF-8) string
181#[inline]
182pub(super) unsafe fn next_code_point_reverse<I>(bytes: &mut I) -> Option<u32>
183where
184    I: DoubleEndedIterator<Item = u8>,
185{
186    // Decode UTF-8
187    let w = match bytes.next_back()? {
188        next_byte if next_byte < 128 => return Some(next_byte.into()),
189        back_byte => back_byte,
190    };
191
192    // Multibyte case follows
193    // Decode from a byte combination out of: [x [y [z w]]]
194
195    // SAFETY: `bytes` produces an UTF-8-like string,
196    // so the iterator must produce a value here.
197    let z = unsafe { unwrap_debug(bytes.next_back()) };
198    let mut ch = utf8_first_byte(z, 2);
199    if utf8_is_cont_byte(z) {
200        // SAFETY: `bytes` produces an UTF-8-like string,
201        // so the iterator must produce a value here.
202        let y = unsafe { unwrap_debug(bytes.next_back()) };
203        ch = utf8_first_byte(y, 3);
204        if utf8_is_cont_byte(y) {
205            // SAFETY: `bytes` produces an UTF-8-like string,
206            // so the iterator must produce a value here.
207            let x = unsafe { unwrap_debug(bytes.next_back()) };
208            ch = utf8_first_byte(x, 4);
209            ch = utf8_acc_cont_byte(ch, y);
210        }
211        ch = utf8_acc_cont_byte(ch, z);
212    }
213    ch = utf8_acc_cont_byte(ch, w);
214
215    Some(ch)
216}
217
218const NONASCII_MASK: usize = usize::from_ne_bytes([0x80; mem::size_of::<usize>()]);
219
220/// Returns `true` if any byte in the word `x` is nonascii (>= 128).
221#[inline]
222const fn contains_nonascii(x: usize) -> bool {
223    (x & NONASCII_MASK) != 0
224}
225
226/// Walks through `v` checking that it's a valid UTF-8 sequence,
227/// returning `Ok(())` in that case, or, if it is invalid, `Err(err)`.
228#[inline]
229#[allow(clippy::too_many_lines)]
230pub(super) fn run_utf8_validation(v: &Vector<u8>) -> Result<(), Utf8Error> {
231    let usize_bytes = mem::size_of::<usize>();
232    let ascii_block_size = 2 * usize_bytes;
233
234    let mut success_offset: usize = 0;
235    let mut chunk_iter = v.leaves();
236
237    let mut chunk: &[u8] = &[];
238    let mut index: usize = 0;
239    let mut len: usize = 0;
240    let mut blocks_end: usize = 0;
241    let mut align: usize = 0;
242
243    macro_rules! err {
244        ($error_len: expr) => {
245            Err(Utf8Error {
246                valid_up_to: success_offset,
247                error_len: $error_len,
248            })
249        };
250    }
251
252    macro_rules! advance {
253        ($result: expr) => {{
254            chunk = match chunk_iter.next() {
255                Some(chunk) => chunk,
256                None => return $result,
257            };
258            index = 0;
259            len = chunk.len();
260            blocks_end = if len >= ascii_block_size {
261                len - ascii_block_size + 1
262            } else {
263                0
264            };
265            align = chunk.as_ptr().align_offset(usize_bytes);
266        }};
267    }
268
269    macro_rules! next {
270        () => {{
271            while (index == len) {
272                advance!(err!(None))
273            }
274            index += 1;
275            chunk[index - 1]
276        }};
277    }
278
279    loop {
280        while index == len {
281            advance!(Ok(()));
282        }
283
284        let first = chunk[index];
285
286        if first >= 128 {
287            index += 1;
288            let w = utf8_char_width(first);
289            // 2-byte encoding is for codepoints  \u{0080} to  \u{07ff}
290            //        first  C2 80        last DF BF
291            // 3-byte encoding is for codepoints  \u{0800} to  \u{ffff}
292            //        first  E0 A0 80     last EF BF BF
293            //   excluding surrogates codepoints  \u{d800} to  \u{dfff}
294            //               ED A0 80 to       ED BF BF
295            // 4-byte encoding is for codepoints \u{1000}0 to \u{10ff}ff
296            //        first  F0 90 80 80  last F4 8F BF BF
297            //
298            // Use the UTF-8 syntax from the RFC
299            //
300            // https://tools.ietf.org/html/rfc3629
301            // UTF8-1      = %x00-7F
302            // UTF8-2      = %xC2-DF UTF8-tail
303            // UTF8-3      = %xE0 %xA0-BF UTF8-tail / %xE1-EC 2( UTF8-tail ) /
304            //               %xED %x80-9F UTF8-tail / %xEE-EF 2( UTF8-tail )
305            // UTF8-4      = %xF0 %x90-BF 2( UTF8-tail ) / %xF1-F3 3( UTF8-tail ) /
306            //               %xF4 %x80-8F 2( UTF8-tail )
307            match w {
308                2 => {
309                    #[allow(clippy::cast_possible_wrap)]
310                    if next!() as i8 >= -64 {
311                        return err!(Some(1));
312                    }
313                    success_offset += 2;
314                }
315                3 => {
316                    #[allow(clippy::unnested_or_patterns)]
317                    match (first, next!()) {
318                        (0xE0, 0xA0..=0xBF)
319                        | (0xE1..=0xEC, 0x80..=0xBF)
320                        | (0xED, 0x80..=0x9F)
321                        | (0xEE..=0xEF, 0x80..=0xBF) => {}
322                        _ => return err!(Some(1)),
323                    }
324
325                    if utf8_is_first_byte(next!()) {
326                        return err!(Some(2));
327                    }
328                    success_offset += 3;
329                }
330                4 => {
331                    match (first, next!()) {
332                        (0xF0, 0x90..=0xBF) | (0xF1..=0xF3, 0x80..=0xBF) | (0xF4, 0x80..=0x8F) => {}
333                        _ => return err!(Some(1)),
334                    }
335                    if utf8_is_first_byte(next!()) {
336                        return err!(Some(2));
337                    }
338                    if utf8_is_first_byte(next!()) {
339                        return err!(Some(3));
340                    }
341                    success_offset += 4;
342                }
343                _ => return err!(Some(1)),
344            }
345        } else {
346            // Ascii case, try to skip forward quickly.
347            // When the pointer is aligned, read 2 words of data per iteration
348            // until we find a word containing a non-ascii byte.
349            if align != usize::MAX && align.wrapping_sub(index) % usize_bytes == 0 {
350                let ptr = chunk.as_ptr();
351                while index < blocks_end {
352                    // SAFETY: since `align - index` and `ascii_block_size` are
353                    // multiples of `usize_bytes`, `block = ptr.add(index)` is
354                    // always aligned with a `usize` so it's safe to dereference
355                    // both `block` and `block.add(1)`.
356                    unsafe {
357                        #[allow(clippy::cast_ptr_alignment)]
358                        let block = ptr.add(index).cast::<usize>();
359                        // break if there is a nonascii byte
360                        let zu = contains_nonascii(*block);
361                        let zv = contains_nonascii(*block.add(1));
362                        if zu || zv {
363                            break;
364                        }
365                    }
366                    index += ascii_block_size;
367                    success_offset += ascii_block_size;
368                }
369                // step from the point where the wordwise loop stopped
370                while index < len && chunk[index] < 128 {
371                    index += 1;
372                    success_offset += 1;
373                }
374            } else {
375                index += 1;
376                success_offset += 1;
377            }
378        }
379    }
380}
381
382// https://tools.ietf.org/html/rfc3629
383const UTF8_CHAR_WIDTH: &[u8; 256] = &[
384    // 1  2  3  4  5  6  7  8  9  A  B  C  D  E  F
385    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0
386    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 1
387    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 2
388    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 3
389    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 4
390    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 5
391    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 6
392    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 7
393    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 8
394    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 9
395    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // A
396    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // B
397    0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // C
398    2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // D
399    3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, // E
400    4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // F
401];
402
403/// Given a first byte, determines how many bytes are in this UTF-8 character.
404#[must_use]
405#[inline]
406pub(super) const fn utf8_char_width(b: u8) -> usize {
407    UTF8_CHAR_WIDTH[b as usize] as usize
408}
409
410/// Mask of the value bits of a continuation byte.
411const CONT_MASK: u8 = 0b0011_1111;
412
413#[inline]
414unsafe fn unwrap_debug<A>(v: Option<A>) -> A {
415    debug_assert!(
416        v.is_some(),
417        "Encountered end-of-iteration on a UTF-8 character non-boundary"
418    );
419    v.unwrap_unchecked()
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425    use crate::test_utils::StreamStrategy;
426    use proptest::prelude::*;
427
428    proptest! {
429        #![proptest_config(ProptestConfig {
430            cases: 4096, .. ProptestConfig::default()
431          })]
432        #[test]
433        fn validation_agrees_with_std(s in ".{0,128}", flips in proptest::collection::vec(any::<usize>(), 1usize..3)) {
434            let mut bytes = Vec::from(s.as_bytes());
435            if !s.is_empty() {
436                for flip in flips {
437                    let flip_bit = flip % 8;
438                    let flip_byte = (flip / 8) % s.len();
439                    bytes[flip_byte] ^= 1 << flip_bit;
440                }
441            }
442            let std_result = std::str::from_utf8(bytes.as_slice());
443            let my_result = run_utf8_validation(&Vector::from(bytes.as_slice()));
444            match std_result {
445                Ok(_) => {
446                    prop_assert_eq!(my_result, Ok(()));
447                },
448                Err(std_e) => {
449                    prop_assert!(my_result.is_err());
450                    let my_e = my_result.unwrap_err();
451                    std::mem::drop(format!("{}, {:?}", &my_e, &my_e)); // Just to exercise these implementations
452                    prop_assert_eq!(std_e.valid_up_to(), my_e.valid_up_to());
453                    prop_assert_eq!(std_e.error_len(), my_e.error_len());
454                    prop_assert_eq!(my_e, std_e.into());
455                }
456            }
457        }
458    }
459
460    proptest! {
461        #[test]
462        fn iteration_agrees_with_std(s in ".{0,128}", mut direction in StreamStrategy(any::<bool>())) {
463            let mut byte_iter = Vector::from(s.as_bytes()).into_iter();
464            let mut char_iter = s.chars();
465            loop {
466                if direction.gen() {
467                    let mine = unsafe { next_code_point(&mut byte_iter).map(|ch| char::from_u32(ch).unwrap()) };
468                    let theirs = char_iter.next();
469                    prop_assert_eq!(mine, theirs);
470                    if mine.is_none() { break; }
471                } else {
472                    let mine = unsafe { next_code_point_reverse(&mut byte_iter).map(|ch| char::from_u32(ch).unwrap()) };
473                    let theirs = char_iter.next_back();
474                    prop_assert_eq!(mine, theirs);
475                    if mine.is_none() { break; }
476                }
477            }
478        }
479    }
480}