use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use vyre_primitives::bitset::bitset_words;
use vyre_primitives::graph::persistent_bfs::{
BINDING_CHANGED, BINDING_FRONTIER_IN, BINDING_FRONTIER_OUT,
};
use vyre_primitives::graph::program_graph::{
ProgramGraphShape, NAME_EDGE_KIND_MASK, NAME_EDGE_OFFSETS, NAME_EDGE_TARGETS,
};
pub const OP_ID: &str = "vyre-self-substrate::optimizer::dce_program";
const DCE_WORKGROUP_X: u32 = 1024;
fn parallel_csr_step_per_thread_masked(node_count: u32, allow_mask: u32) -> Vec<Node> {
let stride_count = (node_count + DCE_WORKGROUP_X - 1) / DCE_WORKGROUP_X;
vec![Node::loop_for(
"stride",
Expr::u32(0),
Expr::u32(stride_count.max(1)),
vec![
Node::let_bind(
"src",
Expr::add(
Expr::gid_x(),
Expr::mul(Expr::var("stride"), Expr::u32(DCE_WORKGROUP_X)),
),
),
Node::if_then(
Expr::lt(Expr::var("src"), Expr::u32(node_count)),
vec![
Node::let_bind("src_word_idx", Expr::shr(Expr::var("src"), Expr::u32(5))),
Node::let_bind(
"src_bit_mask",
Expr::shl(Expr::u32(1), Expr::bitand(Expr::var("src"), Expr::u32(31))),
),
Node::let_bind(
"src_word",
Expr::load("frontier_out", Expr::var("src_word_idx")),
),
Node::if_then(
Expr::ne(
Expr::bitand(Expr::var("src_word"), Expr::var("src_bit_mask")),
Expr::u32(0),
),
vec![
Node::let_bind(
"edge_start",
Expr::load(NAME_EDGE_OFFSETS, Expr::var("src")),
),
Node::let_bind(
"edge_end",
Expr::load(
NAME_EDGE_OFFSETS,
Expr::add(Expr::var("src"), Expr::u32(1)),
),
),
Node::loop_for(
"e",
Expr::var("edge_start"),
Expr::var("edge_end"),
vec![
Node::let_bind(
"kind_mask",
Expr::load(NAME_EDGE_KIND_MASK, Expr::var("e")),
),
Node::if_then(
Expr::ne(
Expr::bitand(
Expr::var("kind_mask"),
Expr::u32(allow_mask),
),
Expr::u32(0),
),
vec![
Node::let_bind(
"dst",
Expr::load(NAME_EDGE_TARGETS, Expr::var("e")),
),
Node::if_then(
Expr::lt(Expr::var("dst"), Expr::u32(node_count)),
vec![
Node::let_bind(
"dst_word_idx",
Expr::shr(Expr::var("dst"), Expr::u32(5)),
),
Node::let_bind(
"dst_bit",
Expr::shl(
Expr::u32(1),
Expr::bitand(
Expr::var("dst"),
Expr::u32(31),
),
),
),
Node::let_bind(
"old",
Expr::atomic_or(
"frontier_out",
Expr::var("dst_word_idx"),
Expr::var("dst_bit"),
),
),
Node::if_then(
Expr::eq(
Expr::bitand(
Expr::var("old"),
Expr::var("dst_bit"),
),
Expr::u32(0),
),
vec![Node::assign(
"local_changed",
Expr::u32(1),
)],
),
],
),
],
),
],
),
],
),
],
),
],
)]
}
#[must_use]
pub fn build_persistent_bfs_program(
shape: ProgramGraphShape,
max_iters: u32,
allow_mask: u32,
) -> Program {
build_persistent_bfs_program_sticky(shape, max_iters, allow_mask)
}
#[must_use]
pub fn build_dce_bfs_program(shape: ProgramGraphShape, max_iters: u32) -> Program {
build_persistent_bfs_program_inner(shape, max_iters, u32::MAX)
}
fn build_persistent_bfs_program_inner(
shape: ProgramGraphShape,
max_iters: u32,
allow_mask: u32,
) -> Program {
build_persistent_bfs_program_internal(shape, max_iters, allow_mask, false)
}
fn build_persistent_bfs_program_sticky(
shape: ProgramGraphShape,
max_iters: u32,
allow_mask: u32,
) -> Program {
build_persistent_bfs_program_internal(shape, max_iters, allow_mask, true)
}
fn build_persistent_bfs_program_internal(
shape: ProgramGraphShape,
max_iters: u32,
allow_mask: u32,
sticky_changed: bool,
) -> Program {
let words = bitset_words(shape.node_count);
let t = Expr::gid_x();
let mut iter_body: Vec<Node> = vec![
Node::if_then(
Expr::eq(t.clone(), Expr::u32(0)),
vec![Node::store("changed", Expr::u32(0), Expr::u32(0))],
),
Node::barrier(),
Node::let_bind("local_changed", Expr::u32(0)),
Node::if_then(
Expr::lt(t.clone(), Expr::u32(shape.node_count)),
parallel_csr_step_per_thread_masked(shape.node_count, allow_mask),
),
Node::if_then(
Expr::eq(Expr::var("local_changed"), Expr::u32(1)),
vec![Node::let_bind(
"_dce_set",
Expr::atomic_or("changed", Expr::u32(0), Expr::u32(1)),
)],
),
];
if sticky_changed {
iter_body.push(Node::if_then(
Expr::eq(Expr::var("local_changed"), Expr::u32(1)),
vec![Node::let_bind(
"_sticky_set",
Expr::atomic_or("changed", Expr::u32(1), Expr::u32(1)),
)],
));
}
iter_body.push(Node::barrier());
iter_body.push(Node::if_then(
Expr::eq(Expr::load("changed", Expr::u32(0)), Expr::u32(0)),
vec![Node::Return],
));
let entry: Vec<Node> = vec![
Node::if_then(
Expr::lt(t.clone(), Expr::u32(words)),
vec![Node::store(
"frontier_out",
t.clone(),
Expr::load("frontier_in", t.clone()),
)],
),
Node::barrier(),
Node::loop_for("iter", Expr::u32(0), Expr::u32(max_iters.max(1)), iter_body),
];
let mut buffers = shape.read_only_buffers();
buffers.push(
BufferDecl::storage(
"frontier_in",
BINDING_FRONTIER_IN,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(words.max(1)),
);
buffers.push(
BufferDecl::storage(
"frontier_out",
BINDING_FRONTIER_OUT,
BufferAccess::ReadWrite,
DataType::U32,
)
.with_count(words.max(1)),
);
buffers.push(
BufferDecl::storage(
"changed",
BINDING_CHANGED,
BufferAccess::ReadWrite,
DataType::U32,
)
.with_count(if sticky_changed { 2 } else { 1 }),
);
buffers.push(BufferDecl::workgroup("wg_scratch", 256, DataType::U32));
Program::wrapped(
buffers,
[1024, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(entry),
}],
)
}