use vyre::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use crate::region::wrap_anonymous;
const OP_ID: &str = "vyre-libs::nn::mlp_backward";
#[must_use]
#[allow(clippy::too_many_arguments)]
pub fn mlp_backward(
x: &str,
w1: &str,
b1: &str,
w2: &str,
grad_out: &str,
grad_x: &str,
model_dim: u32,
hidden_dim: u32,
) -> Program {
let i = Expr::var("i");
let body = vec![
Node::let_bind("i", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::lt(i.clone(), Expr::u32(model_dim)),
vec![
Node::let_bind("gx", Expr::f32(0.0)),
Node::loop_for(
"j",
Expr::u32(0),
Expr::u32(hidden_dim),
vec![
Node::let_bind("h_acc", Expr::load(b1, Expr::var("j"))),
Node::loop_for(
"m",
Expr::u32(0),
Expr::u32(model_dim),
vec![Node::assign(
"h_acc",
Expr::add(
Expr::var("h_acc"),
Expr::mul(
Expr::load(x, Expr::var("m")),
Expr::load(
w1,
Expr::add(
Expr::mul(Expr::var("m"), Expr::u32(hidden_dim)),
Expr::var("j"),
),
),
),
),
)],
),
Node::let_bind(
"d_act",
Expr::max(
Expr::mul(Expr::f32(0.5), Expr::var("h_acc")),
Expr::mul(Expr::f32(2.0), Expr::var("h_acc")),
),
),
Node::let_bind("gh_act", Expr::f32(0.0)),
Node::loop_for(
"k",
Expr::u32(0),
Expr::u32(model_dim),
vec![Node::assign(
"gh_act",
Expr::add(
Expr::var("gh_act"),
Expr::mul(
Expr::load(grad_out, Expr::var("k")),
Expr::load(
w2,
Expr::add(
Expr::mul(Expr::var("j"), Expr::u32(model_dim)),
Expr::var("k"),
),
),
),
),
)],
),
Node::let_bind("gh", Expr::mul(Expr::var("gh_act"), Expr::var("d_act"))),
Node::assign(
"gx",
Expr::add(
Expr::var("gx"),
Expr::mul(
Expr::var("gh"),
Expr::load(
w1,
Expr::add(
Expr::mul(i.clone(), Expr::u32(hidden_dim)),
Expr::var("j"),
),
),
),
),
),
],
),
Node::Store {
buffer: grad_x.into(),
index: i,
value: Expr::var("gx"),
},
],
),
];
Program::wrapped(
vec![
BufferDecl::storage(x, 0, BufferAccess::ReadOnly, DataType::F32).with_count(model_dim),
BufferDecl::storage(w1, 1, BufferAccess::ReadOnly, DataType::F32)
.with_count(model_dim * hidden_dim),
BufferDecl::storage(b1, 2, BufferAccess::ReadOnly, DataType::F32)
.with_count(hidden_dim),
BufferDecl::storage(w2, 3, BufferAccess::ReadOnly, DataType::F32)
.with_count(hidden_dim * model_dim),
BufferDecl::storage(grad_out, 4, BufferAccess::ReadOnly, DataType::F32)
.with_count(model_dim),
BufferDecl::output(grad_x, 5, DataType::F32).with_count(model_dim),
],
[64, 1, 1],
vec![wrap_anonymous(OP_ID, body)],
)
}
inventory::submit! {
crate::harness::OpEntry {
id: OP_ID,
build: || mlp_backward("x", "w1", "b1", "w2", "grad_out", "grad_x", 2, 2),
test_inputs: Some(|| {
let to_f32 = |w: &[f32]| vyre_primitives::wire::pack_f32_slice(w);
vec![vec![
to_f32(&[1.0, 2.0]), to_f32(&[1.0, 0.0, 0.0, 1.0]), to_f32(&[0.0, 0.0]), to_f32(&[1.0, 0.0, 0.0, 1.0]), to_f32(&[1.0, 1.0]), vec![0u8; 4 * 2],
]]
}),
expected_output: Some(|| {
let out = [2.0_f32, 4.0];
let bytes = vyre_primitives::wire::pack_f32_slice(&out);
vec![vec![bytes]]
}),
category: Some("nn"),
}
}