use super::ValidationError;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Utf8Bytes<const MAX_LEN: usize> {
bytes: [u8; MAX_LEN],
len: usize,
}
impl<const MAX_LEN: usize> Utf8Bytes<MAX_LEN> {
pub fn new(bytes: [u8; MAX_LEN], len: usize) -> Result<Self, ValidationError> {
if len > MAX_LEN {
return Err(ValidationError::TooLong {
max: MAX_LEN,
actual: len,
});
}
#[cfg(kani)]
{
let is_valid_utf8: bool = kani::any();
if !is_valid_utf8 {
return Err(ValidationError::InvalidUtf8);
}
}
#[cfg(not(kani))]
{
if !is_valid_utf8(&bytes[..len]) {
return Err(ValidationError::InvalidUtf8);
}
}
Ok(Self { bytes, len })
}
pub fn as_str(&self) -> &str {
std::str::from_utf8(&self.bytes[..self.len]).expect("UTF-8 validated in constructor")
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn as_bytes(&self) -> &[u8] {
&self.bytes[..self.len]
}
}
impl<const MAX_LEN: usize> std::fmt::Display for Utf8Bytes<MAX_LEN> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[inline]
pub fn is_valid_utf8(bytes: &[u8]) -> bool {
let len = bytes.len();
#[cfg(kani)]
kani::assume(len <= 16);
let mut i = 0;
while i < len {
let byte = bytes[i];
if byte & 0b1000_0000 == 0 {
i += 1;
continue;
}
if byte & 0b1110_0000 == 0b1100_0000 {
if i + 1 >= bytes.len() {
return false;
}
if bytes[i + 1] & 0b1100_0000 != 0b1000_0000 {
return false;
}
if byte & 0b0001_1110 == 0 {
return false;
}
i += 2;
continue;
}
if byte & 0b1111_0000 == 0b1110_0000 {
if i + 2 >= bytes.len() {
return false;
}
if bytes[i + 1] & 0b1100_0000 != 0b1000_0000 {
return false;
}
if bytes[i + 2] & 0b1100_0000 != 0b1000_0000 {
return false;
}
if byte == 0b1110_0000 && bytes[i + 1] & 0b0010_0000 == 0 {
return false;
}
let code_point = ((byte & 0x0F) as u32) << 12
| ((bytes[i + 1] & 0x3F) as u32) << 6
| (bytes[i + 2] & 0x3F) as u32;
if (0xD800..=0xDFFF).contains(&code_point) {
return false;
}
i += 3;
continue;
}
if byte & 0b1111_1000 == 0b1111_0000 {
if i + 3 >= bytes.len() {
return false;
}
if bytes[i + 1] & 0b1100_0000 != 0b1000_0000 {
return false;
}
if bytes[i + 2] & 0b1100_0000 != 0b1000_0000 {
return false;
}
if bytes[i + 3] & 0b1100_0000 != 0b1000_0000 {
return false;
}
if byte == 0b1111_0000 && bytes[i + 1] & 0b0011_0000 == 0 {
return false;
}
let code_point = ((byte & 0x07) as u32) << 18
| ((bytes[i + 1] & 0x3F) as u32) << 12
| ((bytes[i + 2] & 0x3F) as u32) << 6
| (bytes[i + 3] & 0x3F) as u32;
if code_point > 0x10FFFF {
return false;
}
i += 4;
continue;
}
return false;
}
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ascii_valid() {
let mut bytes = [0u8; 10];
bytes[0] = b'h';
bytes[1] = b'e';
bytes[2] = b'l';
bytes[3] = b'l';
bytes[4] = b'o';
let utf8 = Utf8Bytes::<10>::new(bytes, 5).unwrap();
assert_eq!(utf8.as_str(), "hello");
}
#[test]
fn test_empty_valid() {
let bytes = [0u8; 10];
let utf8 = Utf8Bytes::<10>::new(bytes, 0).unwrap();
assert_eq!(utf8.as_str(), "");
assert!(utf8.is_empty());
}
#[test]
fn test_two_byte_utf8() {
let mut bytes = [0u8; 10];
bytes[0] = 0xC2; bytes[1] = 0xA9;
let utf8 = Utf8Bytes::<10>::new(bytes, 2).unwrap();
assert_eq!(utf8.as_str(), "©");
}
#[test]
fn test_three_byte_utf8() {
let mut bytes = [0u8; 10];
bytes[0] = 0xE2; bytes[1] = 0x82;
bytes[2] = 0xAC;
let utf8 = Utf8Bytes::<10>::new(bytes, 3).unwrap();
assert_eq!(utf8.as_str(), "€");
}
#[test]
fn test_invalid_continuation() {
let mut bytes = [0u8; 10];
bytes[0] = 0xC2;
bytes[1] = 0xFF;
assert!(Utf8Bytes::<10>::new(bytes, 2).is_err());
}
#[test]
fn test_overlong_encoding() {
let mut bytes = [0u8; 10];
bytes[0] = 0xC0; bytes[1] = 0x80;
assert!(Utf8Bytes::<10>::new(bytes, 2).is_err());
}
#[test]
fn test_surrogate_rejected() {
let mut bytes = [0u8; 10];
bytes[0] = 0xED; bytes[1] = 0xA0;
bytes[2] = 0x80;
assert!(Utf8Bytes::<10>::new(bytes, 3).is_err());
}
#[test]
fn test_length_too_long() {
let bytes = [0u8; 10];
assert!(Utf8Bytes::<10>::new(bytes, 11).is_err());
}
}