use crate::cea_match::remove_pulled;
use crate::consts::{DECOMP, FCD};
use unicode_canonical_combining_class::get_canonical_combining_class_u32 as get_ccc;
pub struct VecSource<'a> {
chars: &'a mut Vec<u32>,
#[cfg(feature = "pipeline-stats")]
start: usize,
pos: usize,
len: usize,
}
pub trait CodePointSource {
fn is_empty(&mut self) -> bool;
fn remaining(&mut self) -> usize;
fn is_blocked(&self) -> bool;
#[cfg(feature = "pipeline-stats")]
fn consumed(&self) -> usize;
fn peek(&mut self, offset: usize) -> Option<u32>;
fn consume(&mut self, count: usize);
fn remove_pulled_lookahead(&mut self, offset: usize, pulled_two: bool);
}
impl<'a> VecSource<'a> {
pub const fn new(chars: &'a mut Vec<u32>, pos: usize) -> Self {
Self {
#[cfg(feature = "pipeline-stats")]
start: pos,
pos,
len: chars.len(),
chars,
}
}
}
impl CodePointSource for VecSource<'_> {
fn is_empty(&mut self) -> bool {
self.pos >= self.len
}
fn remaining(&mut self) -> usize {
self.len - self.pos
}
fn is_blocked(&self) -> bool {
false
}
#[cfg(feature = "pipeline-stats")]
fn consumed(&self) -> usize {
self.pos - self.start
}
fn peek(&mut self, offset: usize) -> Option<u32> {
let index = self.pos + offset;
(index < self.len).then(|| self.chars[index])
}
fn consume(&mut self, count: usize) {
self.pos += count;
}
fn remove_pulled_lookahead(&mut self, offset: usize, pulled_two: bool) {
let index = self.pos + offset;
remove_pulled(self.chars, index, &mut self.len, pulled_two);
}
}
pub struct Utf8Source<'a> {
bytes: &'a [u8],
byte_pos: usize,
lookahead: [u32; 4],
lookahead_bytes: [usize; 4],
lookahead_len: usize,
blocked: bool,
prev_trail_cc: u8,
}
impl<'a> Utf8Source<'a> {
pub const fn new(bytes: &'a [u8]) -> Self {
Self {
bytes,
byte_pos: 0,
lookahead: [0; 4],
lookahead_bytes: [0; 4],
lookahead_len: 0,
blocked: false,
prev_trail_cc: 0,
}
}
fn fill(&mut self, offset: usize) {
while self.lookahead_len <= offset && self.decoded_byte_end() < self.bytes.len() {
if self.blocked {
return;
}
let start = self.decoded_byte_end();
let Some((code_point, len)) = decode_utf8(&self.bytes[start..]) else {
self.blocked = true;
return;
};
if !self.accepts_normalization_boundary(code_point) {
self.blocked = true;
return;
}
self.lookahead[self.lookahead_len] = code_point;
self.lookahead_bytes[self.lookahead_len] = len;
self.lookahead_len += 1;
}
}
fn decoded_byte_end(&self) -> usize {
self.byte_pos
+ self.lookahead_bytes[..self.lookahead_len]
.iter()
.sum::<usize>()
}
fn accepts_normalization_boundary(&mut self, code_point: u32) -> bool {
if code_point < 0xC0 {
self.prev_trail_cc = 0;
return true;
}
if code_point == 0x0F81 || (0xAC00..=0xD7A3).contains(&code_point) {
return false;
}
if DECOMP.get(code_point).is_some() {
return false;
}
let (lead_cc, trail_cc) = FCD.get(code_point).map_or_else(
|| {
let cc = get_ccc(code_point) as u8;
(cc, cc)
},
|vals| vals.to_be_bytes().into(),
);
if lead_cc != 0 && lead_cc < self.prev_trail_cc {
return false;
}
self.prev_trail_cc = trail_cc;
true
}
}
impl CodePointSource for Utf8Source<'_> {
fn is_empty(&mut self) -> bool {
self.fill(0);
self.lookahead_len == 0 && self.byte_pos >= self.bytes.len()
}
fn remaining(&mut self) -> usize {
self.fill(3);
if self.blocked || self.decoded_byte_end() == self.bytes.len() {
self.lookahead_len
} else {
self.lookahead_len.max(4)
}
}
fn is_blocked(&self) -> bool {
self.blocked
}
#[cfg(feature = "pipeline-stats")]
fn consumed(&self) -> usize {
self.byte_pos
}
fn peek(&mut self, offset: usize) -> Option<u32> {
self.fill(offset);
(offset < self.lookahead_len).then(|| self.lookahead[offset])
}
fn consume(&mut self, count: usize) {
self.byte_pos += self.lookahead_bytes[..count].iter().sum::<usize>();
self.lookahead.copy_within(count..self.lookahead_len, 0);
self.lookahead_bytes
.copy_within(count..self.lookahead_len, 0);
self.lookahead_len -= count;
}
fn remove_pulled_lookahead(&mut self, _offset: usize, _pulled_two: bool) {
self.blocked = true;
}
}
fn decode_utf8(bytes: &[u8]) -> Option<(u32, usize)> {
let first = *bytes.first()?;
if first < 0x80 {
return Some((u32::from(first), 1));
}
let (mut code_point, len) = if first & 0xE0 == 0xC0 {
(u32::from(first & 0x1F), 2)
} else if first & 0xF0 == 0xE0 {
(u32::from(first & 0x0F), 3)
} else if first & 0xF8 == 0xF0 {
(u32::from(first & 0x07), 4)
} else {
return None;
};
if bytes.len() < len {
return None;
}
for &byte in &bytes[1..len] {
if byte & 0xC0 != 0x80 {
return None;
}
code_point = (code_point << 6) | u32::from(byte & 0x3F);
}
if !valid_utf8_scalar(code_point, len) {
return None;
}
Some((code_point, len))
}
const fn valid_utf8_scalar(code_point: u32, len: usize) -> bool {
match len {
2 => code_point >= 0x80 && code_point <= 0x7FF,
3 => {
code_point >= 0x800
&& code_point <= 0xFFFF
&& (code_point < 0xD800 || code_point > 0xDFFF)
}
4 => code_point >= 0x1_0000 && code_point <= 0x10_FFFF,
_ => false,
}
}