use crate::region::wrap_anonymous;
use vyre::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
const OP_ID: &str = "vyre-libs::nn::attention::turboquant";
#[must_use]
pub fn turboquant_attention(
q: &str,
k_packed: &str,
v_packed: &str,
out: &str,
seq_len: u32,
d_head: u32,
) -> Program {
let total_vals = seq_len.saturating_mul(d_head);
let packed_words = total_vals.div_ceil(10);
let unpack_3bit = |buf: &str, flat: Expr| {
let word = Expr::load(buf, Expr::div(flat.clone(), Expr::u32(10)));
let shift = Expr::mul(Expr::rem(flat, Expr::u32(10)), Expr::u32(3));
let nib = Expr::bitand(Expr::shr(word, shift), Expr::u32(0x7));
Expr::select(
Expr::eq(nib.clone(), Expr::u32(0)),
Expr::f32(0.0),
Expr::select(
Expr::eq(nib.clone(), Expr::u32(1)),
Expr::f32(1.0),
Expr::select(
Expr::eq(nib.clone(), Expr::u32(2)),
Expr::f32(2.0),
Expr::select(
Expr::eq(nib.clone(), Expr::u32(3)),
Expr::f32(3.0),
Expr::select(
Expr::eq(nib.clone(), Expr::u32(4)),
Expr::f32(4.0),
Expr::select(
Expr::eq(nib.clone(), Expr::u32(5)),
Expr::f32(5.0),
Expr::select(
Expr::eq(nib, Expr::u32(6)),
Expr::f32(6.0),
Expr::f32(7.0),
),
),
),
),
),
),
)
};
if seq_len <= 8 && d_head <= 16 {
let mut stores = Vec::with_capacity(d_head as usize);
for dim in 0..d_head {
let mut acc = Expr::f32(0.0);
for i in 0..seq_len {
let mut score = Expr::f32(0.0);
for e in 0..d_head {
score = Expr::add(
score,
Expr::mul(
Expr::load(q, Expr::u32(e)),
unpack_3bit(k_packed, Expr::u32(i * d_head + e)),
),
);
}
acc = Expr::add(
acc,
Expr::mul(score, unpack_3bit(v_packed, Expr::u32(i * d_head + dim))),
);
}
stores.push(Node::store(out, Expr::u32(dim), acc));
}
return Program::wrapped(
vec![
BufferDecl::storage(q, 0, BufferAccess::ReadOnly, DataType::F32).with_count(d_head),
BufferDecl::storage(k_packed, 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(packed_words),
BufferDecl::storage(v_packed, 2, BufferAccess::ReadOnly, DataType::U32)
.with_count(packed_words),
BufferDecl::storage(out, 3, BufferAccess::ReadWrite, DataType::F32)
.with_count(d_head),
],
[1, 1, 1],
vec![wrap_anonymous(
OP_ID,
vec![Node::if_then(
Expr::eq(Expr::InvocationId { axis: 0 }, Expr::u32(0)),
stores,
)],
)],
);
}
let t = Expr::InvocationId { axis: 0 };
let inner_body = vec![
Node::let_bind("d", t.clone()),
Node::let_bind("acc", Expr::f32(0.0)),
Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(seq_len),
vec![
Node::let_bind("score", Expr::f32(0.0)),
Node::loop_for(
"e",
Expr::u32(0),
Expr::u32(d_head),
vec![Node::assign(
"score",
Expr::add(
Expr::var("score"),
Expr::mul(
Expr::load(q, Expr::var("e")),
unpack_3bit(
k_packed,
Expr::add(
Expr::mul(Expr::var("i"), Expr::u32(d_head)),
Expr::var("e"),
),
),
),
),
)],
),
Node::assign(
"acc",
Expr::add(
Expr::var("acc"),
Expr::mul(
Expr::var("score"),
unpack_3bit(
v_packed,
Expr::add(
Expr::mul(Expr::var("i"), Expr::u32(d_head)),
Expr::var("d"),
),
),
),
),
),
],
),
Node::store(out, Expr::var("d"), Expr::var("acc")),
];
Program::wrapped(
vec![
BufferDecl::storage(q, 0, BufferAccess::ReadOnly, DataType::F32).with_count(d_head),
BufferDecl::storage(k_packed, 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(packed_words),
BufferDecl::storage(v_packed, 2, BufferAccess::ReadOnly, DataType::U32)
.with_count(packed_words),
BufferDecl::storage(out, 3, BufferAccess::ReadWrite, DataType::F32).with_count(d_head),
],
[64, 1, 1],
vec![wrap_anonymous(
OP_ID,
vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(d_head)),
inner_body,
)],
)],
)
}
inventory::submit! {
crate::harness::OpEntry {
id: OP_ID,
build: || turboquant_attention("q", "kp", "vp", "out", 2, 2),
test_inputs: Some(|| {
let to_f32_bytes =
|w: &[f32]| vyre_primitives::wire::pack_f32_slice(w);
vec![vec![
to_f32_bytes(&[1.0, 1.0]),
crate::test_support::byte_pack::u32_bytes(&[0x8D1u32]),
crate::test_support::byte_pack::u32_bytes(&[0x201u32]),
vec![0u8; 2 * 4],
]]
}),
expected_output: Some(|| {
let to_f32_bytes =
|w: &[f32]| vyre_primitives::wire::pack_f32_slice(w);
vec![vec![to_f32_bytes(&[3.0, 7.0])]]
}),
category: Some("nn"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::byte_pack::decode_f32;
use crate::test_support::byte_pack::f32_bytes;
use vyre_reference::value::Value;
#[test]
fn turboquant_nan_in_q_propagates_to_output() {
let q = [f32::NAN, 1.0];
let kp = crate::test_support::byte_pack::u32_bytes(&[0u32]);
let vp = crate::test_support::byte_pack::u32_bytes(&[0u32]);
let program = turboquant_attention("q", "kp", "vp", "out", 2, 2);
let outputs = vyre_reference::reference_eval(
&program,
&[
Value::from(f32_bytes(&q)),
Value::from(kp),
Value::from(vp),
Value::from(vec![0u8; 8]),
],
)
.expect("Fix: turboquant must not panic on NaN q");
let out = decode_f32(&outputs[0].to_bytes());
assert!(
out.iter().all(|v| v.is_nan()),
"turboquant NaN in q must produce NaN output, got {:?}",
out
);
}
#[test]
fn turboquant_zero_seq_len() {
let q = [1.0f32, 1.0];
let kp = crate::test_support::byte_pack::u32_bytes(&[0u32]);
let vp = crate::test_support::byte_pack::u32_bytes(&[0u32]);
let program = turboquant_attention("q", "kp", "vp", "out", 0, 2);
let outputs = vyre_reference::reference_eval(
&program,
&[
Value::from(f32_bytes(&q)),
Value::from(kp),
Value::from(vp),
Value::from(vec![0u8; 8]),
],
)
.expect("Fix: turboquant seq_len=0 must not panic");
let out = decode_f32(&outputs[0].to_bytes());
assert_eq!(
out,
vec![0.0, 0.0],
"turboquant zero seq_len must produce zeros"
);
}
#[test]
fn turboquant_single_token() {
let q = [1.0f32, 1.0];
let kp = crate::test_support::byte_pack::u32_bytes(&[9u32]);
let vp = crate::test_support::byte_pack::u32_bytes(&[9u32]);
let program = turboquant_attention("q", "kp", "vp", "out", 1, 2);
let outputs = vyre_reference::reference_eval(
&program,
&[
Value::from(f32_bytes(&q)),
Value::from(kp),
Value::from(vp),
Value::from(vec![0u8; 8]),
],
)
.expect("Fix: turboquant single token must execute");
let out = decode_f32(&outputs[0].to_bytes());
assert_eq!(out, vec![2.0, 2.0]);
}
}