use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{OpSpec, BYTES_TO_BYTES_INPUTS, BYTES_TO_BYTES_OUTPUTS};
pub const WGSL: &str = concat!(
include_str!("../wgsl_byte_primitives/bytes.wgsl"),
"\n",
include_str!("wgsl/url_percent.wgsl"),
);
pub fn hex_value(byte: u8) -> Option<u8> {
match byte {
b'0'..=b'9' => Some(byte - b'0'),
b'A'..=b'F' => Some(byte - b'A' + 10),
b'a'..=b'f' => Some(byte - b'a' + 10),
_ => None,
}
}
impl UrlPercentDecode {
pub const SPEC: OpSpec = OpSpec::composition(
"decode.url_percent",
BYTES_TO_BYTES_INPUTS,
BYTES_TO_BYTES_OUTPUTS,
LAWS,
Self::program,
);
#[must_use]
pub fn program() -> Program {
let idx = Expr::gid_x();
Program::new(
vec![
BufferDecl::read("input", 0, DataType::Bytes),
BufferDecl::output("out", 1, DataType::Bytes),
],
[64, 1, 1],
vec![
Node::let_bind("idx", idx.clone()),
Node::if_then(
Expr::lt(idx.clone(), Expr::buf_len("out")),
vec![Node::store("out", idx.clone(), Expr::load("input", idx))],
),
],
)
}
}
pub const LAWS: &[crate::ops::AlgebraicLaw] = &[];
pub fn url_percent_decode(input: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(input.len());
let mut cursor = 0;
while cursor < input.len() {
if input[cursor] == b'%' && cursor + 2 < input.len() {
if let (Some(hi), Some(lo)) =
(hex_value(input[cursor + 1]), hex_value(input[cursor + 2]))
{
out.push((hi << 4) | lo);
cursor += 3;
continue;
}
}
out.push(input[cursor]);
cursor += 1;
}
out
}
#[derive(Debug, Clone, Copy, Default)]
pub struct UrlPercentDecode;
#[test]
pub fn committed_kats_match_cpu_reference() -> Result<(), String> {
crate::ops::fixtures::run_committed_kats(
include_str!("../fixtures/reference-vectors.toml"),
|case| {
if case.op.as_deref() != Some("url_percent_decode") {
return Ok(());
}
let input = crate::ops::fixtures::hex_to_bytes(
case.input_hex.as_ref().ok_or("Fix: missing input_hex")?,
)?;
let expected = crate::ops::fixtures::hex_to_bytes(
case.expected_output_hex
.as_ref()
.ok_or("Fix: missing expected_output_hex")?,
)?;
let ok = case.ok.ok_or("Fix: missing ok")?;
let actual = url_percent_decode(&input);
if ok {
assert_eq!(actual, expected);
} else {
assert_eq!(
actual, input,
"lenient decode must pass through malformed input"
);
}
Ok(())
},
)
}