use core::fmt::{self, Write as _};
use core::marker::PhantomData;
use core::num::NonZeroU8;
use core::ops::ControlFlow;
use crate::parser::str::find_split;
use crate::parser::trusted::hexdigits_to_byte;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum PctEncodedFragments<'a> {
NoPctStr(&'a str),
StrayPercent,
Char(&'a str, char),
InvalidUtf8PctTriplets(&'a str),
}
pub(crate) fn process_percent_encoded_best_effort<T, F, B>(
v: T,
mut f: F,
) -> Result<ControlFlow<B>, fmt::Error>
where
T: fmt::Display,
F: FnMut(PctEncodedFragments<'_>) -> ControlFlow<B>,
{
let mut buf = [0_u8; 12];
let mut writer = DecomposeWriter {
f: &mut f,
decoder: Default::default(),
buf: &mut buf,
result: ControlFlow::Continue(()),
_r: PhantomData,
};
if write!(writer, "{v}").is_err() {
match writer.result {
ControlFlow::Continue(_) => return Err(fmt::Error),
ControlFlow::Break(v) => return Ok(ControlFlow::Break(v)),
}
}
if let Some(len) = writer.decoder.flush(&mut buf).map(|v| usize::from(v.get())) {
let len_suffix = len % 3;
let triplets_end = len - len_suffix;
let triplets = core::str::from_utf8(&buf[..triplets_end])
.expect("percent-encoded triplets consist of ASCII characters");
if let ControlFlow::Break(v) = f(PctEncodedFragments::InvalidUtf8PctTriplets(triplets)) {
return Ok(ControlFlow::Break(v));
}
if len_suffix > 0 {
if let ControlFlow::Break(v) = f(PctEncodedFragments::StrayPercent) {
return Ok(ControlFlow::Break(v));
}
}
if len_suffix > 1 {
let after_percent =
core::str::from_utf8(&buf[(triplets_end + 1)..(triplets_end + len_suffix)])
.expect("percent-encoded triplets contains only ASCII characters");
if let ControlFlow::Break(v) = f(PctEncodedFragments::NoPctStr(after_percent)) {
return Ok(ControlFlow::Break(v));
}
}
}
Ok(ControlFlow::Continue(()))
}
struct DecomposeWriter<'a, F, B> {
f: &'a mut F,
decoder: DecoderBuffer,
buf: &'a mut [u8],
result: ControlFlow<B>,
_r: PhantomData<fn() -> B>,
}
impl<F, B> DecomposeWriter<'_, F, B>
where
F: FnMut(PctEncodedFragments<'_>) -> ControlFlow<B>,
{
#[inline(always)]
fn result_continue_or_err(&self) -> fmt::Result {
if self.result.is_break() {
return Err(fmt::Error);
}
Ok(())
}
fn output_as_undecodable(&mut self, len_undecodable: u8) -> fmt::Result {
let len_written = usize::from(len_undecodable);
let frag = core::str::from_utf8(&self.buf[..len_written])
.expect("`DecoderBuffer` writes a valid ASCII string");
let len_incomplete = len_written % 3;
let len_complete = len_written - len_incomplete;
self.result = (self.f)(PctEncodedFragments::InvalidUtf8PctTriplets(
&frag[..len_complete],
));
self.result_continue_or_err()?;
if len_incomplete > 0 {
self.result = (self.f)(PctEncodedFragments::StrayPercent);
if self.result.is_break() {
return Err(fmt::Error);
}
if len_incomplete > 1 {
debug_assert_eq!(
len_incomplete, 2,
"the length of incomplete percent-encoded triplet must be less than 2 bytes"
);
self.result = (self.f)(PctEncodedFragments::NoPctStr(
&frag[(len_complete + 1)..len_written],
));
self.result_continue_or_err()?;
}
}
Ok(())
}
}
impl<F, B> fmt::Write for DecomposeWriter<'_, F, B>
where
F: FnMut(PctEncodedFragments<'_>) -> ControlFlow<B>,
{
fn write_str(&mut self, s: &str) -> fmt::Result {
self.result_continue_or_err()?;
let mut rest = s;
while !rest.is_empty() {
let (len_consumed, result) = self.decoder.push_encoded(self.buf, rest);
if len_consumed == 0 {
if let Some(len_written) = self.decoder.flush(self.buf).map(NonZeroU8::get) {
self.output_as_undecodable(len_written)?;
rest = &rest[usize::from(len_written)..];
}
let (plain_prefix, suffix) = find_split(rest, b'%').unwrap_or((rest, ""));
debug_assert!(
!plain_prefix.is_empty(),
"`len_consumed == 0` indicates non-empty `rest` not starting with `%`"
);
self.result = (self.f)(PctEncodedFragments::NoPctStr(plain_prefix));
self.result_continue_or_err()?;
rest = suffix;
continue;
}
match result {
PushResult::Decoded(len_written, c) => {
let len_written = usize::from(len_written.get());
let frag = core::str::from_utf8(&self.buf[..len_written])
.expect("`DecoderBuffer` writes a valid ASCII string");
self.result = (self.f)(PctEncodedFragments::Char(frag, c));
self.result_continue_or_err()?;
}
PushResult::Undecodable(len_written) => {
self.output_as_undecodable(len_written)?;
}
PushResult::NeedMoreBytes => {
}
}
rest = &rest[len_consumed..];
}
Ok(())
}
}
#[derive(Debug, Clone, Copy)]
enum PushResult {
NeedMoreBytes,
Decoded(NonZeroU8, char),
Undecodable(u8),
}
#[derive(Default, Debug, Clone, Copy)]
struct DecoderBuffer {
encoded: [u8; 12],
decoded: [u8; 4],
len_encoded: u8,
}
impl DecoderBuffer {
fn write_and_pop(&mut self, dest: &mut [u8], remove_len: u8) {
let new_len = self.len_encoded - remove_len;
let remove_len = usize::from(remove_len);
let src_range = remove_len..usize::from(self.len_encoded);
dest[..remove_len].copy_from_slice(&self.encoded[..remove_len]);
if new_len == 0 {
*self = Self::default();
return;
}
self.encoded.copy_within(src_range, 0);
self.decoded
.copy_within((remove_len / 3)..usize::from(self.len_encoded / 3), 0);
self.len_encoded = new_len;
}
fn push_single_encoded_byte(&mut self, byte: u8) {
debug_assert!(
self.len_encoded < 12,
"four percent-encoded triplets are enough for a unicode code point"
);
let pos_enc = usize::from(self.len_encoded);
self.len_encoded += 1;
self.encoded[pos_enc] = byte;
if self.len_encoded % 3 == 0 {
let pos_dec = usize::from(self.len_encoded / 3 - 1);
let upper = self.encoded[pos_enc - 1];
let lower = byte;
debug_assert!(
upper.is_ascii_hexdigit() && lower.is_ascii_hexdigit(),
"the `encoded` buffer should contain valid percent-encoded triplets"
);
self.decoded[pos_dec] = hexdigits_to_byte([upper, lower]);
}
}
#[must_use]
pub(crate) fn push_encoded(&mut self, buf: &mut [u8], s: &str) -> (usize, PushResult) {
debug_assert!(
buf.len() >= 12,
"[precondition] destination buffer should be at least 12 bytes"
);
let mut chars = s.chars();
let mut len_triplet_incomplete = self.len_encoded % 3;
for c in &mut chars {
if len_triplet_incomplete == 0 {
if c != '%' {
let len_consumed = s.len() - chars.as_str().len() - 1;
let len_result = self.len_encoded;
self.write_and_pop(buf, len_result);
return (len_consumed, PushResult::Undecodable(len_result));
}
self.push_single_encoded_byte(b'%');
len_triplet_incomplete = 1;
continue;
}
if !c.is_ascii_hexdigit() {
let len_consumed = s.len() - chars.as_str().len() - 1;
let len_result = self.len_encoded;
self.write_and_pop(buf, len_result);
return (len_consumed, PushResult::Undecodable(len_result));
}
self.push_single_encoded_byte(c as u8);
if len_triplet_incomplete == 1 {
len_triplet_incomplete = 2;
continue;
} else {
debug_assert_eq!(len_triplet_incomplete, 2);
len_triplet_incomplete = 0;
}
let len_decoded = usize::from(self.len_encoded) / 3;
match core::str::from_utf8(&self.decoded[..len_decoded]) {
Ok(decoded_str) => {
let len_consumed = s.len() - chars.as_str().len();
let c = decoded_str
.chars()
.next()
.expect("`decoded` buffer is nonempty");
let len_result = NonZeroU8::new(self.len_encoded).expect(
"`encoded` buffer is nonempty since \
`push_single_encoded_byte()` was called",
);
self.write_and_pop(buf, len_result.get());
return (len_consumed, PushResult::Decoded(len_result, c));
}
Err(e) => {
assert_eq!(
e.valid_up_to(),
0,
"`decoded` buffer contains at most one character"
);
let skip_len_decoded = match e.error_len() {
None => continue,
Some(v) => v,
};
let len_consumed = s.len() - chars.as_str().len();
let len_result = skip_len_decoded as u8 * 3;
assert_ne!(skip_len_decoded, 0, "empty bytes cannot be invalid");
self.write_and_pop(buf, len_result);
return (len_consumed, PushResult::Undecodable(len_result));
}
};
}
let len_consumed = s.len() - chars.as_str().len();
(len_consumed, PushResult::NeedMoreBytes)
}
#[must_use]
pub(crate) fn flush(&mut self, buf: &mut [u8]) -> Option<NonZeroU8> {
let len_result = NonZeroU8::new(self.len_encoded)?;
self.write_and_pop(buf, len_result.get());
debug_assert_eq!(
self.len_encoded, 0,
"the buffer should be cleared after flushed"
);
Some(len_result)
}
}