use vyre::ir::{BinOp, BufferAccess, BufferDecl, DataType, Expr, Node, Program, UnOp};
use vyre_foundation::ir::model::expr::GeneratorRef;
use vyre_primitives::nn::attention_passes::{
ATTENTION_MAX_PASS_OP_ID, ATTENTION_SUM_PASS_OP_ID, ATTENTION_WRITE_PASS_OP_ID,
};
use crate::region::{wrap_anonymous, wrap_child};
use vyre_primitives::nn::attention_stability::{
bounded_exp_arg, bounded_score, flush_tiny, positive_denominator,
};
const OP_ID: &str = "vyre-libs::nn::gqa_attention";
#[allow(clippy::too_many_arguments)]
pub fn gqa_attention(
q: &str,
k: &str,
v_buf: &str,
output: &str,
n_q_heads: u32,
n_kv_heads: u32,
seq_len: u32,
head_dim: u32,
) -> Result<Program, String> {
if n_q_heads == 0 || n_kv_heads == 0 || seq_len == 0 || head_dim == 0 {
return Err("Fix: gqa_attention requires non-zero dims".into());
}
if n_q_heads % n_kv_heads != 0 {
return Err("Fix: n_q_heads must be multiple of n_kv_heads".into());
}
let group_size = n_q_heads / n_kv_heads;
let q_total = n_q_heads * seq_len * head_dim;
let per_head = seq_len * head_dim;
let scale = 1.0f32 / (head_dim as f32).sqrt();
let flat = Expr::var("flat");
let q_head = Expr::BinOp {
op: BinOp::Div,
left: Box::new(flat.clone()),
right: Box::new(Expr::u32(per_head)),
};
let pos_in_head = Expr::sub(flat.clone(), Expr::mul(q_head.clone(), Expr::u32(per_head)));
let row = Expr::BinOp {
op: BinOp::Div,
left: Box::new(pos_in_head.clone()),
right: Box::new(Expr::u32(head_dim)),
};
let col = Expr::sub(pos_in_head, Expr::mul(row.clone(), Expr::u32(head_dim)));
let kv_head = Expr::BinOp {
op: BinOp::Div,
left: Box::new(q_head.clone()),
right: Box::new(Expr::u32(group_size)),
};
let kv_base = Expr::mul(kv_head, Expr::u32(per_head));
let q_row_base = Expr::add(
Expr::mul(q_head, Expr::u32(per_head)),
Expr::mul(row, Expr::u32(head_dim)),
);
let make_dot_loop = |dot_var: &str| -> Vec<Node> {
let mut body = vec![Node::let_bind(dot_var, Expr::f32(0.0))];
if head_dim <= 8 {
body.extend((0..head_dim).map(|lane| {
Node::assign(
dot_var,
Expr::fma(
Expr::load(q, Expr::add(q_row_base.clone(), Expr::u32(lane))),
Expr::load(
k,
Expr::add(
Expr::add(
kv_base.clone(),
Expr::mul(Expr::var("j"), Expr::u32(head_dim)),
),
Expr::u32(lane),
),
),
Expr::var(dot_var),
),
)
}));
} else {
body.push(Node::loop_for(
"d",
Expr::u32(0),
Expr::u32(head_dim),
vec![Node::assign(
dot_var,
Expr::fma(
Expr::load(q, Expr::add(q_row_base.clone(), Expr::var("d"))),
Expr::load(
k,
Expr::add(
Expr::add(
kv_base.clone(),
Expr::mul(Expr::var("j"), Expr::u32(head_dim)),
),
Expr::var("d"),
),
),
Expr::var(dot_var),
),
)],
));
}
body
};
let exp_expr = |dot_var: &str| -> Expr {
Expr::UnOp {
op: UnOp::Exp,
operand: Box::new(bounded_exp_arg(Expr::sub(
Expr::mul(Expr::var(dot_var), Expr::f32(scale)),
Expr::var("max_score"),
))),
}
};
let parent = GeneratorRef {
name: OP_ID.to_string(),
};
let max_pass = {
let mut nodes = vec![];
nodes.push(Node::loop_for("j", Expr::u32(0), Expr::u32(seq_len), {
let mut v = make_dot_loop("dot");
v.push(Node::let_bind(
"score",
{
let raw = Expr::mul(Expr::var("dot"), Expr::f32(scale));
bounded_score(raw)
},
));
v.push(Node::assign(
"max_score",
Expr::select(
Expr::is_nan(Expr::var("score")),
Expr::var("score"),
Expr::select(
Expr::gt(Expr::var("score"), Expr::var("max_score")),
Expr::var("score"),
Expr::var("max_score"),
),
),
));
v
}));
nodes
};
let sum_pass = {
let mut nodes = vec![];
nodes.push(Node::loop_for("j", Expr::u32(0), Expr::u32(seq_len), {
let mut v = make_dot_loop("dot2");
v.push(Node::assign(
"sum_exp",
Expr::add(Expr::var("sum_exp"), exp_expr("dot2")),
));
v
}));
nodes
};
let write_pass = {
let mut nodes = vec![Node::let_bind("val", Expr::f32(0.0))];
nodes.push(Node::loop_for("j", Expr::u32(0), Expr::u32(seq_len), {
let mut v = make_dot_loop("dot3");
v.push(Node::let_bind(
"w",
Expr::div(exp_expr("dot3"), Expr::var("denom")),
));
v.push(Node::assign(
"val",
Expr::fma(
Expr::var("w"),
Expr::load(
v_buf,
Expr::add(
Expr::add(
kv_base.clone(),
Expr::mul(Expr::var("j"), Expr::u32(head_dim)),
),
col.clone(),
),
),
Expr::var("val"),
),
));
v
}));
nodes.push(Node::Store {
buffer: output.into(),
index: flat.clone(),
value: flush_tiny(Expr::var("val")),
});
nodes
};
let body = vec![
Node::let_bind("flat", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::lt(flat.clone(), Expr::u32(q_total)),
vec![
Node::let_bind("max_score", Expr::f32(f32::MIN)),
Node::let_bind("sum_exp", Expr::f32(0.0)),
wrap_child(ATTENTION_MAX_PASS_OP_ID, parent.clone(), max_pass),
wrap_child(ATTENTION_SUM_PASS_OP_ID, parent.clone(), sum_pass),
Node::let_bind("denom", positive_denominator(Expr::var("sum_exp"))),
wrap_child(ATTENTION_WRITE_PASS_OP_ID, parent, write_pass),
],
),
];
let kv_total = n_kv_heads * seq_len * head_dim;
Ok(Program::wrapped(
vec![
BufferDecl::storage(q, 0, BufferAccess::ReadOnly, DataType::F32).with_count(q_total),
BufferDecl::storage(k, 1, BufferAccess::ReadOnly, DataType::F32).with_count(kv_total),
BufferDecl::storage(v_buf, 2, BufferAccess::ReadOnly, DataType::F32)
.with_count(kv_total),
BufferDecl::output(output, 3, DataType::F32).with_count(q_total),
],
[64, 1, 1],
vec![wrap_anonymous(OP_ID, body)],
))
}
inventory::submit! {
crate::harness::OpEntry {
id: OP_ID,
build: || {
gqa_attention("q", "k", "v", "out", 2, 1, 2, 2)
.unwrap_or_else(|error| crate::invalid_program(OP_ID, format!("Fix: gqa_attention fixture must build: {error}")))
},
test_inputs: Some(|| {
let f = vyre_primitives::wire::pack_f32_slice;
vec![vec![
f(&[1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]),
f(&[1.0, 0.0, 0.0, 1.0]),
f(&[10.0, 20.0, 30.0, 40.0]),
vec![0u8; 32],
]]
}),
expected_output: Some(|| {
vec![vec![vec![
145, 214, 132, 65, 146, 214, 212, 65, 111, 41, 187, 65, 183, 148, 5, 66, 111,
41, 187, 65, 183, 148, 5, 66, 145, 214, 132, 65, 146, 214, 212, 65,
]]]
}),
category: Some("nn"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::byte_pack::decode_f32;
use crate::test_support::byte_pack::f32_bytes;
use vyre_reference::value::Value;
#[test]
fn gqa_attention_zero_sequence_length_rejected() {
let err =
gqa_attention("q", "k", "v", "out", 2, 1, 0, 4).expect_err("zero seq_len must error");
assert!(err.contains("seq_len=0") || err.contains("non-zero"));
}
#[test]
fn gqa_attention_single_token() {
let n_q = 2u32;
let n_kv = 1u32;
let s = 1u32;
let d = 2u32;
let q = [1.0f32, 0.0, 0.0, 1.0];
let k = [1.0f32, 0.0];
let v = [10.0f32, 20.0];
let prog = gqa_attention("q", "k", "v", "out", n_q, n_kv, s, d).expect("Fix: build");
let outputs = vyre_reference::reference_eval(
&prog,
&[
Value::from(f32_bytes(&q)),
Value::from(f32_bytes(&k)),
Value::from(f32_bytes(&v)),
Value::from(vec![0u8; (n_q * s * d) as usize * 4]),
],
)
.expect("Fix: gqa_attention single token must execute");
let out = decode_f32(&outputs[0].to_bytes());
for (i, &v) in out.iter().enumerate() {
let expected = if i % 2 == 0 { 10.0 } else { 20.0 };
assert!(
(v - expected).abs() <= 1.0e-4,
"gqa_attention single token mismatch at {i}: {v} != {expected}"
);
}
}
#[test]
fn gqa_attention_very_large_qk_values_stay_finite() {
let n_q = 1u32;
let n_kv = 1u32;
let s = 2u32;
let d = 2u32;
let q = [1e20f32; 4];
let k = [1e20f32; 4];
let v = [1.0f32; 4];
let prog = gqa_attention("q", "k", "v", "out", n_q, n_kv, s, d).expect("Fix: build");
let outputs = vyre_reference::reference_eval(
&prog,
&[
Value::from(f32_bytes(&q)),
Value::from(f32_bytes(&k)),
Value::from(f32_bytes(&v)),
Value::from(vec![0u8; (n_q * s * d) as usize * 4]),
],
)
.expect("Fix: gqa_attention must not panic on large QK values");
let out = decode_f32(&outputs[0].to_bytes());
for (i, &v) in out.iter().enumerate() {
assert!(
v.is_finite(),
"gqa_attention output at {i} must be finite for large QK values, got {v}"
);
}
}
#[test]
fn gqa_attention_nan_in_q_k_v_propagates() {
let n_q = 1u32;
let n_kv = 1u32;
let s = 1u32;
let d = 2u32;
let q = [f32::NAN, 0.0];
let k = [0.0f32, 0.0];
let v = [1.0f32, 2.0];
let prog = gqa_attention("q", "k", "v", "out", n_q, n_kv, s, d).expect("Fix: build");
let outputs = vyre_reference::reference_eval(
&prog,
&[
Value::from(f32_bytes(&q)),
Value::from(f32_bytes(&k)),
Value::from(f32_bytes(&v)),
Value::from(vec![0u8; (n_q * s * d) as usize * 4]),
],
)
.expect("Fix: gqa_attention must not panic on NaN input");
let out = decode_f32(&outputs[0].to_bytes());
assert!(
out.iter().any(|v| v.is_nan()),
"gqa_attention must propagate NaN in Q/K/V instead of silently producing finite output {:?}",
out
);
}
}