use std::error::Error;
use std::fmt;
use std::num::NonZeroU8;
use std::ops::AddAssign;
use std::ops::Sub;
use std::simd::*;
use std::io;
use std::borrow::Cow;
use std::convert::TryInto;
pub const DEFAULT_CHUNK: usize = 64;
pub(crate) const CONT_MASK: u8 = 0b0011_1111u8;
pub(crate) const TAG_CONT_U8: u8 = 0b1000_0000u8;
const TAG_LEN4_MASK: u8 = 0b1111_1000;
const TAG_LEN4_VAL: u8 = 0b1111_0000;
pub(crate) const MUTF8_ZERO: [u8; 2] = [0xC0, 0x80];
pub(crate) unsafe fn cesu8_to_utf8<const ENCODE_NUL: bool>(cesu: Cow<[u8]>) -> Cow<str> {
debug_assert!(!(ENCODE_NUL && cesu.contains(&b'\0')), "nul-byte included in mutf8 string");
let (e, mut cesu) = match std::str::from_utf8(&cesu) {
Ok(_) => { unsafe {
return match cesu {
Cow::Borrowed(b) => Cow::Borrowed(std::str::from_utf8_unchecked(b)),
Cow::Owned(v) => Cow::Owned(String::from_utf8_unchecked(v)),
};
}
},
Err(e) => { (e, cesu.into_owned())
}
};
let mut iw = e.valid_up_to(); let mut ir = e.valid_up_to();
while ir < cesu.len() { let rest = &mut cesu[ir..];
if ENCODE_NUL && rest.starts_with(&MUTF8_ZERO) {
cesu[iw] = b'\0';
iw += 1; ir += 2; } else if let Some(slice6) = rest.get_mut(..6) {
let &mut [first, second, third, fourth, fifth, sixth] = slice6 else { panic!(); };
debug_assert!(
first == 0xED && fourth == 0xED,
"expected surrogate pair, recieved something else (err bytes[..6]: {:X?})",
&rest[..6]
);
let utf8bytes: [u8; 4] = dec_surrogates_infallable(second, third, fifth, sixth);
let _ = rest;
cesu[iw..iw+4].copy_from_slice(&utf8bytes);
iw += 4;
ir += 6;
} else {
let strtype = if ENCODE_NUL { "MUTF8" } else { "CESU8" };
let encnulstr = if ENCODE_NUL { "encoded nul or "} else { "" };
unreachable!(
"{} decoding error. expected {}surrogate pair, got something else. (string up to this point: {:?}, next few bytes: {:X?})",
strtype, encnulstr,
String::from_utf8_lossy(&cesu[..iw]), &cesu[iw..cesu.len().min(iw+16)],
);
}
let valid_utf8 = match std::str::from_utf8(&cesu[ir..]) {
Ok(s) => s.len(),
Err(e) => e.valid_up_to(),
};
cesu.copy_within(ir..ir+valid_utf8, iw);
ir += valid_utf8;
iw += valid_utf8;
}
debug_assert!(iw < cesu.len());
if iw < cesu.len() {
cesu.resize(iw, 0);
}
Cow::Owned(match cfg!(debug_assertions) {
true => match String::from_utf8(cesu) {
Ok(s) => s,
Err(e) => {
panic!(
"reencoded cesu into invalid utf8: (ir={}, iw={}) err={:X?}, lossy={:?}, bytes={:X?}",
ir, iw, e, String::from_utf8_lossy(e.as_bytes()), e.as_bytes(),
);
},
},
false => unsafe { String::from_utf8_unchecked(cesu) }
})
}
fn utf8_to_cesu8_piecewise<const ENCODE_NUL: bool>(b: &[u8]) -> Option<usize> {
pub fn invalid_byte<const ENCODE_NUL: bool>(b: u8) -> bool {
(ENCODE_NUL && b == 0) || (b & TAG_LEN4_MASK) == TAG_LEN4_VAL
}
b.iter().copied().position(invalid_byte::<ENCODE_NUL>)
}
fn utf8_to_cesu8_check_lane<const LANES: usize, const ENCODE_NUL: bool>(arr: [u8; LANES]) -> bool
where
LaneCount<LANES>: SupportedLaneCount
{
let mask: Simd<u8, LANES> = Simd::splat(TAG_LEN4_MASK);
let val: Simd<u8, LANES> = Simd::splat(TAG_LEN4_VAL);
let zero: Simd<u8, LANES> = Simd::splat(0);
let chunk = Simd::from_array(arr);
let utf8_4byte = (chunk & mask).simd_eq(val).any();
let utf8_nul = ENCODE_NUL && chunk.simd_eq(zero).any();
utf8_4byte || utf8_nul
}
pub(crate) fn utf8_to_cesu8_simd<const CHUNK_SIZE: usize, const MIN_SIMD: usize, const MIN_SIMD_FACTOR: usize, const ENCODE_NUL: bool>(b: &[u8]) -> Option<usize>
where
LaneCount<CHUNK_SIZE>: SupportedLaneCount
{
let (chunk_start, chunk) = (0..b.len()).step_by(CHUNK_SIZE)
.filter_map(|i| {
let chunk = b.get(i..i+CHUNK_SIZE)?; let chunkarr = chunk.try_into().unwrap();
utf8_to_cesu8_check_lane::<CHUNK_SIZE, ENCODE_NUL>(chunkarr).then_some((i, chunk))
})
.next() .or_else(|| { let r = b.len() % CHUNK_SIZE;
if r == 0 { return None; }
let i = b.len()-r;
let rest = &b[i..];
if r < MIN_SIMD || r < CHUNK_SIZE/MIN_SIMD_FACTOR {
return Some((i, rest)); }
let mut arr = [b'#'; CHUNK_SIZE];
arr[..r].copy_from_slice(rest);
utf8_to_cesu8_check_lane::<CHUNK_SIZE, ENCODE_NUL>(arr).then_some((i, rest))
})?;
utf8_to_cesu8_piecewise::<ENCODE_NUL>(chunk).map(|i| chunk_start+i)
}
pub(crate) fn check_utf8_to_cesu8<const CHUNK_SIZE: usize, const ENCODE_NUL: bool>(b: &[u8]) -> Option<usize> {
utf8_to_cesu8_simd::<DEFAULT_CHUNK, {DEFAULT_CHUNK/2}, 2, ENCODE_NUL>(b)
}
pub(crate) fn utf8_to_cesu8_vec<const CHUNK_SIZE: usize, const ENCODE_NUL: bool>(src: Cow<str>) -> Cow<[u8]> {
let Some(first_bad_idx) = check_utf8_to_cesu8::<CHUNK_SIZE, true>(src.as_bytes()) else {
return match src {
Cow::Borrowed(s) => Cow::Borrowed(s.as_bytes()),
Cow::Owned(s) => Cow::Owned(s.into_bytes()),
};
};
let mut dst = Vec::with_capacity(crate::default_cesu8_capacity(src.len()));
let (valid, rest) = src.split_at(first_bad_idx);
dst.extend_from_slice(valid.as_bytes()); utf8_to_cesu8_io::<CHUNK_SIZE, ENCODE_NUL, _>(rest, true, &mut dst, &mut BufferUsage::default()).unwrap();
Cow::Owned(dst)
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct BufferUsage {
pub read: usize,
pub written: usize,
}
impl BufferUsage {
pub(crate) fn inc(&mut self, amount: usize) {
self.read += amount;
self.written += amount;
}
}
impl AddAssign<(usize, usize)> for BufferUsage {
fn add_assign(&mut self, (r, w): (usize, usize)) {
self.read += r;
self.written += w;
}
}
impl Sub for BufferUsage {
type Output = BufferUsage;
fn sub(self, rhs: Self) -> Self::Output {
BufferUsage {
read: self.read - rhs.read,
written: self.written - rhs.written,
}
}
}
#[inline]
pub(crate) fn utf8_to_cesu8_io<const CHUNK_SIZE: usize, const ENCODE_NUL: bool, W: io::Write + fmt::Debug>(mut src: &str, mut hint_bad_start: bool, mut w: W, buf_usage: &mut BufferUsage) -> io::Result<BufferUsage> {
let buf_usage_orig = *buf_usage;
src = src.split_at(buf_usage.read).1;
loop {
let artificial_err = (src.is_empty() && hint_bad_start).then_some(0);
let err_ind_opt = artificial_err.or_else(|| check_utf8_to_cesu8::<CHUNK_SIZE, ENCODE_NUL>(src.as_bytes()));
match err_ind_opt {
None if src.is_empty() => {
return Ok(*buf_usage - buf_usage_orig);
}
None => { loop {
let res_written = w.write(src.as_bytes());
match res_written {
Ok(0) => { return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"failed to write whole buffer"
));
},
Ok(n) => {
buf_usage.inc(n);
if n == src.as_bytes().len() {
return Ok(*buf_usage - buf_usage_orig);
}
src = src.split_at(n).1;
},
Err(e) if e.kind() == std::io::ErrorKind::Interrupted => { continue; }
Err(e) => {
return Err(e);
}
}
}
},
Some(err_ind) => { let (valid, rest) = src.split_at(err_ind);
if ! valid.is_empty() {
w.write_all(valid.as_bytes())?;
buf_usage.inc(valid.as_bytes().len());
}
src = rest;
let mut chars = src.chars();
let Some(ch) = chars.next() else {
unreachable!();
};
if ENCODE_NUL && ch == '\0' {
w.write_all(MUTF8_ZERO.as_slice())?;
*buf_usage += (1, 2);
src = chars.as_str(); } else if ch.len_utf8() == 4 {
let cesu_bytes = enc_surrogates(ch as u32);
w.write_all(&cesu_bytes)?;
*buf_usage += (4, 6);
src = chars.as_str(); } else {
assert!(hint_bad_start, "check_utf8_to_cesu8 returned an unexpected error");
}
}
}
hint_bad_start = false;
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct EncodingError {
pub(crate) valid_up_to: usize,
pub(crate) error_len: Option<NonZeroU8>,
}
impl EncodingError {
pub fn valid_up_to(&self) -> usize {
self.valid_up_to
}
pub fn error_len(&self) -> Option<NonZeroU8> {
self.error_len
}
}
impl std::error::Error for EncodingError {}
impl fmt::Display for EncodingError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.error_len {
None => write!(f, "invalid data found at end (index {}) of cesu8/mutf8 encoded byte string, partial codepoint", self.valid_up_to),
Some(n) => write!(f, "invalid data found at index {} of cesu8/mutf8 encoded byte string, of approximate length {}", self.valid_up_to, n)
}
}
}
#[inline(never)]
fn valid_cesu8_char<const ENCODE_NUL: bool>(b: &[u8]) -> Result<usize, Option<()>> {
let not_enough = Err(None);
match &b {
[0xC0] if ENCODE_NUL => not_enough,
[0xC0, 0x80, ..] if ENCODE_NUL => {
Ok(2)
},
[0xED] => not_enough,
[0xED, _] => not_enough,
[0xED, _, _] => not_enough,
[0xED, _, _, 0xED] => not_enough,
[0xED, _, _, 0xED, _] => not_enough,
[0xED, b2, b3, 0xED, b5, b6, ..] => {
dec_surrogates::<true>(*b2, *b3, *b5, *b6)
.map(|_| 6)
.map_err(|_| Some(()))
}
_ => { Err(Some(())) }
}
}
pub(crate) fn validate_cesu8<const CHUNK_SIZE: usize, const ENCODE_NUL: bool>(source: &[u8]) -> Result<(), EncodingError> {
use std::slice::SliceIndex;
fn subslice<I: SliceIndex<[u8]>>(buf: &[u8], r: I) -> &<I as SliceIndex<[u8]>>::Output {
if cfg!(debug_assertions) {
&buf[r]
} else {
unsafe { buf.get_unchecked(r) }
}
}
let mut base = 0;
loop {
let (valid_utf8, utf8_err) = match std::str::from_utf8(subslice(source, base..)) {
Ok(s) => (s.as_bytes(), Ok(())),
Err(e) => (
subslice(source, base..base+e.valid_up_to()),
Err(e.error_len())
),
};
let cesu8_err = check_utf8_to_cesu8::<CHUNK_SIZE, ENCODE_NUL>(valid_utf8);
match (cesu8_err, utf8_err) {
(Some(i), _) => return Err(EncodingError {
valid_up_to: base + i,
error_len: Some(1.try_into().unwrap()),
}),
(None, Ok(())) => return Ok(()),
(None, Err(None)) => return Err(EncodingError {
valid_up_to: base + valid_utf8.len(),
error_len: None,
}),
(None, Err(Some(bad_utf8_len))) => {
base += valid_utf8.len();
match valid_cesu8_char::<true>(subslice(source, base..)) {
Ok(len) => {
base += len;
},
Err(None) => return Err(EncodingError {
valid_up_to: base,
error_len: None,
}),
Err(Some(())) => return Err(EncodingError {
valid_up_to: base,
error_len: Some((bad_utf8_len as u8).try_into().unwrap()),
}),
}
}
}
}
}
#[derive(Debug)]
pub(crate) struct InvalidCesu8SurrogatePair([u8; 4]);
impl Error for InvalidCesu8SurrogatePair {}
impl fmt::Display for InvalidCesu8SurrogatePair {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f,
"attempt to decode invalid cesu-8 6-byte surrogate pair: {:X?}",
&[0xED, self.0[0], self.0[1], 0xED, self.0[2], self.0[3]]
)
}
}
#[inline]
pub(crate) fn dec_surrogates_infallable(second: u8, third: u8, fifth: u8, sixth: u8) -> [u8; 4] {
if cfg!(debug_assertions) {
dec_surrogates::<true>(second, third, fifth, sixth).expect("failed cesu surrogate pair decode when expected to be infallible")
} else {
dec_surrogates::<false>(second, third, fifth, sixth).unwrap()
}
}
#[inline]
pub(crate) fn dec_surrogates<const CHECK_INVALID: bool>(second: u8, third: u8, fifth: u8, sixth: u8) -> Result<[u8; 4], InvalidCesu8SurrogatePair> {
fn dec_surrogate(second: u8, third: u8) -> u32 {
0xD000u32 | ((second & CONT_MASK) as u32) << 6 | (third & CONT_MASK) as u32
}
if CHECK_INVALID {
let invalid_pair = Err(InvalidCesu8SurrogatePair([second, third, fifth, sixth]));
if (second & !CONT_MASK) != TAG_CONT_U8 { return invalid_pair; }
if (second & 0b1111_0000) != 0b1010_0000 { return invalid_pair; }
if (third & !CONT_MASK) != TAG_CONT_U8 { return invalid_pair; }
if (fifth & !CONT_MASK) != TAG_CONT_U8 { return invalid_pair; }
if (fifth & 0b1111_0000) != 0b1011_0000 { return invalid_pair; }
if (sixth & !CONT_MASK) != TAG_CONT_U8 { return invalid_pair; }
}
let s1 = dec_surrogate(second, third);
let s2 = dec_surrogate(fifth, sixth);
let c = 0x10000 + (((s1 - 0xD800) << 10) | (s2 - 0xDC00));
if CHECK_INVALID && !(0x010000..=0x10FFFF).contains(&c) {
return Err(InvalidCesu8SurrogatePair([second, third, fifth, sixth]));
}
Ok([
0b1111_0000u8 | ((c & 0b1_1100_0000_0000_0000_0000) >> 18) as u8,
TAG_CONT_U8 | ((c & 0b0_0011_1111_0000_0000_0000) >> 12) as u8,
TAG_CONT_U8 | ((c & 0b0_0000_0000_1111_1100_0000) >> 6) as u8,
TAG_CONT_U8 | ( c & 0b0_0000_0000_0000_0011_1111) as u8,
])
}
#[inline]
pub(crate) fn enc_surrogates(ch: u32) -> [u8; 6] {
#[inline]
fn enc_surrogate(surrogate: u16) -> [u8; 3] {
if cfg!(debug_assertions) || cfg!(validate_release) {
assert!(
(0xD800..=0xDFFF).contains(&surrogate),
"trying to encode invalid surrogate pair"
);
}
[
0b11100000 | ((surrogate & 0b1111_0000_0000_0000) >> 12) as u8,
TAG_CONT_U8 | ((surrogate & 0b0000_1111_1100_0000) >> 6) as u8,
TAG_CONT_U8 | (surrogate & 0b0000_0000_0011_1111) as u8,
]
}
let c = ch - 0x10000;
let high = enc_surrogate(((c >> 10) as u16) | 0xD800);
let low = enc_surrogate(((c & 0x3FF) as u16) | 0xDC00);
[high[0], high[1], high[2], low[0], low[1], low[2]]
}