use vyre::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use vyre_foundation::execution_plan::fusion::{fuse_programs_vec, FusionError};
use crate::region::wrap_anonymous;
const EMIT_HIT_OP_ID: &str = "vyre-libs::matching::emit_hit";
const COMPACT_HITS_OP_ID: &str = "vyre-libs::matching::compact_hits";
const DEFAULT_LANES: u32 = 4;
const DEFAULT_MAX_HITS: u32 = 4;
pub const HIT_BUFFER_OVERFLOW_COUNT: &str = "hit_buffer_overflow_count";
pub const HIT_BUFFER_LIVE_LENGTH: &str = "hit_buffer_live_length";
#[must_use]
pub fn emit_hit(
rule_id: &str,
file_id: &str,
span_start: &str,
span_len: &str,
out_hits: &str,
out_cursor: &str,
) -> Program {
emit_hit_with_layout(
rule_id,
file_id,
span_start,
span_len,
out_hits,
out_cursor,
DEFAULT_LANES,
DEFAULT_MAX_HITS,
)
}
#[must_use]
#[allow(clippy::too_many_arguments)]
pub fn emit_hit_with_layout(
rule_id: &str,
file_id: &str,
span_start: &str,
span_len: &str,
out_hits: &str,
out_cursor: &str,
lane_count: u32,
max_hits: u32,
) -> Program {
let lane = Expr::var("lane");
let base = Expr::mul(lane.clone(), Expr::u32(4));
let max_capacity = Expr::div(Expr::buf_len(out_hits), Expr::u32(4));
let body = vec![
Node::let_bind("lane", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::lt(lane.clone(), Expr::buf_len(rule_id)),
vec![Node::if_then(
Expr::lt(lane.clone(), max_capacity.clone()),
vec![
Node::store(out_hits, base.clone(), Expr::load(rule_id, lane.clone())),
Node::store(
out_hits,
Expr::add(base.clone(), Expr::u32(1)),
Expr::load(file_id, lane.clone()),
),
Node::store(
out_hits,
Expr::add(base.clone(), Expr::u32(2)),
Expr::load(span_start, lane.clone()),
),
Node::store(
out_hits,
Expr::add(base, Expr::u32(3)),
Expr::load(span_len, lane),
),
],
)],
),
Node::if_then(
Expr::eq(Expr::var("lane"), Expr::u32(0)),
vec![
Node::store(
out_cursor,
Expr::u32(0),
Expr::min(Expr::buf_len(rule_id), max_capacity.clone()),
),
Node::if_then(
Expr::lt(max_capacity.clone(), Expr::buf_len(rule_id)),
vec![Node::store(
HIT_BUFFER_OVERFLOW_COUNT,
Expr::u32(0),
Expr::sub(Expr::buf_len(rule_id), max_capacity),
)],
),
],
),
];
Program::wrapped(
vec![
BufferDecl::storage(rule_id, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(lane_count),
BufferDecl::storage(file_id, 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(lane_count),
BufferDecl::storage(span_start, 2, BufferAccess::ReadOnly, DataType::U32)
.with_count(lane_count),
BufferDecl::storage(span_len, 3, BufferAccess::ReadOnly, DataType::U32)
.with_count(lane_count),
BufferDecl::output(out_hits, 4, DataType::U32).with_count(max_hits.saturating_mul(4)),
BufferDecl::read_write(out_cursor, 5, DataType::U32).with_count(1),
BufferDecl::read_write(HIT_BUFFER_OVERFLOW_COUNT, 6, DataType::U32).with_count(1),
],
[64, 1, 1],
vec![wrap_anonymous(EMIT_HIT_OP_ID, body)],
)
}
#[must_use]
pub fn compact_hits(out_hits: &str, out_cursor: &str, max_capacity: u32) -> Program {
compact_hits_with_layout(out_hits, out_cursor, max_capacity, max_capacity)
}
#[must_use]
pub fn compact_hits_with_layout(
out_hits: &str,
out_cursor: &str,
hit_capacity: u32,
max_capacity: u32,
) -> Program {
let body = vec![
Node::let_bind("lane", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::eq(Expr::var("lane"), Expr::u32(0)),
vec![
Node::let_bind("cursor", Expr::load(out_cursor, Expr::u32(0))),
Node::let_bind(
"buffer_cap",
Expr::div(Expr::buf_len(out_hits), Expr::u32(4)),
),
Node::let_bind(
"live_len",
Expr::min(
Expr::var("cursor"),
Expr::min(Expr::u32(max_capacity), Expr::var("buffer_cap")),
),
),
Node::store(HIT_BUFFER_LIVE_LENGTH, Expr::u32(0), Expr::var("live_len")),
],
),
];
Program::wrapped(
vec![
BufferDecl::storage(out_hits, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(hit_capacity.saturating_mul(4)),
BufferDecl::storage(out_cursor, 1, BufferAccess::ReadOnly, DataType::U32).with_count(1),
BufferDecl::output(HIT_BUFFER_LIVE_LENGTH, 2, DataType::U32).with_count(1),
],
[1, 1, 1],
vec![wrap_anonymous(COMPACT_HITS_OP_ID, body)],
)
}
#[must_use]
#[allow(clippy::too_many_arguments)]
pub fn emit_hit_then_compact_with_layout(
rule_id: &str,
file_id: &str,
span_start: &str,
span_len: &str,
out_hits: &str,
out_cursor: &str,
lane_count: u32,
max_hits: u32,
) -> Result<Program, FusionError> {
let emit = emit_hit_with_layout(
rule_id, file_id, span_start, span_len, out_hits, out_cursor, lane_count, max_hits,
);
let compact = compact_hits_with_layout(out_hits, out_cursor, max_hits, max_hits);
fuse_programs_vec(vec![emit, compact])
}
#[must_use]
pub fn emit_hit_then_compact(
rule_id: &str,
file_id: &str,
span_start: &str,
span_len: &str,
out_hits: &str,
out_cursor: &str,
) -> Result<Program, FusionError> {
emit_hit_then_compact_with_layout(
rule_id,
file_id,
span_start,
span_len,
out_hits,
out_cursor,
DEFAULT_LANES,
DEFAULT_MAX_HITS,
)
}
fn emit_hit_inputs() -> Vec<Vec<Vec<u8>>> {
vec![vec![
pack_words(&[7, 9, 11, 13]),
pack_words(&[101, 103, 107, 109]),
pack_words(&[5, 9, 13, 17]),
pack_words(&[2, 4, 6, 8]),
pack_words(&[0]),
pack_words(&[0]),
]]
}
fn emit_hit_expected_output() -> Vec<Vec<Vec<u8>>> {
vec![vec![
pack_words(&[7, 101, 5, 2, 9, 103, 9, 4, 11, 107, 13, 6, 13, 109, 17, 8]),
pack_words(&[4]),
pack_words(&[0]),
]]
}
fn compact_hits_inputs() -> Vec<Vec<Vec<u8>>> {
vec![vec![
pack_words(&[7, 101, 5, 2, 9, 103, 9, 4, 11, 107, 13, 6, 13, 109, 17, 8]),
pack_words(&[7]),
]]
}
fn compact_hits_expected_output() -> Vec<Vec<Vec<u8>>> {
vec![vec![pack_words(&[DEFAULT_MAX_HITS])]]
}
use crate::scan::dispatch_io::pack_u32_slice as pack_words;
#[cfg(test)]
mod emit_then_compact_tests {
use super::*;
#[test]
fn fused_program_builds() {
let fused = emit_hit_then_compact(
"rule_id",
"file_id",
"span_start",
"span_len",
"out_hits",
"out_cursor",
)
.expect("Fix: emit_hit and compact_hits must fuse");
assert_eq!(fused.workgroup_size(), [64, 1, 1]);
assert_eq!(fused.buffers().len(), 8);
assert_eq!(fused.buffers()[0].name(), "rule_id");
assert_eq!(fused.buffers()[4].name(), "out_hits");
assert_eq!(fused.buffers()[7].name(), "hit_buffer_live_length");
}
}
inventory::submit! {
crate::harness::OpEntry::new(
EMIT_HIT_OP_ID,
|| emit_hit(
"rule_id",
"file_id",
"span_start",
"span_len",
"out_hits",
"out_cursor",
),
Some(emit_hit_inputs),
Some(emit_hit_expected_output),
)
}
inventory::submit! {
crate::harness::OpEntry::new(
COMPACT_HITS_OP_ID,
|| compact_hits("out_hits", "out_cursor", DEFAULT_MAX_HITS),
Some(compact_hits_inputs),
Some(compact_hits_expected_output),
)
}