use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::AlgebraicLaw;
use crate::ops::{OpSpec, BYTES_TO_U32_OUTPUTS, U32_INPUTS};
#[derive(Debug, Clone, Copy, Default)]
pub struct DfaScan;
impl DfaScan {
pub const SPEC: OpSpec =
OpSpec::composition(OP_ID, U32_INPUTS, BYTES_TO_U32_OUTPUTS, LAWS, Self::program);
#[must_use]
pub fn program() -> Program {
let pos = Expr::var("pos");
let state = Expr::var("state");
let pattern_id = Expr::var("pattern_id");
Program::new(
vec![
BufferDecl::read("input_bytes", 0, DataType::U32),
BufferDecl::read("transitions", 1, DataType::U32),
BufferDecl::read("accept_map", 2, DataType::U32),
BufferDecl::output("matches", 3, DataType::U32),
BufferDecl::read_write("match_count", 4, DataType::U32),
BufferDecl::read("params", 5, DataType::U32),
BufferDecl::read("pattern_lengths", 6, DataType::U32),
],
WORKGROUP_SIZE,
vec![
Node::let_bind("worker", Expr::gid_x()),
Node::if_then(
Expr::eq(Expr::var("worker"), Expr::u32(0)),
vec![
Node::let_bind("state", Expr::u32(0)),
Node::let_bind("pos", Expr::u32(0)),
Node::Loop {
var: "step".to_string(),
from: Expr::u32(0),
to: Expr::load("params", Expr::u32(0)),
body: vec![
Node::let_bind("byte_val", read_byte(pos.clone())),
Node::assign(
"state",
transition(state.clone(), Expr::var("byte_val")),
),
Node::let_bind(
"pattern_id",
Expr::load("accept_map", state.clone()),
),
Node::if_then(
Expr::ne(pattern_id.clone(), Expr::u32(NO_MATCH)),
vec![report_match(
pattern_id,
Expr::add(pos.clone(), Expr::u32(1)),
)],
),
Node::assign("pos", Expr::add(pos, Expr::u32(1))),
],
},
],
),
],
)
.with_entry_op_id(OP_ID)
}
}
pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}];
pub const NO_MATCH: u32 = u32::MAX;
pub const OP_ID: &str = "match.dfa_scan";
pub fn read_byte(pos: Expr) -> Expr {
let word = Expr::load("input_bytes", Expr::shr(pos.clone(), Expr::u32(2)));
let shift = Expr::mul(Expr::bitand(pos, Expr::u32(3)), Expr::u32(8));
Expr::bitand(Expr::shr(word, shift), Expr::u32(0xff))
}
pub fn report_match(pattern_id: Expr, end_pos: Expr) -> Node {
Node::Block(vec![
Node::let_bind(
"pattern_len",
Expr::load("pattern_lengths", pattern_id.clone()),
),
Node::let_bind(
"match_idx",
Expr::atomic_add("match_count", Expr::u32(0), Expr::u32(1)),
),
Node::if_then(
Expr::lt(Expr::var("match_idx"), Expr::load("params", Expr::u32(2))),
vec![
Node::store(
"matches",
Expr::mul(Expr::var("match_idx"), Expr::u32(3)),
pattern_id,
),
Node::store(
"matches",
Expr::add(
Expr::mul(Expr::var("match_idx"), Expr::u32(3)),
Expr::u32(1),
),
Expr::sub(end_pos.clone(), Expr::var("pattern_len")),
),
Node::store(
"matches",
Expr::add(
Expr::mul(Expr::var("match_idx"), Expr::u32(3)),
Expr::u32(2),
),
end_pos,
),
],
),
])
}
pub fn transition(state: Expr, byte: Expr) -> Expr {
Expr::load(
"transitions",
Expr::add(Expr::mul(state, Expr::u32(256)), byte),
)
}
pub const WORKGROUP_SIZE: [u32; 3] = [1, 1, 1];