use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{OpSpec, BYTES_TO_BYTES_INPUTS, BYTES_TO_BYTES_OUTPUTS};
pub const CATEGORY_A_WGSL_MARKER: &str = "encode.base64_encode";
pub const ALPHABET: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
pub fn alphabet_char(idx: Expr) -> Expr {
Expr::select(
Expr::lt(idx.clone(), Expr::u32(26)),
Expr::add(Expr::u32(u32::from(b'A')), idx.clone()),
Expr::select(
Expr::lt(idx.clone(), Expr::u32(52)),
Expr::add(
Expr::u32(u32::from(b'a')),
Expr::sub(idx.clone(), Expr::u32(26)),
),
Expr::select(
Expr::lt(idx.clone(), Expr::u32(62)),
Expr::add(
Expr::u32(u32::from(b'0')),
Expr::sub(idx.clone(), Expr::u32(52)),
),
Expr::select(
Expr::eq(idx, Expr::u32(62)),
Expr::u32(u32::from(b'+')),
Expr::u32(u32::from(b'/')),
),
),
),
)
}
#[must_use]
pub fn base64_encode(input: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(input.len().div_ceil(3) * 4);
for chunk in input.chunks(3) {
let a = chunk[0];
let b = chunk.get(1).copied().unwrap_or(0);
let c = chunk.get(2).copied().unwrap_or(0);
out.push(ALPHABET[(a >> 2) as usize]);
out.push(ALPHABET[(((a & 0x03) << 4) | (b >> 4)) as usize]);
if chunk.len() > 1 {
out.push(ALPHABET[(((b & 0x0f) << 2) | (c >> 6)) as usize]);
} else {
out.push(b'=');
}
if chunk.len() > 2 {
out.push(ALPHABET[(c & 0x3f) as usize]);
} else {
out.push(b'=');
}
}
out
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Base64Encode;
pub fn has_input_byte(offset: u32) -> Expr {
Expr::lt(
Expr::add(Expr::var("input_idx"), Expr::u32(offset)),
Expr::buf_len("input"),
)
}
impl Base64Encode {
pub const SPEC: OpSpec = OpSpec::composition_inlinable(
"encode.base64_encode",
BYTES_TO_BYTES_INPUTS,
BYTES_TO_BYTES_OUTPUTS,
LAWS,
Self::program,
);
#[must_use]
pub fn program() -> Program {
let chunk = Expr::var("chunk");
let input_idx = Expr::mul(chunk.clone(), Expr::u32(3));
let out_idx = Expr::mul(chunk, Expr::u32(4));
Program::new(
vec![
BufferDecl::read("input", 0, DataType::Bytes),
BufferDecl::output("out", 1, DataType::Bytes),
],
[64, 1, 1],
vec![
Node::let_bind("chunk", Expr::gid_x()),
Node::if_then(
Expr::lt(input_idx.clone(), Expr::buf_len("input")),
vec![
Node::let_bind("input_idx", input_idx),
Node::let_bind("out_idx", out_idx),
Node::let_bind("a", Expr::load("input", Expr::var("input_idx"))),
Node::let_bind(
"b",
Expr::load("input", Expr::add(Expr::var("input_idx"), Expr::u32(1))),
),
Node::let_bind(
"c",
Expr::load("input", Expr::add(Expr::var("input_idx"), Expr::u32(2))),
),
Node::let_bind("has_b", has_input_byte(1)),
Node::let_bind("has_c", has_input_byte(2)),
Node::store(
"out",
Expr::var("out_idx"),
alphabet_char(Expr::shr(Expr::var("a"), Expr::u32(2))),
),
Node::store(
"out",
Expr::add(Expr::var("out_idx"), Expr::u32(1)),
alphabet_char(Expr::bitor(
Expr::shl(
Expr::bitand(Expr::var("a"), Expr::u32(0x03)),
Expr::u32(4),
),
Expr::shr(Expr::var("b"), Expr::u32(4)),
)),
),
Node::store(
"out",
Expr::add(Expr::var("out_idx"), Expr::u32(2)),
Expr::select(
Expr::var("has_b"),
alphabet_char(Expr::bitor(
Expr::shl(
Expr::bitand(Expr::var("b"), Expr::u32(0x0f)),
Expr::u32(2),
),
Expr::shr(Expr::var("c"), Expr::u32(6)),
)),
Expr::u32(u32::from(b'=')),
),
),
Node::store(
"out",
Expr::add(Expr::var("out_idx"), Expr::u32(3)),
Expr::select(
Expr::var("has_c"),
alphabet_char(Expr::bitand(Expr::var("c"), Expr::u32(0x3f))),
Expr::u32(u32::from(b'=')),
),
),
],
),
],
)
}
}
pub const LAWS: &[crate::ops::AlgebraicLaw] = &[];
#[test]
pub fn committed_kats_match_cpu_reference() -> Result<(), String> {
crate::ops::fixtures::run_committed_kats(
include_str!("../fixtures/reference-vectors.toml"),
|case| {
if case.op.as_deref() != Some("base64_encode") {
return Ok(());
}
assert_eq!(
base64_encode(&crate::ops::fixtures::hex_to_bytes(
case.input_hex.as_ref().ok_or("Fix: missing input_hex")?,
)?),
crate::ops::fixtures::hex_to_bytes(
case.expected_output_hex.as_ref().ok_or("Fix: missing expected_output_hex")?,
)?
);
Ok(())
},
)
}