use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{AlgebraicLaw, OpSpec};
pub const INPUTS: &[DataType] = &[DataType::U32, DataType::U32];
pub const OUTPUTS: &[DataType] = &[DataType::U32, DataType::U32];
pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}];
pub const MAX_OUTPUT_RATIO: u32 = 1024;
pub const WORKGROUP_SIZE: [u32; 3] = [1, 1, 1];
pub const LOOP_LIMIT: u32 = 131_072;
pub const BLOCK_RAW: u32 = 0;
pub const BLOCK_RLE: u32 = 1;
pub const BLOCK_COMPRESSED: u32 = 2;
pub const STATUS_INPUT_TRUNCATED: u32 = 1;
pub const STATUS_OUTPUT_OVERFLOW: u32 = 2;
pub const STATUS_COMPRESSED_BLOCK_UNSUPPORTED: u32 = 3;
pub const STATUS_RESERVED_BLOCK_TYPE: u32 = 4;
pub const STATUS_LOOP_LIMIT_EXCEEDED: u32 = 5;
pub const OP_ID: &str = "compression.zstd_decompress";
pub const COMPRESSED_BLOCK_OP: &str = "compression.zstd.compressed_block";
#[derive(Debug, Clone, Copy, Default)]
pub struct ZstdDecompress;
impl ZstdDecompress {
pub const SPEC: OpSpec = OpSpec::composition(OP_ID, INPUTS, OUTPUTS, LAWS, Self::program);
#[must_use]
pub fn program() -> Program {
zstd_decompress_program()
}
}
#[must_use]
pub fn zstd_decompress_program() -> Program {
Program::new(
vec![
BufferDecl::read("input", 0, DataType::U32),
BufferDecl::output("out", 1, DataType::U32),
BufferDecl::read("params", 2, DataType::U32),
BufferDecl::read_write("status", 3, DataType::U32),
],
WORKGROUP_SIZE,
vec![
Node::let_bind("ip", Expr::u32(0)),
Node::let_bind("op", Expr::u32(0)),
Node::let_bind("is_active", Expr::u32(1)),
Node::store("status", Expr::u32(0), Expr::u32(0)),
Node::Loop {
var: "block".to_string(),
from: Expr::u32(0),
to: Expr::u32(LOOP_LIMIT),
body: vec![Node::if_then(active(), zstd_block_body())],
},
Node::if_then(
and(active(), Expr::lt(Expr::var("ip"), input_len())),
vec![fail(STATUS_LOOP_LIMIT_EXCEEDED)],
),
Node::store("status", Expr::u32(1), Expr::var("op")),
],
)
.with_entry_op_id(OP_ID)
}
pub fn zstd_block_body() -> Vec<Node> {
vec![
Node::let_bind("z_b0", Expr::u32(0)),
Node::let_bind("z_b1", Expr::u32(0)),
Node::let_bind("z_b2", Expr::u32(0)),
Node::let_bind("z_header", Expr::u32(0)),
Node::let_bind("z_last", Expr::u32(0)),
Node::let_bind("z_block_type", Expr::u32(0)),
Node::let_bind("z_block_size", Expr::u32(0)),
read_header(),
decode_raw_block(),
decode_rle_block(),
reject_compressed_block(),
reject_reserved_block(),
Node::if_then(
and(active(), Expr::ne(Expr::var("z_last"), Expr::u32(0))),
vec![Node::assign("is_active", Expr::u32(0))],
),
]
}
pub fn read_header() -> Node {
Node::if_then_else(
Expr::lt(Expr::add(Expr::var("ip"), Expr::u32(2)), input_len()),
vec![
Node::assign("z_b0", read_byte("input", Expr::var("ip"))),
Node::assign(
"z_b1",
read_byte("input", Expr::add(Expr::var("ip"), Expr::u32(1))),
),
Node::assign(
"z_b2",
read_byte("input", Expr::add(Expr::var("ip"), Expr::u32(2))),
),
Node::assign("ip", Expr::add(Expr::var("ip"), Expr::u32(3))),
Node::assign(
"z_header",
Expr::bitor(
Expr::bitor(
Expr::var("z_b0"),
Expr::shl(Expr::var("z_b1"), Expr::u32(8)),
),
Expr::shl(Expr::var("z_b2"), Expr::u32(16)),
),
),
Node::assign("z_last", Expr::bitand(Expr::var("z_header"), Expr::u32(1))),
Node::assign(
"z_block_type",
Expr::bitand(Expr::shr(Expr::var("z_header"), Expr::u32(1)), Expr::u32(3)),
),
Node::assign(
"z_block_size",
Expr::shr(Expr::var("z_header"), Expr::u32(3)),
),
],
vec![fail(STATUS_INPUT_TRUNCATED)],
)
}
pub fn decode_raw_block() -> Node {
Node::Loop {
var: "raw_i".to_string(),
from: Expr::u32(0),
to: Expr::u32(LOOP_LIMIT),
body: vec![Node::if_then(
and(
and(
active(),
Expr::eq(Expr::var("z_block_type"), Expr::u32(BLOCK_RAW)),
),
Expr::lt(Expr::var("raw_i"), Expr::var("z_block_size")),
),
vec![Node::if_then_else(
and(
Expr::lt(Expr::var("ip"), input_len()),
Expr::lt(Expr::var("op"), output_len()),
),
vec![
store_byte("out", Expr::var("op"), read_byte("input", Expr::var("ip"))),
Node::assign("ip", Expr::add(Expr::var("ip"), Expr::u32(1))),
Node::assign("op", Expr::add(Expr::var("op"), Expr::u32(1))),
],
vec![Node::if_then_else(
Expr::lt(Expr::var("ip"), input_len()),
vec![fail(STATUS_OUTPUT_OVERFLOW)],
vec![fail(STATUS_INPUT_TRUNCATED)],
)],
)],
)],
}
}
pub fn decode_rle_block() -> Node {
Node::if_then(
and(
active(),
Expr::eq(Expr::var("z_block_type"), Expr::u32(BLOCK_RLE)),
),
vec![Node::if_then_else(
Expr::lt(Expr::var("ip"), input_len()),
vec![
Node::let_bind("rle_byte", read_byte("input", Expr::var("ip"))),
Node::assign("ip", Expr::add(Expr::var("ip"), Expr::u32(1))),
rle_fill_loop(),
],
vec![fail(STATUS_INPUT_TRUNCATED)],
)],
)
}
pub fn rle_fill_loop() -> Node {
Node::Loop {
var: "rle_i".to_string(),
from: Expr::u32(0),
to: Expr::u32(LOOP_LIMIT),
body: vec![Node::if_then(
and(
active(),
Expr::lt(Expr::var("rle_i"), Expr::var("z_block_size")),
),
vec![Node::if_then_else(
Expr::lt(Expr::var("op"), output_len()),
vec![
store_byte("out", Expr::var("op"), Expr::var("rle_byte")),
Node::assign("op", Expr::add(Expr::var("op"), Expr::u32(1))),
],
vec![fail(STATUS_OUTPUT_OVERFLOW)],
)],
)],
}
}
pub fn reject_compressed_block() -> Node {
Node::if_then(
and(
active(),
Expr::eq(Expr::var("z_block_type"), Expr::u32(BLOCK_COMPRESSED)),
),
vec![fail(STATUS_COMPRESSED_BLOCK_UNSUPPORTED)],
)
}
pub fn reject_reserved_block() -> Node {
Node::if_then(
and(
active(),
Expr::gt(Expr::var("z_block_type"), Expr::u32(BLOCK_COMPRESSED)),
),
vec![fail(STATUS_RESERVED_BLOCK_TYPE)],
)
}
pub fn read_byte(buffer: &'static str, pos: Expr) -> Expr {
let word = Expr::load(buffer, Expr::shr(pos.clone(), Expr::u32(2)));
let shift = Expr::mul(Expr::bitand(pos, Expr::u32(3)), Expr::u32(8));
Expr::bitand(Expr::shr(word, shift), Expr::u32(0xff))
}
pub fn store_byte(buffer: &'static str, pos: Expr, byte: Expr) -> Node {
let word_index = Expr::shr(pos.clone(), Expr::u32(2));
let shift = Expr::mul(Expr::bitand(pos, Expr::u32(3)), Expr::u32(8));
let mask = Expr::shl(Expr::u32(0xff), shift.clone());
let cleared = Expr::bitand(
Expr::load(buffer, word_index.clone()),
Expr::bitnot(mask.clone()),
);
let inserted = Expr::shl(Expr::bitand(byte, Expr::u32(0xff)), shift);
Node::store(
buffer,
word_index,
Expr::bitor(cleared, Expr::bitand(inserted, mask)),
)
}
pub fn fail(code: u32) -> Node {
Node::Block(vec![
Node::store("status", Expr::u32(0), Expr::u32(code)),
Node::assign("is_active", Expr::u32(0)),
])
}
pub fn input_len() -> Expr {
Expr::load("params", Expr::u32(0))
}
pub fn output_len() -> Expr {
Expr::load("params", Expr::u32(1))
}
pub fn active() -> Expr {
Expr::ne(Expr::var("is_active"), Expr::u32(0))
}
pub fn and(left: Expr, right: Expr) -> Expr {
Expr::bitand(left, right)
}