pub mod kernel {
use super::super::shared::{and, in_range, BYTE_BOUNDED_LAWS};
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{OpSpec, BYTES_TO_BYTES_INPUTS, BYTES_TO_BYTES_OUTPUTS};
#[derive(Debug, Clone, Copy, Default)]
pub struct Base64Decode;
impl Base64Decode {
pub const SPEC: OpSpec = OpSpec::composition(
"decode.base64",
BYTES_TO_BYTES_INPUTS,
BYTES_TO_BYTES_OUTPUTS,
BYTE_BOUNDED_LAWS,
Self::program,
);
#[must_use]
pub fn program() -> Program {
let idx = Expr::var("idx");
let out_idx = Expr::mul(idx.clone(), Expr::u32(3));
let a = Expr::load("input", Expr::mul(idx.clone(), Expr::u32(4)));
let b = Expr::load(
"input",
Expr::add(Expr::mul(idx.clone(), Expr::u32(4)), Expr::u32(1)),
);
let c = Expr::load(
"input",
Expr::add(Expr::mul(idx.clone(), Expr::u32(4)), Expr::u32(2)),
);
let d = Expr::load(
"input",
Expr::add(Expr::mul(idx.clone(), Expr::u32(4)), Expr::u32(3)),
);
Program::new(
vec![
BufferDecl::read("input", 0, DataType::Bytes),
BufferDecl::output("out", 1, DataType::Bytes),
],
[64, 1, 1],
vec![
Node::let_bind("idx", Expr::gid_x()),
Node::if_then(
and(
Expr::lt(
Expr::add(Expr::mul(idx.clone(), Expr::u32(4)), Expr::u32(3)),
Expr::buf_len("input"),
),
Expr::lt(
Expr::add(out_idx.clone(), Expr::u32(2)),
Expr::buf_len("out"),
),
),
vec![
Node::let_bind("a6", decode_value(a)),
Node::let_bind("b6", decode_value(b)),
Node::let_bind("c6", decode_value(c)),
Node::let_bind("d6", decode_value(d)),
Node::store("out", out_idx.clone(), byte0()),
Node::store("out", Expr::add(out_idx.clone(), Expr::u32(1)), byte1()),
Node::store("out", Expr::add(out_idx, Expr::u32(2)), byte2()),
],
),
],
)
}
}
#[must_use]
pub fn decode_value(byte: Expr) -> Expr {
let upper = in_range(byte.clone(), b'A', b'Z');
let lower = in_range(byte.clone(), b'a', b'z');
let digit = in_range(byte.clone(), b'0', b'9');
let plus = Expr::eq(byte.clone(), Expr::u32(u32::from(b'+')));
let slash = Expr::eq(byte.clone(), Expr::u32(u32::from(b'/')));
Expr::select(
upper,
Expr::sub(byte.clone(), Expr::u32(u32::from(b'A'))),
Expr::select(
lower,
Expr::add(
Expr::sub(byte.clone(), Expr::u32(u32::from(b'a'))),
Expr::u32(26),
),
Expr::select(
digit,
Expr::add(Expr::sub(byte, Expr::u32(u32::from(b'0'))), Expr::u32(52)),
Expr::select(
plus,
Expr::u32(62),
Expr::select(slash, Expr::u32(63), Expr::u32(0)),
),
),
),
)
}
#[must_use]
pub fn byte0() -> Expr {
Expr::bitor(
Expr::shl(Expr::var("a6"), Expr::u32(2)),
Expr::shr(Expr::var("b6"), Expr::u32(4)),
)
}
#[must_use]
pub fn byte1() -> Expr {
Expr::bitor(
Expr::shl(Expr::bitand(Expr::var("b6"), Expr::u32(0x0f)), Expr::u32(4)),
Expr::shr(Expr::var("c6"), Expr::u32(2)),
)
}
#[must_use]
pub fn byte2() -> Expr {
Expr::bitor(
Expr::shl(Expr::bitand(Expr::var("c6"), Expr::u32(0x03)), Expr::u32(6)),
Expr::var("d6"),
)
}
}
pub use kernel::Base64Decode;