use core::fmt;
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{OpSpec, BYTES_TO_BYTES_INPUTS, BYTES_TO_BYTES_OUTPUTS};
pub const CATEGORY_A_WGSL_MARKER: &str = "decode.hex_decode_strict";
pub fn and(left: Expr, right: Expr) -> Expr {
Expr::and(left, right)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DecodeError {
OddLength,
InvalidByte {
index: usize,
byte: u8,
},
}
pub fn hex_decode_strict(input: &[u8]) -> Result<Vec<u8>, DecodeError> {
if input.len() % 2 != 0 {
return Err(DecodeError::OddLength);
}
let mut out = Vec::with_capacity(input.len() / 2);
for (pair_index, pair) in input.chunks_exact(2).enumerate() {
let hi_index = pair_index * 2;
let hi = hex_value(pair[0]).ok_or(DecodeError::InvalidByte {
index: hi_index,
byte: pair[0],
})?;
let lo = hex_value(pair[1]).ok_or(DecodeError::InvalidByte {
index: hi_index + 1,
byte: pair[1],
})?;
out.push((hi << 4) | lo);
}
Ok(out)
}
#[derive(Debug, Clone, Copy, Default)]
pub struct HexDecodeStrict;
pub fn hex_nibble(byte: Expr) -> Expr {
Expr::select(
is_digit(byte.clone()),
Expr::sub(byte.clone(), Expr::u32(u32::from(b'0'))),
Expr::select(
is_lower_hex(byte.clone()),
Expr::add(
Expr::sub(byte.clone(), Expr::u32(u32::from(b'a'))),
Expr::u32(10),
),
Expr::select(
is_upper_hex(byte.clone()),
Expr::add(Expr::sub(byte, Expr::u32(u32::from(b'A'))), Expr::u32(10)),
Expr::u32(0),
),
),
)
}
pub fn hex_value(byte: u8) -> Option<u8> {
match byte {
b'0'..=b'9' => Some(byte - b'0'),
b'A'..=b'F' => Some(byte - b'A' + 10),
b'a'..=b'f' => Some(byte - b'a' + 10),
_ => None,
}
}
impl fmt::Display for DecodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::OddLength => write!(f, "Fix: hex input must contain an even number of bytes"),
Self::InvalidByte { index, byte } => {
write!(
f,
"Fix: byte 0x{byte:02x} at offset {index} is not hexadecimal"
)
}
}
}
}
impl std::error::Error for DecodeError {}
impl HexDecodeStrict {
pub const SPEC: OpSpec = OpSpec::composition_inlinable(
"decode.hex_decode_strict",
BYTES_TO_BYTES_INPUTS,
BYTES_TO_BYTES_OUTPUTS,
LAWS,
Self::program,
);
#[must_use]
pub fn program() -> Program {
let idx = Expr::var("idx");
let input_idx = Expr::mul(idx.clone(), Expr::u32(2));
Program::new(
vec![
BufferDecl::read("input", 0, DataType::Bytes),
BufferDecl::output("out", 1, DataType::Bytes),
],
[64, 1, 1],
vec![
Node::let_bind("idx", Expr::gid_x()),
Node::if_then(
Expr::lt(idx.clone(), Expr::buf_len("out")),
vec![
Node::let_bind("input_idx", input_idx),
Node::let_bind("hi_byte", Expr::load("input", Expr::var("input_idx"))),
Node::let_bind(
"lo_byte",
Expr::load("input", Expr::add(Expr::var("input_idx"), Expr::u32(1))),
),
Node::let_bind("hi_valid", is_hex_byte(Expr::var("hi_byte"))),
Node::let_bind("lo_valid", is_hex_byte(Expr::var("lo_byte"))),
Node::let_bind(
"even_len",
Expr::eq(
Expr::rem(Expr::buf_len("input"), Expr::u32(2)),
Expr::u32(0),
),
),
Node::let_bind("hi", hex_nibble(Expr::var("hi_byte"))),
Node::let_bind("lo", hex_nibble(Expr::var("lo_byte"))),
Node::store(
"out",
idx,
Expr::select(
and(
Expr::var("even_len"),
and(Expr::var("hi_valid"), Expr::var("lo_valid")),
),
Expr::bitor(
Expr::shl(Expr::var("hi"), Expr::u32(4)),
Expr::var("lo"),
),
Expr::u32(0),
),
),
],
),
],
)
}
}
pub fn in_range(value: Expr, low: u8, high: u8) -> Expr {
and(
Expr::le(Expr::u32(u32::from(low)), value.clone()),
Expr::le(value, Expr::u32(u32::from(high))),
)
}
pub fn is_digit(byte: Expr) -> Expr {
in_range(byte, b'0', b'9')
}
pub fn is_hex_byte(byte: Expr) -> Expr {
or(
is_digit(byte.clone()),
or(is_lower_hex(byte.clone()), is_upper_hex(byte)),
)
}
pub fn is_lower_hex(byte: Expr) -> Expr {
in_range(byte, b'a', b'f')
}
pub fn is_upper_hex(byte: Expr) -> Expr {
in_range(byte, b'A', b'F')
}
pub const LAWS: &[crate::ops::AlgebraicLaw] = &[];
pub fn or(left: Expr, right: Expr) -> Expr {
Expr::or(left, right)
}
#[test]
pub fn committed_kats_match_cpu_reference() -> Result<(), String> {
crate::ops::fixtures::run_committed_kats(
include_str!("../fixtures/reference-vectors.toml"),
|case| {
if case.op.as_deref() != Some("hex_decode_strict") {
return Ok(());
}
let input = crate::ops::fixtures::hex_to_bytes(
case.input_hex.as_ref().ok_or("Fix: missing input_hex")?,
)?;
let expected = crate::ops::fixtures::hex_to_bytes(
case.expected_output_hex
.as_ref()
.ok_or("Fix: missing expected_output_hex")?,
)?;
let ok = case.ok.ok_or("Fix: missing ok")?;
match (ok, hex_decode_strict(&input)) {
(true, Ok(actual)) => assert_eq!(actual, expected),
(false, Err(_)) => {}
(true, Err(error)) => {
return Err(format!(
"Fix: valid hex decode vector was rejected: {error}"
));
}
(false, Ok(actual)) => {
return Err(format!(
"Fix: invalid hex decode vector should fail, decoded to {actual:02x?}"
));
}
}
Ok(())
},
)
}