use std::sync::Arc;
use vyre_foundation::ir::{BinOp, BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use crate::dispatch_buffers::{
decode_u32_output_exact, ensure_input_slots, write_u32_slice_le_bytes, write_zero_bytes,
};
use super::dispatcher::{DispatchError, OptimizerDispatcher};
use super::encode::EncodeError;
use super::expr_arena::{encode_expr_arena, expr_kind, ExprArenaEncoding};
#[derive(Debug, Default)]
struct ConstFoldKernelScratch {
inputs: Vec<Vec<u8>>,
current_level: [u32; 1],
}
#[derive(Debug)]
pub enum ConstFoldError {
Encode(EncodeError),
Dispatch(DispatchError),
}
impl std::fmt::Display for ConstFoldError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Encode(err) => write!(f, "gpu_const_fold encode error: {err:?}"),
Self::Dispatch(err) => write!(f, "gpu_const_fold dispatch error: {err}"),
}
}
}
impl std::error::Error for ConstFoldError {}
pub fn gpu_const_fold(
program: Program,
dispatcher: &dyn OptimizerDispatcher,
) -> Result<Program, ConstFoldError> {
let arena = encode_expr_arena(&program).map_err(ConstFoldError::Encode)?;
if arena.expr_count == 0 {
return Ok(program);
}
let mut scratch = ConstFoldKernelScratch::default();
let mut foldable = Vec::with_capacity(arena.expr_count as usize);
let mut value = Vec::with_capacity(arena.expr_count as usize);
run_const_fold_kernel_with_scratch_into(
&arena,
dispatcher,
&mut scratch,
&mut foldable,
&mut value,
)
.map_err(ConstFoldError::Dispatch)?;
Ok(rewrite_program_with_folded_values(
program, &arena, &foldable, &value,
))
}
#[cfg(test)]
fn run_const_fold_kernel_into(
arena: &ExprArenaEncoding,
dispatcher: &dyn OptimizerDispatcher,
foldable: &mut Vec<u32>,
value: &mut Vec<u32>,
) -> Result<(), DispatchError> {
let mut scratch = ConstFoldKernelScratch::default();
run_const_fold_kernel_with_scratch_into(arena, dispatcher, &mut scratch, foldable, value)
}
fn run_const_fold_kernel_with_scratch_into(
arena: &ExprArenaEncoding,
dispatcher: &dyn OptimizerDispatcher,
scratch: &mut ConstFoldKernelScratch,
foldable: &mut Vec<u32>,
value: &mut Vec<u32>,
) -> Result<(), DispatchError> {
let n = arena.expr_count;
let analysis = build_const_fold_program(n);
let words = n as usize;
let state_bytes = words
.checked_mul(std::mem::size_of::<u32>())
.ok_or_else(|| {
DispatchError::BadInputs(format!(
"Fix: const-fold state byte count overflows usize for expr_count={n}."
))
})?;
ensure_input_slots(&mut scratch.inputs, 8);
write_u32_slice_le_bytes(&mut scratch.inputs[0], &arena.kinds);
write_u32_slice_le_bytes(&mut scratch.inputs[1], &arena.arg0);
write_u32_slice_le_bytes(&mut scratch.inputs[2], &arena.arg1);
write_u32_slice_le_bytes(&mut scratch.inputs[3], &arena.arg2);
write_u32_slice_le_bytes(&mut scratch.inputs[4], &arena.depths);
write_zero_bytes(&mut scratch.inputs[6], state_bytes);
write_zero_bytes(&mut scratch.inputs[7], state_bytes);
let grid_x = (n + WORKGROUP_X - 1) / WORKGROUP_X;
for level in 0..=arena.max_depth {
scratch.current_level[0] = level;
write_u32_slice_le_bytes(&mut scratch.inputs[5], &scratch.current_level);
let outputs = dispatcher.dispatch(&analysis, &scratch.inputs, Some([grid_x, 1, 1]))?;
if outputs.len() != 2 {
return Err(DispatchError::BackendError(format!(
"Fix: const-fold dispatch expected exactly 2 outputs (foldable, value), got {}.",
outputs.len()
)));
}
decode_u32_output_exact(&outputs[0], words, "const-fold foldable", foldable)?;
decode_u32_output_exact(&outputs[1], words, "const-fold value", value)?;
scratch.inputs[6].clear();
scratch.inputs[6].extend_from_slice(&outputs[0]);
scratch.inputs[7].clear();
scratch.inputs[7].extend_from_slice(&outputs[1]);
}
Ok(())
}
const WORKGROUP_X: u32 = 256;
#[must_use]
pub fn build_const_fold_program_fused(expr_count: u32, max_depth_iter_cap: u32) -> Program {
let buffers = vec![
BufferDecl::storage("arena_kinds", 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(expr_count.max(1)),
BufferDecl::storage("arena_arg0", 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(expr_count.max(1)),
BufferDecl::storage("arena_arg1", 2, BufferAccess::ReadOnly, DataType::U32)
.with_count(expr_count.max(1)),
BufferDecl::storage("arena_arg2", 3, BufferAccess::ReadOnly, DataType::U32)
.with_count(expr_count.max(1)),
BufferDecl::storage("arena_depths", 4, BufferAccess::ReadOnly, DataType::U32)
.with_count(expr_count.max(1)),
BufferDecl::storage("max_depth_buf", 5, BufferAccess::ReadOnly, DataType::U32)
.with_count(1),
BufferDecl::storage("foldable", 6, BufferAccess::ReadWrite, DataType::U32)
.with_count(expr_count.max(1)),
BufferDecl::storage("value", 7, BufferAccess::ReadWrite, DataType::U32)
.with_count(expr_count.max(1)),
];
let chunk_cap = (expr_count + WORKGROUP_X - 1) / WORKGROUP_X;
let chunk_loop = Node::loop_for(
"chunk",
Expr::u32(0),
Expr::u32(chunk_cap.max(1)),
vec![
Node::let_bind(
"i",
Expr::add(
Expr::gid_x(),
Expr::mul(Expr::var("chunk"), Expr::u32(WORKGROUP_X)),
),
),
Node::if_then(
Expr::lt(Expr::var("i"), Expr::u32(expr_count)),
vec![
Node::let_bind("my_depth", Expr::load("arena_depths", Expr::var("i"))),
Node::if_then(
Expr::eq(Expr::var("my_depth"), Expr::var("level")),
per_expr_body(),
),
],
),
],
);
let outer = Node::loop_for(
"level",
Expr::u32(0),
Expr::u32(max_depth_iter_cap.max(1)),
vec![
Node::let_bind("md", Expr::load("max_depth_buf", Expr::u32(0))),
Node::if_then(
Expr::le(Expr::var("level"), Expr::var("md")),
vec![chunk_loop],
),
Node::Barrier {
ordering: vyre_foundation::MemoryOrdering::SeqCst,
},
],
);
Program::wrapped(buffers, [WORKGROUP_X, 1, 1], vec![outer])
}
pub fn build_const_fold_program(expr_count: u32) -> Program {
let buffers = vec![
BufferDecl::storage("arena_kinds", 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(expr_count.max(1)),
BufferDecl::storage("arena_arg0", 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(expr_count.max(1)),
BufferDecl::storage("arena_arg1", 2, BufferAccess::ReadOnly, DataType::U32)
.with_count(expr_count.max(1)),
BufferDecl::storage("arena_arg2", 3, BufferAccess::ReadOnly, DataType::U32)
.with_count(expr_count.max(1)),
BufferDecl::storage("arena_depths", 4, BufferAccess::ReadOnly, DataType::U32)
.with_count(expr_count.max(1)),
BufferDecl::storage("current_level", 5, BufferAccess::ReadOnly, DataType::U32)
.with_count(1),
BufferDecl::storage("foldable", 6, BufferAccess::ReadWrite, DataType::U32)
.with_count(expr_count.max(1)),
BufferDecl::storage("value", 7, BufferAccess::ReadWrite, DataType::U32)
.with_count(expr_count.max(1)),
];
let body = vec![
Node::let_bind("i", Expr::gid_x()),
Node::if_then(
Expr::lt(Expr::var("i"), Expr::u32(expr_count)),
vec![
Node::let_bind("my_depth", Expr::load("arena_depths", Expr::var("i"))),
Node::let_bind("level", Expr::load("current_level", Expr::u32(0))),
Node::if_then(
Expr::eq(Expr::var("my_depth"), Expr::var("level")),
per_expr_body(),
),
],
),
];
Program::wrapped(buffers, [WORKGROUP_X, 1, 1], body)
}
fn per_expr_body() -> Vec<Node> {
vec![
Node::let_bind("kind", Expr::load("arena_kinds", Expr::var("i"))),
Node::if_then(
Expr::eq(Expr::var("kind"), Expr::u32(expr_kind::LIT_U32)),
vec![
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
Node::store(
"value",
Expr::var("i"),
Expr::load("arena_arg0", Expr::var("i")),
),
],
),
Node::if_then(
Expr::eq(Expr::var("kind"), Expr::u32(expr_kind::BIN_OP)),
bin_op_body(),
),
]
}
fn bin_op_body() -> Vec<Node> {
vec![
Node::let_bind("op", Expr::load("arena_arg0", Expr::var("i"))),
Node::let_bind("l", Expr::load("arena_arg1", Expr::var("i"))),
Node::let_bind("r", Expr::load("arena_arg2", Expr::var("i"))),
Node::let_bind("lf", Expr::load("foldable", Expr::var("l"))),
Node::let_bind("rf", Expr::load("foldable", Expr::var("r"))),
Node::if_then(
Expr::and(
Expr::eq(Expr::var("lf"), Expr::u32(1)),
Expr::eq(Expr::var("rf"), Expr::u32(1)),
),
vec![
Node::let_bind("lv", Expr::load("value", Expr::var("l"))),
Node::let_bind("rv", Expr::load("value", Expr::var("r"))),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x01)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::add(Expr::var("lv"), Expr::var("rv")),
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x02)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::sub(Expr::var("lv"), Expr::var("rv")),
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x03)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::mul(Expr::var("lv"), Expr::var("rv")),
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x06)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::bitand(Expr::var("lv"), Expr::var("rv")),
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x07)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::bitor(Expr::var("lv"), Expr::var("rv")),
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x08)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::bitxor(Expr::var("lv"), Expr::var("rv")),
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x09)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::shl(Expr::var("lv"), Expr::var("rv")),
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x0A)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::shr(Expr::var("lv"), Expr::var("rv")),
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::and(
Expr::eq(Expr::var("op"), Expr::u32(0x04)),
Expr::ne(Expr::var("rv"), Expr::u32(0)),
),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::div(Expr::var("lv"), Expr::var("rv")),
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::and(
Expr::eq(Expr::var("op"), Expr::u32(0x05)),
Expr::ne(Expr::var("rv"), Expr::u32(0)),
),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::rem(Expr::var("lv"), Expr::var("rv")),
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x15)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::Select {
cond: Box::new(Expr::lt(Expr::var("lv"), Expr::var("rv"))),
true_val: Box::new(Expr::var("lv")),
false_val: Box::new(Expr::var("rv")),
},
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x16)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::Select {
cond: Box::new(Expr::gt(Expr::var("lv"), Expr::var("rv"))),
true_val: Box::new(Expr::var("lv")),
false_val: Box::new(Expr::var("rv")),
},
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x14)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::Select {
cond: Box::new(Expr::gt(Expr::var("lv"), Expr::var("rv"))),
true_val: Box::new(Expr::sub(Expr::var("lv"), Expr::var("rv"))),
false_val: Box::new(Expr::sub(Expr::var("rv"), Expr::var("lv"))),
},
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x17)),
vec![
Node::let_bind("sat_sum", Expr::add(Expr::var("lv"), Expr::var("rv"))),
Node::store(
"value",
Expr::var("i"),
Expr::Select {
cond: Box::new(Expr::lt(Expr::var("sat_sum"), Expr::var("lv"))),
true_val: Box::new(Expr::u32(u32::MAX)),
false_val: Box::new(Expr::var("sat_sum")),
},
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x18)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::Select {
cond: Box::new(Expr::ge(Expr::var("lv"), Expr::var("rv"))),
true_val: Box::new(Expr::sub(Expr::var("lv"), Expr::var("rv"))),
false_val: Box::new(Expr::u32(0)),
},
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x20)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::add(Expr::var("lv"), Expr::var("rv")),
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x21)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::sub(Expr::var("lv"), Expr::var("rv")),
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x0B)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::Select {
cond: Box::new(Expr::eq(Expr::var("lv"), Expr::var("rv"))),
true_val: Box::new(Expr::u32(1)),
false_val: Box::new(Expr::u32(0)),
},
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x0C)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::Select {
cond: Box::new(Expr::ne(Expr::var("lv"), Expr::var("rv"))),
true_val: Box::new(Expr::u32(1)),
false_val: Box::new(Expr::u32(0)),
},
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x0D)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::Select {
cond: Box::new(Expr::lt(Expr::var("lv"), Expr::var("rv"))),
true_val: Box::new(Expr::u32(1)),
false_val: Box::new(Expr::u32(0)),
},
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x0E)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::Select {
cond: Box::new(Expr::gt(Expr::var("lv"), Expr::var("rv"))),
true_val: Box::new(Expr::u32(1)),
false_val: Box::new(Expr::u32(0)),
},
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x10)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::Select {
cond: Box::new(Expr::le(Expr::var("lv"), Expr::var("rv"))),
true_val: Box::new(Expr::u32(1)),
false_val: Box::new(Expr::u32(0)),
},
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x11)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::Select {
cond: Box::new(Expr::ge(Expr::var("lv"), Expr::var("rv"))),
true_val: Box::new(Expr::u32(1)),
false_val: Box::new(Expr::u32(0)),
},
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x1E)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::BinOp {
op: BinOp::RotateLeft,
left: Box::new(Expr::var("lv")),
right: Box::new(Expr::var("rv")),
},
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x1F)),
vec![
Node::store(
"value",
Expr::var("i"),
Expr::BinOp {
op: BinOp::RotateRight,
left: Box::new(Expr::var("lv")),
right: Box::new(Expr::var("rv")),
},
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
Node::if_then(
Expr::eq(Expr::var("op"), Expr::u32(0x19)),
vec![
Node::let_bind("sm_prod", Expr::mul(Expr::var("lv"), Expr::var("rv"))),
Node::let_bind(
"sm_divisor",
Expr::Select {
cond: Box::new(Expr::eq(Expr::var("lv"), Expr::u32(0))),
true_val: Box::new(Expr::u32(1)),
false_val: Box::new(Expr::var("lv")),
},
),
Node::let_bind(
"sm_quot",
Expr::div(Expr::var("sm_prod"), Expr::var("sm_divisor")),
),
Node::let_bind(
"sm_no_overflow",
Expr::or(
Expr::eq(Expr::var("lv"), Expr::u32(0)),
Expr::eq(Expr::var("sm_quot"), Expr::var("rv")),
),
),
Node::store(
"value",
Expr::var("i"),
Expr::Select {
cond: Box::new(Expr::var("sm_no_overflow")),
true_val: Box::new(Expr::var("sm_prod")),
false_val: Box::new(Expr::u32(u32::MAX)),
},
),
Node::store("foldable", Expr::var("i"), Expr::u32(1)),
],
),
],
),
]
}
fn rewrite_program_with_folded_values(
program: Program,
arena: &ExprArenaEncoding,
foldable: &[u32],
value: &[u32],
) -> Program {
let body: Vec<Node> = match program.entry() {
[Node::Region { body, .. }] => body.as_ref().clone(),
entry => entry.to_vec(),
};
let mut counter = 0u32;
let rebuilt = rewrite_scope(&body, arena, foldable, value, &mut counter);
let new_entry = match program.entry() {
[Node::Region {
generator,
source_region,
..
}] => vec![Node::Region {
generator: generator.clone(),
source_region: source_region.clone(),
body: Arc::new(rebuilt),
}],
_ => rebuilt,
};
program.with_rewritten_entry(new_entry)
}
fn rewrite_scope(
body: &[Node],
arena: &ExprArenaEncoding,
foldable: &[u32],
value: &[u32],
counter: &mut u32,
) -> Vec<Node> {
let prefix_len = super::encode::reachable_prefix_len(body);
let mut out = Vec::with_capacity(prefix_len);
for node in &body[..prefix_len] {
out.push(rewrite_node(node, arena, foldable, value, counter));
}
out
}
fn rewrite_node(
node: &Node,
arena: &ExprArenaEncoding,
foldable: &[u32],
value: &[u32],
counter: &mut u32,
) -> Node {
match node {
Node::Let { name, value: e } => Node::let_bind(
name.clone(),
rewrite_expr(e, arena, foldable, value, counter),
),
Node::Assign { name, value: e } => Node::assign(
name.clone(),
rewrite_expr(e, arena, foldable, value, counter),
),
Node::Store {
buffer,
index,
value: e,
} => Node::store(
buffer.clone(),
rewrite_expr(index, arena, foldable, value, counter),
rewrite_expr(e, arena, foldable, value, counter),
),
Node::If {
cond,
then,
otherwise,
} => Node::if_then_else(
rewrite_expr(cond, arena, foldable, value, counter),
rewrite_scope(then, arena, foldable, value, counter),
rewrite_scope(otherwise, arena, foldable, value, counter),
),
Node::Loop {
var,
from,
to,
body,
} => Node::loop_for(
var.clone(),
rewrite_expr(from, arena, foldable, value, counter),
rewrite_expr(to, arena, foldable, value, counter),
rewrite_scope(body, arena, foldable, value, counter),
),
Node::AsyncLoad {
source,
destination,
offset,
size,
tag,
} => Node::AsyncLoad {
source: source.clone(),
destination: destination.clone(),
offset: Box::new(rewrite_expr(offset, arena, foldable, value, counter)),
size: Box::new(rewrite_expr(size, arena, foldable, value, counter)),
tag: tag.clone(),
},
Node::AsyncStore {
source,
destination,
offset,
size,
tag,
} => Node::AsyncStore {
source: source.clone(),
destination: destination.clone(),
offset: Box::new(rewrite_expr(offset, arena, foldable, value, counter)),
size: Box::new(rewrite_expr(size, arena, foldable, value, counter)),
tag: tag.clone(),
},
Node::Trap { address, tag } => Node::Trap {
address: Box::new(rewrite_expr(address, arena, foldable, value, counter)),
tag: tag.clone(),
},
Node::Block(body) => Node::Block(rewrite_scope(body, arena, foldable, value, counter)),
Node::Region {
generator,
source_region,
body,
} => Node::Region {
generator: generator.clone(),
source_region: source_region.clone(),
body: Arc::new(rewrite_scope(
body.as_slice(),
arena,
foldable,
value,
counter,
)),
},
Node::Return
| Node::Barrier { .. }
| Node::IndirectDispatch { .. }
| Node::AsyncWait { .. }
| Node::Resume { .. }
| Node::Opaque(_) => node.clone(),
_ => node.clone(),
}
}
#[allow(clippy::only_used_in_recursion)]
fn rewrite_expr(
expr: &Expr,
arena: &ExprArenaEncoding,
foldable: &[u32],
value: &[u32],
counter: &mut u32,
) -> Expr {
match expr {
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::Var(_)
| Expr::BufLen { .. }
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize => {
let id = *counter;
*counter += 1;
decide(expr, id, foldable, value)
}
Expr::Load { buffer, index } => {
let new_index = rewrite_expr(index, arena, foldable, value, counter);
let id = *counter;
*counter += 1;
let _ = (foldable, value, id);
Expr::Load {
buffer: buffer.clone(),
index: Box::new(new_index),
}
}
Expr::BinOp { op, left, right } => {
let new_left = rewrite_expr(left, arena, foldable, value, counter);
let new_right = rewrite_expr(right, arena, foldable, value, counter);
let id = *counter;
*counter += 1;
if foldable[id as usize] == 1 {
Expr::LitU32(value[id as usize])
} else {
Expr::BinOp {
op: *op,
left: Box::new(new_left),
right: Box::new(new_right),
}
}
}
Expr::UnOp { op, operand } => {
let new_operand = rewrite_expr(operand, arena, foldable, value, counter);
let id = *counter;
*counter += 1;
if foldable[id as usize] == 1 {
Expr::LitU32(value[id as usize])
} else {
Expr::UnOp {
op: op.clone(),
operand: Box::new(new_operand),
}
}
}
Expr::Select {
cond,
true_val,
false_val,
} => {
let new_cond = rewrite_expr(cond, arena, foldable, value, counter);
let new_true = rewrite_expr(true_val, arena, foldable, value, counter);
let new_false = rewrite_expr(false_val, arena, foldable, value, counter);
let id = *counter;
*counter += 1;
let _ = (foldable, value, id);
Expr::Select {
cond: Box::new(new_cond),
true_val: Box::new(new_true),
false_val: Box::new(new_false),
}
}
Expr::Fma { a, b, c } => {
let na = rewrite_expr(a, arena, foldable, value, counter);
let nb = rewrite_expr(b, arena, foldable, value, counter);
let nc = rewrite_expr(c, arena, foldable, value, counter);
let id = *counter;
*counter += 1;
let _ = (foldable, value, id);
Expr::Fma {
a: Box::new(na),
b: Box::new(nb),
c: Box::new(nc),
}
}
_ => expr.clone(),
}
}
fn decide(expr: &Expr, id: u32, foldable: &[u32], value: &[u32]) -> Expr {
if foldable[id as usize] == 1 {
match expr {
Expr::LitU32(_) | Expr::LitI32(_) | Expr::LitF32(_) | Expr::LitBool(_) => expr.clone(),
_ => Expr::LitU32(value[id as usize]),
}
} else {
expr.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dispatch_buffers::u32_slice_to_le_bytes;
struct ConstFoldDispatcher {
outputs: Vec<Vec<u8>>,
}
impl OptimizerDispatcher for ConstFoldDispatcher {
fn dispatch(
&self,
_program: &Program,
inputs: &[Vec<u8>],
grid_override: Option<[u32; 3]>,
) -> Result<Vec<Vec<u8>>, DispatchError> {
assert_eq!(grid_override, Some([1, 1, 1]));
if inputs.len() != 8 {
return Err(DispatchError::BadInputs(format!(
"Fix: const-fold test dispatcher expected 8 inputs, got {}.",
inputs.len()
)));
}
Ok(self.outputs.clone())
}
}
fn one_expr_arena() -> ExprArenaEncoding {
ExprArenaEncoding {
expr_count: 1,
kinds: vec![expr_kind::LIT_U32],
arg0: vec![0],
arg1: vec![0],
arg2: vec![0],
depths: vec![0],
max_depth: 0,
..ExprArenaEncoding::default()
}
}
#[test]
fn kernel_into_decodes_exact_outputs_into_reused_buffers() {
let dispatcher = ConstFoldDispatcher {
outputs: vec![u32_slice_to_le_bytes(&[1]), u32_slice_to_le_bytes(&[7])],
};
let mut foldable = Vec::with_capacity(4);
let mut value = Vec::with_capacity(4);
let foldable_ptr = foldable.as_ptr();
let value_ptr = value.as_ptr();
run_const_fold_kernel_into(&one_expr_arena(), &dispatcher, &mut foldable, &mut value)
.expect("Fix: dispatch succeeds");
assert_eq!(foldable, vec![1]);
assert_eq!(value, vec![7]);
assert_eq!(foldable.as_ptr(), foldable_ptr);
assert_eq!(value.as_ptr(), value_ptr);
}
#[test]
fn kernel_with_scratch_reuses_dispatch_state_and_outputs() {
let dispatcher = ConstFoldDispatcher {
outputs: vec![u32_slice_to_le_bytes(&[1]), u32_slice_to_le_bytes(&[7])],
};
let arena = one_expr_arena();
let mut scratch = ConstFoldKernelScratch::default();
let mut foldable = Vec::with_capacity(1);
let mut value = Vec::with_capacity(1);
run_const_fold_kernel_with_scratch_into(
&arena,
&dispatcher,
&mut scratch,
&mut foldable,
&mut value,
)
.expect("Fix: dispatch succeeds");
let input_capacities = scratch.inputs.iter().map(Vec::capacity).collect::<Vec<_>>();
let foldable_capacity = foldable.capacity();
let value_capacity = value.capacity();
run_const_fold_kernel_with_scratch_into(
&arena,
&dispatcher,
&mut scratch,
&mut foldable,
&mut value,
)
.expect("Fix: dispatch succeeds");
assert_eq!(
scratch.inputs.iter().map(Vec::capacity).collect::<Vec<_>>(),
input_capacities
);
assert_eq!(foldable.capacity(), foldable_capacity);
assert_eq!(value.capacity(), value_capacity);
assert_eq!(foldable, vec![1]);
assert_eq!(value, vec![7]);
}
#[test]
fn kernel_rejects_extra_outputs() {
let dispatcher = ConstFoldDispatcher {
outputs: vec![
u32_slice_to_le_bytes(&[1]),
u32_slice_to_le_bytes(&[7]),
u32_slice_to_le_bytes(&[0]),
],
};
let mut foldable = Vec::new();
let mut value = Vec::new();
let err =
run_const_fold_kernel_into(&one_expr_arena(), &dispatcher, &mut foldable, &mut value)
.expect_err("extra outputs must be rejected");
assert!(
matches!(err, DispatchError::BackendError(_)),
"unexpected error: {err:?}"
);
}
#[test]
fn kernel_rejects_trailing_value_bytes() {
let dispatcher = ConstFoldDispatcher {
outputs: vec![u32_slice_to_le_bytes(&[1]), vec![7, 0, 0, 0, 1]],
};
let mut foldable = Vec::new();
let mut value = Vec::new();
let err =
run_const_fold_kernel_into(&one_expr_arena(), &dispatcher, &mut foldable, &mut value)
.expect_err("trailing bytes must be rejected");
assert!(
matches!(err, DispatchError::BackendError(_)),
"unexpected error: {err:?}"
);
}
}