use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Utf8Error {
valid_up_to: usize,
error_len: Option<usize>,
}
impl Utf8Error {
#[inline]
pub fn valid_up_to(&self) -> usize {
self.valid_up_to
}
#[inline]
pub fn error_len(&self) -> Option<usize> {
self.error_len
}
}
impl fmt::Display for Utf8Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.error_len {
Some(len) => write!(
f,
"invalid utf-8 sequence of {} bytes from index {}",
len, self.valid_up_to
),
None => write!(
f,
"incomplete utf-8 byte sequence from index {}",
self.valid_up_to
),
}
}
}
impl std::error::Error for Utf8Error {}
#[inline]
pub fn validate(bytes: &[u8]) -> Result<&str, Utf8Error> {
let result = simdutf::validate_utf8_with_errors(bytes);
if result.error == simdutf::ErrorCode::Success {
return Ok(unsafe { std::str::from_utf8_unchecked(bytes) });
}
let valid_up_to = result.count;
let error_len = compute_error_len(bytes, valid_up_to);
Err(Utf8Error {
valid_up_to,
error_len,
})
}
fn compute_error_len(bytes: &[u8], valid_up_to: usize) -> Option<usize> {
if valid_up_to >= bytes.len() {
return None;
}
let lead = bytes[valid_up_to];
let expected_len = match lead {
0x00..=0x7F => return Some(1), 0x80..=0xBF => return Some(1), 0xC0..=0xC1 => return Some(1), 0xC2..=0xDF => 2,
0xE0..=0xEF => 3,
0xF0..=0xF4 => 4,
0xF5..=0xFF => return Some(1), };
let remaining = bytes.len() - valid_up_to;
if remaining < expected_len {
let mut i = 1;
while i < remaining {
if bytes[valid_up_to + i] & 0xC0 != 0x80 {
return Some(i);
}
i += 1;
}
return None;
}
let mut i = 1;
while i < expected_len {
if bytes[valid_up_to + i] & 0xC0 != 0x80 {
return Some(i);
}
i += 1;
}
Some(expected_len)
}
#[inline]
pub fn from_utf8_fast(bytes: &[u8]) -> Result<&str, Utf8Error> {
validate(bytes)
}
#[inline]
pub fn from_utf8_compat(bytes: &[u8]) -> Result<&str, Utf8Error> {
validate(bytes)
}
#[inline]
pub fn from_utf8_to_string(bytes: &[u8]) -> Result<String, Utf8Error> {
validate(bytes).map(|s| s.to_string())
}
#[inline]
pub fn from_utf8_lossy_fast(bytes: &[u8]) -> String {
match validate(bytes) {
Ok(s) => s.to_string(),
Err(_) => String::from_utf8_lossy(bytes).to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_valid_utf8() {
let bytes = b"Hello, \xE2\x9D\xA4\xEF\xB8\x8F UTF-8!";
let result = from_utf8_fast(bytes).unwrap();
assert_eq!(result, "Hello, ❤️ UTF-8!");
}
#[test]
fn test_invalid_utf8() {
let bytes = b"Hello, \xFF invalid UTF-8!";
assert!(from_utf8_fast(bytes).is_err());
let result = from_utf8_lossy_fast(bytes);
assert!(result.contains("Hello"));
assert!(result.contains("invalid UTF-8!"));
}
#[test]
fn test_ascii_fast_path() {
let bytes = b"Pure ASCII text";
let result = from_utf8_fast(bytes).unwrap();
assert_eq!(result, "Pure ASCII text");
}
#[test]
fn test_compat_error_info() {
let bytes = b"Valid\xFF\xFEInvalid";
let err = from_utf8_compat(bytes).unwrap_err();
assert!(err.to_string().contains("invalid utf-8"));
assert_eq!(err.valid_up_to(), 5);
assert_eq!(err.error_len(), Some(1));
}
#[test]
fn truncated_utf8_returns_none() {
let bytes = b"\xF0\x9F";
let err = validate(bytes).unwrap_err();
assert_eq!(err.valid_up_to(), 0);
assert_eq!(err.error_len(), None);
}
#[test]
fn truncated_with_bad_continuation_returns_some() {
let bytes = b"\xE0\x20";
let err = validate(bytes).unwrap_err();
assert_eq!(err.valid_up_to(), 0);
assert_eq!(err.error_len(), Some(1));
}
#[test]
fn complete_invalid_sequence_returns_some_len() {
let bytes = b"\xED\xA0\x80"; let err = validate(bytes).unwrap_err();
assert_eq!(err.valid_up_to(), 0);
assert_eq!(err.error_len(), Some(3));
}
}