use crate::normalize::{Transform, apply_chain};
use wafrift_grammar::grammar::{bestfit, nfkc_preimage};
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum Stage {
Identity,
UrlDecode {
plus_is_space: bool,
},
DoubleUrlDecode,
HtmlEntityDecode,
JsonUnescape,
NfkcNormalize,
BestFitDownconvert,
StripNulls,
OverlongUtf8Decode,
Base64Decode,
HexDecode,
CrsView(Vec<Transform>),
}
impl Stage {
#[must_use]
pub fn apply(&self, input: &[u8]) -> Vec<u8> {
match self {
Stage::Identity => input.to_vec(),
Stage::UrlDecode { plus_is_space } => url_decode_once(input, *plus_is_space),
Stage::DoubleUrlDecode => url_decode_once(&url_decode_once(input, false), false),
Stage::HtmlEntityDecode => Transform::HtmlEntityDecode.apply(input),
Stage::JsonUnescape => json_unescape(input),
Stage::NfkcNormalize => normalize_text_stage(input, nfkc_preimage::normalize),
Stage::BestFitDownconvert => normalize_text_stage(input, bestfit::normalize),
Stage::StripNulls => input.iter().copied().filter(|&b| b != 0).collect(),
Stage::OverlongUtf8Decode => overlong_utf8_decode(input),
Stage::Base64Decode => {
use base64::Engine;
base64::engine::general_purpose::STANDARD
.decode(input)
.unwrap_or_else(|_| input.to_vec())
}
Stage::HexDecode => hex::decode(input).unwrap_or_else(|_| input.to_vec()),
Stage::CrsView(chain) => apply_chain(chain, input),
}
}
}
fn overlong_utf8_decode(input: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(input.len());
let mut i = 0;
while i < input.len() {
let b = input[i];
if (b == 0xC0 || b == 0xC1) && i + 1 < input.len() {
let c = input[i + 1];
if (0x80..=0xBF).contains(&c) {
out.push(((b & 0x1F) << 6) | (c & 0x3F));
i += 2;
continue;
}
}
out.push(b);
i += 1;
}
out
}
fn normalize_text_stage(input: &[u8], f: impl Fn(&str) -> String) -> Vec<u8> {
match std::str::from_utf8(input) {
Ok(s) => f(s).into_bytes(),
Err(_) => input.to_vec(),
}
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct Pipeline(pub Vec<Stage>);
impl Pipeline {
#[must_use]
pub fn apply(&self, input: &[u8]) -> Vec<u8> {
self.0.iter().fold(input.to_vec(), |acc, s| s.apply(&acc))
}
#[must_use]
pub fn len(&self) -> usize {
self.0.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
fn hexv(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(b - b'a' + 10),
b'A'..=b'F' => Some(b - b'A' + 10),
_ => None,
}
}
#[must_use]
pub fn url_decode_once(input: &[u8], plus_is_space: bool) -> Vec<u8> {
let mut out = Vec::with_capacity(input.len());
let mut i = 0;
while i < input.len() {
match input[i] {
b'%' if i + 2 < input.len() => {
if let (Some(h), Some(l)) = (hexv(input[i + 1]), hexv(input[i + 2])) {
out.push((h << 4) | l);
i += 3;
} else {
out.push(b'%');
i += 1;
}
}
b'+' if plus_is_space => {
out.push(b' ');
i += 1;
}
c => {
out.push(c);
i += 1;
}
}
}
out
}
fn push_utf8(out: &mut Vec<u8>, cp: u32) {
match cp {
0..=0x7F => out.push(cp as u8),
0x80..=0x7FF => {
out.push(0xC0 | (cp >> 6) as u8);
out.push(0x80 | (cp & 0x3F) as u8);
}
0x800..=0xFFFF => {
out.push(0xE0 | (cp >> 12) as u8);
out.push(0x80 | ((cp >> 6) & 0x3F) as u8);
out.push(0x80 | (cp & 0x3F) as u8);
}
_ => {
out.push(0xF0 | (cp >> 18) as u8);
out.push(0x80 | ((cp >> 12) & 0x3F) as u8);
out.push(0x80 | ((cp >> 6) & 0x3F) as u8);
out.push(0x80 | (cp & 0x3F) as u8);
}
}
}
fn read_u4(b: &[u8]) -> Option<u32> {
if b.len() < 4 {
return None;
}
let mut v = 0u32;
for &c in &b[..4] {
v = v * 16 + u32::from(hexv(c)?);
}
Some(v)
}
#[must_use]
pub fn json_unescape(input: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(input.len());
let mut i = 0;
while i < input.len() {
if input[i] != b'\\' || i + 1 >= input.len() {
out.push(input[i]);
i += 1;
continue;
}
match input[i + 1] {
b'"' => {
out.push(b'"');
i += 2;
}
b'\\' => {
out.push(b'\\');
i += 2;
}
b'/' => {
out.push(b'/');
i += 2;
}
b'b' => {
out.push(0x08);
i += 2;
}
b'f' => {
out.push(0x0C);
i += 2;
}
b'n' => {
out.push(b'\n');
i += 2;
}
b'r' => {
out.push(b'\r');
i += 2;
}
b't' => {
out.push(b'\t');
i += 2;
}
b'u' => {
if let Some(hi) = read_u4(&input[i + 2..]) {
if (0xD800..=0xDBFF).contains(&hi)
&& input.get(i + 6) == Some(&b'\\')
&& input.get(i + 7) == Some(&b'u')
&& let Some(lo) = read_u4(&input[i + 8..])
&& (0xDC00..=0xDFFF).contains(&lo)
{
let cp = 0x10000 + ((hi - 0xD800) << 10) + (lo - 0xDC00);
push_utf8(&mut out, cp);
i += 12;
} else if (0xD800..=0xDFFF).contains(&hi) {
push_utf8(&mut out, 0xFFFD);
i += 6;
} else {
push_utf8(&mut out, hi);
i += 6;
}
} else {
out.push(b'\\');
i += 1;
}
}
_ => {
out.push(b'\\');
i += 1;
}
}
}
out
}