use core::fmt;
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{OpSpec, BYTES_TO_BYTES_INPUTS, BYTES_TO_BYTES_OUTPUTS};
pub const WGSL: &str = concat!(
include_str!("../wgsl_byte_primitives/bytes.wgsl"),
"\n",
include_str!("wgsl/base32.wgsl"),
);
pub fn base32_decode(input: &[u8]) -> Result<Vec<u8>, DecodeError> {
let first_pad = input.iter().position(|&byte| byte == b'=');
let data_len = first_pad.unwrap_or(input.len());
if let Some(pad_at) = first_pad {
if input[pad_at..].iter().any(|&byte| byte != b'=') {
return Err(DecodeError::InvalidPadding);
}
if input.len() % 8 != 0 {
return Err(DecodeError::InvalidPadding);
}
}
if matches!(data_len % 8, 1 | 3 | 6) {
return Err(DecodeError::InvalidLength);
}
let mut out = Vec::with_capacity((data_len * 5) / 8);
let mut buffer = 0_u32;
let mut bits = 0_u32;
for (index, &byte) in input[..data_len].iter().enumerate() {
let value = base32_value(byte).ok_or(DecodeError::InvalidByte { index, byte })?;
buffer = (buffer << 5) | u32::from(value);
bits += 5;
if bits >= 8 {
bits -= 8;
out.push(((buffer >> bits) & 0xff) as u8);
}
}
if bits > 0 && (buffer & ((1_u32 << bits) - 1)) != 0 {
return Err(DecodeError::NonZeroTrailingBits);
}
Ok(out)
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Base32Decode;
pub fn base32_value(byte: u8) -> Option<u8> {
match byte {
b'A'..=b'Z' => Some(byte - b'A'),
b'2'..=b'7' => Some(byte - b'2' + 26),
_ => None,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DecodeError {
InvalidLength,
InvalidByte {
index: usize,
byte: u8,
},
InvalidPadding,
NonZeroTrailingBits,
}
impl Base32Decode {
pub const SPEC: OpSpec = OpSpec::composition(
"decode.base32",
BYTES_TO_BYTES_INPUTS,
BYTES_TO_BYTES_OUTPUTS,
LAWS,
Self::program,
);
#[must_use]
pub fn program() -> Program {
let idx = Expr::gid_x();
Program::new(
vec![
BufferDecl::read("input", 0, DataType::Bytes),
BufferDecl::output("out", 1, DataType::Bytes),
],
[64, 1, 1],
vec![
Node::let_bind("idx", idx.clone()),
Node::if_then(
Expr::lt(idx.clone(), Expr::buf_len("out")),
vec![Node::store("out", idx.clone(), Expr::load("input", idx))],
),
],
)
}
}
impl fmt::Display for DecodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidLength => write!(f, "Fix: base32 encoded length has an invalid residue"),
Self::InvalidByte { index, byte } => {
write!(
f,
"Fix: byte 0x{byte:02x} at offset {index} is not RFC 4648 base32"
)
}
Self::InvalidPadding => write!(f, "Fix: place base32 padding only at the end"),
Self::NonZeroTrailingBits => {
write!(f, "Fix: canonical base32 trailing bits must be zero")
}
}
}
}
impl std::error::Error for DecodeError {}
pub const LAWS: &[crate::ops::AlgebraicLaw] = &[];
#[test]
pub fn committed_kats_match_cpu_reference() -> Result<(), String> {
crate::ops::fixtures::run_committed_kats(
include_str!("../fixtures/reference-vectors.toml"),
|case| {
if case.op.as_deref() != Some("base32_decode") {
return Ok(());
}
let input = crate::ops::fixtures::hex_to_bytes(
case.input_hex.as_ref().ok_or("Fix: missing input_hex")?,
)?;
let expected = crate::ops::fixtures::hex_to_bytes(
case.expected_output_hex
.as_ref()
.ok_or("Fix: missing expected_output_hex")?,
)?;
let ok = case.ok.ok_or("Fix: missing ok")?;
match (ok, base32_decode(&input)) {
(true, Ok(actual)) => assert_eq!(actual, expected),
(false, Err(_)) => {}
(true, Err(error)) => {
return Err(format!("Fix: valid base32 vector was rejected: {error}"));
}
(false, Ok(actual)) => {
return Err(format!(
"Fix: invalid base32 vector should fail, decoded to {actual:02x?}"
));
}
}
Ok(())
},
)
}