use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{AlgebraicLaw, OpSpec};
pub const INPUTS: &[DataType] = &[DataType::U32, DataType::U32];
pub const OUTPUTS: &[DataType] = &[DataType::U32, DataType::U32];
pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}];
pub const MAX_OUTPUT_RATIO: u32 = 1024;
pub const WORKGROUP_SIZE: [u32; 3] = [1, 1, 1];
pub const LOOP_LIMIT: u32 = 65_536;
pub const STATUS_INPUT_TRUNCATED: u32 = 1;
pub const STATUS_OUTPUT_OVERFLOW: u32 = 2;
pub const STATUS_INVALID_OFFSET: u32 = 3;
pub const STATUS_LOOP_LIMIT_EXCEEDED: u32 = 4;
pub const OP_ID: &str = "compression.lz4_decompress";
#[derive(Debug, Clone, Copy, Default)]
pub struct Lz4Decompress;
impl Lz4Decompress {
pub const SPEC: OpSpec = OpSpec::composition(OP_ID, INPUTS, OUTPUTS, LAWS, Self::program);
#[must_use]
pub fn program() -> Program {
lz4_decompress_program()
}
}
#[must_use]
pub fn lz4_decompress_program() -> Program {
Program::new(
vec![
BufferDecl::read("input", 0, DataType::U32),
BufferDecl::output("out", 1, DataType::U32),
BufferDecl::read("params", 2, DataType::U32),
BufferDecl::read_write("status", 3, DataType::U32),
],
WORKGROUP_SIZE,
vec![
Node::let_bind("ip", Expr::u32(0)),
Node::let_bind("op", Expr::u32(0)),
Node::let_bind("is_active", Expr::u32(1)),
Node::store("status", Expr::u32(0), Expr::u32(0)),
Node::Loop {
var: "sequence".to_string(),
from: Expr::u32(0),
to: Expr::u32(LOOP_LIMIT),
body: vec![Node::if_then(active(), lz4_sequence_body())],
},
Node::if_then(
and(active(), Expr::lt(Expr::var("ip"), input_len())),
vec![fail(STATUS_LOOP_LIMIT_EXCEEDED)],
),
Node::store("status", Expr::u32(1), Expr::var("op")),
],
)
.with_entry_op_id(OP_ID)
}
pub fn lz4_sequence_body() -> Vec<Node> {
vec![Node::if_then_else(
Expr::lt(Expr::var("ip"), input_len()),
vec![
Node::let_bind("token", read_byte("input", Expr::var("ip"))),
Node::assign("ip", Expr::add(Expr::var("ip"), Expr::u32(1))),
Node::let_bind("lit_len", Expr::shr(Expr::var("token"), Expr::u32(4))),
Node::let_bind("lit_ext", Expr::eq(Expr::var("lit_len"), Expr::u32(15))),
extend_length("lit", "lit_len", "lit_ext"),
Node::if_then(
and(active(), Expr::var("lit_ext")),
vec![fail(STATUS_LOOP_LIMIT_EXCEEDED)],
),
copy_literals(),
Node::if_then(
and(
active(),
Expr::gt(Expr::var("lit_len"), Expr::u32(LOOP_LIMIT)),
),
vec![fail(STATUS_LOOP_LIMIT_EXCEEDED)],
),
Node::let_bind("offset", Expr::u32(0)),
read_match_offset(),
Node::if_then(
active(),
vec![
Node::let_bind(
"match_len",
Expr::add(
Expr::bitand(Expr::var("token"), Expr::u32(0x0f)),
Expr::u32(4),
),
),
Node::let_bind(
"match_ext",
Expr::eq(
Expr::bitand(Expr::var("token"), Expr::u32(0x0f)),
Expr::u32(15),
),
),
extend_length("match", "match_len", "match_ext"),
Node::if_then(
and(active(), Expr::var("match_ext")),
vec![fail(STATUS_LOOP_LIMIT_EXCEEDED)],
),
validate_match_offset(),
copy_match(),
Node::if_then(
and(
active(),
Expr::gt(Expr::var("match_len"), Expr::u32(LOOP_LIMIT)),
),
vec![fail(STATUS_LOOP_LIMIT_EXCEEDED)],
),
],
),
],
vec![Node::assign("is_active", Expr::u32(0))],
)]
}
pub fn extend_length(prefix: &'static str, len_var: &'static str, flag_var: &'static str) -> Node {
let ext_byte = format!("{prefix}_ext_byte");
Node::Loop {
var: format!("{prefix}_ext_i"),
from: Expr::u32(0),
to: Expr::u32(LOOP_LIMIT),
body: vec![Node::if_then(
and(active(), Expr::var(flag_var)),
vec![Node::if_then_else(
Expr::lt(Expr::var("ip"), input_len()),
vec![
Node::let_bind(&ext_byte, read_byte("input", Expr::var("ip"))),
Node::assign("ip", Expr::add(Expr::var("ip"), Expr::u32(1))),
Node::assign(len_var, Expr::add(Expr::var(len_var), Expr::var(&ext_byte))),
Node::assign(flag_var, Expr::eq(Expr::var(&ext_byte), Expr::u32(255))),
],
vec![fail(STATUS_INPUT_TRUNCATED)],
)],
)],
}
}
pub fn copy_literals() -> Node {
Node::Loop {
var: "lit_i".to_string(),
from: Expr::u32(0),
to: Expr::u32(LOOP_LIMIT),
body: vec![Node::if_then(
and(active(), Expr::lt(Expr::var("lit_i"), Expr::var("lit_len"))),
vec![Node::if_then_else(
and(
Expr::lt(Expr::var("ip"), input_len()),
Expr::lt(Expr::var("op"), output_len()),
),
vec![
store_byte("out", Expr::var("op"), read_byte("input", Expr::var("ip"))),
Node::assign("ip", Expr::add(Expr::var("ip"), Expr::u32(1))),
Node::assign("op", Expr::add(Expr::var("op"), Expr::u32(1))),
],
vec![Node::if_then_else(
Expr::lt(Expr::var("ip"), input_len()),
vec![fail(STATUS_OUTPUT_OVERFLOW)],
vec![fail(STATUS_INPUT_TRUNCATED)],
)],
)],
)],
}
}
pub fn read_match_offset() -> Node {
Node::if_then(
active(),
vec![Node::if_then_else(
Expr::eq(Expr::var("ip"), input_len()),
vec![Node::assign("is_active", Expr::u32(0))],
vec![Node::if_then_else(
Expr::lt(Expr::add(Expr::var("ip"), Expr::u32(1)), input_len()),
vec![
Node::let_bind("offset_lo", read_byte("input", Expr::var("ip"))),
Node::let_bind(
"offset_hi",
read_byte("input", Expr::add(Expr::var("ip"), Expr::u32(1))),
),
Node::assign(
"offset",
Expr::bitor(
Expr::var("offset_lo"),
Expr::shl(Expr::var("offset_hi"), Expr::u32(8)),
),
),
Node::assign("ip", Expr::add(Expr::var("ip"), Expr::u32(2))),
],
vec![fail(STATUS_INPUT_TRUNCATED)],
)],
)],
)
}
pub fn validate_match_offset() -> Node {
Node::if_then(
active(),
vec![Node::if_then(
or(
Expr::eq(Expr::var("offset"), Expr::u32(0)),
Expr::gt(Expr::var("offset"), Expr::var("op")),
),
vec![fail(STATUS_INVALID_OFFSET)],
)],
)
}
pub fn copy_match() -> Node {
Node::Loop {
var: "match_i".to_string(),
from: Expr::u32(0),
to: Expr::u32(LOOP_LIMIT),
body: vec![Node::if_then(
and(
active(),
Expr::lt(Expr::var("match_i"), Expr::var("match_len")),
),
vec![Node::if_then_else(
Expr::lt(Expr::var("op"), output_len()),
vec![
Node::let_bind(
"match_byte",
read_byte("out", Expr::sub(Expr::var("op"), Expr::var("offset"))),
),
store_byte("out", Expr::var("op"), Expr::var("match_byte")),
Node::assign("op", Expr::add(Expr::var("op"), Expr::u32(1))),
],
vec![fail(STATUS_OUTPUT_OVERFLOW)],
)],
)],
}
}
pub fn read_byte(buffer: &'static str, pos: Expr) -> Expr {
let word = Expr::load(buffer, Expr::shr(pos.clone(), Expr::u32(2)));
let shift = Expr::mul(Expr::bitand(pos, Expr::u32(3)), Expr::u32(8));
Expr::bitand(Expr::shr(word, shift), Expr::u32(0xff))
}
pub fn store_byte(buffer: &'static str, pos: Expr, byte: Expr) -> Node {
let word_index = Expr::shr(pos.clone(), Expr::u32(2));
let shift = Expr::mul(Expr::bitand(pos, Expr::u32(3)), Expr::u32(8));
let mask = Expr::shl(Expr::u32(0xff), shift.clone());
let cleared = Expr::bitand(
Expr::load(buffer, word_index.clone()),
Expr::bitnot(mask.clone()),
);
let inserted = Expr::shl(Expr::bitand(byte, Expr::u32(0xff)), shift);
Node::store(
buffer,
word_index,
Expr::bitor(cleared, Expr::bitand(inserted, mask)),
)
}
pub fn fail(code: u32) -> Node {
Node::Block(vec![
Node::store("status", Expr::u32(0), Expr::u32(code)),
Node::assign("is_active", Expr::u32(0)),
])
}
pub fn input_len() -> Expr {
Expr::load("params", Expr::u32(0))
}
pub fn output_len() -> Expr {
Expr::load("params", Expr::u32(1))
}
pub fn active() -> Expr {
Expr::ne(Expr::var("is_active"), Expr::u32(0))
}
pub fn and(left: Expr, right: Expr) -> Expr {
Expr::bitand(left, right)
}
pub fn or(left: Expr, right: Expr) -> Expr {
Expr::bitor(left, right)
}