use crate::spec::types::OpSpec;
pub(super) fn invalid_program_expr(op: &OpSpec, rule_id: &str) -> Option<(String, String)> {
match rule_id {
"V001" => duplicate_buffer_program(op),
"V021" => wrong_argument_count_program(op),
_ => None,
}
}
fn duplicate_buffer_program(op: &OpSpec) -> Option<(String, String)> {
let mut buffers = Vec::new();
for (index, dt) in op.signature.inputs.iter().enumerate() {
buffers.push(format!(
"vyre::ir::BufferDecl::read(\"input{index}\", {index}, {})",
ir_data_type_expr(dt)
));
}
if op.signature.inputs.is_empty() {
buffers.push(
"vyre::ir::BufferDecl::read_write(\"out0\", 999, vyre::ir::DataType::U32)".to_string(),
);
} else {
buffers.push(
"vyre::ir::BufferDecl::read(\"input0\", 999, vyre::ir::DataType::U32)".to_string(),
);
}
let out_binding = op.signature.inputs.len();
buffers.push(format!(
"vyre::ir::BufferDecl::read_write(\"out0\", {out_binding}, {})",
ir_data_type_expr(&op.signature.output)
));
let args = call_args(op);
let workgroup_array = workgroup_array(op);
let program = format!(
"vyre::ir::Program::new(vec![{}], {workgroup_array}, vec![vyre::ir::Node::if_then(vyre::ir::Expr::lt(vyre::ir::Expr::gid_x(), vyre::ir::Expr::buf_len(\"out0\")), vec![vyre::ir::Node::store(\"out0\", vyre::ir::Expr::gid_x(), vyre::ir::Expr::Call {{ op_id: op_id_str, args: vec![{args}] }})])])",
buffers.join(", "),
);
Some((program, "duplicate buffer name".to_string()))
}
fn wrong_argument_count_program(op: &OpSpec) -> Option<(String, String)> {
let mut buffers = Vec::new();
for (index, dt) in op.signature.inputs.iter().enumerate() {
buffers.push(format!(
"vyre::ir::BufferDecl::read(\"input{index}\", {index}, {})",
ir_data_type_expr(dt)
));
}
let out_binding = op.signature.inputs.len();
buffers.push(format!(
"vyre::ir::BufferDecl::read_write(\"out0\", {out_binding}, {})",
ir_data_type_expr(&op.signature.output)
));
let wrong_args = if op.signature.inputs.is_empty() {
"vyre::ir::Expr::u32(0)"
} else {
""
};
let workgroup_array = workgroup_array(op);
let program = format!(
"vyre::ir::Program::new(vec![{}], {workgroup_array}, vec![vyre::ir::Node::if_then(vyre::ir::Expr::lt(vyre::ir::Expr::gid_x(), vyre::ir::Expr::buf_len(\"out0\")), vec![vyre::ir::Node::store(\"out0\", vyre::ir::Expr::gid_x(), vyre::ir::Expr::Call {{ op_id: op_id_str, args: vec![{wrong_args}] }})])])",
buffers.join(", "),
);
Some((program, "V021".to_string()))
}
pub(super) fn program_expr(op: &OpSpec) -> String {
let mut buffers = Vec::new();
for (index, dt) in op.signature.inputs.iter().enumerate() {
buffers.push(format!(
"vyre::ir::BufferDecl::read(\"input{index}\", {index}, {})",
ir_data_type_expr(dt)
));
}
let out_binding = op.signature.inputs.len();
buffers.push(format!(
"vyre::ir::BufferDecl::read_write(\"out0\", {out_binding}, {})",
ir_data_type_expr(&op.signature.output)
));
let args = call_args(op);
let workgroup_array = workgroup_array(op);
format!(
"vyre::ir::Program::new(vec![{}], {workgroup_array}, vec![vyre::ir::Node::if_then(vyre::ir::Expr::lt(vyre::ir::Expr::gid_x(), vyre::ir::Expr::buf_len(\"out0\")), vec![vyre::ir::Node::store(\"out0\", vyre::ir::Expr::gid_x(), vyre::ir::Expr::Call {{ op_id: op_id_str, args: vec![{args}] }})])])",
buffers.join(", "),
)
}
fn call_args(op: &OpSpec) -> String {
(0..op.signature.inputs.len())
.map(|index| format!("vyre::ir::Expr::load(\"input{index}\", vyre::ir::Expr::gid_x())"))
.collect::<Vec<_>>()
.join(", ")
}
fn workgroup_array(op: &OpSpec) -> String {
let workgroup_size = op.workgroup_size.unwrap_or(1);
format!("[{workgroup_size}, 1, 1]")
}
fn ir_data_type_expr(dt: &crate::spec::types::DataType) -> &'static str {
match dt {
crate::spec::types::DataType::U32 => "vyre::ir::DataType::U32",
crate::spec::types::DataType::I32 => "vyre::ir::DataType::I32",
crate::spec::types::DataType::U64 => "vyre::ir::DataType::U64",
crate::spec::types::DataType::Vec2U32 => "vyre::ir::DataType::Vec2U32",
crate::spec::types::DataType::Vec4U32 => "vyre::ir::DataType::Vec4U32",
crate::spec::types::DataType::Bool => "vyre::ir::DataType::Bool",
crate::spec::types::DataType::Bytes => "vyre::ir::DataType::Bytes",
_ => panic!(
"Fix: unsupported DataType {:?} for IR program generation",
dt
),
}
}